Package com.facebook.presto.sql.planner

Source Code of com.facebook.presto.sql.planner.TestEqualityInference

/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner;

import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.tree.ArithmeticExpression;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.QualifiedNameReference;
import com.facebook.presto.util.IterableTransformer;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import org.testng.Assert;
import org.testng.annotations.Test;

import java.util.Arrays;
import java.util.Set;

import static com.facebook.presto.sql.tree.ComparisonExpression.Type.EQUAL;
import static com.facebook.presto.sql.tree.ComparisonExpression.Type.GREATER_THAN;
import static com.google.common.base.Predicates.not;

public class TestEqualityInference
{
    @Test
    public void testTransitivity()
            throws Exception
    {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        addEquality("a1", "b1", builder);
        addEquality("b1", "c1", builder);
        addEquality("d1", "c1", builder);

        addEquality("a2", "b2", builder);
        addEquality("b2", "a2", builder);
        addEquality("b2", "c2", builder);
        addEquality("d2", "b2", builder);
        addEquality("c2", "d2", builder);

        EqualityInference inference = builder.build();

        Assert.assertEquals(
                inference.rewriteExpression(someExpression("a1", "a2"), matchesSymbols("d1", "d2")),
                someExpression("d1", "d2"));

        Assert.assertEquals(
                inference.rewriteExpression(someExpression("a1", "c1"), matchesSymbols("b1")),
                someExpression("b1", "b1"));

        Assert.assertEquals(
                inference.rewriteExpression(someExpression("a1", "a2"), matchesSymbols("b1", "d2", "c3")),
                someExpression("b1", "d2"));

        // Both starting expressions should canonicalize to the same expression
        Assert.assertEquals(
                inference.getScopedCanonical(nameReference("a2"), matchesSymbols("c2", "d2")),
                inference.getScopedCanonical(nameReference("b2"), matchesSymbols("c2", "d2")));
        Expression canonical = inference.getScopedCanonical(nameReference("a2"), matchesSymbols("c2", "d2"));

        // Given multiple translatable candidates, should choose the canonical
        Assert.assertEquals(
                inference.rewriteExpression(someExpression("a2", "b2"), matchesSymbols("c2", "d2")),
                someExpression(canonical, canonical));
    }

    @Test
    public void testTriviallyRewritable()
            throws Exception
    {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        Expression expression = builder.build()
                .rewriteExpression(someExpression("a1", "a2"), matchesSymbols("a1", "a2"));

        Assert.assertEquals(expression, someExpression("a1", "a2"));
    }

    @Test
    public void testUnrewritable()
            throws Exception
    {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        addEquality("a1", "b1", builder);
        addEquality("a2", "b2", builder);
        EqualityInference inference = builder.build();

        Assert.assertNull(inference.rewriteExpression(someExpression("a1", "a2"), matchesSymbols("b1", "c1")));
        Assert.assertNull(inference.rewriteExpression(someExpression("c1", "c2"), matchesSymbols("a1", "a2")));
    }

    @Test
    public void testParseEqualityExpression()
            throws Exception
    {
        EqualityInference inference = new EqualityInference.Builder()
                .addEquality(equals("a1", "b1"))
                .addEquality(equals("a1", "c1"))
                .addEquality(equals("c1", "a1"))
                .build();

        Expression expression = inference.rewriteExpression(someExpression("a1", "b1"), matchesSymbols("c1"));
        Assert.assertEquals(expression, someExpression("c1", "c1"));
    }

    @Test(expectedExceptions = IllegalArgumentException.class)
    public void testInvalidEqualityExpression1()
            throws Exception
    {
        new EqualityInference.Builder()
                .addEquality(equals("a1", "a1"));
    }

    @Test(expectedExceptions = IllegalArgumentException.class)
    public void testInvalidEqualityExpression2()
            throws Exception
    {
        new EqualityInference.Builder()
                .addEquality(someExpression("a1", "b1"));
    }

    @Test(expectedExceptions = IllegalArgumentException.class)
    public void testInvalidEqualityExpression3()
            throws Exception
    {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        addEquality("a1", "a1", builder);
    }

