Package org.infinispan.protostream.impl

Source Code of org.infinispan.protostream.impl.ProtoStreamReaderImpl

package org.infinispan.protostream.impl;

import com.google.protobuf.ByteString;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.Descriptors;
import com.google.protobuf.MessageLite;
import com.google.protobuf.Parser;
import com.google.protobuf.ProtocolMessageEnum;
import org.infinispan.protostream.BaseMarshaller;
import org.infinispan.protostream.EnumMarshaller;
import org.infinispan.protostream.Message;
import org.infinispan.protostream.MessageMarshaller;
import org.infinispan.protostream.RawProtobufMarshaller;
import org.infinispan.protostream.SerializationContext;

import java.io.IOException;
import java.lang.reflect.Array;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collection;
import java.util.EnumSet;
import java.util.List;

/**
* @author anistor@redhat.com
*/
public final class ProtoStreamReaderImpl implements MessageMarshaller.ProtoStreamReader {

   private final SerializationContext ctx;

   private static final EnumSet<Descriptors.FieldDescriptor.Type> primitiveTypes = EnumSet.of(
         Descriptors.FieldDescriptor.Type.DOUBLE,
         Descriptors.FieldDescriptor.Type.FLOAT,
         Descriptors.FieldDescriptor.Type.INT64,
         Descriptors.FieldDescriptor.Type.UINT64,
         Descriptors.FieldDescriptor.Type.INT32,
         Descriptors.FieldDescriptor.Type.FIXED64,
         Descriptors.FieldDescriptor.Type.FIXED32,
         Descriptors.FieldDescriptor.Type.BOOL,
         Descriptors.FieldDescriptor.Type.STRING,
         Descriptors.FieldDescriptor.Type.BYTES,
         Descriptors.FieldDescriptor.Type.UINT32,
         Descriptors.FieldDescriptor.Type.SFIXED32,
         Descriptors.FieldDescriptor.Type.SFIXED64,
         Descriptors.FieldDescriptor.Type.SINT32,
         Descriptors.FieldDescriptor.Type.SINT64
   );

   private ReadMessageContext messageContext;

   public ProtoStreamReaderImpl(SerializationContext ctx) {
      this.ctx = ctx;
   }

   public <A> A read(CodedInputStream in, Class<A> clazz) throws IOException {
      if (MessageLite.class.isAssignableFrom(clazz)) {
         try {
            Parser<A> parser = (Parser) clazz.getDeclaredField("PARSER").get(null);
            return parser.parseFrom(messageContext.in);
         } catch (NoSuchFieldException e) {
            throw new IOException(e);
         } catch (IllegalAccessException e) {
            throw new IOException(e);
         }
      }

      BaseMarshaller<A> marshaller = ctx.getMarshaller(clazz);
      Descriptors.Descriptor messageDescriptor = ctx.getMessageDescriptor(marshaller.getTypeName());
      resetContext();
      enterContext(null, messageDescriptor, in);
      A a = marshaller instanceof MessageMarshaller ? ((MessageMarshaller<A>) marshaller).readFrom(this) : ((RawProtobufMarshaller<A>) marshaller).readFrom(ctx, in);
      exitContext();
      return a;
   }

   public <A> A read(CodedInputStream in, MessageMarshaller<A> marshaller) throws IOException {
      Descriptors.Descriptor messageDescriptor = ctx.getMessageDescriptor(marshaller.getTypeName());
      resetContext();
      enterContext(null, messageDescriptor, in);
      A a = marshaller.readFrom(this);
      exitContext();
      return a;
   }

   private void resetContext() {
      messageContext = null;
   }

   private void enterContext(String fieldName, Descriptors.Descriptor messageDescriptor, CodedInputStream in) {
      messageContext = new ReadMessageContext(messageContext, fieldName, messageDescriptor, in);
   }

