Package org.apache.mahout.classifier.sgd

Examples of org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression


    }

    Dictionary newsGroups = new Dictionary();

    encoder.setProbes(2);
    AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(20, FEATURES, new L1());
    learningAlgorithm.setInterval(800);
    learningAlgorithm.setAveragingWindow(500);

    List<File> files = Lists.newArrayList();
    File[] directories = base.listFiles();
    Arrays.sort(directories, Ordering.usingToString());
    for (File newsgroup : directories) {
      if (newsgroup.isDirectory()) {
        newsGroups.intern(newsgroup.getName());
        files.addAll(Arrays.asList(newsgroup.listFiles()));
      }
    }
    Collections.shuffle(files);
    System.out.printf("%d training files\n", files.size());
    System.out.printf("%s\n", Arrays.asList(directories));

    double averageLL = 0;
    double averageCorrect = 0;

    int k = 0;
    double step = 0;
    int[] bumps = {1, 2, 5};
    for (File file : files) {
      String ng = file.getParentFile().getName();
      int actual = newsGroups.intern(ng);

      Vector v = encodeFeatureVector(file);
      learningAlgorithm.train(actual, v);

      k++;

      int bump = bumps[(int) Math.floor(step) % bumps.length];
      int scale = (int) Math.pow(10, Math.floor(step / bumps.length));
      State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest();
      double maxBeta;
      double nonZeros;
      double positive;
      double norm;

      double lambda = 0;
      double mu = 0;

      if (best != null) {
        CrossFoldLearner state = best.getPayload().getLearner();
        averageCorrect = state.percentCorrect();
        averageLL = state.logLikelihood();

        OnlineLogisticRegression model = state.getModels().get(0);
        // finish off pending regularization
        model.close();
       
        Matrix beta = model.getBeta();
        maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
        nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
          @Override
          public double apply(double v) {
            return Math.abs(v) > 1.0e-6 ? 1 : 0;
          }
        });
        positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
          @Override
          public double apply(double v) {
            return v > 0 ? 1 : 0;
          }
        });
        norm = beta.aggregate(Functions.PLUS, Functions.ABS);

        lambda = learningAlgorithm.getBest().getMappedParams()[0];
        mu = learningAlgorithm.getBest().getMappedParams()[1];
      } else {
        maxBeta = 0;
        nonZeros = 0;
        positive = 0;
        norm = 0;
      }
      if (k % (bump * scale) == 0) {
        if (learningAlgorithm.getBest() != null) {
          ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model",
            learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
        }

        step += 0.25;
        System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu);
        System.out.printf("%d\t%.3f\t%.2f\t%s\n",
          k, averageLL, averageCorrect * 100, LEAK_LABELS[leakType % 3]);
      }
    }
    learningAlgorithm.close();
    dissect(newsGroups, learningAlgorithm, files);
    System.out.println("exiting main");

    ModelSerializer.writeBinary("/tmp/news-group.model",
                                learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
  }
View Full Code Here

TOP

Related Classes of org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.