Package org.eigenbase.rel.metadata

Source Code of org.eigenbase.rel.metadata.RelMdPredicates$JoinConditionBasedPredicateInference$ExprsItr

/*
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements.  See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to you under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License.  You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/
package org.eigenbase.rel.metadata;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;

import org.eigenbase.rel.AggregateRelBase;
import org.eigenbase.rel.FilterRelBase;
import org.eigenbase.rel.JoinRelBase;
import org.eigenbase.rel.JoinRelType;
import org.eigenbase.rel.ProjectRelBase;
import org.eigenbase.rel.RelNode;
import org.eigenbase.rel.SortRel;
import org.eigenbase.rel.TableAccessRelBase;
import org.eigenbase.rel.UnionRelBase;
import org.eigenbase.relopt.RelOptPredicateList;
import org.eigenbase.relopt.RelOptUtil;
import org.eigenbase.rex.RexBuilder;
import org.eigenbase.rex.RexCall;
import org.eigenbase.rex.RexInputRef;
import org.eigenbase.rex.RexNode;
import org.eigenbase.rex.RexPermuteInputsShuttle;
import org.eigenbase.rex.RexUtil;
import org.eigenbase.rex.RexVisitorImpl;
import org.eigenbase.sql.SqlKind;
import org.eigenbase.util.mapping.Mapping;
import org.eigenbase.util.mapping.MappingType;
import org.eigenbase.util.mapping.Mappings;

import net.hydromatic.linq4j.Linq4j;
import net.hydromatic.linq4j.Ord;
import net.hydromatic.linq4j.function.Predicate1;

import net.hydromatic.optiq.BuiltinMethod;
import net.hydromatic.optiq.util.BitSets;

import com.google.common.base.Function;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;

/**
* Utility to infer Predicates that are applicable above a RelNode.
*
* <p>This is currently used by
* {@link org.eigenbase.rel.rules.TransitivePredicatesOnJoinRule} to
* infer <em>Predicates</em> that can be inferred from one side of a Join
* to the other.
*
* <p>The PullUp Strategy is sound but not complete. Here are some of the
* limitations:
* <ol>
*
* <li> For Aggregations we only PullUp predicates that only contain
* Grouping Keys. This can be extended to infer predicates on Aggregation
* expressions from  expressions on the aggregated columns. For e.g.
* <pre>
* select a, max(b) from R1 where b > 7 => max(b) > 7 or max(b) is null
* </pre>
*
* <li> For Projections we only look at columns that are projected without
* any function applied. So:
* <pre>
* select a from R1 where a > 7 -> a > 7 is pulledUp from the Projection.
* select a + 1 from R1 where a + 1 > 7 -> a + 1 > 7 is not pulledUp
* </pre>
*
* <li> There are several restrictions on Joins:
*   <ul>
*   <li> We only pullUp inferred predicates for now. Pulling up existing
*   predicates causes an explosion of duplicates. The existing predicates
*   are pushed back down as new predicates. Once we have rules to eliminate
*   duplicate Filter conditions, we should pullUp all predicates.
*
*   <li> For Left Outer: we infer new predicates from the left and set them
*   as applicable on the Right side. No predicates are pulledUp.
*
*   <li> Right Outer Joins are handled in an analogous manner.
*
*   <li> For Full Outer Joins no predicates are pulledUp or inferred.
*   </ul>
* </ol>
*/
public class RelMdPredicates {
  public static final RelMetadataProvider SOURCE = ReflectiveRelMetadataProvider
      .reflectiveSource(BuiltinMethod.PREDICATES.method, new RelMdPredicates());

  private static final List<RexNode> EMPTY_LIST = ImmutableList.of();

  // Catch-all rule when none of the others apply.
  public RelOptPredicateList getPredicates(RelNode rel) {
    return RelOptPredicateList.EMPTY;
  }

  /**
   * Infers predicates for a table scan.
   */
  public RelOptPredicateList getPredicates(TableAccessRelBase table) {
    return RelOptPredicateList.EMPTY;
  }