   private void exitContext() {
      //todo on context exit need to validate that all required fields were seen in the stream (even if not read from api)

      //todo check now that we have seen in the stream all required fields (ie. the app that wrote the message did not break the protocol)

      // check that all required fields were read
      for (Descriptors.FieldDescriptor fd : messageContext.getFieldDescriptors().values()) {
         if (fd.isRequired() && !messageContext.getSeenFields().contains(fd.getNumber())) {
            throw new IllegalStateException("Field " + fd.getName() + " is declared as required but was never read by the MessageMarshaller");
         }
      }
      if (messageContext.getFieldDescriptors().size() != messageContext.getSeenFields().size()) {
         //todo log that not all declared fields were processed by the marshaller
      }
      messageContext = messageContext.getParentContext();
   }

   private Object readPrimitive(String fieldName, Descriptors.FieldDescriptor.JavaType javaType) throws IOException {
      Descriptors.FieldDescriptor fd = messageContext.getFieldByName(fieldName);
      if (fd.getType() == Descriptors.FieldDescriptor.Type.ENUM
            || fd.getType() == Descriptors.FieldDescriptor.Type.GROUP
            || fd.getType() == Descriptors.FieldDescriptor.Type.MESSAGE) {
         throw new IllegalArgumentException("Declared field type is not a primitive : " + fd.getName());
      }
      if (fd.getJavaType() != javaType) {
         throw new IllegalArgumentException("Declared field type is not of the expected type : " + fd.getName());
      }
      checkNonRepeatedField(fd);
      int expectedTag = WireFormat.makeTag(fd.getNumber(), fd.getLiteType().getWireType());
      Object o = messageContext.unknownFieldSet.consumeTag(expectedTag);
      if (o != null) {
         if (javaType == Descriptors.FieldDescriptor.JavaType.STRING) {
            o = ((ByteString) o).toStringUtf8();
         }
         return o;
      }

      while (true) {
         int tag = messageContext.in.readTag();
         if (tag == 0) {
            break;
         }
         if (tag == expectedTag) {
            switch (fd.getType()) {
               case DOUBLE:
                  return messageContext.in.readDouble();
               case FLOAT:
                  return messageContext.in.readFloat();
               case BOOL:
                  return messageContext.in.readBool();
               case STRING:
                  return messageContext.in.readString();
               case BYTES:
                  return messageContext.in.readBytes().toByteArray();
               case INT32:
                  return messageContext.in.readInt32();
               case SFIXED32:
                  return messageContext.in.readSFixed32();
               case FIXED32:
                  return messageContext.in.readFixed32();
               case UINT32:
                  return messageContext.in.readUInt32();
               case SINT32:
                  return messageContext.in.readSInt32();
               case INT64:
                  return messageContext.in.readInt64();
               case UINT64:
                  return messageContext.in.readUInt64();
               case FIXED64:
                  return messageContext.in.readFixed64();
               case SFIXED64:
                  return messageContext.in.readSFixed64();
               case SINT64:
                  return messageContext.in.readSInt64();
               default:
                  throw new IOException("Unexpected field type : " + fd.getType());
            }
         }
         messageContext.unknownFieldSet.mergeFieldFrom(tag, messageContext.in);
      }

      if (fd.isRequired()) {
         throw new IllegalStateException("Field " + fd.getName() + " is required but is not present in the stream");
      }

      if (fd.hasDefaultValue()) {
         return fd.getDefaultValue();
      }

      return null;
   }

   @Override
   public Integer readInt(String fieldName) throws IOException {
      Object o = readPrimitive(fieldName, Descriptors.FieldDescriptor.JavaType.INT);
      if (o == null) return null;
      return o instanceof Integer ? (Integer) o : ((Number) o).intValue()//todo [anistor] hack!
   }

   @Override
   public Long readLong(String fieldName) throws IOException {
      return (Long) readPrimitive(fieldName, Descriptors.FieldDescriptor.JavaType.LONG);
   }

