Package org.apache.drill.exec.expr

Source Code of org.apache.drill.exec.expr.ExpressionTreeMaterializer$MaterializeVisitor

/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.drill.exec.expr;

import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import com.google.common.base.Function;
import com.google.common.base.Optional;
import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;

import org.apache.drill.common.expression.BooleanOperator;
import org.apache.drill.common.expression.CastExpression;
import org.apache.drill.common.expression.ConvertExpression;
import org.apache.drill.common.expression.ErrorCollector;
import org.apache.drill.common.expression.ExpressionPosition;
import org.apache.drill.common.expression.FunctionCall;
import org.apache.drill.common.expression.FunctionHolderExpression;
import org.apache.drill.common.expression.IfExpression;
import org.apache.drill.common.expression.IfExpression.IfCondition;
import org.apache.drill.common.expression.LogicalExpression;
import org.apache.drill.common.expression.NullExpression;
import org.apache.drill.common.expression.SchemaPath;
import org.apache.drill.common.expression.TypedNullConstant;
import org.apache.drill.common.expression.ValueExpressions;
import org.apache.drill.common.expression.ValueExpressions.BooleanExpression;
import org.apache.drill.common.expression.ValueExpressions.DoubleExpression;
import org.apache.drill.common.expression.ValueExpressions.FloatExpression;
import org.apache.drill.common.expression.ValueExpressions.LongExpression;
import org.apache.drill.common.expression.ValueExpressions.DateExpression;
import org.apache.drill.common.expression.ValueExpressions.IntervalYearExpression;
import org.apache.drill.common.expression.ValueExpressions.IntervalDayExpression;
import org.apache.drill.common.expression.ValueExpressions.TimeStampExpression;
import org.apache.drill.common.expression.ValueExpressions.TimeExpression;
import org.apache.drill.common.expression.ValueExpressions.Decimal9Expression;
import org.apache.drill.common.expression.ValueExpressions.Decimal18Expression;
import org.apache.drill.common.expression.ValueExpressions.Decimal28Expression;
import org.apache.drill.common.expression.ValueExpressions.Decimal38Expression;
import org.apache.drill.common.expression.ValueExpressions.IntExpression;
import org.apache.drill.common.expression.ValueExpressions.QuotedString;
import org.apache.drill.common.expression.fn.CastFunctions;
import org.apache.drill.common.expression.visitors.AbstractExprVisitor;
import org.apache.drill.common.expression.visitors.ConditionalExprOptimizer;
import org.apache.drill.common.expression.visitors.ExpressionValidator;
import org.apache.drill.common.types.TypeProtos;
import org.apache.drill.common.types.TypeProtos.DataMode;
import org.apache.drill.common.types.TypeProtos.MajorType;
import org.apache.drill.common.types.TypeProtos.MinorType;
import org.apache.drill.common.types.Types;
import org.apache.drill.exec.expr.annotations.FunctionTemplate;
import org.apache.drill.exec.expr.fn.AbstractFuncHolder;
import org.apache.drill.exec.expr.fn.DrillComplexWriterFuncHolder;
import org.apache.drill.exec.expr.fn.DrillFuncHolder;
import org.apache.drill.exec.expr.fn.FunctionImplementationRegistry;
import org.apache.drill.exec.record.TypedFieldId;
import org.apache.drill.exec.record.VectorAccessible;
import org.apache.drill.exec.resolver.DefaultFunctionResolver;
import org.apache.drill.exec.resolver.FunctionResolver;
import org.apache.drill.exec.resolver.FunctionResolverFactory;

import com.google.common.collect.Lists;
import org.apache.drill.exec.resolver.TypeCastRules;

public class ExpressionTreeMaterializer {

  static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(ExpressionTreeMaterializer.class);

  private ExpressionTreeMaterializer() {
  };

  public static LogicalExpression materialize(LogicalExpression expr, VectorAccessible batch, ErrorCollector errorCollector, FunctionImplementationRegistry registry) {
    return ExpressionTreeMaterializer.materialize(expr, batch, errorCollector, registry, false);
  }