  /**
   * Infers predicates for a project.
   *
   * <ol>
   * <li>create a mapping from input to projection. Map only positions that
   * directly reference an input column.
   * <li>Expressions that only contain above columns are retained in the
   * Project's pullExpressions list.
   * <li>For e.g. expression 'a + e = 9' below will not be pulled up because 'e'
   * is not in the projection list.
   *
   * <pre>
   * childPullUpExprs:      {a > 7, b + c < 10, a + e = 9}
   * projectionExprs:       {a, b, c, e / 2}
   * projectionPullupExprs: {a > 7, b + c < 10}
   * </pre>
   *
   * </ol>
   */
  public RelOptPredicateList getPredicates(ProjectRelBase project) {
    RelNode child = project.getChild();
    RelOptPredicateList childInfo =
        RelMetadataQuery.getPulledUpPredicates(child);

    List<RexNode> projectPullUpPredicates = new ArrayList<RexNode>();

    BitSet columnsMapped = new BitSet(child.getRowType().getFieldCount());
    Mapping m = Mappings.create(MappingType.PARTIAL_FUNCTION,
        child.getRowType().getFieldCount(),
        project.getRowType().getFieldCount());

    for (Ord<RexNode> o : Ord.zip(project.getProjects())) {
      if (o.e instanceof RexInputRef) {
        int sIdx = ((RexInputRef) o.e).getIndex();
        m.set(sIdx, o.i);
        columnsMapped.set(sIdx);
      }
    }

    // Go over childPullUpPredicates. If a predicate only contains columns in
    // 'columnsMapped' construct a new predicate based on mapping.
    for (RexNode r : childInfo.pulledUpPredicates) {
      BitSet rCols = RelOptUtil.InputFinder.bits(r);
      if (BitSets.contains(columnsMapped, rCols)) {
        r = r.accept(new RexPermuteInputsShuttle(m, child));
        projectPullUpPredicates.add(r);
      }
    }
    return RelOptPredicateList.of(projectPullUpPredicates);
  }

  /**
   * Add the Filter condition to the pulledPredicates list from the child.
   */
  public RelOptPredicateList getPredicates(FilterRelBase filter) {
    RelNode child = filter.getChild();
    RelOptPredicateList childInfo =
        RelMetadataQuery.getPulledUpPredicates(child);

    return RelOptPredicateList.of(
        Iterables.concat(childInfo.pulledUpPredicates,
            RelOptUtil.conjunctions(filter.getCondition())));
  }

  public RelOptPredicateList getPredicates(JoinRelBase join) {
    RexBuilder rB = join.getCluster().getRexBuilder();
    RelNode left = join.getInput(0);
    RelNode right = join.getInput(1);

    RelOptPredicateList leftInfo =
        RelMetadataQuery.getPulledUpPredicates(left);
    RelOptPredicateList rightInfo =
        RelMetadataQuery.getPulledUpPredicates(right);

    JoinConditionBasedPredicateInference jI =
        new JoinConditionBasedPredicateInference(join,
            RexUtil.composeConjunction(rB, leftInfo.pulledUpPredicates, false),
            RexUtil.composeConjunction(rB, rightInfo.pulledUpPredicates,
                false));

    return jI.inferPredicates(false);
  }

  /**
   * Infers predicates for an AggregateRel.
   *
   * <p>Pulls up predicates that only contains references to columns in the
   * GroupSet. For e.g.
   *
   * <pre>
   * childPullUpExprs : { a > 7, b + c < 10, a + e = 9}
   * groupSet         : { a, b}
   * pulledUpExprs    : { a > 7}
   * </pre>
   */
  public RelOptPredicateList getPredicates(AggregateRelBase agg) {
    RelNode child = agg.getChild();
    RelOptPredicateList childInfo =
        RelMetadataQuery.getPulledUpPredicates(child);

    List<RexNode> aggPullUpPredicates = new ArrayList<RexNode>();

    BitSet groupKeys = agg.getGroupSet();
    Mapping m = Mappings.create(MappingType.PARTIAL_FUNCTION,
        child.getRowType().getFieldCount(), agg.getRowType().getFieldCount());

    int i = 0;
    for (int j : BitSets.toIter(groupKeys)) {
      m.set(j, i++);
    }

    for (RexNode r : childInfo.pulledUpPredicates) {
      BitSet rCols = RelOptUtil.InputFinder.bits(r);
      if (BitSets.contains(groupKeys, rCols)) {
        r = r.accept(new RexPermuteInputsShuttle(m, child));
        aggPullUpPredicates.add(r);
      }
    }
    return RelOptPredicateList.of(aggPullUpPredicates);
  }

