Package com.facebook.presto.operator

Source Code of com.facebook.presto.operator.GroupByHash$ChannelBuilder

/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator;

import com.facebook.presto.block.BlockBuilder;
import com.facebook.presto.block.BlockCursor;
import com.facebook.presto.block.uncompressed.UncompressedBlock;
import com.facebook.presto.tuple.TupleInfo;
import com.facebook.presto.tuple.TupleInfo.Type;
import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;
import io.airlift.units.DataSize;
import io.airlift.units.DataSize.Unit;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.longs.Long2IntOpenCustomHashMap;
import it.unimi.dsi.fastutil.longs.LongHash;
import it.unimi.dsi.fastutil.longs.LongHash.Strategy;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;

import java.util.List;

import static com.facebook.presto.operator.SyntheticAddress.decodePosition;
import static com.facebook.presto.operator.SyntheticAddress.decodeSliceIndex;
import static com.facebook.presto.operator.SyntheticAddress.encodeSyntheticAddress;
import static com.facebook.presto.tuple.TupleInfo.SINGLE_LONG;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import static io.airlift.slice.SizeOf.SIZE_OF_BYTE;
import static io.airlift.slice.SizeOf.SIZE_OF_DOUBLE;
import static io.airlift.slice.SizeOf.SIZE_OF_INT;
import static io.airlift.slice.SizeOf.SIZE_OF_LONG;
import static io.airlift.slice.SizeOf.sizeOf;

public class GroupByHash
{
    private static final long CURRENT_ROW_ADDRESS = 0xFF_FF_FF_FF_FF_FF_FF_FFL;

    private final List<Type> types;
    private final int[] channels;

    private GroupByPageBuilder activePage;

    private final List<GroupByPageBuilder> allPages;
    private long completedPagesMemorySize;

    private final PageBuilderHashStrategy hashStrategy;
    private final PagePositionToGroupId pagePositionToGroupId;

    private int nextGroupId;

    public GroupByHash(List<Type> types, int[] channels, int expectedSize)
    {
        this.types = checkNotNull(types, "types is null");
        this.channels = checkNotNull(channels, "channels is null").clone();
        checkArgument(types.size() == channels.length, "types and channels have different sizes");

        this.allPages = ObjectArrayList.wrap(new GroupByPageBuilder[1024], 0);
        this.activePage = new GroupByPageBuilder(types);
        this.allPages.add(activePage);

        this.hashStrategy = new PageBuilderHashStrategy();
        this.pagePositionToGroupId = new PagePositionToGroupId(expectedSize, hashStrategy);
        this.pagePositionToGroupId.defaultReturnValue(-1);
    }

    public long getEstimatedSize()
    {
        return completedPagesMemorySize + activePage.getMemorySize() + pagePositionToGroupId.getEstimatedSize();
    }

    public List<Type> getTypes()
    {
        return types;
    }

    public GroupByIdBlock getGroupIds(Page page)
    {
        int positionCount = page.getPositionCount();

        int groupIdBlockSize = SINGLE_LONG.getFixedSize() * positionCount;
        BlockBuilder blockBuilder = new BlockBuilder(SINGLE_LONG, groupIdBlockSize, Slices.allocate(groupIdBlockSize).getOutput());

        // open cursors for group blocks
        BlockCursor[] cursors = new BlockCursor[channels.length];
        for (int i = 0; i < channels.length; i++) {
            cursors[i] = page.getBlock(channels[i]).cursor();
        }

        // use cursors in hash strategy to provide value for "current" row
        hashStrategy.setCurrentRow(cursors);

        for (int position = 0; position < positionCount; position++) {
            for (BlockCursor cursor : cursors) {
                checkState(cursor.advanceNextPosition());
            }

            int groupId = pagePositionToGroupId.get(CURRENT_ROW_ADDRESS);
            if (groupId < 0) {
                groupId = addNewGroup(cursors);
            }
            blockBuilder.append(groupId);
        }
        UncompressedBlock block = blockBuilder.build();
        return new GroupByIdBlock(nextGroupId, block);
    }

    private int addNewGroup(BlockCursor... row)
    {
        int pageIndex = allPages.size() - 1;
        if (!activePage.append(row)) {
            // record the active page memory size
            completedPagesMemorySize += activePage.getMemorySize();

            activePage = new GroupByPageBuilder(types);
            allPages.add(activePage);
            pageIndex++;

            // TODO make the page builder allocation guarantee enough space to hold at least the first row.
            checkState(activePage.append(row), "Could not add row to empty page builder");
        }

        // record group id in hash
        int groupId = nextGroupId++;
        long address = encodeSyntheticAddress(pageIndex, activePage.getPositionCount() - 1);
        pagePositionToGroupId.put(address, groupId);

        return groupId;
    }

