Package com.nexr.rhive.hive.udf

Source Code of com.nexr.rhive.hive.udf.RUDF

/**
* Copyright 2011 NexR
*  
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*      http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nexr.rhive.hive.udf;

import java.io.ByteArrayOutputStream;
import java.io.PrintStream;
import java.util.HashSet;
import java.util.Set;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.rosuda.REngine.REXP;
import org.rosuda.REngine.REXPDouble;
import org.rosuda.REngine.REXPInteger;
import org.rosuda.REngine.REXPString;
import org.rosuda.REngine.Rserve.RConnection;

import com.nexr.rhive.hive.HiveVariations;
import com.nexr.rhive.util.EnvUtils;

/**
* RUDF
*
*/
@Description(name = "R", value = "_FUNC_(export-name,arg1,arg2,...,return-type) - Returns the result of R scalar function")
public class RUDF extends GenericUDF {
  private static Configuration conf = new Configuration();
  private static Set<String> funcSet = new HashSet<String>();
  private static String NULL = "";
  private static int STRING_TYPE = 1;
  private static int NUMBER_TYPE = 0;
  private static RConnection rconnection;

  private transient Converter[] converters;
  private transient int[] types;

  @Override
  public Object evaluate(DeferredObject[] arguments) throws HiveException {

    String function_name = converters[0].convert(arguments[0].get()).toString();

    loadRObjects(function_name);

    StringBuffer argument = new StringBuffer();

    for (int i = 1; i < (arguments.length - 1); i++) {

      Object value = converters[i].convert(arguments[i].get());

      if (value == null) {
        argument.append("NULL");
      } else {
        if (types[i] == STRING_TYPE) {
          argument.append("\"" + converters[i].convert(arguments[i].get()) + "\"");
        } else {
          argument.append(converters[i].convert(arguments[i].get()));
        }
      }

      if (i < (arguments.length - 2))
        argument.append(",");
    }

    REXP rdata = null;
    try {
      rdata = getConnection().eval(function_name + "(" + argument.toString() + ")");
    } catch (Exception e) {
      ByteArrayOutputStream output = new ByteArrayOutputStream();
      e.printStackTrace(new PrintStream(output));
      throw new HiveException(new String(output.toByteArray())
          + " -- fail to eval : " + function_name + "("
          + argument.toString() + ")");
    }

    if (rdata != null) {
      try {
        if (rdata instanceof REXPInteger) {
          return new IntWritable(rdata.asInteger());
        } else if (rdata instanceof REXPString) {
          return new Text(rdata.asString());
        } else if (rdata instanceof REXPDouble) {
          return new DoubleWritable(rdata.asDouble());
        } else {
          throw new HiveException(
              "only support integer, string and double");
        }
      } catch (Exception e) {
        ByteArrayOutputStream output = new ByteArrayOutputStream();
        e.printStackTrace(new PrintStream(output));
        throw new HiveException(new String(output.toByteArray()));
      }
    }

    return null;
  }

  @Override
  public String getDisplayString(String[] children) {
    StringBuilder sb = new StringBuilder();
    sb.append("Rfunction(");
    for (int i = 0; i < children.length; i++) {
      sb.append(children[i]);
      if (i + 1 != children.length) {
        sb.append(",");
      }
    }
    sb.append(")");
    return sb.toString();
  }