  /**
   * Infers predicates for a UnionRelBase.
   *
   * <p>The pulled up expression is a disjunction of its children's predicates.
   */
  public RelOptPredicateList getPredicates(UnionRelBase union) {
    RexBuilder rB = union.getCluster().getRexBuilder();
    List<RexNode> orList = Lists.newArrayList();
    for (RelNode input : union.getInputs()) {
      RelOptPredicateList info = RelMetadataQuery.getPulledUpPredicates(input);
      if (!info.pulledUpPredicates.isEmpty()) {
        orList.addAll(
            RelOptUtil.disjunctions(
                RexUtil.composeConjunction(rB, info.pulledUpPredicates,
                    false)));
      }
    }

    if (orList.isEmpty()) {
      return RelOptPredicateList.EMPTY;
    }
    return RelOptPredicateList.of(
        RelOptUtil.conjunctions(RexUtil.composeDisjunction(rB, orList, false)));
  }

  /**
   * Infers predicates for a SortRel.
   */
  public RelOptPredicateList getPredicates(SortRel sort) {
    RelNode child = sort.getInput(0);
    return RelMetadataQuery.getPulledUpPredicates(child);
  }

  /**
   * Utility to infer predicates from one side of the join that apply on the
   * other side. Contract is: - initialize with a {@link JoinRelBase} and
   * optional predicates applicable on its left and right subtrees. - you can
   * then ask it for equivalentPredicate(s) given a predicate.
   * <p>
   * So for:
   * <ol>
   * <li>'<code>R1(x) join R2(y) on x = y</code>' a call for
   * equivalentPredciates on '<code>x > 7</code>' will return '
   * <code>[y > 7]</code>'
   * <li>'<code>R1(x) join R2(y) on x = y join R3(z) on y = z</code>' a call for
   * equivalentPredciates on the second join '<code>x > 7</code>' will return '
   * <code>[y > 7, z > 7]</code>'
   * </ol>
   */
  static class JoinConditionBasedPredicateInference {
    final JoinRelBase joinRel;
    final int nSysFields;
    final int nFieldsLeft;
    final int nFieldsRight;
    final BitSet leftFieldsBitSet;
    final BitSet rightFieldsBitSet;
    final BitSet allFieldsBitSet;
    SortedMap<Integer, BitSet> equivalence;
    final Map<String, BitSet> exprFields;
    final Set<String> allExprsDigests;
    final Set<String> equalityPredicates;
    final RexNode leftChildPredicates;
    final RexNode rightChildPredicates;