    @Test
    public void testExtractInferrableEqualities()
            throws Exception
    {
        EqualityInference inference = new EqualityInference.Builder()
                .extractInferenceCandidates(ExpressionUtils.and(equals("a1", "b1"), equals("b1", "c1"), someExpression("c1", "d1")))
                .build();

        // Able to rewrite to c1 due to equalities
        Assert.assertEquals(nameReference("c1"), inference.rewriteExpression(nameReference("a1"), matchesSymbols("c1")));

        // But not be able to rewrite to d1 which is not connected via equality
        Assert.assertNull(inference.rewriteExpression(nameReference("a1"), matchesSymbols("d1")));
    }

    @Test
    public void testEqualityPartitionGeneration()
            throws Exception
    {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        builder.addEquality(nameReference("a1"), nameReference("b1"));
        builder.addEquality(add("a1", "a1"), multiply(nameReference("a1"), number(2)));
        builder.addEquality(nameReference("b1"), nameReference("c1"));
        builder.addEquality(add("a1", "a1"), nameReference("c1"));
        builder.addEquality(add("a1", "b1"), nameReference("c1"));

        EqualityInference inference = builder.build();

        EqualityInference.EqualityPartition emptyScopePartition = inference.generateEqualitiesPartitionedBy(Predicates.<Symbol>alwaysFalse());
        // Cannot generate any scope equalities with no matching symbols
        Assert.assertTrue(emptyScopePartition.getScopeEqualities().isEmpty());
        // All equalities should be represented in the inverse scope
        Assert.assertFalse(emptyScopePartition.getScopeComplementEqualities().isEmpty());
        // There should be no equalities straddling the scope
        Assert.assertTrue(emptyScopePartition.getScopeStraddlingEqualities().isEmpty());

        EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesSymbols("c1"));