  public static LogicalExpression materialize(LogicalExpression expr, VectorAccessible batch, ErrorCollector errorCollector, FunctionImplementationRegistry registry,
      boolean allowComplexWriterExpr) {
    LogicalExpression out =  expr.accept(new MaterializeVisitor(batch, errorCollector, allowComplexWriterExpr), registry);
   
    if (!errorCollector.hasErrors()) {
      out = out.accept(ConditionalExprOptimizer.INSTANCE, null);
    }
   
    if(out instanceof NullExpression){
      return new TypedNullConstant(Types.optional(MinorType.INT));
    }else{
      return out;
    }
  }

  private static class MaterializeVisitor extends AbstractExprVisitor<LogicalExpression, FunctionImplementationRegistry, RuntimeException> {
    private ExpressionValidator validator = new ExpressionValidator();
    private final ErrorCollector errorCollector;
    private final VectorAccessible batch;
    private final boolean allowComplexWriter;

    public MaterializeVisitor(VectorAccessible batch, ErrorCollector errorCollector, boolean allowComplexWriter) {
      this.batch = batch;
      this.errorCollector = errorCollector;
      this.allowComplexWriter = allowComplexWriter;
    }

    private LogicalExpression validateNewExpr(LogicalExpression newExpr) {
      newExpr.accept(validator, errorCollector);
      return newExpr;
    }

    @Override
    public LogicalExpression visitUnknown(LogicalExpression e, FunctionImplementationRegistry registry)
      throws RuntimeException {
      return e;
    }

    @Override
    public LogicalExpression visitFunctionHolderExpression(FunctionHolderExpression holder, FunctionImplementationRegistry value) throws RuntimeException {
      // a function holder is already materialized, no need to rematerialize.  generally this won't be used unless we materialize a partial tree and rematerialize the whole tree.
      return holder;
    }

    @Override
    public LogicalExpression visitBooleanOperator(BooleanOperator op, FunctionImplementationRegistry registry) {
      List<LogicalExpression> args = Lists.newArrayList();
      for (int i = 0; i < op.args.size(); ++i) {
        LogicalExpression newExpr = op.args.get(i).accept(this, registry);
        assert newExpr != null : String.format("Materialization of %s return a null expression.", op.args.get(i));
        args.add(newExpr);
      }

      //replace with a new function call, since its argument could be changed.
      return new BooleanOperator(op.getName(), args, op.getPosition());     
    }