    public JoinConditionBasedPredicateInference(JoinRelBase joinRel,
        RexNode lPreds, RexNode rPreds) {
      super();
      this.joinRel = joinRel;
      nFieldsLeft = joinRel.getLeft().getRowType().getFieldList().size();
      nFieldsRight = joinRel.getRight().getRowType().getFieldList().size();
      nSysFields = joinRel.getSystemFieldList().size();
      leftFieldsBitSet = BitSets.range(nSysFields, nSysFields + nFieldsLeft);
      rightFieldsBitSet = BitSets.range(nSysFields + nFieldsLeft,
          nSysFields + nFieldsLeft + nFieldsRight);
      allFieldsBitSet = BitSets.range(0,
          nSysFields + nFieldsLeft + nFieldsRight);

      exprFields = new HashMap<String, BitSet>();
      allExprsDigests = new HashSet<String>();

      if (lPreds == null) {
        leftChildPredicates = null;
      } else {
        Mappings.TargetMapping leftMapping = Mappings.createShiftMapping(
            nSysFields + nFieldsLeft, nSysFields, 0, nFieldsLeft);
        leftChildPredicates = lPreds.accept(
            new RexPermuteInputsShuttle(leftMapping, joinRel.getInput(0)));

        for (RexNode r : RelOptUtil.conjunctions(leftChildPredicates)) {
          exprFields.put(r.toString(), RelOptUtil.InputFinder.bits(r));
          allExprsDigests.add(r.toString());
        }
      }
      if (rPreds == null) {
        rightChildPredicates = null;
      } else {
        Mappings.TargetMapping rightMapping = Mappings.createShiftMapping(
            nSysFields + nFieldsLeft + nFieldsRight,
            nSysFields + nFieldsLeft, 0, nFieldsRight);
        rightChildPredicates = rPreds.accept(
            new RexPermuteInputsShuttle(rightMapping, joinRel.getInput(1)));

        for (RexNode r : RelOptUtil.conjunctions(rightChildPredicates)) {
          exprFields.put(r.toString(), RelOptUtil.InputFinder.bits(r));
          allExprsDigests.add(r.toString());
        }
      }

      equivalence = new TreeMap<Integer, BitSet>();
      equalityPredicates = new HashSet<String>();
      for (int i = 0; i < nSysFields + nFieldsLeft + nFieldsRight; i++) {
        equivalence.put(i, BitSets.of(i));
      }

      // Only process equivalences found in the join conditions. Processing
      // Equivalences from the left or right side infer predicates that are
      // already present in the Tree below the join.
      RexBuilder rexBuilder = joinRel.getCluster().getRexBuilder();
      List<RexNode> exprs =
          RelOptUtil.conjunctions(
              compose(rexBuilder, ImmutableList.of(joinRel.getCondition())));

      final EquivalenceFinder eF = new EquivalenceFinder();
      new ArrayList<Void>(Lists.transform(exprs, new Function<RexNode, Void>() {
        public Void apply(RexNode input) {
          return input.accept(eF);
        }
      }));

      equivalence = BitSets.closure(equivalence);
    }

    /**
     * The PullUp Strategy is sound but not complete.
     * <ol>
     * <li>We only pullUp inferred predicates for now. Pulling up existing
     * predicates causes an explosion of duplicates. The existing predicates are
     * pushed back down as new predicates. Once we have rules to eliminate
     * duplicate Filter conditions, we should pullUp all predicates.
     * <li>For Left Outer: we infer new predicates from the left and set them as
     * applicable on the Right side. No predicates are pulledUp.
     * <li>Right Outer Joins are handled in an analogous manner.
     * <li>For Full Outer Joins no predicates are pulledUp or inferred.
     * </ol>
     */
    public RelOptPredicateList inferPredicates(
        boolean includeEqualityInference) {
      List<RexNode> inferredPredicates = new ArrayList<RexNode>();
      Set<String> allExprsDigests = new HashSet<String>(this.allExprsDigests);
      final JoinRelType joinType = joinRel.getJoinType();
      switch (joinType) {
      case INNER:
      case LEFT:
        infer(leftChildPredicates, allExprsDigests, inferredPredicates,
            includeEqualityInference,
            joinType == JoinRelType.LEFT ? rightFieldsBitSet
                : allFieldsBitSet);
        break;
      }
      switch (joinType) {
      case INNER:
      case RIGHT:
        infer(rightChildPredicates, allExprsDigests, inferredPredicates,
            includeEqualityInference,
            joinType == JoinRelType.RIGHT ? leftFieldsBitSet
                : allFieldsBitSet);
        break;
      }

      Mappings.TargetMapping rightMapping = Mappings.createShiftMapping(
          nSysFields + nFieldsLeft + nFieldsRight,
          0, nSysFields + nFieldsLeft, nFieldsRight);
      final RexPermuteInputsShuttle rightPermute =
          new RexPermuteInputsShuttle(rightMapping, joinRel);
      Mappings.TargetMapping leftMapping = Mappings.createShiftMapping(
          nSysFields + nFieldsLeft, 0, nSysFields, nFieldsLeft);
      final RexPermuteInputsShuttle leftPermute =
          new RexPermuteInputsShuttle(leftMapping, joinRel);
      List<RexNode> leftInferredPredicates = new ArrayList<RexNode>();
      List<RexNode> rightInferredPredicates = new ArrayList<RexNode>();

      for (RexNode iP : inferredPredicates) {
        BitSet iPBitSet = RelOptUtil.InputFinder.bits(iP);
        if (BitSets.contains(leftFieldsBitSet, iPBitSet)) {
          leftInferredPredicates.add(iP.accept(leftPermute));
        } else if (BitSets.contains(rightFieldsBitSet, iPBitSet)) {
          rightInferredPredicates.add(iP.accept(rightPermute));
        }
      }

      switch (joinType) {
      case INNER:
        return RelOptPredicateList.of(
            Iterables.concat(RelOptUtil.conjunctions(leftChildPredicates),
                RelOptUtil.conjunctions(rightChildPredicates),
                RelOptUtil.conjunctions(joinRel.getCondition()),
                inferredPredicates),
            leftInferredPredicates, rightInferredPredicates);
      case LEFT:
        return RelOptPredicateList.of(
            RelOptUtil.conjunctions(leftChildPredicates),
            leftInferredPredicates, rightInferredPredicates);
      case RIGHT:
        return RelOptPredicateList.of(
            RelOptUtil.conjunctions(rightChildPredicates),
            inferredPredicates, EMPTY_LIST);
      default:
        assert inferredPredicates.size() == 0;
        return RelOptPredicateList.EMPTY;
      }
    }

