Package org.apache.mahout.knn

Source Code of org.apache.mahout.knn.Vectorize20NewsGroups

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.knn;

import com.google.common.base.*;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multiset;
import com.google.common.collect.Sets;
import com.google.common.io.Files;
import com.google.common.io.LineProcessor;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;

import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.regex.Pattern;

/**
* Read, tokenize and convert the 20 newsgroups test data to vector form.
* <p/>
* The vectorization is done using a hashed projection to a fixed dimension vector using a
* selectable term weighting.
*
* Command line options are
*
* <ul>
*   <li>weighting code, three characters long.  The first character can be l, s, t or x to indicate
*   term weighting of log, square root, term frequency or no weighting.  The second
*   character can be i or x to indicate IDF weighting or no corpus weighting.  The third character
*   can be c or x to indicate cosine normalization or no normalization.</li>
*   <li>a comma separated list of header lines to use</li>
*   <li>a boolean to indicate whether quoted lines should be retained (true to retain, false to omit)</li>
*   <li>the dimension of the result vector</li>
*   <li>a list of directories containing files to parse</li>
* </ul>
*/
public class Vectorize20NewsGroups {
  private static boolean includeQuotes;
  private static Set<String> legalHeaders;

  public static void main(String[] args) throws IOException {
    String weightingCode = args[0];
    boolean normalize = weightingCode.endsWith("c");

    legalHeaders = Sets.newHashSet();
    Iterables.addAll(legalHeaders, Iterables.transform(Splitter.on(",").trimResults().split(args[1]), new Function<String, String>() {
      @Override
      public String apply(String s) {
        return s.toLowerCase();
      }
    }));

    includeQuotes = Boolean.parseBoolean(args[2]);

    CorpusWeighting cw = CorpusWeighting.parse(weightingCode);
    if (cw.needCorpusWeights()) {
      Multiset<String> wordFrequency = HashMultiset.create();
      Set<String> documents = Sets.newHashSet();
      for (String file : Arrays.asList(args).subList(4, args.length)) {
        recursivelyCount(documents, wordFrequency, new File(file));
      }
      cw.setCorpusCounts(wordFrequency, documents.size());
    }

    int dimension = Integer.parseInt(args[3]);

    Configuration conf = new Configuration();
    SequenceFile.Writer sf = SequenceFile.createWriter(FileSystem.getLocal(conf), conf, new Path("output"), Text.class, VectorWritable.class);
    PrintWriter csv = new PrintWriter("output.csv");
    for (String file : Arrays.asList(args).subList(4, args.length)) {
      recursivelyVectorize(csv, sf, new File(file), cw, normalize, dimension);
    }
    csv.close();
    sf.close();
  }

  private static void recursivelyCount(Set<String> documents, Multiset<String> wordFrequency, File f) throws IOException {
    if (f.isDirectory()) {
      for (File file : f.listFiles()) {
        recursivelyCount(documents, wordFrequency, file);
      }
    } else {
      // count each word once per document regardless of actual count
      documents.add(f.getCanonicalPath());
      wordFrequency.addAll(parse(f).elementSet());
    }
  }

  static void recursivelyVectorize(PrintWriter csv, SequenceFile.Writer sf, File f, CorpusWeighting w, boolean normalize, int dimension) throws IOException {
    if (f.isDirectory()) {
      for (File file : f.listFiles()) {
        recursivelyVectorize(csv, sf, file, w, normalize, dimension);
      }
    } else {
      Vector v = vectorizeFile(f, w, normalize, dimension);
      csv.printf("%s,%s", f.getParentFile().getName(), f.getName());
      for (int i = 0; i < v.size(); i++) {
        csv.printf(",%.5f", v.get(i));
      }
      csv.printf("\n");
      sf.append(new Text(f.getParentFile().getName()), new VectorWritable(v));
    }
  }

  static Vector vectorizeFile(File f, CorpusWeighting w, boolean normalize, int dimension) throws IOException {
    Multiset<String> counts = parse(f);
    return vectorize(counts, w, normalize, dimension);
  }

  static Vector vectorize(Multiset<String> doc, CorpusWeighting w, boolean normalize, int dimension) {
    Vector v = new RandomAccessSparseVector(dimension);
    FeatureVectorEncoder encoder = new StaticWordValueEncoder("text");
    for (String word : doc.elementSet()) {
      encoder.addToVector(word, w.weight(word, doc.count(word)), v);
    }
    if (normalize) {
      return v.assign(Functions.div(v.norm(2)));
    } else {
      return v;
    }
  }