    private LogicalExpression addCastExpression(LogicalExpression fromExpr, MajorType toType, FunctionImplementationRegistry registry) {
      String castFuncName = CastFunctions.getCastFunc(toType.getMinorType());
      List<LogicalExpression> castArgs = Lists.newArrayList();
      castArgs.add(fromExpr)//input_expr

      if (!Types.isFixedWidthType(toType)) {

              /* We are implicitly casting to VARCHAR so we don't have a max length,
               * using an arbitrary value. We trim down the size of the stored bytes
               * to the actual size so this size doesn't really matter.
               */
        castArgs.add(new ValueExpressions.LongExpression(65536, null));
      }
      else if (toType.getMinorType().name().startsWith("DECIMAL")) {
        // Add the scale and precision to the arguments of the implicit cast
        castArgs.add(new ValueExpressions.LongExpression(fromExpr.getMajorType().getPrecision(), null));
        castArgs.add(new ValueExpressions.LongExpression(fromExpr.getMajorType().getScale(), null));
      }

      FunctionCall castCall = new FunctionCall(castFuncName, castArgs, ExpressionPosition.UNKNOWN);
      FunctionResolver resolver = FunctionResolverFactory.getResolver(castCall);
      DrillFuncHolder matchedCastFuncHolder = registry.findDrillFunction(resolver, castCall);

      if (matchedCastFuncHolder == null) {
        logFunctionResolutionError(errorCollector, castCall);
        return NullExpression.INSTANCE;
      }

      return matchedCastFuncHolder.getExpr(castFuncName, castArgs, ExpressionPosition.UNKNOWN);

    }
    @Override
    public LogicalExpression visitFunctionCall(FunctionCall call, FunctionImplementationRegistry registry) {
      List<LogicalExpression> args = Lists.newArrayList();
      for (int i = 0; i < call.args.size(); ++i) {
        LogicalExpression newExpr = call.args.get(i).accept(this, registry);
        assert newExpr != null : String.format("Materialization of %s return a null expression.", call.args.get(i));
        args.add(newExpr);
      }

      //replace with a new function call, since its argument could be changed.
      call = new FunctionCall(call.getName(), args, call.getPosition());

      FunctionResolver resolver = FunctionResolverFactory.getResolver(call);
      DrillFuncHolder matchedFuncHolder = registry.findDrillFunction(resolver, call);

      if (matchedFuncHolder instanceof DrillComplexWriterFuncHolder && ! allowComplexWriter) {
        errorCollector.addGeneralError(call.getPosition(), "Only ProjectRecordBatch could have complex writer function. You are using complex writer function " + call.getName() + " in a non-project operation!");
      }
     
      //new arg lists, possible with implicit cast inserted.
      List<LogicalExpression> argsWithCast = Lists.newArrayList();

      if (matchedFuncHolder!=null) {
        //Compare parm type against arg type. Insert cast on top of arg, whenever necessary.
        for (int i = 0; i < call.args.size(); ++i) {

          LogicalExpression currentArg = call.args.get(i);

          TypeProtos.MajorType parmType = matchedFuncHolder.getParmMajorType(i);

          //Case 1: If  1) the argument is NullExpression
          //            2) the parameter of matchedFuncHolder allows null input, or func's null_handling is NULL_IF_NULL (means null and non-null are exchangable).
          //        then replace NullExpression with a TypedNullConstant
          if (currentArg.equals(NullExpression.INSTANCE) &&
            ( parmType.getMode().equals(TypeProtos.DataMode.OPTIONAL) ||
              matchedFuncHolder.getNullHandling() == FunctionTemplate.NullHandling.NULL_IF_NULL)) {
            argsWithCast.add(new TypedNullConstant(parmType));
          } else if (Types.softEquals(parmType, currentArg.getMajorType(), matchedFuncHolder.getNullHandling() == FunctionTemplate.NullHandling.NULL_IF_NULL) ||
                     matchedFuncHolder.isFieldReader(i)) {
            //Case 2: argument and parameter matches, or parameter is FieldReader.  Do nothing.
            argsWithCast.add(currentArg);
          } else {
            //Case 3: insert cast if param type is different from arg type.
            argsWithCast.add(addCastExpression(call.args.get(i), parmType, registry));
          }
        }

        return matchedFuncHolder.getExpr(call.getName(), argsWithCast, call.getPosition());
      }

      // as no drill func is found, search for a non-Drill function.
      AbstractFuncHolder matchedNonDrillFuncHolder = registry.findNonDrillFunction(call);
      if (matchedNonDrillFuncHolder != null) {
        // Insert implicit cast function holder expressions if required
        List<LogicalExpression> extArgsWithCast = Lists.newArrayList();

        for (int i = 0; i < call.args.size(); ++i) {
          LogicalExpression currentArg = call.args.get(i);
          TypeProtos.MajorType parmType = matchedNonDrillFuncHolder.getParmMajorType(i);

          if (Types.softEquals(parmType, currentArg.getMajorType(), true)) {
            extArgsWithCast.add(currentArg);
          } else {
            // Insert cast if param type is different from arg type.
            extArgsWithCast.add(addCastExpression(call.args.get(i), parmType, registry));
          }
        }

        return matchedNonDrillFuncHolder.getExpr(call.getName(), extArgsWithCast, call.getPosition());
      }

      logFunctionResolutionError(errorCollector, call);
      return NullExpression.INSTANCE;
    }

