Package com.facebook.presto.operator.aggregation

Source Code of com.facebook.presto.operator.aggregation.AbstractTestApproximateCountDistinct

/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.block.Block;
import com.facebook.presto.block.BlockAssertions;
import com.facebook.presto.block.BlockBuilder;
import com.facebook.presto.block.BlockCursor;
import com.facebook.presto.operator.AggregationOperator.Aggregator;
import com.facebook.presto.operator.Page;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.tree.Input;
import com.facebook.presto.tuple.TupleInfo;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import org.apache.commons.math.stat.descriptive.DescriptiveStatistics;
import org.testng.annotations.Test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;

import static com.facebook.presto.operator.AggregationFunctionDefinition.aggregation;
import static com.facebook.presto.operator.AggregationOperator.createAggregator;
import static io.airlift.testing.Assertions.assertLessThan;
import static org.testng.Assert.assertEquals;

public abstract class AbstractTestApproximateCountDistinct
{
    public abstract ApproximateCountDistinctAggregation getAggregationFunction();

    public abstract TupleInfo.Type getValueType();

    public abstract Object randomValue();

    @Test
    public void testNoPositions()
            throws Exception
    {
        assertCount(ImmutableList.<Object>of(), 0);
    }

    @Test
    public void testSinglePosition()
            throws Exception
    {
        assertCount(ImmutableList.<Object>of(randomValue()), 1);
    }

    @Test
    public void testAllPositionsNull()
            throws Exception
    {
        assertCount(Collections.<Object>nCopies(100, null), 0);
    }

    @Test
    public void testMixedNullsAndNonNulls()
            throws Exception
    {
        List<Object> baseline = createRandomSample(10000, 15000);

        List<Object> mixed = new ArrayList<>(baseline);
        mixed.addAll(Collections.<Long>nCopies(baseline.size(), null));
        Collections.shuffle(mixed);

        assertCount(mixed, estimateCount(baseline, 0));
    }

    @Test
    public void testMultiplePositions()
            throws Exception
    {
        DescriptiveStatistics stats = new DescriptiveStatistics();

        for (int i = 0; i < 500; ++i) {
            int uniques = ThreadLocalRandom.current().nextInt(20000) + 1;

            List<Object> values = createRandomSample(uniques, (int) (uniques * 1.5));

            long actual = estimateCount(values, 0);
            double error = (actual - uniques) * 1.0 / uniques;

            stats.addValue(error);
        }

        assertLessThan(stats.getMean(), 1e-2);
        assertLessThan(Math.abs(stats.getStandardDeviation() - getAggregationFunction().getStandardError()), 1e-2);
    }

    @Test
    public void testMultiplePositionsPartial()
            throws Exception
    {
        for (int i = 0; i < 100; ++i) {
            int uniques = ThreadLocalRandom.current().nextInt(20000) + 1;
            List<Object> values = createRandomSample(uniques, (int) (uniques * 1.5));
            assertEquals(estimateCountPartial(values, 0), estimateCount(values, 0));
        }
    }

    private void assertCount(List<Object> values, long expectedCount)
    {
        assertCount(values, expectedCount, 0);
        assertCount(values, expectedCount, 1);
    }

    private void assertCount(List<Object> values, long expectedCount, int field)
    {
        assertEquals(estimateCount(values, field), expectedCount);
        assertEquals(estimateCountVectorized(values, field), expectedCount);
        assertEquals(estimateCountPartial(values, field), expectedCount);
    }

    private long estimateCount(List<Object> values, int field)
    {
        Aggregator aggregator = createAggregator(aggregation(getAggregationFunction(), new Input(0, field)), AggregationNode.Step.SINGLE);

        if (!values.isEmpty()) {
            BlockCursor cursor = createBlock(values, field + 1).cursor();
            while (cursor.advanceNextPosition()) {
                aggregator.addValue(cursor);
            }
        }

        return (long) BlockAssertions.toValues(aggregator.getResult()).get(0).get(0);
    }

    private long estimateCountVectorized(List<Object> values, int field)
    {
        Aggregator aggregator = createAggregator(aggregation(getAggregationFunction(), new Input(0, field)), AggregationNode.Step.SINGLE);

        if (!values.isEmpty()) {
            aggregator.addValue(new Page(createBlock(values, field + 1)));
        }

        return (long) BlockAssertions.toValues(aggregator.getResult()).get(0).get(0);
    }

    private long estimateCountPartial(List<Object> values, int field)
    {
        int size = Math.min(values.size(), Math.max(values.size() / 2, 1));

        Block first = aggregatePartial(values.subList(0, size), field);
        Block second = aggregatePartial(values.subList(size, values.size()), field);

        Aggregator aggregator = createAggregator(aggregation(getAggregationFunction(), new Input(0, field)), AggregationNode.Step.FINAL);

        BlockCursor cursor = first.cursor();
        while (cursor.advanceNextPosition()) {
            aggregator.addValue(cursor);
        }

        cursor = second.cursor();
        while (cursor.advanceNextPosition()) {
            aggregator.addValue(cursor);
        }

        return (long) BlockAssertions.toValues(aggregator.getResult()).get(0).get(0);
    }

    private Block aggregatePartial(List<Object> values, int field)
    {
        Aggregator aggregator = createAggregator(aggregation(getAggregationFunction(), new Input(0, field)), AggregationNode.Step.PARTIAL);

        if (!values.isEmpty()) {
            BlockCursor cursor = createBlock(values, field + 1).cursor();
            while (cursor.advanceNextPosition()) {
                aggregator.addValue(cursor);
            }
        }

        return aggregator.getResult();
    }

    /**
     * Produce a block with the given number of fields and the values in the last field. All other fields are set to null.
     */
    private Block createBlock(List<Object> values, int numberOfFields)
    {
        TupleInfo.Type[] types = new TupleInfo.Type[numberOfFields];
        Arrays.fill(types, getValueType());
        types[numberOfFields - 1] = getValueType();

        BlockBuilder blockBuilder = new BlockBuilder(new TupleInfo(types));

        for (Object value : values) {
            for (int i = 0; i < numberOfFields - 1; i++) {
                blockBuilder.appendNull();
            }

            if (value == null) {
                blockBuilder.appendNull();
            }
            else {
                switch (getValueType()) {
                    case FIXED_INT_64:
                        blockBuilder.append((Long) value);
                        break;
                    case VARIABLE_BINARY:
                        blockBuilder.append((Slice) value);
                        break;
                    case DOUBLE:
                        blockBuilder.append((Double) value);
                        break;
                    default:
                        throw new UnsupportedOperationException("not yet implemented");
                }
            }
        }

        return blockBuilder.build();
    }

    private List<Object> createRandomSample(int uniques, int total)
    {
        Preconditions.checkArgument(uniques <= total, "uniques (%s) must be <= total (%s)", uniques, total);

        List<Object> result = new ArrayList<>(total);
        result.addAll(makeRandomSet(uniques));

        Random random = ThreadLocalRandom.current();
        while (result.size() < total) {
            int index = random.nextInt(result.size());
            result.add(result.get(index));
        }

        return result;
    }

    private Set<Object> makeRandomSet(int count)
    {
        Set<Object> result = new HashSet<>();
        while (result.size() < count) {
            result.add(randomValue());
        }

        return result;
    }
}
TOP

Related Classes of com.facebook.presto.operator.aggregation.AbstractTestApproximateCountDistinct

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.