  static Multiset<String> parse(File f) throws IOException {
    return Files.readLines(f, Charsets.UTF_8, new LineProcessor<Multiset<String>>() {
      private boolean readingHeaders = true;
      private Splitter header = Splitter.on(":").limit(2);
      private Splitter words = Splitter.on(CharMatcher.forPredicate(new Predicate<Character>() {
        @Override
        public boolean apply(Character ch) {
          return !Character.isLetterOrDigit(ch) && ch != '.' && ch != '/' && ch != ':';
        }
      })).omitEmptyStrings().trimResults();

      private Pattern quotedLine = Pattern.compile("(^In article .*)|(^> .*)|(.*writes:$)|(^\\|>)");

      private Multiset<String> counts = HashMultiset.create();

      @Override
      public boolean processLine(String line) throws IOException {
        if (readingHeaders && line.length() == 0) {
          readingHeaders = false;
        }

        if (readingHeaders) {
          Iterator<String> i = header.split(line).iterator();
          String head = i.next().toLowerCase();
          if (legalHeaders.contains(head)) {
            addText(counts, i.next());
          }
        } else {
          boolean quote = quotedLine.matcher(line).matches();
          if (includeQuotes || !quote) {
            addText(counts, line);
          }
        }
        return true;
      }

      @Override
      public Multiset<String> getResult() {
        return counts;
      }

      private void addText(Multiset<String> v, String line) {
        for (String word : words.split(line)) {
          v.add(word.toLowerCase());
        }
      }
    });
  }

  private static abstract class CorpusWeighting {
    static Map<String, CorpusWeighting> corpusWeights = ImmutableMap.of("i", new Idf(), "x", new Unit());

    static CorpusWeighting parse(String code) {
      CorpusWeighting cw = corpusWeights.get(code.substring(1, 2));
      TermWeighting tw = TermWeighting.parse(code.substring(0, 1));
      cw.setTermWeighting(tw);
      return cw;
    }

    TermWeighting termWeighting;

    public void setTermWeighting(TermWeighting termWeighting) {
      this.termWeighting = termWeighting;
    }

    abstract double weight(String word, int count);

    abstract boolean needCorpusWeights();

    public void setCorpusCounts(Multiset<String> corpusCounts, int corpusSize) {
      throw new UnsupportedOperationException("Can't add counts to a Unit weighting");
    }
  }

  private static class Idf extends CorpusWeighting {
    Multiset<String> documentFrequency;
    int corpusSize;

    @Override
    double weight(String word, int count) {
      return termWeighting.termFrequencyWeight(count) * Math.log((corpusSize + 1) / (documentFrequency.count(word) + 1));
    }

    @Override
    boolean needCorpusWeights() {
      return true;
    }

    @Override
    public void setCorpusCounts(Multiset<String> corpusCounts, int corpusSize) {
      this.documentFrequency = corpusCounts;
      this.corpusSize = corpusSize;
    }
  }

  private static class Unit extends CorpusWeighting {
    @Override
    double weight(String word, int count) {
      return termWeighting.termFrequencyWeight(count);
    }

    @Override
    boolean needCorpusWeights() {
      return false;
    }

  }

  private static abstract class TermWeighting {
    abstract double termFrequencyWeight(int count);

    static final TermWeighting log = new TermWeighting() {
      @Override
      double termFrequencyWeight(int count) {
        return Math.log(count + 1);
      }
    };
    static final TermWeighting linear = new TermWeighting() {
      @Override
      double termFrequencyWeight(int count) {
        return count;
      }
    };
    static final TermWeighting root = new TermWeighting() {
      @Override
      double termFrequencyWeight(int count) {
        return Math.sqrt(count);
      }
    };
    static final TermWeighting unit = new TermWeighting() {
      @Override
      double termFrequencyWeight(int count) {
        return 1;
      }
    };

    static Map<String, TermWeighting> termWeights = ImmutableMap.of("l", TermWeighting.log, "s", TermWeighting.root, "t", TermWeighting.linear, "x", TermWeighting.unit);

    static final TermWeighting parse(String code) {
      return termWeights.get(code);
    }


  }

}
TOP

Related Classes of org.apache.mahout.knn.Vectorize20NewsGroups

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.