    private void logFunctionResolutionError(ErrorCollector errorCollector, FunctionCall call) {
      // add error to collector
      StringBuilder sb = new StringBuilder();
      sb.append("Missing function implementation: ");
      sb.append("[");
      sb.append(call.getName());
      sb.append("(");
      boolean first = true;
      for(LogicalExpression e : call.args) {
        TypeProtos.MajorType mt = e.getMajorType();
        if(first){
          first = false;
        } else {
          sb.append(", ");
        }
        sb.append(mt.getMinorType().name());
        sb.append("-");
        sb.append(mt.getMode().name());
      }
      sb.append(")");
      sb.append("]");

      errorCollector.addGeneralError(call.getPosition(), sb.toString());
    }


    @Override
    public LogicalExpression visitIfExpression(IfExpression ifExpr, FunctionImplementationRegistry registry) {
      IfExpression.IfCondition conditions = ifExpr.ifCondition;
      LogicalExpression newElseExpr = ifExpr.elseExpression.accept(this, registry);

      LogicalExpression newCondition = conditions.condition.accept(this, registry);
      LogicalExpression newExpr = conditions.expression.accept(this, registry);
      conditions = new IfExpression.IfCondition(newCondition, newExpr);

      MinorType thenType = conditions.expression.getMajorType().getMinorType();
      MinorType elseType = newElseExpr.getMajorType().getMinorType();

      // Check if we need a cast
      if (thenType != elseType && !(thenType == MinorType.NULL || elseType == MinorType.NULL)) {

        MinorType leastRestrictive = TypeCastRules.getLeastRestrictiveType((Arrays.asList(thenType, elseType)));
        if (leastRestrictive != thenType) {
          // Implicitly cast the then expression
          conditions = new IfExpression.IfCondition(newCondition,
              addCastExpression(conditions.expression, newElseExpr.getMajorType(), registry));
        } else if (leastRestrictive != elseType) {
          // Implicitly cast the else expression
          newElseExpr = addCastExpression(newElseExpr, conditions.expression.getMajorType(), registry);
        } else {
          assert false: "Incorrect least restrictive type computed, leastRestrictive: " +
              leastRestrictive.toString() + " thenType: " + thenType.toString() + " elseType: " + elseType;
        }
      }

      // Resolve NullExpression into TypedNullConstant by visiting all conditions
      // We need to do this because we want to give the correct MajorType to the Null constant
      List<LogicalExpression> allExpressions = Lists.newArrayList();
      allExpressions.add(conditions.expression);
      allExpressions.add(newElseExpr);

      boolean containsNullExpr = Iterables.any(allExpressions, new Predicate<LogicalExpression>() {
        @Override
        public boolean apply(LogicalExpression input) {
          return input instanceof NullExpression;
        }
      });

      if (containsNullExpr) {
        Optional<LogicalExpression> nonNullExpr = Iterables.tryFind(allExpressions,
          new Predicate<LogicalExpression>() {
            @Override
            public boolean apply(LogicalExpression input) {
              return !input.getMajorType().getMinorType().equals(TypeProtos.MinorType.NULL);
            }
          }
        );

        if(nonNullExpr.isPresent()) {
          MajorType type = nonNullExpr.get().getMajorType();
          conditions = new IfExpression.IfCondition(conditions.condition, rewriteNullExpression(conditions.expression, type));

          newElseExpr = rewriteNullExpression(newElseExpr, type);
        }
      }

      // If the type of the IF expression is nullable, apply a convertToNullable*Holder function for "THEN"/"ELSE"
      // expressions whose type is not nullable.
      if (IfExpression.newBuilder().setElse(newElseExpr).setIfCondition(conditions).build().getMajorType().getMode()
          == DataMode.OPTIONAL) {
          IfExpression.IfCondition condition = conditions;
          if (condition.expression.getMajorType().getMode() != DataMode.OPTIONAL) {
            conditions = new IfExpression.IfCondition(condition.condition, getConvertToNullableExpr(ImmutableList.of(condition.expression),
                                                      condition.expression.getMajorType().getMinorType(), registry));
         }

        if (newElseExpr.getMajorType().getMode() != DataMode.OPTIONAL) {
          newElseExpr = getConvertToNullableExpr(ImmutableList.of(newElseExpr),
              newElseExpr.getMajorType().getMinorType(), registry);
        }
      }

      return validateNewExpr(IfExpression.newBuilder().setElse(newElseExpr).setIfCondition(conditions).build());
    }