    public Long2IntOpenCustomHashMap getPagePositionToGroupId()
    {
        return pagePositionToGroupId;
    }

    public void appendValuesTo(long pagePosition, BlockBuilder[] builders)
    {
        GroupByPageBuilder page = allPages.get(decodeSliceIndex(pagePosition));
        page.appendValuesTo(decodePosition(pagePosition), builders);
    }

    private class PageBuilderHashStrategy
            implements Strategy
    {
        private BlockCursor[] currentRow;

        public void setCurrentRow(BlockCursor[] currentRow)
        {
            this.currentRow = currentRow;
        }

        @Override
        public int hashCode(long sliceAddress)
        {
            if (sliceAddress == CURRENT_ROW_ADDRESS) {
                return hashCurrentRow();
            }
            else {
                return hashPosition(sliceAddress);
            }
        }

        private int hashPosition(long sliceAddress)
        {
            int sliceIndex = decodeSliceIndex(sliceAddress);
            int position = decodePosition(sliceAddress);
            return allPages.get(sliceIndex).hashCode(position);
        }

        private int hashCurrentRow()
        {
            int result = 0;
            for (int channel = 0; channel < types.size(); channel++) {
                Type type = types.get(channel);
                BlockCursor cursor = currentRow[channel];
                result = addToHashCode(result, valueHashCode(type, cursor.getRawSlice(), cursor.getRawOffset()));
            }
            return result;
        }

        @Override
        public boolean equals(long leftSliceAddress, long rightSliceAddress)
        {
            // current row always equals itself
            if (leftSliceAddress == CURRENT_ROW_ADDRESS && rightSliceAddress == CURRENT_ROW_ADDRESS) {
                return true;
            }

            // current row == position
            if (leftSliceAddress == CURRENT_ROW_ADDRESS) {
                return positionEqualsCurrentRow(decodeSliceIndex(rightSliceAddress), decodePosition(rightSliceAddress));
            }

            // position == current row
            if (rightSliceAddress == CURRENT_ROW_ADDRESS) {
                return positionEqualsCurrentRow(decodeSliceIndex(leftSliceAddress), decodePosition(leftSliceAddress));
            }

            // position == position
            return positionEqualsPosition(
                    decodeSliceIndex(leftSliceAddress), decodePosition(leftSliceAddress),
                    decodeSliceIndex(rightSliceAddress), decodePosition(rightSliceAddress));
        }

        private boolean positionEqualsCurrentRow(int sliceIndex, int position)
        {
            return allPages.get(sliceIndex).equals(position, currentRow);
        }

        private boolean positionEqualsPosition(int leftSliceIndex, int leftPosition, int rightSliceIndex, int rightPosition)
        {
            return allPages.get(leftSliceIndex).equals(leftPosition, allPages.get(rightSliceIndex), rightPosition);
        }
    }

    private static class GroupByPageBuilder
    {
        private final List<ChannelBuilder> channels;
        private int positionCount;
        private boolean full;

        public GroupByPageBuilder(List<Type> types)
        {
            ImmutableList.Builder<ChannelBuilder> builder = ImmutableList.builder();
            for (Type type : types) {
                builder.add(new ChannelBuilder(type));
            }
            channels = builder.build();
        }

        public int getPositionCount()
        {
            return positionCount;
        }

        public long getMemorySize()
        {
            long memorySize = 0;
            for (ChannelBuilder channel : channels) {
                memorySize += channel.getMemorySize();
            }
            return memorySize;
        }

        private boolean append(BlockCursor... row)
        {
            // don't add row if already full
            if (full) {
                return false;
            }

            // append to each channel
            for (int channel = 0; channel < row.length; channel++) {
                if (!channels.get(channel).append(row[channel])) {
                    // This early return will result in uneven channels, but this is not
                    // a problem since the position count is not incremented.  This means
                    // that although some channels have "garbage" on the end, these values
                    // will never be read since the position is not valid.
                    full = true;
                    return false;
                }
            }
            positionCount++;
            return true;
        }

        public void appendValuesTo(int position, BlockBuilder[] builders)
        {
            for (int i = 0; i < channels.size(); i++) {
                ChannelBuilder channel = channels.get(i);
                channel.appendTo(position, builders[i]);
            }
        }

        public int hashCode(int position)
        {
            int result = 0;
            for (ChannelBuilder channel : channels) {
                result = addToHashCode(result, channel.hashCode(position));
            }
            return result;
        }