   @Override
   public Float readFloat(String fieldName) throws IOException {
      return (Float) readPrimitive(fieldName, Descriptors.FieldDescriptor.JavaType.FLOAT);
   }

   @Override
   public Double readDouble(String fieldName) throws IOException {
      return (Double) readPrimitive(fieldName, Descriptors.FieldDescriptor.JavaType.DOUBLE);
   }

   @Override
   public Boolean readBoolean(String fieldName) throws IOException {
      return (Boolean) readPrimitive(fieldName, Descriptors.FieldDescriptor.JavaType.BOOLEAN);
   }

   @Override
   public String readString(String fieldName) throws IOException {
      return (String) readPrimitive(fieldName, Descriptors.FieldDescriptor.JavaType.STRING);
   }

   @Override
   public byte[] readBytes(String fieldName) throws IOException {
      return (byte[]) readPrimitive(fieldName, Descriptors.FieldDescriptor.JavaType.BYTE_STRING);
   }

   @Override
   public <A> A readObject(String fieldName, Class<? extends A> clazz) throws IOException {
      Descriptors.FieldDescriptor fd = messageContext.getFieldByName(fieldName);
      checkNonRepeatedField(fd);

      if (fd.getType() == Descriptors.FieldDescriptor.Type.ENUM) {
         if (ProtocolMessageEnum.class.isAssignableFrom(clazz)) {
            return (A) readProtocolMessageEnum(fd, (Class<ProtocolMessageEnum>) clazz);
         } else {
            return (A) readEnum(fd, (Class<Enum>) clazz);
         }
      }

      //todo validate type is compatible with readObject
      int expectedTag = WireFormat.makeTag(fd.getNumber(), fd.getLiteType().getWireType());
      Object o = messageContext.unknownFieldSet.consumeTag(expectedTag);
      if (o != null) {
         ByteString byteString = (ByteString) o;
         CodedInputStream codedInputStream = byteString.newCodedInput();
         return readObject(fd, clazz, codedInputStream, byteString.size());
      }

      while (true) {
         int tag = messageContext.in.readTag();
         if (tag == 0) {
            break;
         }
         if (tag == expectedTag) {
            return readObject(fd, clazz, messageContext.in, -1);
         }
         messageContext.unknownFieldSet.mergeFieldFrom(tag, messageContext.in);
      }

      return null;
   }

   private <A extends ProtocolMessageEnum> A readProtocolMessageEnum(Descriptors.FieldDescriptor fd, Class<A> clazz) throws IOException {
      assert fd.getLiteType().getWireType() == WireFormat.WIRETYPE_VARINT;

      int expectedTag = WireFormat.makeTag(fd.getNumber(), WireFormat.WIRETYPE_VARINT);
      int enumValue;
      Object o = messageContext.unknownFieldSet.consumeTag(expectedTag);
      if (o != null) {
         enumValue = ((Long) o).intValue();      //todo why is this a Long and not an Integer?
      } else {
         while (true) {
            int tag = messageContext.in.readTag();
            if (tag == 0) {
               return null;
            }
            if (tag == expectedTag) {
               enumValue = messageContext.in.readEnum();
               break;
            }
            messageContext.unknownFieldSet.mergeFieldFrom(tag, messageContext.in);
         }
      }

      A decoded = decodeProtocolMessageEnum(clazz, enumValue);

      // the enum value was not recognized by the decoder so rather than discarding it we add it to the unknown
      if (decoded == null) {
         messageContext.unknownFieldSet.mergeVarintField(expectedTag, enumValue);
      }

      return decoded;
   }

   private <A extends ProtocolMessageEnum> A decodeProtocolMessageEnum(Class<A> clazz, int enumValue) throws IOException {
      try {
         Method valueOf = clazz.getDeclaredMethod("valueOf", int.class);
         return (A) valueOf.invoke(null, enumValue);
      } catch (NoSuchMethodException e) {
         throw new IOException(e);
      } catch (IllegalAccessException e) {
         throw new IOException(e);
      } catch (InvocationTargetException e) {
         throw new IOException(e);
      }
   }

