Package org.apache.mahout.utils.vectors.lucene

Source Code of org.apache.mahout.utils.vectors.lucene.Driver

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.utils.vectors.lucene;

import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.mahout.matrix.SparseVector;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.utils.vectors.TF;
import org.apache.mahout.utils.vectors.TFIDF;
import org.apache.mahout.utils.vectors.TermInfo;
import org.apache.mahout.utils.vectors.Weight;
import org.apache.mahout.utils.vectors.io.JWriterTermInfoWriter;
import org.apache.mahout.utils.vectors.io.JWriterVectorWriter;
import org.apache.mahout.utils.vectors.io.SequenceFileVectorWriter;
import org.apache.mahout.utils.vectors.io.VectorWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.nio.charset.Charset;

public class Driver {
  private static final Logger log = LoggerFactory.getLogger(Driver.class);

  private Driver() {
  }

  public static void main(String[] args) throws IOException {
    DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
    ArgumentBuilder abuilder = new ArgumentBuilder();
    GroupBuilder gbuilder = new GroupBuilder();

    Option inputOpt = obuilder.withLongName("dir").withRequired(true).withArgument(
            abuilder.withName("dir").withMinimum(1).withMaximum(1).create()).
            withDescription("The Lucene directory").withShortName("d").create();

    Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
            abuilder.withName("output").withMinimum(1).withMaximum(1).create()).
            withDescription("The output file").withShortName("o").create();

    Option fieldOpt = obuilder.withLongName("field").withRequired(true).withArgument(
            abuilder.withName("field").withMinimum(1).withMaximum(1).create()).
            withDescription("The field in the index").withShortName("f").create();

    Option idFieldOpt = obuilder.withLongName("idField").withRequired(false).withArgument(
            abuilder.withName("idField").withMinimum(1).withMaximum(1).create()).
            withDescription("The field in the index containing the index.  If null, then the Lucene internal doc " +
                    "id is used which is prone to error if the underlying index changes").withShortName("i").create();

    Option dictOutOpt = obuilder.withLongName("dictOut").withRequired(true).withArgument(
            abuilder.withName("dictOut").withMinimum(1).withMaximum(1).create()).
            withDescription("The output of the dictionary").withShortName("t").create();

    Option weightOpt = obuilder.withLongName("weight").withRequired(false).withArgument(
            abuilder.withName("weight").withMinimum(1).withMaximum(1).create()).
            withDescription("The kind of weight to use. Currently TF or TFIDF").withShortName("w").create();

    Option delimiterOpt = obuilder.withLongName("delimiter").withRequired(false).withArgument(
            abuilder.withName("delimiter").withMinimum(1).withMaximum(1).create()).
            withDescription("The delimiter for outputing the dictionary").withShortName("l").create();
    Option powerOpt = obuilder.withLongName("norm").withRequired(false).withArgument(
            abuilder.withName("norm").withMinimum(1).withMaximum(1).create()).
            withDescription("The norm to use, expressed as either a double or \"INF\" if you want to use the Infinite norm.  " +
                    "Must be greater or equal to 0.  The default is not to normalize").withShortName("n").create();
    Option maxOpt = obuilder.withLongName("max").withRequired(false).withArgument(
            abuilder.withName("max").withMinimum(1).withMaximum(1).create()).
            withDescription("The maximum number of vectors to output.  If not specified, then it will loop over all docs").withShortName("m").create();