        // There should be equalities in the scope, that only use c1 and are all inferrable equalities
        Assert.assertFalse(equalityPartition.getScopeEqualities().isEmpty());
        Assert.assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), matchesSymbolScope(matchesSymbols("c1"))));
        Assert.assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), EqualityInference.isInferenceCandidate()));

        // There should be equalities in the inverse scope, that never use c1 and are all inferrable equalities
        Assert.assertFalse(equalityPartition.getScopeComplementEqualities().isEmpty());
        Assert.assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), matchesSymbolScope(not(matchesSymbols("c1")))));
        Assert.assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), EqualityInference.isInferenceCandidate()));

        // There should be equalities in the straddling scope, that should use both c1 and not c1 symbols
        Assert.assertFalse(equalityPartition.getScopeStraddlingEqualities().isEmpty());
        Assert.assertTrue(Iterables.any(equalityPartition.getScopeStraddlingEqualities(), matchesStraddlingScope(matchesSymbols("c1"))));
        Assert.assertTrue(Iterables.all(equalityPartition.getScopeStraddlingEqualities(), EqualityInference.isInferenceCandidate()));

        // There should be a "full cover" of all of the equalities used
        // THUS, we should be able to plug the generated equalities back in and get an equivalent set of equalities back the next time around
        EqualityInference newInference = new EqualityInference.Builder()
                .addAllEqualities(equalityPartition.getScopeEqualities())
                .addAllEqualities(equalityPartition.getScopeComplementEqualities())
                .addAllEqualities(equalityPartition.getScopeStraddlingEqualities())
                .build();

        EqualityInference.EqualityPartition newEqualityPartition = newInference.generateEqualitiesPartitionedBy(matchesSymbols("c1"));

        Assert.assertEquals(setCopy(equalityPartition.getScopeEqualities()), setCopy(newEqualityPartition.getScopeEqualities()));
        Assert.assertEquals(setCopy(equalityPartition.getScopeComplementEqualities()), setCopy(newEqualityPartition.getScopeComplementEqualities()));
        Assert.assertEquals(setCopy(equalityPartition.getScopeStraddlingEqualities()), setCopy(newEqualityPartition.getScopeStraddlingEqualities()));
    }

    @Test
    public void testMultipleEqualitySetsPredicateGeneration()
            throws Exception
    {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        addEquality("a1", "b1", builder);
        addEquality("b1", "c1", builder);
        addEquality("c1", "d1", builder);

        addEquality("a2", "b2", builder);
        addEquality("b2", "c2", builder);
        addEquality("c2", "d2", builder);

        EqualityInference inference = builder.build();

        // Generating equalities for disjoint groups
        EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(symbolBeginsWith("a", "b"));

        // There should be equalities in the scope, that only use a* and b* symbols and are all inferrable equalities
        Assert.assertFalse(equalityPartition.getScopeEqualities().isEmpty());
        Assert.assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), matchesSymbolScope(symbolBeginsWith("a", "b"))));
        Assert.assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), EqualityInference.isInferenceCandidate()));

        // There should be equalities in the inverse scope, that never use a* and b* symbols and are all inferrable equalities
        Assert.assertFalse(equalityPartition.getScopeComplementEqualities().isEmpty());
        Assert.assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), matchesSymbolScope(not(symbolBeginsWith("a", "b")))));
        Assert.assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), EqualityInference.isInferenceCandidate()));

        // There should be equalities in the straddling scope, that should use both c1 and not c1 symbols
        Assert.assertFalse(equalityPartition.getScopeStraddlingEqualities().isEmpty());
        Assert.assertTrue(Iterables.any(equalityPartition.getScopeStraddlingEqualities(), matchesStraddlingScope(symbolBeginsWith("a", "b"))));
        Assert.assertTrue(Iterables.all(equalityPartition.getScopeStraddlingEqualities(), EqualityInference.isInferenceCandidate()));

        // Again, there should be a "full cover" of all of the equalities used
        // THUS, we should be able to plug the generated equalities back in and get an equivalent set of equalities back the next time around
        EqualityInference newInference = new EqualityInference.Builder()
                .addAllEqualities(equalityPartition.getScopeEqualities())
                .addAllEqualities(equalityPartition.getScopeComplementEqualities())
                .addAllEqualities(equalityPartition.getScopeStraddlingEqualities())
                .build();

        EqualityInference.EqualityPartition newEqualityPartition = newInference.generateEqualitiesPartitionedBy(symbolBeginsWith("a", "b"));

        Assert.assertEquals(setCopy(equalityPartition.getScopeEqualities()), setCopy(newEqualityPartition.getScopeEqualities()));
        Assert.assertEquals(setCopy(equalityPartition.getScopeComplementEqualities()), setCopy(newEqualityPartition.getScopeComplementEqualities()));
        Assert.assertEquals(setCopy(equalityPartition.getScopeStraddlingEqualities()), setCopy(newEqualityPartition.getScopeStraddlingEqualities()));
    }

    @Test
    public void testSubExpressionRewrites()
            throws Exception
    {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        builder.addEquality(nameReference("a1"), add("b", "c")); // a1 = b + c
        builder.addEquality(nameReference("a2"), multiply(nameReference("b"), add("b", "c"))); // a2 = b * (b + c)
        builder.addEquality(nameReference("a3"), multiply(nameReference("a1"), add("b", "c"))); // a3 = a1 * (b + c)
        EqualityInference inference = builder.build();

        // Expression (b + c) should get entirely rewritten as a1
        Assert.assertEquals(inference.rewriteExpression(add("b", "c"), symbolBeginsWith("a")), nameReference("a1"));

        // Only the sub-expression (b + c) should get rewritten in terms of a*
        Assert.assertEquals(inference.rewriteExpression(multiply(nameReference("ax"), add("b", "c")), symbolBeginsWith("a")), multiply(nameReference("ax"), nameReference("a1")));

        // To be compliant, could rewrite either the whole expression, or just the sub-expression. Rewriting larger expressions are preferred
        Assert.assertEquals(inference.rewriteExpression(multiply(nameReference("a1"), add("b", "c")), symbolBeginsWith("a")), nameReference("a3"));
    }

    @Test
    public void testConstantEqualities()
            throws Exception
    {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        addEquality("a1", "b1", builder);
        addEquality("b1", "c1", builder);
        builder.addEquality(nameReference("c1"), number(1));
        EqualityInference inference = builder.build();

        // Should always prefer a constant if available (constant is part of all scopes)
        Assert.assertEquals(inference.rewriteExpression(nameReference("a1"), matchesSymbols("a1", "b1")), number(1));

        // All scope equalities should utilize the constant if possible
        EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesSymbols("a1", "b1"));
        Assert.assertEquals(equalitiesAsSets(equalityPartition.getScopeEqualities()),
                set(set(nameReference("a1"), number(1)), set(nameReference("b1"), number(1))));
        Assert.assertEquals(equalitiesAsSets(equalityPartition.getScopeComplementEqualities()),
                set(set(nameReference("c1"), number(1))));

        // There should be no scope straddling equalities as the full set of equalities should be already represented by the scope and inverse scope
        Assert.assertTrue(equalityPartition.getScopeStraddlingEqualities().isEmpty());
    }

    private static Predicate<Expression> matchesSymbolScope(final Predicate<Symbol> symbolScope)
    {
        return new Predicate<Expression>()
        {
            @Override
            public boolean apply(Expression expression)
            {
                return Iterables.all(DependencyExtractor.extractUnique(expression), symbolScope);
            }
        };
    }

    private static Predicate<Expression> matchesStraddlingScope(final Predicate<Symbol> symbolScope)
    {
        return new Predicate<Expression>()
        {
            @Override
            public boolean apply(Expression expression)
            {
                Set<Symbol> symbols = DependencyExtractor.extractUnique(expression);
                return Iterables.any(symbols, symbolScope) && Iterables.any(symbols, not(symbolScope));
            }
        };
    }

    private static void addEquality(String symbol1, String symbol2, EqualityInference.Builder builder)
    {
        builder.addEquality(nameReference(symbol1), nameReference(symbol2));
    }

    private static Expression someExpression(String symbol1, String symbol2)
    {
        return someExpression(nameReference(symbol1), nameReference(symbol2));
    }

    private static Expression someExpression(Expression expression1, Expression expression2)
    {
        return new ComparisonExpression(GREATER_THAN, expression1, expression2);
    }

    private static Expression add(String symbol1, String symbol2)
    {
        return add(nameReference(symbol1), nameReference(symbol2));
    }

    private static Expression add(Expression expression1, Expression expression2)
    {
        return new ArithmeticExpression(ArithmeticExpression.Type.ADD, expression1, expression2);
    }

    private static Expression multiply(String symbol1, String symbol2)
    {
        return multiply(nameReference(symbol1), nameReference(symbol2));
    }

    private static Expression multiply(Expression expression1, Expression expression2)
    {
        return new ArithmeticExpression(ArithmeticExpression.Type.MULTIPLY, expression1, expression2);
    }

    private static Expression equals(String symbol1, String symbol2)
    {
        return equals(nameReference(symbol1), nameReference(symbol2));
    }

    private static Expression equals(Expression expression1, Expression expression2)
    {
        return new ComparisonExpression(EQUAL, expression1, expression2);
    }

    private static QualifiedNameReference nameReference(String symbol)
    {
        return new QualifiedNameReference(new Symbol(symbol).toQualifiedName());
    }

    private static LongLiteral number(long number)
    {
        return new LongLiteral(String.valueOf(number));
    }

    private static Predicate<Symbol> matchesSymbols(String... symbols)
    {
        return matchesSymbols(Arrays.asList(symbols));
    }

    private static Predicate<Symbol> matchesSymbols(Iterable<String> symbols)
    {
        final Set<Symbol> symbolSet = IterableTransformer.<String>on(symbols)
                .transform(new Function<String, Symbol>()
                {
                    @Override
                    public Symbol apply(String symbol)
                    {
                        return new Symbol(symbol);
                    }
                }).set();
        return Predicates.in(symbolSet);
    }

    private static Predicate<Symbol> symbolBeginsWith(String... prefixes)
    {
        return symbolBeginsWith(Arrays.asList(prefixes));
    }

    private static Predicate<Symbol> symbolBeginsWith(final Iterable<String> prefixes)
    {
        return new Predicate<Symbol>()
        {
            @Override
            public boolean apply(Symbol symbol)
            {
                for (String prefix : prefixes) {
                    if (symbol.getName().startsWith(prefix)) {
                        return true;
                    }
                }
                return false;
            }
        };
    }

    private static Set<Set<Expression>> equalitiesAsSets(Iterable<Expression> expressions)
    {
        ImmutableSet.Builder<Set<Expression>> builder = ImmutableSet.builder();
        for (Expression expression : expressions) {
            builder.add(equalityAsSet(expression));
        }
        return builder.build();
    }

    private static Set<Expression> equalityAsSet(Expression expression)
    {
        Preconditions.checkArgument(expression instanceof ComparisonExpression);
        ComparisonExpression comparisonExpression = (ComparisonExpression) expression;
        Preconditions.checkArgument(comparisonExpression.getType() == EQUAL);
        return ImmutableSet.of(comparisonExpression.getLeft(), comparisonExpression.getRight());
    }

    private static <E> Set<E> set(E... elements)
    {
        return setCopy(Arrays.asList(elements));
    }

    private static <E> Set<E> setCopy(Iterable<E> elements)
    {
        return ImmutableSet.copyOf(elements);
    }
}
TOP

Related Classes of com.facebook.presto.sql.planner.TestEqualityInference

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.