        public boolean equals(int thisPosition, GroupByPageBuilder that, int thatPosition)
        {
            for (int i = 0; i < channels.size(); i++) {
                ChannelBuilder thisBlock = this.channels.get(i);
                ChannelBuilder thatBlock = that.channels.get(i);
                if (!thisBlock.equals(thisPosition, thatBlock, thatPosition)) {
                    return false;
                }
            }
            return true;
        }

        public boolean equals(int position, BlockCursor... row)
        {
            for (int i = 0; i < channels.size(); i++) {
                ChannelBuilder thisBlock = this.channels.get(i);
                if (!thisBlock.equals(position, row[i])) {
                    return false;
                }
            }
            return true;
        }
    }

    private static class ChannelBuilder
    {
        public static final DataSize DEFAULT_MAX_BLOCK_SIZE = new DataSize(64, Unit.KILOBYTE);

        private final Type type;
        private final SliceOutput sliceOutput;
        private final Slice slice;
        private final IntArrayList positionOffsets;

        public ChannelBuilder(Type type)
        {
            checkNotNull(type, "type is null");

            this.type = type;
            this.slice = Slices.allocate(Ints.checkedCast(DEFAULT_MAX_BLOCK_SIZE.toBytes()));
            this.sliceOutput = slice.getOutput();
            this.positionOffsets = new IntArrayList(1024);
        }

        public long getMemorySize()
        {
            return slice.length() + sizeOf(positionOffsets.elements());
        }

        public boolean equals(int position, ChannelBuilder rightBuilder, int rightPosition)
        {
            checkArgument(position >= 0 && position < positionOffsets.size());
            checkArgument(rightPosition >= 0 && rightPosition < rightBuilder.positionOffsets.size());

            Slice leftSlice = slice;
            int leftOffset = positionOffsets.getInt(position);

            Slice rightSlice = rightBuilder.slice;
            int rightOffset = rightBuilder.positionOffsets.getInt(rightPosition);

            return valueEquals(type, leftSlice, leftOffset, rightSlice, rightOffset);
        }

        public boolean equals(int position, BlockCursor cursor)
        {
            checkArgument(position >= 0 && position < positionOffsets.size());

            int offset = positionOffsets.getInt(position);

            Slice rightSlice = cursor.getRawSlice();
            int rightOffset = cursor.getRawOffset();
            return valueEquals(type, slice, offset, rightSlice, rightOffset);
        }

        public void appendTo(int position, BlockBuilder builder)
        {
            checkArgument(position >= 0 && position < positionOffsets.size());

            int offset = positionOffsets.getInt(position);

            if (slice.getByte(offset) != 0) {
                builder.appendNull();
            }
            else if (type == Type.FIXED_INT_64) {
                builder.append(slice.getLong(offset + SIZE_OF_BYTE));
            }
            else if (type == Type.DOUBLE) {
                builder.append(slice.getDouble(offset + SIZE_OF_BYTE));
            }
            else if (type == Type.BOOLEAN) {
                builder.append(slice.getByte(offset + SIZE_OF_BYTE) != 0);
            }
            else if (type == Type.VARIABLE_BINARY) {
                int sliceLength = getVariableBinaryLength(slice, offset);
                builder.append(slice.slice(offset + SIZE_OF_BYTE + SIZE_OF_INT, sliceLength));
            }
            else {
                throw new IllegalArgumentException("Unsupported type " + type);
            }
        }

        public int hashCode(int position)
        {
            checkArgument(position >= 0 && position < positionOffsets.size());
            return valueHashCode(type, slice, positionOffsets.getInt(position));
        }

        public boolean append(BlockCursor cursor)
        {
            // the extra BYTE here is for the null flag
            int writableBytes = sliceOutput.writableBytes() - SIZE_OF_BYTE;

            boolean isNull = cursor.isNull();

            if (type == Type.FIXED_INT_64) {
                if (writableBytes < SIZE_OF_LONG) {
                    return false;
                }

                positionOffsets.add(sliceOutput.size());
                sliceOutput.writeByte(isNull ? 1 : 0);
                sliceOutput.appendLong(isNull ? 0 : cursor.getLong());
            }
            else if (type == Type.DOUBLE) {
                if (writableBytes < SIZE_OF_DOUBLE) {
                    return false;
                }

                positionOffsets.add(sliceOutput.size());
                sliceOutput.writeByte(isNull ? 1 : 0);
                sliceOutput.appendDouble(isNull ? 0 : cursor.getDouble());
            }
            else if (type == Type.BOOLEAN) {
                if (writableBytes < SIZE_OF_BYTE) {
                    return false;
                }

                positionOffsets.add(sliceOutput.size());
                sliceOutput.writeByte(isNull ? 1 : 0);
                sliceOutput.writeByte(!isNull && cursor.getBoolean() ? 1 : 0);
            }
            else if (type == Type.VARIABLE_BINARY) {
                int sliceLength = isNull ? 0 : getVariableBinaryLength(cursor.getRawSlice(), cursor.getRawOffset());
                if (writableBytes < SIZE_OF_INT + sliceLength) {
                    return false;
                }

                int startingOffset = sliceOutput.size();
                positionOffsets.add(startingOffset);
                sliceOutput.writeByte(isNull ? 1 : 0);
                sliceOutput.appendInt(sliceLength + SIZE_OF_BYTE + SIZE_OF_INT);
                if (!isNull) {
                    sliceOutput.writeBytes(cursor.getRawSlice(), cursor.getRawOffset() + SIZE_OF_BYTE + SIZE_OF_INT, sliceLength);
                }
            }
            else {
                throw new IllegalArgumentException("Unsupported type " + type);
            }
            return true;
        }