    Option outWriterOpt = obuilder.withLongName("outputWriter").withRequired(false).withArgument(
            abuilder.withName("outputWriter").withMinimum(1).withMaximum(1).create()).
            withDescription("The VectorWriter to use, either seq (SequenceFileVectorWriter - default) or file (Writes to a File using JSON format)").withShortName("e").create();
    Option minDFOpt = obuilder.withLongName("minDF").withRequired(false).withArgument(
            abuilder.withName("minDF").withMinimum(1).withMaximum(1).create()).
            withDescription("The minimum document frequency.  Default is 1").withShortName("md").create();
    Option maxDFPercentOpt = obuilder.withLongName("maxDFPercent").withRequired(false).withArgument(
            abuilder.withName("maxDFPercent").withMinimum(1).withMaximum(1).create()).
            withDescription("The max percentage of docs for the DF.  Can be used to remove really high frequency terms.  Expressed as an integer between 0 and 100. Default is 99.").withShortName("x").create();
    Option helpOpt = obuilder.withLongName("help").
            withDescription("Print out help").withShortName("h").create();
    Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(idFieldOpt).withOption(outputOpt).withOption(delimiterOpt)
            .withOption(helpOpt).withOption(fieldOpt).withOption(maxOpt).withOption(dictOutOpt).withOption(powerOpt).withOption(outWriterOpt).withOption(maxDFPercentOpt)
            .withOption(weightOpt).withOption(minDFOpt).create();
    try {
      Parser parser = new Parser();
      parser.setGroup(group);
      CommandLine cmdLine = parser.parse(args);

      if (cmdLine.hasOption(helpOpt)) {

        CommandLineUtil.printHelp(group);
        return;
      }
      //Springify all this
      if (cmdLine.hasOption(inputOpt)) {//Lucene case
        File file = new File(cmdLine.getValue(inputOpt).toString());
        if (file.exists() && file.isDirectory()) {
          long maxDocs = Long.MAX_VALUE;
          if (cmdLine.hasOption(maxOpt)) {
            maxDocs = Long.parseLong(cmdLine.getValue(maxOpt).toString());
          }
          if (maxDocs < 0) {
            throw new IllegalArgumentException("maxDocs must be >= 0");
          }
          Directory dir = FSDirectory.open(file);
          IndexReader reader = IndexReader.open(dir, true);
          Weight weight;
          if (cmdLine.hasOption(weightOpt)) {
            String wString = cmdLine.getValue(weightOpt).toString();
            if (wString.equalsIgnoreCase("tf")) {
              weight = new TF();
            } else if (wString.equalsIgnoreCase("tfidf")) {
              weight = new TFIDF();
            } else {
              throw new OptionException(weightOpt);
            }
          } else {
            weight = new TFIDF();
          }
          String field = cmdLine.getValue(fieldOpt).toString();
          int minDf = 1;
          if (cmdLine.hasOption(minDFOpt)) {
            minDf = Integer.parseInt(cmdLine.getValue(minDFOpt).toString());
          }
          int maxDFPercent = 99;
          if (cmdLine.hasOption(maxDFPercentOpt)) {
            maxDFPercent = Integer.parseInt(cmdLine.getValue(maxDFPercentOpt).toString());
          }
          TermInfo termInfo = new CachedTermInfo(reader, field, minDf, maxDFPercent);
          VectorMapper mapper = new TFDFMapper(reader, weight, termInfo);
          double norm = LuceneIterable.NO_NORMALIZING;
          if (cmdLine.hasOption(powerOpt)) {
            String power = cmdLine.getValue(powerOpt).toString();
            if (power.equals("INF")) {
              norm = Double.POSITIVE_INFINITY;
            } else {
              norm = Double.parseDouble(power);
            }
          }
          String idField = null;
          if (cmdLine.hasOption(idFieldOpt)) {
            idField = cmdLine.getValue(idFieldOpt).toString();
          }
          LuceneIterable iterable;
          if (norm == LuceneIterable.NO_NORMALIZING) {
            iterable = new LuceneIterable(reader, idField, field, mapper, LuceneIterable.NO_NORMALIZING);
          } else {
            iterable = new LuceneIterable(reader, idField, field, mapper, norm);
          }
          String outFile = cmdLine.getValue(outputOpt).toString();
          log.info("Output File: " + outFile);

          VectorWriter vectorWriter;
          if (cmdLine.hasOption(outWriterOpt)) {
            String outWriter = cmdLine.getValue(outWriterOpt).toString();
            if (outWriter.equals("file")) {
              BufferedWriter writer = new BufferedWriter(new FileWriter(outFile));
              vectorWriter = new JWriterVectorWriter(writer);
            } else {
              vectorWriter = getSeqFileWriter(outFile);
            }
          } else {
            vectorWriter = getSeqFileWriter(outFile);
          }

          long numDocs = vectorWriter.write(iterable, maxDocs);
          vectorWriter.close();
          log.info("Wrote: " + numDocs + " vectors");

          String delimiter = cmdLine.hasOption(delimiterOpt) ? cmdLine.getValue(delimiterOpt).toString() : "\t";
          File dictOutFile = new File(cmdLine.getValue(dictOutOpt).toString());
          log.info("Dictionary Output file: " + dictOutFile);
          BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(dictOutFile), Charset.forName("UTF8")));
          JWriterTermInfoWriter tiWriter = new JWriterTermInfoWriter(writer, delimiter, field);
          tiWriter.write(termInfo);
          tiWriter.close();
          writer.close();
        }
      }

    } catch (OptionException e) {
      log.error("Exception", e);
      CommandLineUtil.printHelp(group);
    }
  }

  private static VectorWriter getSeqFileWriter(String outFile) throws IOException {
    Path path = new Path(outFile);
    Configuration conf = new Configuration();
    FileSystem fs = FileSystem.get(conf);
    //TODO: Make this parameter driven
    SequenceFile.Writer seqWriter = SequenceFile.createWriter(fs, conf, path, LongWritable.class, SparseVector.class);

    return new SequenceFileVectorWriter(seqWriter);
  }


}
TOP

Related Classes of org.apache.mahout.utils.vectors.lucene.Driver

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.