    public RexNode left() {
      return leftChildPredicates;
    }

    public RexNode right() {
      return rightChildPredicates;
    }

    private void infer(RexNode predicates, Set<String> allExprsDigests,
        List<RexNode> inferedPredicates, boolean includeEqualityInference,
        BitSet inferringFields) {
      for (RexNode r : RelOptUtil.conjunctions(predicates)) {
        if (!includeEqualityInference
            && equalityPredicates.contains(r.toString())) {
          continue;
        }
        for (Mapping m : mappings(r)) {
          RexNode tr = r.accept(
              new RexPermuteInputsShuttle(m, joinRel.getInput(0),
                  joinRel.getInput(1)));
          if (BitSets.contains(inferringFields, RelOptUtil.InputFinder.bits(tr))
              && !allExprsDigests.contains(tr.toString())
              && !isAlwaysTrue(tr)) {
            inferedPredicates.add(tr);
            allExprsDigests.add(tr.toString());
          }
        }
      }
    }

    Iterable<Mapping> mappings(final RexNode predicate) {
      return new Iterable<Mapping>() {
        public Iterator<Mapping> iterator() {
          BitSet fields = exprFields.get(predicate.toString());
          if (fields.cardinality() == 0) {
            return Iterators.emptyIterator();
          }
          return new ExprsItr(fields);
        }
      };
    }

    private void equivalent(int p1, int p2) {
      BitSet b = equivalence.get(p1);
      b.set(p2);

      b = equivalence.get(p2);
      b.set(p1);
    }

    RexNode compose(RexBuilder rexBuilder, Iterable<RexNode> exprs) {
      exprs = Linq4j.asEnumerable(exprs).where(new Predicate1<RexNode>() {
        public boolean apply(RexNode expr) {
          return expr != null;
        }
      });
      return RexUtil.composeConjunction(rexBuilder, exprs, false);
    }

    /**
     * Find expressions of the form 'col_x = col_y'.
     */
    class EquivalenceFinder extends RexVisitorImpl<Void> {
      protected EquivalenceFinder() {
        super(true);
      }

