Package org.apache.solr.search.grouping.distributed.shardresultserializer

Source Code of org.apache.solr.search.grouping.distributed.shardresultserializer.TopGroupsResultTransformer

package org.apache.solr.search.grouping.distributed.shardresultserializer;

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import org.apache.lucene.document.Document;
import org.apache.lucene.document.FieldSelector;
import org.apache.lucene.document.FieldSelectorResult;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.grouping.GroupDocs;
import org.apache.lucene.search.grouping.TopGroups;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.handler.component.ResponseBuilder;
import org.apache.solr.handler.component.ShardDoc;
import org.apache.solr.schema.FieldType;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.search.grouping.Command;
import org.apache.solr.search.grouping.distributed.command.QueryCommand;
import org.apache.solr.search.grouping.distributed.command.QueryCommandResult;
import org.apache.solr.search.grouping.distributed.command.TopGroupsFieldCommand;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
* Implementation for transforming {@link TopGroups} and {@link TopDocs} into a {@link NamedList} structure and
* visa versa.
*/
public class TopGroupsResultTransformer implements ShardResultTransformer<List<Command>, Map<String, ?>> {

  private final ResponseBuilder rb;

  public TopGroupsResultTransformer(ResponseBuilder rb) {
    this.rb = rb;
  }

  /**
   * {@inheritDoc}
   */
  public NamedList transform(List<Command> data) throws IOException {
    NamedList<NamedList> result = new NamedList<NamedList>();
    for (Command command : data) {
      NamedList commandResult;
      if (TopGroupsFieldCommand.class.isInstance(command)) {
        TopGroupsFieldCommand fieldCommand = (TopGroupsFieldCommand) command;
        SchemaField groupField = rb.req.getSearcher().getSchema().getField(fieldCommand.getKey());
        commandResult = serializeTopGroups(fieldCommand.result(), groupField);
      } else if (QueryCommand.class.isInstance(command)) {
        QueryCommand queryCommand = (QueryCommand) command;
        commandResult = serializeTopDocs(queryCommand.result());
      } else {
        commandResult = null;
      }

      result.add(command.getKey(), commandResult);
    }
    return result;
  }