    private LogicalExpression getConvertToNullableExpr(List<LogicalExpression> args, MinorType minorType,
        FunctionImplementationRegistry registry) {
      String funcName = "convertToNullable" + minorType.toString();
      FunctionCall funcCall = new FunctionCall(funcName, args, ExpressionPosition.UNKNOWN);
      FunctionResolver resolver = FunctionResolverFactory.getResolver(funcCall);

      DrillFuncHolder matchedConvertToNullableFuncHolder = registry.findDrillFunction(resolver, funcCall);

      if (matchedConvertToNullableFuncHolder == null) {
        logFunctionResolutionError(errorCollector, funcCall);
        return NullExpression.INSTANCE;
      }

      return matchedConvertToNullableFuncHolder.getExpr(funcName, args, ExpressionPosition.UNKNOWN);
    }

    private LogicalExpression rewriteNullExpression(LogicalExpression expr, MajorType type) {
      if(expr instanceof NullExpression) {
        return new TypedNullConstant(type);
      } else {
        return expr;
      }
    }

    @Override
    public LogicalExpression visitSchemaPath(SchemaPath path, FunctionImplementationRegistry value) {
//      logger.debug("Visiting schema path {}", path);
      TypedFieldId tfId = batch.getValueVectorId(path);
      if (tfId == null) {
        logger.warn("Unable to find value vector of path {}, returning null instance.", path);
        return NullExpression.INSTANCE;
      } else {
        ValueVectorReadExpression e = new ValueVectorReadExpression(tfId);
        return e;
      }
    }

    @Override
    public LogicalExpression visitIntConstant(IntExpression intExpr, FunctionImplementationRegistry value) {
      return intExpr;
    }

    @Override
    public LogicalExpression visitFloatConstant(FloatExpression fExpr, FunctionImplementationRegistry value) {
      return fExpr;
    }

    @Override
    public LogicalExpression visitLongConstant(LongExpression intExpr, FunctionImplementationRegistry registry) {
      return intExpr;
    }

    @Override
    public LogicalExpression visitDateConstant(DateExpression intExpr, FunctionImplementationRegistry registry) {
      return intExpr;
    }

    @Override
    public LogicalExpression visitTimeConstant(TimeExpression intExpr, FunctionImplementationRegistry registry) {
      return intExpr;
    }

    @Override
    public LogicalExpression visitTimeStampConstant(TimeStampExpression intExpr, FunctionImplementationRegistry registry) {
      return intExpr;
    }

    @Override
    public LogicalExpression visitNullConstant(TypedNullConstant nullConstant, FunctionImplementationRegistry value) throws RuntimeException {
      return nullConstant;
    }

    @Override
    public LogicalExpression visitIntervalYearConstant(IntervalYearExpression intExpr, FunctionImplementationRegistry registry) {
      return intExpr;
    }

    @Override
    public LogicalExpression visitIntervalDayConstant(IntervalDayExpression intExpr, FunctionImplementationRegistry registry) {
      return intExpr;
    }

    @Override
    public LogicalExpression visitDecimal9Constant(Decimal9Expression decExpr, FunctionImplementationRegistry registry) {
      return decExpr;
    }

    @Override
    public LogicalExpression visitDecimal18Constant(Decimal18Expression decExpr, FunctionImplementationRegistry registry) {
      return decExpr;
    }

    @Override
    public LogicalExpression visitDecimal28Constant(Decimal28Expression decExpr, FunctionImplementationRegistry registry) {
      return decExpr;
    }

    @Override
    public LogicalExpression visitDecimal38Constant(Decimal38Expression decExpr, FunctionImplementationRegistry registry) {
      return decExpr;
    }

    @Override
    public LogicalExpression visitDoubleConstant(DoubleExpression dExpr, FunctionImplementationRegistry registry) {
      return dExpr;
    }

    @Override
    public LogicalExpression visitBooleanConstant(BooleanExpression e, FunctionImplementationRegistry registry) {
      return e;
    }

    @Override
    public LogicalExpression visitQuotedStringConstant(QuotedString e, FunctionImplementationRegistry registry) {
      return e;
    }

