Package com.facebook.presto.operator

Source Code of com.facebook.presto.operator.HashAggregationOperator$VariableWidthAggregator

/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator;

import com.facebook.presto.block.Block;
import com.facebook.presto.block.BlockBuilder;
import com.facebook.presto.block.BlockCursor;
import com.facebook.presto.block.uncompressed.UncompressedBlock;
import com.facebook.presto.operator.aggregation.AggregationFunction;
import com.facebook.presto.operator.aggregation.FixedWidthAggregationFunction;
import com.facebook.presto.operator.aggregation.VariableWidthAggregationFunction;
import com.facebook.presto.sql.planner.plan.AggregationNode.Step;
import com.facebook.presto.sql.tree.Input;
import com.facebook.presto.tuple.TupleInfo;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.slice.SizeOf;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import it.unimi.dsi.fastutil.longs.Long2IntOpenCustomHashMap;
import it.unimi.dsi.fastutil.longs.LongHash.Strategy;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import static com.facebook.presto.operator.SyntheticAddress.decodeSliceIndex;
import static com.facebook.presto.operator.SyntheticAddress.decodeSliceOffset;
import static com.facebook.presto.operator.SyntheticAddress.encodeSyntheticAddress;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;

public class HashAggregationOperator
        implements Operator
{
    public static class HashAggregationOperatorFactory
            implements OperatorFactory
    {
        private final int operatorId;
        private final TupleInfo groupByTupleInfo;
        private final int groupByChannel;
        private final Step step;
        private final List<AggregationFunctionDefinition> functionDefinitions;
        private final int expectedGroups;
        private final List<TupleInfo> tupleInfos;
        private boolean closed;

        public HashAggregationOperatorFactory(
                int operatorId,
                TupleInfo groupByTupleInfo,
                int groupByChannel,
                Step step,
                List<AggregationFunctionDefinition> functionDefinitions,
                int expectedGroups)
        {
            this.operatorId = operatorId;
            this.groupByTupleInfo = groupByTupleInfo;
            this.groupByChannel = groupByChannel;
            this.step = step;
            this.functionDefinitions = functionDefinitions;
            this.expectedGroups = expectedGroups;

            this.tupleInfos = toTupleInfos(groupByTupleInfo, step, functionDefinitions);
        }

        @Override
        public List<TupleInfo> getTupleInfos()
        {
            return tupleInfos;
        }

        @Override
        public Operator createOperator(DriverContext driverContext)
        {
            checkState(!closed, "Factory is already closed");

            OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, HashAggregationOperator.class.getSimpleName());
            return new HashAggregationOperator(
                    operatorContext,
                    groupByTupleInfo,
                    groupByChannel,
                    step,
                    functionDefinitions,
                    expectedGroups
            );
        }

        @Override
        public void close()
        {
            closed = true;
        }
    }

    private static final int LOOKUP_SLICE_INDEX = 0xFF_FF_FF_FF;

    private final OperatorContext operatorContext;
    private final TupleInfo groupByTupleInfo;
    private final int groupByChannel;
    private final Step step;
    private final List<AggregationFunctionDefinition> functionDefinitions;
    private final int expectedGroups;

    private final List<TupleInfo> tupleInfos;
    private final HashMemoryManager memoryManager;

    private GroupByHashAggregationBuilder aggregationBuilder;
    private Iterator<Page> outputIterator;
    private boolean finishing;

    public HashAggregationOperator(
            OperatorContext operatorContext,
            TupleInfo groupByTupleInfo,
            int groupByChannel,
            Step step,
            List<AggregationFunctionDefinition> functionDefinitions,
            int expectedGroups)
    {
        this.operatorContext = checkNotNull(operatorContext, "operatorContext is null");
        Preconditions.checkArgument(groupByChannel >= 0, "groupByChannel is negative");
        Preconditions.checkNotNull(step, "step is null");
        Preconditions.checkNotNull(functionDefinitions, "functionDefinitions is null");
        Preconditions.checkNotNull(operatorContext, "operatorContext is null");

        this.groupByTupleInfo = groupByTupleInfo;
        this.groupByChannel = groupByChannel;
        this.functionDefinitions = ImmutableList.copyOf(functionDefinitions);
        this.step = step;
        this.expectedGroups = expectedGroups;
        this.memoryManager = new HashMemoryManager(operatorContext);

        this.tupleInfos = toTupleInfos(groupByTupleInfo, step, functionDefinitions);
    }

    @Override
    public OperatorContext getOperatorContext()
    {
        return operatorContext;
    }

    @Override
    public List<TupleInfo> getTupleInfos()
    {
        return tupleInfos;
    }

    @Override
    public void finish()
    {
        finishing = true;
    }

    @Override
    public boolean isFinished()
    {
        return finishing && aggregationBuilder == null && (outputIterator == null || !outputIterator.hasNext());
    }

    @Override
    public ListenableFuture<?> isBlocked()
    {
        return NOT_BLOCKED;
    }

    @Override
    public boolean needsInput()
    {
        return !finishing && outputIterator == null && (aggregationBuilder == null || !aggregationBuilder.isFull());
    }

    @Override
    public void addInput(Page page)
    {
        checkState(!finishing, "Operator is already finishing");
        checkNotNull(page, "page is null");
        if (aggregationBuilder == null) {
            aggregationBuilder = new GroupByHashAggregationBuilder(
                    functionDefinitions,
                    step,
                    expectedGroups,
                    groupByChannel,
                    groupByTupleInfo,
                    memoryManager);

            // assume initial aggregationBuilder is not full
        }
        else {
            checkState(!aggregationBuilder.isFull(), "Aggregation buffer is full");
        }
        aggregationBuilder.processPage(page);
    }

    @Override
    public Page getOutput()
    {
        if (outputIterator == null || !outputIterator.hasNext()) {
            // no data
            if (aggregationBuilder == null) {
                return null;
            }

            // only flush if we are finishing or the aggregation builder is full
            if (!finishing && !aggregationBuilder.isFull()) {
                return null;
            }

            // Only partial aggregation can flush early. Also, check that we are not flushing tiny bits at a time
            checkState(finishing || step == Step.PARTIAL, "Task exceeded max memory size of %s", memoryManager.getMaxMemorySize());

            outputIterator = aggregationBuilder.build();
            aggregationBuilder = null;

            if (!outputIterator.hasNext()) {
                return null;
            }
        }

        return outputIterator.next();
    }

    private static List<TupleInfo> toTupleInfos(TupleInfo groupByTupleInfo, Step step, List<AggregationFunctionDefinition> functionDefinitions)
    {
        ImmutableList.Builder<TupleInfo> tupleInfos = ImmutableList.builder();
        tupleInfos.add(groupByTupleInfo);
        for (AggregationFunctionDefinition functionDefinition : functionDefinitions) {
            if (step != Step.PARTIAL) {
                tupleInfos.add(functionDefinition.getFunction().getFinalTupleInfo());
            }
            else {
                tupleInfos.add(functionDefinition.getFunction().getIntermediateTupleInfo());
            }
        }
        return tupleInfos.build();
    }

    private static class GroupByHashAggregationBuilder
    {
        private final List<Aggregator> aggregates;
        private final SliceHashStrategy hashStrategy;
        private final Long2IntOpenCustomHashMap addressToGroupId;
        private final List<UncompressedBlock> groupByBlocks = new ArrayList<>();
        private final int groupByChannel;
        private final TupleInfo groupByTupleInfo;
        private final HashMemoryManager memoryManager;

        private BlockBuilder blockBuilder;
        private int nextGroupId;

        private GroupByHashAggregationBuilder(
                List<AggregationFunctionDefinition> functionDefinitions,
                Step step,
                int expectedGroups,
                int groupByChannel,
                TupleInfo groupByTupleInfo,
                HashMemoryManager memoryManager)
        {
            this.groupByChannel = groupByChannel;
            this.groupByTupleInfo = groupByTupleInfo;
            this.memoryManager = memoryManager;

            // wrapper each function with an aggregator
            ImmutableList.Builder<Aggregator> builder = ImmutableList.builder();
            for (AggregationFunctionDefinition functionDefinition : checkNotNull(functionDefinitions, "functionDefinitions is null")) {
                builder.add(createAggregator(functionDefinition, step, expectedGroups));
            }
            aggregates = builder.build();

            // create hash table
            hashStrategy = new SliceHashStrategy(groupByTupleInfo);
            addressToGroupId = new Long2IntOpenCustomHashMap(expectedGroups, hashStrategy);

            // initialize hash table
            addressToGroupId.defaultReturnValue(-1);
            Slice slice = Slices.allocate((int) BlockBuilder.DEFAULT_MAX_BLOCK_SIZE.toBytes());
            hashStrategy.addSlice(slice);

            // Group by keys are packed into new blocks
            blockBuilder = new BlockBuilder(groupByTupleInfo, slice.length(), slice.getOutput());
        }

        private void processPage(Page page)
        {
            // open cursors
            Block[] blocks = page.getBlocks();
            BlockCursor[] cursors = new BlockCursor[blocks.length];
            for (int i = 0; i < blocks.length; i++) {
                cursors[i] = blocks[i].cursor();
            }

            Slice groupBySlice = ((UncompressedBlock) blocks[groupByChannel]).getSlice();
            hashStrategy.setLookupSlice(groupBySlice);

            // process row at a time
            int rows = page.getPositionCount();
            for (int position = 0; position < rows; position++) {
                for (BlockCursor cursor : cursors) {
                    checkState(cursor.advanceNextPosition());
                }

                int groupId = putIfAbsent(groupBySlice, cursors);

                // process the row
                processRow(cursors, groupId);
            }

            // verify all cursors are complete
            for (BlockCursor cursor : cursors) {
                checkState(!cursor.advanceNextPosition());
            }
        }

        private int putIfAbsent(Slice groupBySlice, BlockCursor[] cursors)
        {
            // lookup the group id (row number of the key)
            int rawOffset = cursors[groupByChannel].getRawOffset();
            int groupId = addressToGroupId.get(encodeSyntheticAddress(LOOKUP_SLICE_INDEX, rawOffset));
            if (groupId < 0) {
                groupId = addNewGroup(groupBySlice, rawOffset);
            }
            return groupId;
        }

        private int addNewGroup(Slice groupBySlice, int rawOffset)
        {
            // copy group by tuple (key) to hash
            int length = groupByTupleInfo.size(groupBySlice, rawOffset);
            if (blockBuilder.writableBytes() < length) {
                UncompressedBlock block = blockBuilder.build();
                groupByBlocks.add(block);
                Slice slice = Slices.allocate(Math.max((int) BlockBuilder.DEFAULT_MAX_BLOCK_SIZE.toBytes(), length));
                blockBuilder = new BlockBuilder(groupByTupleInfo, slice.length(), slice.getOutput());
                hashStrategy.addSlice(slice);
            }
            int groupByValueRawOffset = blockBuilder.size();
            blockBuilder.appendTuple(groupBySlice, rawOffset, length);

            // record group id in hash
            int groupId = nextGroupId++;
            addressToGroupId.put(encodeSyntheticAddress(groupByBlocks.size(), groupByValueRawOffset), groupId);

            // initialize the aggregates
            initializeRow(groupId);

            return groupId;
        }

        private void initializeRow(int groupId)
        {
            for (Aggregator aggregate : aggregates) {
                aggregate.initialize(groupId);
            }
        }

        private void processRow(BlockCursor[] cursors, int groupId)
        {
            for (Aggregator aggregate : aggregates) {
                aggregate.addValue(cursors, groupId);
            }
        }

        public boolean isFull()
        {
            long memorySize = hashStrategy.getEstimatedSize();
            for (Aggregator aggregate : aggregates) {
                memorySize += aggregate.getEstimatedSize();
            }
            return memoryManager.canUse(memorySize);
        }

        public Iterator<Page> build()
        {
            // add the last block if it is not empty
            if (!blockBuilder.isEmpty()) {
                UncompressedBlock block = blockBuilder.build();
                groupByBlocks.add(block);
            }

            return Iterators.transform(groupByBlocks.iterator(), new Function<UncompressedBlock, Page>()
            {
                private int currentPosition = 0;

                @Override
                public Page apply(UncompressedBlock groupByBlock)
                {
                    // build  the page channel at at time
                    Block[] blocks = new Block[aggregates.size() + 1];
                    blocks[0] = groupByBlock;
                    int pagePositionCount = groupByBlock.getPositionCount();
                    for (int channel = 1; channel < aggregates.size() + 1; channel++) {
                        Aggregator aggregator = aggregates.get(channel - 1);
                        // todo there is no need to eval for intermediates since buffer is already in block form
                        BlockBuilder blockBuilder = new BlockBuilder(aggregator.getTupleInfo());
                        for (int position = 0; position < pagePositionCount; position++) {
                            aggregator.evaluate(currentPosition + position, blockBuilder);
                        }
                        blocks[channel] = blockBuilder.build();
                    }

                    Page page = new Page(blocks);
                    currentPosition += pagePositionCount;
                    return page;
                }
            });
        }
    }

    public static class HashMemoryManager
    {
        private final OperatorContext operatorContext;
        private long currentMemoryReservation;

        public HashMemoryManager(OperatorContext operatorContext)
        {
            this.operatorContext = operatorContext;
        }

        public boolean canUse(long memorySize)
        {
            // remove the pre-allocated memory from this size
            memorySize -= operatorContext.getOperatorPreAllocatedMemory().toBytes();

            long delta = memorySize - currentMemoryReservation;
            if (delta <= 0) {
                return false;
            }

            if (!operatorContext.reserveMemory(delta)) {
                return true;
            }

            // reservation worked, record the reservation
            currentMemoryReservation = Math.max(currentMemoryReservation, memorySize);
            return false;
        }

        public Object getMaxMemorySize()
        {
            return operatorContext.getMaxMemorySize();
        }
    }

    @SuppressWarnings("rawtypes")
    private static Aggregator createAggregator(AggregationFunctionDefinition functionDefinition, Step step, int expectedGroups)
    {
        AggregationFunction function = functionDefinition.getFunction();
        if (function instanceof VariableWidthAggregationFunction) {
            return new VariableWidthAggregator((VariableWidthAggregationFunction) functionDefinition.getFunction(), functionDefinition.getInputs(), step, expectedGroups);
        }
        else {
            Input input = null;
            if (!functionDefinition.getInputs().isEmpty()) {
                input = Iterables.getOnlyElement(functionDefinition.getInputs());
            }

            return new FixedWidthAggregator((FixedWidthAggregationFunction) functionDefinition.getFunction(), input, step);
        }
    }

    private interface Aggregator
    {
        long getEstimatedSize();

        TupleInfo getTupleInfo();

        void initialize(int position);

        void addValue(BlockCursor[] cursors, int position);

        void evaluate(int position, BlockBuilder output);
    }

    private static class FixedWidthAggregator
            implements Aggregator
    {
        private final FixedWidthAggregationFunction function;
        private final Input input;
        private final Step step;
        private final int fixedWidthSize;
        private final int sliceSize;
        private final List<Slice> slices = new ArrayList<>();
        private int currentMaxPosition;

        private FixedWidthAggregator(FixedWidthAggregationFunction function, Input input, Step step)
        {
            this.function = function;
            this.input = input;
            this.step = step;
            this.fixedWidthSize = this.function.getFixedSize();
            this.sliceSize = (int) (BlockBuilder.DEFAULT_MAX_BLOCK_SIZE.toBytes() / fixedWidthSize) * fixedWidthSize;
            Slice slice = Slices.allocate(sliceSize);
            slices.add(slice);
            currentMaxPosition = sliceSize / fixedWidthSize;
        }

        @Override
        public long getEstimatedSize()
        {
            return slices.size() * sliceSize;
        }

        @Override
        public TupleInfo getTupleInfo()
        {
            // if this is a partial, the output is an intermediate value
            if (step == Step.PARTIAL) {
                return function.getIntermediateTupleInfo();
            }
            else {
                return function.getFinalTupleInfo();
            }
        }

        @Override
        public void initialize(int position)
        {
            // add more slices if necessary
            while (position >= currentMaxPosition) {
                Slice slice = Slices.allocate(sliceSize);
                slices.add(slice);
                currentMaxPosition += sliceSize / fixedWidthSize;
            }

            int globalOffset = position * fixedWidthSize;

            int sliceIndex = globalOffset / sliceSize; // todo do this with shifts?
            Slice slice = slices.get(sliceIndex);
            int sliceOffset = globalOffset - (sliceIndex * sliceSize);
            function.initialize(slice, sliceOffset);
        }

        @Override
        public void addValue(BlockCursor[] cursors, int position)
        {
            BlockCursor cursor;
            int field = -1;
            if (input != null) {
                cursor = cursors[input.getChannel()];
                field = input.getField();
            }
            else {
                cursor = null;
            }

            int globalOffset = position * fixedWidthSize;

            int sliceIndex = globalOffset / sliceSize; // todo do this with shifts?
            Slice slice = slices.get(sliceIndex);
            int sliceOffset = globalOffset - (sliceIndex * sliceSize);

            // if this is a final aggregation, the input is an intermediate value
            if (step == Step.FINAL) {
                function.addIntermediate(cursor, field, slice, sliceOffset);
            }
            else {
                function.addInput(cursor, field, slice, sliceOffset);
            }
        }

        @Override
        public void evaluate(int position, BlockBuilder output)
        {
            int offset = position * fixedWidthSize;

            int sliceIndex = offset / sliceSize; // todo do this with shifts
            Slice slice = slices.get(sliceIndex);
            int sliceOffset = offset - (sliceIndex * sliceSize);

            // if this is a partial, the output is an intermediate value
            if (step == Step.PARTIAL) {
                function.evaluateIntermediate(slice, sliceOffset, output);
            }
            else {
                function.evaluateFinal(slice, sliceOffset, output);
            }
        }
    }

    private static class VariableWidthAggregator<T>
            implements Aggregator
    {
        private final VariableWidthAggregationFunction<T> function;
        private final List<Input> inputs;
        private final Step step;
        private final ObjectArrayList<T> intermediateValues;
        private long totalElementSizeInBytes;

        private final BlockCursor[] blockCursors;
        private final int[] fields;

        private VariableWidthAggregator(VariableWidthAggregationFunction<T> function, List<Input> inputs, Step step, int expectedGroups)
        {
            this.function = function;
            this.inputs = inputs;
            this.step = step;
            this.intermediateValues = new ObjectArrayList<>(expectedGroups);

            this.blockCursors = new BlockCursor[inputs.size()];
            this.fields = new int[inputs.size()];

            for (int i = 0; i < fields.length; i++) {
                fields[i] = inputs.get(i).getField();
            }
        }

        @Override
        public long getEstimatedSize()
        {
            return SizeOf.sizeOf(intermediateValues.elements()) + totalElementSizeInBytes;
        }

        @Override
        public TupleInfo getTupleInfo()
        {
            // if this is a partial, the output is an intermediate value
            if (step == Step.PARTIAL) {
                return function.getIntermediateTupleInfo();
            }
            else {
                return function.getFinalTupleInfo();
            }
        }

        @Override
        public void initialize(int position)
        {
            Preconditions.checkState(position == intermediateValues.size(), "expected array to grow by 1");
            intermediateValues.add(function.initialize());
        }

        @Override
        public void addValue(BlockCursor[] cursors, int position)
        {
            for (int i = 0; i < blockCursors.length; i++) {
                blockCursors[i] = cursors[inputs.get(i).getChannel()];
            }

            // if this is a final aggregation, the input is an intermediate value
            T oldValue = intermediateValues.get(position);
            long oldSize = 0;
            if (oldValue != null) {
                oldSize = function.estimateSizeInBytes(oldValue);
            }

            T newValue;
            if (step == Step.FINAL) {
                newValue = function.addIntermediate(blockCursors, fields, oldValue);
            }
            else {
                newValue = function.addInput(blockCursors, fields, oldValue);
            }
            intermediateValues.set(position, newValue);

            long newSize = 0;
            if (newValue != null) {
                newSize = function.estimateSizeInBytes(newValue);
            }
            totalElementSizeInBytes += newSize - oldSize;
        }

        @Override
        public void evaluate(int position, BlockBuilder output)
        {
            T value = intermediateValues.get(position);
            // if this is a partial, the output is an intermediate value
            if (step == Step.PARTIAL) {
                function.evaluateIntermediate(value, output);
            }
            else {
                function.evaluateFinal(value, output);
            }
        }
    }

    public static class SliceHashStrategy
            implements Strategy
    {
        private final TupleInfo tupleInfo;
        private final List<Slice> slices;
        private Slice lookupSlice;
        private long memorySize;

        public SliceHashStrategy(TupleInfo tupleInfo)
        {
            this.tupleInfo = tupleInfo;
            this.slices = ObjectArrayList.wrap(new Slice[1024], 0);
        }

        public long getEstimatedSize()
        {
            return memorySize;
        }

        public void setLookupSlice(Slice lookupSlice)
        {
            this.lookupSlice = lookupSlice;
        }

        public void addSlice(Slice slice)
        {
            memorySize += slice.length();
            slices.add(slice);
        }

        @Override
        public int hashCode(long sliceAddress)
        {
            Slice slice = getSliceForSyntheticAddress(sliceAddress);
            int offset = (int) sliceAddress;
            int length = tupleInfo.size(slice, offset);
            int hashCode = slice.hashCode(offset, length);
            return hashCode;
        }

        @Override
        public boolean equals(long leftSliceAddress, long rightSliceAddress)
        {
            Slice leftSlice = getSliceForSyntheticAddress(leftSliceAddress);
            int leftOffset = decodeSliceOffset(leftSliceAddress);
            int leftLength = tupleInfo.size(leftSlice, leftOffset);

            Slice rightSlice = getSliceForSyntheticAddress(rightSliceAddress);
            int rightOffset = decodeSliceOffset(rightSliceAddress);
            int rightLength = tupleInfo.size(rightSlice, rightOffset);

            return leftSlice.equals(leftOffset, leftLength, rightSlice, rightOffset, rightLength);
        }

        private Slice getSliceForSyntheticAddress(long sliceAddress)
        {
            int sliceIndex = decodeSliceIndex(sliceAddress);
            Slice slice;
            if (sliceIndex == LOOKUP_SLICE_INDEX) {
                slice = lookupSlice;
            }
            else {
                slice = slices.get(sliceIndex);
            }
            return slice;
        }
    }
}
TOP

Related Classes of com.facebook.presto.operator.HashAggregationOperator$VariableWidthAggregator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.