  /**
   * {@inheritDoc}
   */
  public Map<String, ?> transformToNative(NamedList<NamedList> shardResponse, Sort groupSort, Sort sortWithinGroup, String shard) {
    Map<String, Object> result = new HashMap<String, Object>();

    for (Map.Entry<String, NamedList> entry : shardResponse) {
      String key = entry.getKey();
      NamedList commandResult = entry.getValue();
      Integer totalGroupedHitCount = (Integer) commandResult.get("totalGroupedHitCount");
      Integer totalHits = (Integer) commandResult.get("totalHits");
      if (totalHits != null) {
        Integer matches = (Integer) commandResult.get("matches");
        Float maxScore = (Float) commandResult.get("maxScore");
        if (maxScore == null) {
          maxScore = Float.NaN;
        }
    Double sum = (Double) commandResult.get("sum");
    Double max = (Double) commandResult.get("max");
        Double min = (Double) commandResult.get("min");
        if (sum == null) { sum = Double.valueOf(0.0f); }
        if (max == null) { max = Double.NEGATIVE_INFINITY; }
        if (min == null) { min = Double.POSITIVE_INFINITY; }

        @SuppressWarnings("unchecked")
        List<NamedList<Object>> documents = (List<NamedList<Object>>) commandResult.get("documents");
        ScoreDoc[] scoreDocs = new ScoreDoc[documents.size()];
        int j = 0;
        for (NamedList<Object> document : documents) {
          Object uniqueId = document.get("id").toString();
          Float score = (Float) document.get("score");
          if (score == null) {
            score = Float.NaN;
          }
          Object[] sortValues = ((List) document.get("sortValues")).toArray();
          scoreDocs[j++] = new ShardDoc(score, sortValues, uniqueId, shard);
        }
        result.put(key, new QueryCommandResult(new TopDocs(totalHits, scoreDocs, maxScore, sum, max, min), matches));
        continue;
      }

      Integer totalHitCount = (Integer) commandResult.get("totalHitCount");
      Integer totalGroupCount = (Integer) commandResult.get("totalGroupCount");

      List<GroupDocs<String>> groupDocs = new ArrayList<GroupDocs<String>>();
      for (int i = totalGroupCount == null ? 2 : 3; i < commandResult.size(); i++) {
        String groupValue = commandResult.getName(i);
        @SuppressWarnings("unchecked")
        NamedList<Object> groupResult = (NamedList<Object>) commandResult.getVal(i);
        Integer totalGroupHits = (Integer) groupResult.get("totalHits");
        Float maxScore = (Float) groupResult.get("maxScore");
        if (maxScore == null) {
          maxScore = Float.NaN;
        }
    Double sum = (Double) groupResult.get("sum");
    Double max = (Double) groupResult.get("max");
        Double min = (Double) groupResult.get("min");
        if (sum == null) { sum = Double.valueOf(0.0f); }
        if (max == null) { max = Double.NEGATIVE_INFINITY; }
        if (min == null) { min = Double.POSITIVE_INFINITY; }

        @SuppressWarnings("unchecked")
        List<NamedList<Object>> documents = (List<NamedList<Object>>) groupResult.get("documents");
        ScoreDoc[] scoreDocs = new ScoreDoc[documents.size()];
        int j = 0;
        for (NamedList<Object> document : documents) {
          Object uniqueId = document.get("id").toString();
          Float score = (Float) document.get("score");
          if (score == null) {
            score = Float.NaN;
          }
          //System.out.println("######### transformToNative:"+document.get("sortValues"));
          List sortValuesList = (List) document.get("sortValues");
          Object[] sortValues = sortValuesList==null?new Object[0]:sortValuesList.toArray();
          scoreDocs[j++] = new ShardDoc(score, sortValues, uniqueId, shard);
        }

        String groupValueRef = groupValue != null ? groupValue : null;
        groupDocs.add(new GroupDocs<String>(maxScore, totalGroupHits, scoreDocs, groupValueRef, null, sum, max, min));
      }

      @SuppressWarnings("unchecked")
      GroupDocs<String>[] groupDocsArr = groupDocs.toArray(new GroupDocs[groupDocs.size()]);
      TopGroups<String> topGroups = new TopGroups<String>(
        groupSort.getSort(), sortWithinGroup.getSort(), totalHitCount, totalGroupedHitCount, groupDocsArr
      );
      if (totalGroupCount != null) {
        topGroups = new TopGroups<String>(topGroups, totalGroupCount);
      }

      result.put(key, topGroups);
    }

    return result;
  }

  protected NamedList serializeTopGroups(TopGroups<String> data, SchemaField groupField) throws IOException {
    NamedList<Object> result = new NamedList<Object>();
    result.add("totalGroupedHitCount", data.totalGroupedHitCount);
    result.add("totalHitCount", data.totalHitCount);
    if (data.totalGroupCount != null) {
      result.add("totalGroupCount", data.totalGroupCount);
    }
    SchemaField uniqueField = rb.req.getSearcher().getSchema().getUniqueKeyField();
    for (GroupDocs<String> searchGroup : data.groups) {
      NamedList<Object> groupResult = new NamedList<Object>();
      groupResult.add("totalHits", searchGroup.totalHits);
      if (!Float.isNaN(searchGroup.maxScore)) {
        groupResult.add("maxScore", searchGroup.maxScore);
      }
    groupResult.add("sum", searchGroup.sum);
      groupResult.add("max", searchGroup.max);
      groupResult.add("min", searchGroup.min);
      List<NamedList<Object>> documents = new ArrayList<NamedList<Object>>();
      for (int i = 0; i < searchGroup.scoreDocs.length; i++) {
        NamedList<Object> document = new NamedList<Object>();
        documents.add(document);

        Document doc = retrieveDocument(uniqueField, searchGroup.scoreDocs[i].doc);
        document.add("id", uniqueField.getType().toObject(doc.getFieldable(uniqueField.getName())));
        if (!Float.isNaN(searchGroup.scoreDocs[i].score))  {
          document.add("score", searchGroup.scoreDocs[i].score);
        }
        if (!(searchGroup.scoreDocs[i] instanceof FieldDoc)) {
          continue;
        }

        FieldDoc fieldDoc = (FieldDoc) searchGroup.scoreDocs[i];
        Object[] convertedSortValues  = new Object[fieldDoc.fields.length];
        for (int j = 0; j < fieldDoc.fields.length; j++) {
          Object sortValue  = fieldDoc.fields[j];
          Sort sortWithinGroup = rb.getGroupingSpec().getSortWithinGroup();
          SchemaField field = sortWithinGroup.getSort()[j].getField() != null ? rb.req.getSearcher().getSchema().getFieldOrNull(sortWithinGroup.getSort()[j].getField()) : null;
          if (field != null) {
            FieldType fieldType = field.getType();
            if (sortValue instanceof String) {
              sortValue = fieldType.toObject(field.createField(fieldType.indexedToReadable((String) sortValue), 0.0f));
            }
          }
          convertedSortValues[j] = sortValue;
        }
        //System.out.println("######### serializeTopGroups:"+convertedSortValues.length);
        document.add("sortValues", convertedSortValues);
      }
      groupResult.add("documents", documents);
      String groupValue = searchGroup.groupValue != null ? groupField.getType().indexedToReadable(searchGroup.groupValue): null;
      result.add(groupValue, groupResult);
    }

    return result;
  }