  @Override
  public ObjectInspector initialize(ObjectInspector[] arguments)
      throws UDFArgumentException {

    GenericUDFUtils.ReturnObjectInspectorResolver returnOIResolver;
    returnOIResolver = new GenericUDFUtils.ReturnObjectInspectorResolver(
        true);

    for (int i = 0; i < arguments.length; i++) {
      if (!returnOIResolver.update(arguments[i])) {
        throw new UDFArgumentTypeException(i, "Argument type \""
            + arguments[i].getTypeName()
            + "\" is different from preceding arguments. "
            + "Previous type was \""
            + arguments[i - 1].getTypeName() + "\"");
      }
    }

    converters = new Converter[arguments.length];
    types = new int[arguments.length];

    ObjectInspector returnOI = returnOIResolver.get();
    if (returnOI == null) {
      returnOI = PrimitiveObjectInspectorFactory
          .getPrimitiveJavaObjectInspector(PrimitiveCategory.STRING);
    }
    for (int i = 0; i < arguments.length; i++) {
      converters[i] = ObjectInspectorConverters.getConverter(
          arguments[i], returnOI);
      if (arguments[i].getCategory() == Category.PRIMITIVE
          && ((PrimitiveObjectInspector) arguments[i])
              .getPrimitiveCategory() == PrimitiveCategory.STRING)
        types[i] = STRING_TYPE;
      else
        types[i] = NUMBER_TYPE;
    }

    String typeName = arguments[arguments.length - 1].getTypeName();

    try {
      if (typeName.equals(HiveVariations.getFieldValue(HiveVariations.serdeConstants, "INT_TYPE_NAME"))) {
        return PrimitiveObjectInspectorFactory.writableIntObjectInspector;
      } else if (typeName.equals(HiveVariations.getFieldValue(HiveVariations.serdeConstants, "DOUBLE_TYPE_NAME"))) {
        return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
      } else if (typeName.equals(HiveVariations.getFieldValue(HiveVariations.serdeConstants, "STRING_TYPE_NAME"))) {
        return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
      } else
        throw new IllegalArgumentException("can't support this type : " + typeName);
    } catch (Exception e) {
      throw new UDFArgumentException(e);
    }
  }

  private void loadExportedRScript(String export_name) throws HiveException {
    if (!funcSet.contains(export_name)) {

      try {
        REXP rhive_data = getConnection().eval(
            "Sys.getenv('RHIVE_DATA')");
        String srhive_data = null;

        if (rhive_data != null) {
          srhive_data = rhive_data.asString();
        }

        if (srhive_data == null || srhive_data == ""
            || srhive_data.length() == 0) {

          getConnection().eval(
              "load(file=paste('/tmp','/" + export_name + ".Rdata',sep=''))");
        } else {

          getConnection().eval(
              "load(file=paste(Sys.getenv('RHIVE_DATA'),'/" + export_name + ".Rdata',sep=''))");
        }

      } catch (Exception e) {

        ByteArrayOutputStream output = new ByteArrayOutputStream();
        e.printStackTrace(new PrintStream(output));
        throw new HiveException(new String(output.toByteArray()));
      }

      funcSet.add(export_name);
    }
  }

  private void loadRObjects(String name) throws HiveException {
    if (!funcSet.contains(name)) {
      try {
        FileSystem fs = FileSystem.get(conf);
       
        boolean srcDel = false;
        Path src = UDFUtils.getPath(name);
        Path dst = getLocalPath(name);
        fs.copyToLocalFile(srcDel, src, dst);
       
        String dataFilePath = dst.toString();
        getConnection().eval(String.format("load(file=\"%s\")", dataFilePath));

      } catch (Exception e) {
        throw new HiveException(e);
      }
       
      funcSet.add(name);
    }
  }

  private Path getLocalPath(String name) {
    String tempDirectory = EnvUtils.getTempDirectory();
    return new Path(tempDirectory, UDFUtils.getFileName(name));
  }
 
  private RConnection getConnection() throws UDFArgumentException {
    if (rconnection == null || !rconnection.isConnected()) {
      try {
        rconnection = new RConnection("127.0.0.1");
      } catch (Exception e) {
        ByteArrayOutputStream output = new ByteArrayOutputStream();
        e.printStackTrace(new PrintStream(output));
        throw new UDFArgumentException(new String(output.toByteArray()));
      }
    }

    return rconnection;
  }
}
TOP

Related Classes of com.nexr.rhive.hive.udf.RUDF

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.