   private <A extends Enum<A>> A readEnum(Descriptors.FieldDescriptor fd, Class<A> clazz) throws IOException {
      assert fd.getLiteType().getWireType() == WireFormat.WIRETYPE_VARINT;

      int expectedTag = WireFormat.makeTag(fd.getNumber(), WireFormat.WIRETYPE_VARINT);
      int enumValue;
      Object o = messageContext.unknownFieldSet.consumeTag(expectedTag);
      if (o != null) {
         enumValue = ((Long) o).intValue();      //todo why is this a Long and not an Integer?
      } else {
         while (true) {
            int tag = messageContext.in.readTag();
            if (tag == 0) {
               return null;
            }
            if (tag == expectedTag) {
               enumValue = messageContext.in.readEnum();
               break;
            }
            messageContext.unknownFieldSet.mergeFieldFrom(tag, messageContext.in);
         }
      }

      EnumMarshaller<A> enumMarshaller = (EnumMarshaller<A>) ctx.getMarshaller(clazz);
      A decoded = enumMarshaller.decode(enumValue);

      // the enum value was not recognized by the decoder so rather than discarding it we add it to the unknown
      if (decoded == null) {
         messageContext.unknownFieldSet.mergeVarintField(expectedTag, enumValue);
      }

      return decoded;
   }

   private <A> A readObject(Descriptors.FieldDescriptor fd, Class<A> clazz, CodedInputStream in, int length) throws IOException {
      BaseMarshaller<A> marshaller = ctx.getMarshaller(clazz);
      enterContext(fd.getName(), fd.getMessageType(), in);
      A a;
      if (fd.getType() == Descriptors.FieldDescriptor.Type.GROUP) {
         a = unmarshall(marshaller, in);
         in.checkLastTagWas(WireFormat.makeTag(fd.getNumber(), WireFormat.WIRETYPE_END_GROUP));
      } else if (fd.getType() == Descriptors.FieldDescriptor.Type.MESSAGE) {
         if (length < 0) {
            length = in.readRawVarint32();
         }
         int oldLimit = in.pushLimit(length);
         a = unmarshall(marshaller, in);
         in.checkLastTagWas(0);
         in.popLimit(oldLimit);
      } else {
         throw new IllegalArgumentException("Declared field type is not a message or an enum : " + fd.getName());
      }
      exitContext();
      return a;
   }

   private <A> A unmarshall(BaseMarshaller<A> marshaller, CodedInputStream in) throws IOException {
      A a = marshaller instanceof MessageMarshaller ? ((MessageMarshaller<A>) marshaller).readFrom(this) : ((RawProtobufMarshaller<A>) marshaller).readFrom(ctx, in);
      messageContext.unknownFieldSet.mergeFrom(messageContext.in);
      if (a instanceof Message && !messageContext.unknownFieldSet.isEmpty()) {
         ((Message) a).setUnknownFieldSet(messageContext.unknownFieldSet);
      }
      return a;
   }

   @Override
   public <A, C extends Collection<? super A>> C readCollection(String fieldName, C collection, Class<? extends A> clazz) throws IOException {
      Descriptors.FieldDescriptor fd = messageContext.getFieldByName(fieldName);
      checkRepeatedField(fd);

      if (primitiveTypes.contains(fd.getType())) {
         readPrimitiveCollection(fd, (Collection<Object>) collection, clazz);
         return collection;
      }

      //todo validate type is compatible with readObject
      int expectedTag = WireFormat.makeTag(fd.getNumber(), fd.getLiteType().getWireType());

      while (true) {
         Object o = messageContext.unknownFieldSet.consumeTag(expectedTag);
         if (o == null) {
            break;
         }
         ByteString byteString = (ByteString) o;
         CodedInputStream codedInputStream = byteString.newCodedInput();
         collection.add(readObject(fd, clazz, codedInputStream, byteString.size()));
      }

      while (true) {
         int tag = messageContext.in.readTag();
         if (tag == 0) {
            break;
         }
         if (tag == expectedTag) {
            collection.add(readObject(fd, clazz, messageContext.in, -1));
         } else {
            messageContext.unknownFieldSet.mergeFieldFrom(tag, messageContext.in);
         }
      }
      return collection;
   }