    @Override
    public LogicalExpression visitConvertExpression(ConvertExpression e, FunctionImplementationRegistry value) {
      String convertFunctionName = e.getConvertFunction() + e.getEncodingType();

      List<LogicalExpression> newArgs = Lists.newArrayList();
      newArgs.add(e.getInput())//input_expr

      FunctionCall fc = new FunctionCall(convertFunctionName, newArgs, e.getPosition());
      return fc.accept(this, value);
    }

    @Override
    public LogicalExpression visitCastExpression(CastExpression e, FunctionImplementationRegistry value){

      // if the cast is pointless, remove it.
      LogicalExpression input = e.getInput().accept(this,  value);

      MajorType newMajor = e.getMajorType();
      MinorType newMinor = input.getMajorType().getMinorType();

      if(castEqual(e.getPosition(), newMajor, input.getMajorType())) return input; // don't do pointless cast.

      if(newMinor == MinorType.LATE){
        // if the type still isn't fully bound, leave as cast expression.
        return new CastExpression(input, e.getMajorType(), e.getPosition());
      } else if (newMinor == MinorType.NULL) {
        // if input is a NULL expression, remove cast expression and return a TypedNullConstant directly. 
        return new TypedNullConstant(Types.optional(e.getMajorType().getMinorType()));
      } else {
        // if the type is fully bound, convert to functioncall and materialze the function.
        MajorType type = e.getMajorType();

        // Get the cast function name from the map
        String castFuncWithType = CastFunctions.getCastFunc(type.getMinorType());

        List<LogicalExpression> newArgs = Lists.newArrayList();
        newArgs.add(e.getInput())//input_expr

        //VarLen type
        if (!Types.isFixedWidthType(type)) {
          newArgs.add(new ValueExpressions.LongExpression(type.getWidth(), null));
        } else if (type.getMinorType().name().startsWith("DECIMAL")) {
            newArgs.add(new ValueExpressions.LongExpression(type.getPrecision(), null));
            newArgs.add(new ValueExpressions.LongExpression(type.getScale(), null));
        }
        FunctionCall fc = new FunctionCall(castFuncWithType, newArgs, e.getPosition());
        return fc.accept(this, value);
      }
    }

    private boolean castEqual(ExpressionPosition pos, MajorType from, MajorType to){
      if(!from.getMinorType().equals(to.getMinorType())) return false;
      switch(from.getMinorType()){
      case FLOAT4:
      case FLOAT8:
      case INT:
      case BIGINT:
      case BIT:
      case TINYINT:
      case SMALLINT:
      case UINT1:
      case UINT2:
      case UINT4:
      case UINT8:
      case TIME:
      case TIMESTAMP:
      case TIMESTAMPTZ:
      case DATE:
      case INTERVAL:
      case INTERVALDAY:
      case INTERVALYEAR:
        // nothing else matters.
        return true;
      case DECIMAL9:
      case DECIMAL18:
      case DECIMAL28DENSE:
      case DECIMAL28SPARSE:
      case DECIMAL38DENSE:
      case DECIMAL38SPARSE:
        if (to.getScale() == from.getScale() && to.getPrecision() == from.getPrecision()) {
          return true;
        }
        return false;

      case FIXED16CHAR:
      case FIXEDBINARY:
      case FIXEDCHAR:
        // width always matters
        this.errorCollector.addGeneralError(pos, "Casting fixed width types are not yet supported..");
        return false;

      case VAR16CHAR:
      case VARBINARY:
      case VARCHAR:
        if(to.getWidth() < from.getWidth() && to.getWidth() > 0){
          this.errorCollector.addGeneralError(pos, "Casting from a longer variable length type to a shorter variable length type is not currently supported.");
          return false;
        }else{
          return true;
        }

      default:
        errorCollector.addGeneralError(pos, String.format("Casting rules are unknown for type %s.", from));
        return false;
      }
    }
  }
}
TOP

Related Classes of org.apache.drill.exec.expr.ExpressionTreeMaterializer$MaterializeVisitor

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.