      @Override
      public Void visitCall(RexCall call) {
        if (call.getOperator().getKind() == SqlKind.EQUALS) {
          int lPos = pos(call.getOperands().get(0));
          int rPos = pos(call.getOperands().get(1));
          if (lPos != -1 && rPos != -1) {
            JoinConditionBasedPredicateInference.this.equivalent(lPos, rPos);
            JoinConditionBasedPredicateInference.this.equalityPredicates
                .add(call.toString());
          }
        }
        return null;
      }
    }

    /**
     * Given an expression returns all the possible substitutions.
     *
     * <p>For example, for an expression 'a + b + c' and the following
     * equivalences: <pre>
     * a : {a, b}
     * b : {a, b}
     * c : {c, e}
     * </pre>
     *
     * <p>The following Mappings will be returned:
     * <pre>
     * {a->a, b->a, c->c}
     * {a->a, b->a, c->e}
     * {a->a, b->b, c->c}
     * {a->a, b->b, c->e}
     * {a->b, b->a, c->c}
     * {a->b, b->a, c->e}
     * {a->b, b->b, c->c}
     * {a->b, b->b, c->e}
     * </pre>
     *
     * <p>which imply the following inferences:
     * <pre>
     * a + a + c
     * a + a + e
     * a + b + c
     * a + b + e
     * b + a + c
     * b + a + e
     * b + b + c
     * b + b + e
     * </pre>
     */
    class ExprsItr implements Iterator<Mapping> {
      final int[] columns;
      final BitSet[] columnSets;
      final int[] iterationIdx;
      Mapping nextMapping;
      boolean firstCall;

      ExprsItr(BitSet fields) {
        nextMapping = null;
        columns = new int[fields.cardinality()];
        columnSets = new BitSet[fields.cardinality()];
        iterationIdx = new int[fields.cardinality()];
        for (int j = 0, i = fields.nextSetBit(0); i >= 0; i = fields
            .nextSetBit(i + 1), j++) {
          columns[j] = i;
          columnSets[j] = equivalence.get(i);
          iterationIdx[j] = 0;
        }
        firstCall = true;
      }

      public boolean hasNext() {
        if (firstCall) {
          initializeMapping();
          firstCall = false;
        } else {
          computeNextMapping(iterationIdx.length - 1);
        }
        return nextMapping != null;
      }

      public Mapping next() {
        return nextMapping;
      }

      public void remove() {
        throw new UnsupportedOperationException();
      }

      private void computeNextMapping(int level) {
        int t = columnSets[level].nextSetBit(iterationIdx[level]);
        if (t < 0) {
          if (level == 0) {
            nextMapping = null;
          } else {
            iterationIdx[level] = 0;
            computeNextMapping(level - 1);
          }
        } else {
          nextMapping.set(columns[level], t);
          iterationIdx[level] = t + 1;
        }
      }

      private void initializeMapping() {
        nextMapping = Mappings.create(MappingType.PARTIAL_FUNCTION,
            nSysFields + nFieldsLeft + nFieldsRight,
            nSysFields + nFieldsLeft + nFieldsRight);
        for (int i = 0; i < columnSets.length; i++) {
          BitSet c = columnSets[i];
          int t = c.nextSetBit(iterationIdx[i]);
          if (t < 0) {
            nextMapping = null;
            return;
          }
          nextMapping.set(columns[i], t);
          iterationIdx[i] = t + 1;
        }
      }
    }

    private int pos(RexNode expr) {
      if (expr instanceof RexInputRef) {
        return ((RexInputRef) expr).getIndex();
      }
      return -1;
    }

    private boolean isAlwaysTrue(RexNode predicate) {
      if (predicate instanceof RexCall) {
        RexCall c = (RexCall) predicate;
        if (c.getOperator().getKind() == SqlKind.EQUALS) {
          int lPos = pos(c.getOperands().get(0));
          int rPos = pos(c.getOperands().get(1));
          return lPos != -1 && lPos == rPos;
        }
      }
      return predicate.isAlwaysTrue();
    }
  }
}

// End RelMdPredicates.java
TOP

Related Classes of org.eigenbase.rel.metadata.RelMdPredicates$JoinConditionBasedPredicateInference$ExprsItr

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.