   private void readPrimitiveCollection(Descriptors.FieldDescriptor fd, Collection<? super Object> collection, Class clazz) throws IOException {
      int expectedTag = WireFormat.makeTag(fd.getNumber(), fd.getLiteType().getWireType());

      while (true) {
         Object o = messageContext.unknownFieldSet.consumeTag(expectedTag);
         if (o == null) {
            break;
         }
         collection.add(o);   //todo check that (o.getClass() == clazz)
      }

      while (true) {
         int tag = messageContext.in.readTag();
         if (tag == 0) {
            break;
         }
         if (tag == expectedTag) {
            switch (fd.getType()) {
               case DOUBLE:
                  collection.add(messageContext.in.readDouble());
                  break;
               case FLOAT:
                  collection.add(messageContext.in.readFloat());
                  break;
               case BOOL:
                  collection.add(messageContext.in.readBool());
                  break;
               case STRING:
                  collection.add(messageContext.in.readString());
                  break;
               case BYTES:
                  collection.add(messageContext.in.readBytes().toByteArray());
                  break;
               case INT64:
                  collection.add(messageContext.in.readInt64());
                  break;
               case UINT64:
                  collection.add(messageContext.in.readUInt64());
                  break;
               case FIXED64:
                  collection.add(messageContext.in.readFixed64());
                  break;
               case SFIXED64:
                  collection.add(messageContext.in.readSFixed64());
                  break;
               case SINT64:
                  collection.add(messageContext.in.readSInt64());
                  break;
               case INT32:
                  collection.add(messageContext.in.readInt32());
                  break;
               case FIXED32:
                  collection.add(messageContext.in.readFixed32());
                  break;
               case UINT32:
                  collection.add(messageContext.in.readUInt32());
                  break;
               case SFIXED32:
                  collection.add(messageContext.in.readSFixed32());
                  break;
               case SINT32:
                  collection.add(messageContext.in.readSInt32());
                  break;
               default:
                  throw new IllegalStateException("Unexpected field type : " + fd.getType());
            }
         } else {
            messageContext.unknownFieldSet.mergeFieldFrom(tag, messageContext.in);
         }
      }
   }

   @Override
   public <A> A[] readArray(String fieldName, Class<? extends A> clazz) throws IOException {
      List<A> list = readCollection(fieldName, new ArrayList<A>(), clazz);
      return list.toArray((A[]) Array.newInstance(clazz, list.size()));
   }

   private void checkRepeatedField(Descriptors.FieldDescriptor fd) {
      if (!fd.isRepeated()) {
         throw new IllegalArgumentException("This field is not repeated and cannot be read with one of the methods intended for collections or arrays: " + fd.getName());
      }
      if (!messageContext.getSeenFields().add(fd.getNumber())) {
         throw new IllegalArgumentException("Cannot read a field twice : " + fd.getName());
      }
   }

   private void checkNonRepeatedField(Descriptors.FieldDescriptor fd) {
      if (fd.isRepeated()) {
         throw new IllegalArgumentException("A repeated field should be read with one of the methods intended for collections or arrays: " + fd.getName());
      }
      if (!messageContext.getSeenFields().add(fd.getNumber())) {
         throw new IllegalArgumentException("Cannot read a field twice : " + fd.getName());
      }
   }
}
TOP

Related Classes of org.infinispan.protostream.impl.ProtoStreamReaderImpl

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.