        public UncompressedBlock build()
        {
            checkState(!positionOffsets.isEmpty(), "Cannot build an empty block");

            return new UncompressedBlock(positionOffsets.size(), new TupleInfo(type), sliceOutput.slice());
        }
    }

    private static int addToHashCode(int result, int hashCode)
    {
        result = 31 * result + hashCode;
        return result;
    }

    private static int valueHashCode(Type type, Slice slice, int offset)
    {
        boolean isNull = slice.getByte(offset) != 0;
        if (isNull) {
            return 0;
        }

        if (type == Type.FIXED_INT_64) {
            return Longs.hashCode(slice.getLong(offset + SIZE_OF_BYTE));
        }
        else if (type == Type.DOUBLE) {
            long longValue = Double.doubleToLongBits(slice.getDouble(offset + SIZE_OF_BYTE));
            return Longs.hashCode(longValue);
        }
        else if (type == Type.BOOLEAN) {
            return slice.getByte(offset + SIZE_OF_BYTE) != 0 ? 1 : 0;
        }
        else if (type == Type.VARIABLE_BINARY) {
            int sliceLength = getVariableBinaryLength(slice, offset);
            return slice.hashCode(offset + SIZE_OF_BYTE + SIZE_OF_INT, sliceLength);
        }
        else {
            throw new IllegalArgumentException("Unsupported type " + type);
        }
    }

    private static int getVariableBinaryLength(Slice slice, int offset)
    {
        // INT here is the length and the BYTE is the null flag
        return slice.getInt(offset + SIZE_OF_BYTE) - SIZE_OF_INT - SIZE_OF_BYTE;
    }

    private static boolean valueEquals(Type type, Slice leftSlice, int leftOffset, Slice rightSlice, int rightOffset)
    {
        // check if null flags are the same
        boolean leftIsNull = leftSlice.getByte(leftOffset) != 0;
        boolean rightIsNull = rightSlice.getByte(rightOffset) != 0;
        if (leftIsNull != rightIsNull) {
            return false;
        }

        // if values are both null, they are equal
        if (leftIsNull) {
            return true;
        }

        if (type == Type.FIXED_INT_64 || type == Type.DOUBLE) {
            long leftValue = leftSlice.getLong(leftOffset + SIZE_OF_BYTE);
            long rightValue = rightSlice.getLong(rightOffset + SIZE_OF_BYTE);
            return leftValue == rightValue;
        }
        else if (type == Type.BOOLEAN) {
            boolean leftValue = leftSlice.getByte(leftOffset + SIZE_OF_BYTE) != 0;
            boolean rightValue = rightSlice.getByte(rightOffset + SIZE_OF_BYTE) != 0;
            return leftValue == rightValue;
        }
        else if (type == Type.VARIABLE_BINARY) {
            int leftLength = getVariableBinaryLength(leftSlice, leftOffset);
            int rightLength = getVariableBinaryLength(rightSlice, rightOffset);
            return leftSlice.equals(leftOffset + SIZE_OF_BYTE + SIZE_OF_INT, leftLength,
                    rightSlice, rightOffset + SIZE_OF_BYTE + SIZE_OF_INT, rightLength);
        }
        else {
            throw new IllegalArgumentException("Unsupported type " + type);
        }
    }

    private static class PagePositionToGroupId
            extends Long2IntOpenCustomHashMap
    {
        private PagePositionToGroupId(int expected, LongHash.Strategy strategy)
        {
            super(expected, strategy);
            defaultReturnValue(-1);
        }

        public long getEstimatedSize()
        {
            return sizeOf(this.key) + sizeOf(this.value) + sizeOf(this.used);
        }
    }
}
TOP

Related Classes of com.facebook.presto.operator.GroupByHash$ChannelBuilder

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.