  protected NamedList serializeTopDocs(QueryCommandResult result) throws IOException {
    NamedList<Object> queryResult = new NamedList<Object>();
    queryResult.add("matches", result.getMatches());
    queryResult.add("totalHits", result.getTopDocs().totalHits);
    if (rb.getGroupingSpec().isNeedScore()) {
      queryResult.add("maxScore", result.getTopDocs().getMaxScore());
    }
  queryResult.add("sum", result.getTopDocs().getSum());
    queryResult.add("max", result.getTopDocs().getMax());
    queryResult.add("min", result.getTopDocs().getMin());
 
    List<NamedList> documents = new ArrayList<NamedList>();
    queryResult.add("documents", documents);

    SchemaField uniqueField = rb.req.getSearcher().getSchema().getUniqueKeyField();
    for (ScoreDoc scoreDoc : result.getTopDocs().scoreDocs) {
      NamedList<Object> document = new NamedList<Object>();
      documents.add(document);

      Document doc = retrieveDocument(uniqueField, scoreDoc.doc);
      document.add("id", uniqueField.getType().toObject(doc.getFieldable(uniqueField.getName())));
      if (rb.getGroupingSpec().isNeedScore())  {
        document.add("score", scoreDoc.score);
      }
      if (!FieldDoc.class.isInstance(scoreDoc)) {
        continue;
      }

      FieldDoc fieldDoc = (FieldDoc) scoreDoc;
      Object[] convertedSortValues  = new Object[fieldDoc.fields.length];
      for (int j = 0; j < fieldDoc.fields.length; j++) {
        Object sortValue  = fieldDoc.fields[j];
        Sort groupSort = rb.getGroupingSpec().getGroupSort();
        SchemaField field = groupSort.getSort()[j].getField() != null ? rb.req.getSearcher().getSchema().getFieldOrNull(groupSort.getSort()[j].getField()) : null;
        if (field != null) {
          FieldType fieldType = field.getType();
          if (sortValue instanceof String) {
            sortValue = fieldType.toObject(field.createField(fieldType.indexedToReadable((String) sortValue), 0.0f));
          }
        }
        convertedSortValues[j] = sortValue;
      }
      //System.out.println("######### serializeTopDocs:"+convertedSortValues.length);
      document.add("sortValues", convertedSortValues);
    }

    return queryResult;
  }

  private Document retrieveDocument(final SchemaField uniqueField, int doc) throws IOException {
    FieldSelector fieldSelectorVisitor = new FieldSelector() {

      public FieldSelectorResult accept(String fieldName) {
        if (uniqueField.getName().equals(fieldName)) {
          return FieldSelectorResult.LOAD_AND_BREAK;
        }
        return FieldSelectorResult.NO_LOAD;
      }
    };
    return rb.req.getSearcher().doc(doc, fieldSelectorVisitor);
  }

}
TOP

Related Classes of org.apache.solr.search.grouping.distributed.shardresultserializer.TopGroupsResultTransformer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.