Package com.facebook.presto.operator.aggregation

Source Code of com.facebook.presto.operator.aggregation.CountAggregation$CountAccumulatorFactory

/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.operator.GroupByIdBlock;
import com.facebook.presto.operator.Page;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.block.BlockBuilderStatus;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.util.array.LongBigArray;
import com.google.common.base.Optional;
import com.google.common.collect.ImmutableList;

import java.util.List;

import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.BooleanType.BOOLEAN;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;

public class CountAggregation
        implements InternalAggregationFunction
{
    public static final CountAggregation COUNT = new CountAggregation();

    @Override
    public String name()
    {
        return "count";
    }

    @Override
    public List<Type> getParameterTypes()
    {
        return ImmutableList.of();
    }

    @Override
    public Type getFinalType()
    {
        return BIGINT;
    }

    @Override
    public Type getIntermediateType()
    {
        return BIGINT;
    }

    @Override
    public boolean isDecomposable()
    {
        return true;
    }

    @Override
    public boolean isApproximate()
    {
        return false;
    }

    @Override
    public AccumulatorFactory bind(List<Integer> inputChannels, Optional<Integer> maskChannel, Optional<Integer> sampleWeightChannel, double confidence)
    {
        checkArgument(confidence == 1.0 && !sampleWeightChannel.isPresent(), "approximate not supported for count(*)");
        return new CountAccumulatorFactory(inputChannels, maskChannel);
    }

    public static class CountAccumulatorFactory
            implements AccumulatorFactory
    {
        private final List<Integer> inputChannels;
        private final Optional<Integer> maskChannel;

        public CountAccumulatorFactory(List<Integer> inputChannels, Optional<Integer> maskChannel)
        {
            this.inputChannels = ImmutableList.copyOf(checkNotNull(inputChannels, "inputChannels is null"));
            this.maskChannel = checkNotNull(maskChannel, "maskChannel is null");
        }

        @Override
        public CountGroupedAccumulator createGroupedAccumulator()
        {
            return new CountGroupedAccumulator(maskChannel);
        }

        @Override
        public GroupedAccumulator createGroupedIntermediateAccumulator()
        {
            return new CountGroupedAccumulator(Optional.<Integer>absent());
        }

        public static class CountGroupedAccumulator
                implements GroupedAccumulator
        {
            private final LongBigArray counts;
            private final Optional<Integer> maskChannel;

            public CountGroupedAccumulator(Optional<Integer> maskChannel)
            {
                this.counts = new LongBigArray();
                this.maskChannel = maskChannel;
            }

            @Override
            public long getEstimatedSize()
            {
                return counts.sizeOf();
            }

            @Override
            public Type getFinalType()
            {
                return BIGINT;
            }

            @Override
            public Type getIntermediateType()
            {
                return BIGINT;
            }

            @Override
            public void addInput(GroupByIdBlock groupIdsBlock, Page page)
            {
                counts.ensureCapacity(groupIdsBlock.getGroupCount());
                Block masks = maskChannel.isPresent() ? page.getBlock(maskChannel.get()) : null;

                for (int position = 0; position < groupIdsBlock.getPositionCount(); position++) {
                    long groupId = groupIdsBlock.getGroupId(position);
                    if (masks == null || BOOLEAN.getBoolean(masks, position)) {
                        counts.increment(groupId);
                    }
                }
            }

            @Override
            public void addIntermediate(GroupByIdBlock groupIdsBlock, Block intermediates)
            {
                counts.ensureCapacity(groupIdsBlock.getGroupCount());

                for (int position = 0; position < groupIdsBlock.getPositionCount(); position++) {
                    long groupId = groupIdsBlock.getGroupId(position);
                    counts.add(groupId, BIGINT.getLong(intermediates, position));
                }
            }

            @Override
            public void evaluateIntermediate(int groupId, BlockBuilder output)
            {
                evaluateFinal(groupId, output);
            }

            @Override
            public void evaluateFinal(int groupId, BlockBuilder output)
            {
                long value = counts.get((long) groupId);
                BIGINT.writeLong(output, value);
            }
        }

        @Override
        public List<Integer> getInputChannels()
        {
            return inputChannels;
        }

        @Override
        public CountAccumulator createAccumulator()
        {
            return new CountAccumulator(maskChannel);
        }

        @Override
        public CountAccumulator createIntermediateAccumulator()
        {
            return new CountAccumulator(Optional.<Integer>absent());
        }

        public static class CountAccumulator
                implements Accumulator
        {
            private long count;
            private final Optional<Integer> maskChannel;

            public CountAccumulator(Optional<Integer> maskChannel)
            {
                this.maskChannel = maskChannel;
            }

            @Override
            public long getEstimatedSize()
            {
                return 0;
            }

            @Override
            public Type getFinalType()
            {
                return BIGINT;
            }

            @Override
            public Type getIntermediateType()
            {
                return BIGINT;
            }

            @Override
            public void addInput(Page page)
            {
                if (!maskChannel.isPresent()) {
                    count += page.getPositionCount();
                }
                else {
                    Block masks = page.getBlock(maskChannel.get());
                    for (int position = 0; position < page.getPositionCount(); position++) {
                        if (masks == null || BOOLEAN.getBoolean(masks, position)) {
                            count++;
                        }
                    }
                }
            }

            @Override
            public void addIntermediate(Block intermediates)
            {
                for (int position = 0; position < intermediates.getPositionCount(); position++) {
                    count += BIGINT.getLong(intermediates, position);
                }
            }

            @Override
            public final Block evaluateIntermediate()
            {
                return evaluateFinal();
            }

            @Override
            public final Block evaluateFinal()
            {
                BlockBuilder out = getFinalType().createBlockBuilder(new BlockBuilderStatus());

                BIGINT.writeLong(out, count);

                return out.build();
            }
        }
    }
}
TOP

Related Classes of com.facebook.presto.operator.aggregation.CountAggregation$CountAccumulatorFactory

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.