Package com.twitter.elephantbird.util

Source Code of com.twitter.elephantbird.util.Protobufs

package com.twitter.elephantbird.util;

import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Map;

import com.google.common.base.Function;
import com.google.common.base.Predicates;
import com.google.common.collect.Lists;
import com.google.protobuf.DescriptorProtos.DescriptorProto;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.DescriptorValidationException;
import com.google.protobuf.Descriptors.EnumValueDescriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.FieldDescriptor.Type;
import com.google.protobuf.Message.Builder;
import com.google.protobuf.ByteString;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.CodedOutputStream;
import com.google.protobuf.DescriptorProtos;
import com.google.protobuf.Descriptors;
import com.google.protobuf.DynamicMessage;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.google.protobuf.ProtocolMessageEnum;
import com.google.protobuf.UninitializedMessageException;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.Text;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Protobufs {
  private static final Logger LOG = LoggerFactory.getLogger(Protobufs.class);

  public static final byte[] KNOWN_GOOD_POSITION_MARKER = new byte[] { 0x29, (byte)0xd8, (byte)0xd5, 0x06, 0x58,
                                                                         (byte)0xcd, 0x4c, 0x29, (byte)0xb2,
                                                                         (byte)0xbc, 0x57, (byte)0x99, 0x21, 0x71,
                                                                         (byte)0xbd, (byte)0xff };

  public static final byte NEWLINE_UTF8_BYTE = '\n';
  public static final byte[] NEWLINE_UTF8_BYTES = new byte[]{NEWLINE_UTF8_BYTE};

  public static final String IGNORE_KEY = "IGNORE";

  private static final String CLASS_CONF_PREFIX = "elephantbird.protobuf.class.for.";

  /**
   * Returns Protobuf class. The class name could be either normal name or
   * its canonical name (canonical name does not have a $ for inner classes).
   */
  public static Class<? extends Message> getProtobufClass(String protoClassName) {
    return getProtobufClass(null, protoClassName);
  }

  private static Class<? extends Message> getProtobufClass(Configuration conf, String protoClassName) {
    // Try both normal name and canonical name of the class.
    Class<?> protoClass = null;
    try {
      if (conf == null) {
        protoClass = Class.forName(protoClassName);
      } else {
        protoClass = conf.getClassByName(protoClassName);
      }
    } catch (ClassNotFoundException e) {
      // the class name might be canonical name.
      protoClass = getInnerProtobufClass(protoClassName);
    }

    return protoClass.asSubclass(Message.class);
  }

  /**
   * For a configured protoClass, should the message be dynamic or is it a pre-generated Message class? If protoClass is
   * null or set to DynamicMessage.class, then the configurer intends for a dynamically generated protobuf to be used.
   */
  public static boolean useDynamicProtoMessage(Class<?> protoClass) {
    return protoClass == null || protoClass.getCanonicalName().equals(DynamicMessage.class.getCanonicalName());
  }

  public static Class<? extends Message> getInnerProtobufClass(String canonicalClassName) {
    // is an inner class and is not visible from the outside.  We have to instantiate
    String parentClass = canonicalClassName.substring(0, canonicalClassName.lastIndexOf("."));
    String subclass = canonicalClassName.substring(canonicalClassName.lastIndexOf(".") + 1);
    return getInnerClass(parentClass, subclass);
  }

  public static Class<? extends Message> getInnerClass(String canonicalParentName, String subclassName) {
    try {
      Class<?> outerClass = Class.forName(canonicalParentName);
      for (Class<?> innerClass: outerClass.getDeclaredClasses()) {
        if (innerClass.getSimpleName().equals(subclassName)) {
          return innerClass.asSubclass(Message.class);
        }
      }
    } catch (ClassNotFoundException e) {
      LOG.error("Could not find class with parent " + canonicalParentName + " and inner class " + subclassName, e);
      throw new IllegalArgumentException(e);
    }
    return null;
  }

  public static Message.Builder getMessageBuilder(Class<? extends Message> protoClass) {
    try {
      Method newBuilder = protoClass.getMethod("newBuilder", new Class[] {});
      return (Message.Builder) newBuilder.invoke(null, new Object[] {});
    } catch (NoSuchMethodException e) {
      LOG.error("Could not find method newBuilder in class " + protoClass, e);
      throw new IllegalArgumentException(e);
    } catch (IllegalAccessException e) {
      LOG.error("Could not access method newBuilder in class " + protoClass, e);
      throw new IllegalArgumentException(e);
    } catch (InvocationTargetException e) {
      LOG.error("Error invoking method newBuilder in class " + protoClass, e);
      throw new IllegalArgumentException(e);
    }
  }

  public static Descriptor getMessageDescriptor(Class<? extends Message> protoClass) {
    try {
      Method getDescriptor = protoClass.getMethod("getDescriptor", new Class[] {});
      return (Descriptor)getDescriptor.invoke(null, new Object[] {});
    } catch (NoSuchMethodException e) {
      LOG.error("Could not find method getDescriptor in class " + protoClass, e);
      throw new IllegalArgumentException(e);
    } catch (IllegalAccessException e) {
      LOG.error("Could not access method getDescriptor in class " + protoClass, e);
      throw new IllegalArgumentException(e);
    } catch (InvocationTargetException e) {
      LOG.error("Error invoking method getDescriptor in class " + protoClass, e);
    }

    return null;
  }

  public static List<String> getMessageFieldNames(Class<? extends Message> protoClass) {
    return Lists.transform(getMessageDescriptor(protoClass).getFields(), new Function<FieldDescriptor, String>() {
      public String apply(FieldDescriptor f) {
        return f.getName();
      }
    });
  }

  /**
   * Returns a Message {@link Descriptor} for a dynamically generated
   * DescriptorProto.
   *
   * @param descProto
   * @throws DescriptorValidationException
   */
  public static Descriptor makeMessageDescriptor(DescriptorProto descProto)
                                      throws DescriptorValidationException {

    DescriptorProtos.FileDescriptorProto fileDescP =
      DescriptorProtos.FileDescriptorProto.newBuilder().addMessageType(descProto).build();

    Descriptors.FileDescriptor[] fileDescs = new Descriptors.FileDescriptor[0];
    Descriptors.FileDescriptor dynamicDescriptor = Descriptors.FileDescriptor.buildFrom(fileDescP, fileDescs);
    return dynamicDescriptor.findMessageTypeByName(descProto.getName());
  }

   // Translates from protobuf field names to other field names, stripping out any that have value IGNORE.
   // i.e. if you have a protobuf with field names a, b, c and your fieldNameTranslations map looks like
   // { "a" => "A", "c" => "IGNORE" }, then you'll end up with A, b
  public static List<String> getMessageFieldNames(Class<? extends Message> protoClass, Map<String, String> fieldNameTranslations) {
    return getMessageFieldNames(getMessageDescriptor(protoClass), fieldNameTranslations);
  }

   public static List<String> getMessageFieldNames(Descriptor descriptor, Map<String, String> fieldNameTranslations) {
     Function<FieldDescriptor, String> fieldTransformer = getFieldTransformerFor(fieldNameTranslations);
     return ListHelper.filter(Lists.transform(descriptor.getFields(), fieldTransformer), Predicates.<String>notNull());
   }

   public static Function<FieldDescriptor, String> getFieldTransformerFor(final Map<String, String> fieldNameTranslations) {
     return new Function<FieldDescriptor, String>() {
       public String apply(FieldDescriptor f) {
         String name = f.getName();
         if (fieldNameTranslations != null && fieldNameTranslations.containsKey(name)) {
           name = fieldNameTranslations.get(name);
         }
         if (IGNORE_KEY.equals(name)) {
           name = null;
         }
         return name;
       }
     };
   }


  @SuppressWarnings("unchecked")
  public static <M extends Message> M parseFrom(Class<M> protoClass, byte[] messageBytes) {
    try {
      Method parseFrom = protoClass.getMethod("parseFrom", new Class[] { byte[].class });
      return (M)parseFrom.invoke(null, new Object[] { messageBytes });
    } catch (NoSuchMethodException e) {
      LOG.error("Could not find method parseFrom in class " + protoClass, e);
      throw new IllegalArgumentException(e);
    } catch (IllegalAccessException e) {
      LOG.error("Could not access method parseFrom in class " + protoClass, e);
      throw new IllegalArgumentException(e);
    } catch (InvocationTargetException e) {
      LOG.error("Error invoking method parseFrom in class " + protoClass, e);
    }

    return null;
  }

  public static DynamicMessage parseDynamicFrom(Class<? extends Message> protoClass, byte[] messageBytes) {
    try {
      return DynamicMessage.parseFrom(getMessageDescriptor(protoClass), messageBytes);
    } catch (InvalidProtocolBufferException e) {
      LOG.error("Protocol buffer parsing error in parseDynamicFrom", e);
    } catch(UninitializedMessageException ume) {
      LOG.error("Uninitialized Message error in parseDynamicFrom " + protoClass.getName(), ume);
    }

    return null;
  }

  public static Message instantiateFromClassName(String canonicalClassName) {
    Class<? extends Message> protoClass = getInnerProtobufClass(canonicalClassName);
    Message.Builder builder = getMessageBuilder(protoClass);
    return builder.build();
  }

  public static <M extends Message> M instantiateFromClass(Class<M> protoClass) {
    return Protobufs.parseFrom(protoClass, new byte[] {});
  }

  public static Message addField(Message m, String name, Object value) {
    Message.Builder builder = m.toBuilder();
    setFieldByName(builder, name, value);
    return builder.build();
  }

  public static void setFieldByName(Message.Builder builder, String name, Object value) {
    FieldDescriptor fieldDescriptor = builder.getDescriptorForType().findFieldByName(name);
    if (value == null) {
      builder.clearField(fieldDescriptor);
    } else {
      builder.setField(fieldDescriptor, value);
    }
  }

  public static boolean isFieldSetByName(Message message, String name) {
    return message.hasField(message.getDescriptorForType().findFieldByName(name));
  }

  public static Object getFieldByName(Message message, String name) {
    return message.getField(message.getDescriptorForType().findFieldByName(name));
  }

  public static boolean hasFieldByName(Message message, String name) {
    return message.getDescriptorForType().findFieldByName(name) != null;
  }

  public static Type getTypeByName(Message message, String name) {
    return message.getDescriptorForType().findFieldByName(name).getType();
  }

  /**
   * Returns typeref for a Protobuf class
   */
  public static<M extends Message> TypeRef<M> getTypeRef(String protoClassName) {
    return new TypeRef<M>(getProtobufClass(protoClassName)){};
  }

  /**
   * Returns TypeRef for the Protobuf class that was set using setClass(jobConf);
   */
  public static<M extends Message> TypeRef<M> getTypeRef(Configuration jobConf, Class<?> genericClass) {
    String className = jobConf.get(CLASS_CONF_PREFIX + genericClass.getName());
    if (className == null) {
      throw new RuntimeException(CLASS_CONF_PREFIX + genericClass.getName() + " is not set");
    }
    return new TypeRef<M>(getProtobufClass(jobConf, className)){};
  }

  public static void setClassConf(Configuration jobConf, Class<?> genericClass,
      Class<? extends Message> protoClass) {
    HadoopUtils.setClassConf(jobConf,
           CLASS_CONF_PREFIX + genericClass.getName(),
           protoClass);
  }

  public static Text toText(Message message) {
    return new Text(message.toByteArray());
  }

  public static <B extends Message.Builder> B mergeFromText(B builder, Text bytes)
                                    throws InvalidProtocolBufferException {
    @SuppressWarnings("unchecked")
    B b = (B) builder.mergeFrom(bytes.getBytes());
    return b;
  }

  /**
   * Serializes a single field. All the native fields are serialized using
   * "NoTag" methods in {@link CodedOutputStream}
   * e.g. <code>writeInt32NoTag()</code>. The field index is not written.
   *
   * @param output
   * @param fd
   * @param value
   * @throws IOException
   */
  public static void writeFieldNoTag(CodedOutputStream   output,
                                     FieldDescriptor     fd,
                                     Object              value)
                                     throws IOException {
    if (value == null) {
      return;
    }

    if (fd.isRepeated()) {
      @SuppressWarnings("unchecked")
      List<Object> values = (List<Object>) value;
      for(Object obj : values) {
        writeSingleFieldNoTag(output, fd, obj);
      }
    } else {
      writeSingleFieldNoTag(output, fd, value);
    }
  }

  private static void writeSingleFieldNoTag(CodedOutputStream   output,
                                            FieldDescriptor     fd,
                                            Object              value)
                                            throws IOException {
    switch (fd.getType()) {
    case DOUBLE:
      output.writeDoubleNoTag((Double) value);    break;
    case FLOAT:
      output.writeFloatNoTag((Float) value);      break;
    case INT64:
    case UINT64:
      output.writeInt64NoTag((Long) value);       break;
    case INT32:
      output.writeInt32NoTag((Integer) value);    break;
    case FIXED64:
      output.writeFixed64NoTag((Long) value);     break;
    case FIXED32:
      output.writeFixed32NoTag((Integer) value)break;
    case BOOL:
      output.writeBoolNoTag((Boolean) value);     break;
    case STRING:
      output.writeStringNoTag((String) value);    break;
    case GROUP:
    case MESSAGE:
      output.writeMessageNoTag((Message) value)break;
    case BYTES:
      output.writeBytesNoTag((ByteString) value); break;
    case UINT32:
      output.writeUInt32NoTag((Integer) value);   break;
    case ENUM:
      output.writeEnumNoTag(((ProtocolMessageEnum) value).getNumber()); break;
    case SFIXED32:
      output.writeSFixed32NoTag((Integer) value); break;
    case SFIXED64:
      output.writeSFixed64NoTag((Long) value);    break;
    case SINT32:
      output.writeSInt32NoTag((Integer) value);   break;
    case SINT64:
      output.writeSInt64NoTag((Integer) value);   break;

    default:
      throw new IllegalArgumentException("Unknown type " + fd.getType()
                                         + " for " + fd.getFullName());
    }
  }

  /**
   *  Wrapper around {@link #readFieldNoTag(CodedInputStream, FieldDescriptor, Builder)}. <br>
   *  same as <br> <code>
   *  builder.setField(fd, readFieldNoTag(input, fd, builder));  </code>
   */
  public static void setFieldValue(CodedInputStream input,
                              FieldDescriptor  fd,
                              Builder          builder)
                              throws IOException {
    builder.setField(fd, readFieldNoTag(input, fd, builder));
  }

  /**
   * Deserializer for protobuf objects written with
   * {@link #writeFieldNoTag(CodedOutputStream, FieldDescriptor, Object)}.
   *
   * @param input
   * @param fd
   * @param enclosingBuilder required to create a builder when the field
   *                         is a message
   * @throws IOException
   */
  public static Object readFieldNoTag(CodedInputStream input,
                                      FieldDescriptor  fd,
                                      Builder          enclosingBuilder)
                                      throws IOException {
    if (!fd.isRepeated()) {
      return readSingleFieldNoTag(input, fd, enclosingBuilder);
    }

    // repeated field

    List<Object> values = Lists.newArrayList();
    while (!input.isAtEnd()) {
      values.add(readSingleFieldNoTag(input, fd, enclosingBuilder));
    }
    return values;
  }

  private static Object readSingleFieldNoTag(CodedInputStream input,
                                             FieldDescriptor  fd,
                                             Builder          enclosingBuilder)
                                             throws IOException {
    switch (fd.getType()) {
    case DOUBLE:
      return input.readDouble();
    case FLOAT:
      return input.readFloat();
    case INT64:
    case UINT64:
      return input.readInt64();
    case INT32:
      return input.readInt32();
    case FIXED64:
      return input.readFixed64();
    case FIXED32:
      return input.readFixed32();
    case BOOL:
      return input.readBool();
    case STRING:
      return input.readString();
    case GROUP:
    case MESSAGE:
      Builder fieldBuilder = enclosingBuilder.newBuilderForField(fd);
      input.readMessage(fieldBuilder, null);
      return fieldBuilder.build();
    case BYTES:
      return input.readBytes();
    case UINT32:
      return input.readUInt32();
    case ENUM:
      EnumValueDescriptor eVal = fd.getEnumType().findValueByNumber(input.readEnum());
      // ideally if a given enum does not exist, we should search
      // unknown fields. but we don't have access to that here. return default
      return eVal != null ? eVal : fd.getDefaultValue();
    case SFIXED32:
      return input.readSFixed32();
    case SFIXED64:
      return input.readSFixed64();
    case SINT32:
      return input.readSInt32();
    case SINT64:
      return input.readSInt64();

    default:
      throw new IllegalArgumentException("Unknown type " + fd.getType()
          + " for " + fd.getFullName());
    }
  }

}
TOP

Related Classes of com.twitter.elephantbird.util.Protobufs

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.