Package com.ipeirotis.gal.engine

Source Code of com.ipeirotis.gal.engine.Engine

/*******************************************************************************
* Copyright 2012 Panos Ipeirotis
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
******************************************************************************/
package com.ipeirotis.gal.engine;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.Set;

import com.ipeirotis.gal.Helper;
import com.ipeirotis.gal.algorithms.DawidSkene;
import com.ipeirotis.gal.core.AssignedLabel;
import com.ipeirotis.gal.core.Category;
import com.ipeirotis.gal.core.CorrectLabel;
import com.ipeirotis.gal.core.MisclassificationCost;
import com.ipeirotis.gal.engine.rpt.CategoryPriorsReport;
import com.ipeirotis.gal.engine.rpt.ConfusionMatrixReport;
import com.ipeirotis.gal.engine.rpt.ObjectResultReport;
import com.ipeirotis.gal.engine.rpt.Report;
import com.ipeirotis.gal.engine.rpt.ReportingContext;
import com.ipeirotis.gal.engine.rpt.SummaryReport;
import com.ipeirotis.gal.engine.rpt.WorkerQualityReport;

public class Engine {
  private Set<Category> categories;

  private DawidSkene ds;

  private Set<MisclassificationCost> costs;

  private Set<AssignedLabel> labels;

  private Set<CorrectLabel> correct;

  private Set<CorrectLabel> evaluation;

  private EngineContext ctx;

  private ReportingContext rptCtx;

  private Set<Report> reports = new LinkedHashSet<Report>();

  public Engine(EngineContext ctx) {
    this.ctx = ctx;
    this.rptCtx = new ReportingContext(this);
    this.reports.addAll(Arrays.asList(new WorkerQualityReport(),
        new ObjectResultReport()));
  }
 
  public EngineContext getEngineContext() {
    return ctx;
  }

  public Set<Category> getCategories() {
    return categories;
  }

  public void setCategories(Set<Category> categories) {
    this.categories = categories;
  }

  public DawidSkene getDs() {
    return ds;
  }

  public void setDawidSkene(DawidSkene ds) {
    this.ds = ds;
  }

  public Set<MisclassificationCost> getCosts() {
    return costs;
  }

  public void setCosts(Set<MisclassificationCost> costs) {
    this.costs = costs;
  }

  public Set<AssignedLabel> getLabels() {
    return labels;
  }

  public void setLabels(Set<AssignedLabel> labels) {
    this.labels = labels;
  }

  public Set<CorrectLabel> getCorrect() {
    return correct;
  }

  public void setCorrect(Set<CorrectLabel> correct) {
    this.correct = correct;
  }

  public Set<CorrectLabel> getEvaluation() {
    return evaluation;
  }

  public void setEvaluation(Set<CorrectLabel> evaluation) {
    this.evaluation = evaluation;
  }

  public void execute() {
    setCategories(loadCategories(ctx.getCategoriesFile()));

    setDawidSkene(new DawidSkene(getCategories()));
   
    if (getDs().fixedPriors() == true)
      println("Using fixed priors.");
    else
      println("Using data-inferred priors.");

    if (ctx.hasCosts()) {
      setCosts(loadCosts(ctx.getCostFile()));

      for (MisclassificationCost mcc : getCosts()) {
        getDs().addMisclassificationCost(mcc);
      }
    }

    setLabels(loadWorkerAssignedLabels(ctx.getInputFile()));

    int al = 0;

    for (AssignedLabel l : getLabels()) {
      if (++al % 1000 == 0)
        print(".");
      getDs().addAssignedLabel(l);
    }
    println("%d worker-assigned labels loaded.", getLabels().size());

    if (ctx.hasGoldFile()) {
      setCorrect(loadGoldLabels(ctx.getGoldFile()));

      int cl = 0;
      for (CorrectLabel l : getCorrect()) {
        if (++cl % 1000 == 0)
          print(".");
        getDs().addCorrectLabel(l);
      }
      println("%d correct labels loaded.", getCorrect().size());
    }

    if (ctx.hasEvaluations()) {
      setEvaluation(loadEvaluationLabels(ctx.getEvaluationFile()));
      int el = 0;
      for (CorrectLabel l : getEvaluation()) {
        if (++el % 1000 == 0)
          print(".");
        getDs().addEvaluationLabel(l);
      }
      println(getEvaluation().size() + " evaluation labels loaded.");
    }

    // We compute the evaluation-based confusion matrix for the workers
    getDs().evaluateWorkers();

    println("");
    println("Running the Dawid&Skene algorithm");
    double epsilon = ctx.getEpsilon();
    int maxIterations = ctx.getNumIterations();
    double ll = getDs().getLogLikelihood();
    println("Initial Log-likelihood: %3.6f", ll);
    ll = getDs().estimate(maxIterations, epsilon);
    println("Final Log-likelihood: %3.6f", ll);
    println("Done\n");

    if (ctx.hasEvaluateResultsAgainstFile()) {
      rptCtx.setExpectedEvaluation(loadGoldLabels(ctx
          .getEvaluateResultsAgainstFile()));
    }

    executeReports();
  }

  private void executeReports() {
    if (!ctx.isDryRun()) {
      reports.add(new CategoryPriorsReport());
    }
   
    reports.add(new SummaryReport());
    reports.add(new ConfusionMatrixReport());
   
    try {
      File outputDir = new File("results");
     
      if (! outputDir.exists())
        outputDir.mkdir();
     
     
      for (Report report : reports) {
        report.execute(rptCtx);
      }
    } catch (IOException exc) {
      throw new RuntimeException(exc);
    }
  }

  /**
   * @param correctfile
   * @return
   */
  private Set<CorrectLabel> loadGoldLabels(String correctfile) {
    // We load the "gold" cases (if any)
    println("");
    println("Loading file with correct labels. ");
    String[] lines_correct = Helper.readFile(correctfile).split("\n");
    println("File contained %d entries.", lines_correct.length);
    Set<CorrectLabel> correct = loadCorrectLabels(lines_correct);
    return correct;
  }

  /**
   * @param evalfile
   * @return
   */
  private Set<CorrectLabel> loadEvaluationLabels(String evalfile) {

    // We load the "gold" cases (if any)
    println("");
    println("Loading file with evaluation labels. ");
    String[] lines_correct = Helper.readFile(evalfile).split("\n");
    println("File contained %d entries.", lines_correct.length);
    Set<CorrectLabel> correct = loadEvaluationLabels(lines_correct);
    return correct;
  }

  public Set<AssignedLabel> loadAssignedLabels(String[] lines) {

    Set<AssignedLabel> labels = new HashSet<AssignedLabel>();
    int cnt = 1;
    for (String line : lines) {
      String[] entries = line.split("\t");
      if (entries.length != 3) {
        throw new IllegalArgumentException(
            "Error while loading from assigned labels file (line #"
                + cnt + "): " + line);
      }
      cnt++;

      String workername = entries[0];
      String objectname = entries[1];
      String categoryname = entries[2];

      AssignedLabel al = new AssignedLabel(workername, objectname,
          categoryname);
      labels.add(al);
    }
    return labels;
  }

  public Set<Category> loadCategories(String[] lines) {

    Set<Category> categories = new HashSet<Category>();
    for (String line : lines) {
      // First we check if we have fixed priors or not
      // If we have fixed priors, we have a TAB character
      // after the name of each category, followed by the prior value
      String[] l = line.split("\t");
      if (l.length == 1) {
        Category c = new Category(line);
        categories.add(c);
      } else if (l.length == 2) {
        String name = l[0];
        Double prior = new Double(l[1]);
        Category c = new Category(name);
        c.setPrior(prior);
        categories.add(c);
      }
    }
    return categories;
  }

  public Set<MisclassificationCost> loadClassificationCost(String[] lines) {

    Set<MisclassificationCost> labels = new HashSet<MisclassificationCost>();
    int cnt = 1;
    for (String line : lines) {
      String[] entries = line.split("\t");
      if (entries.length != 3) {
        throw new IllegalArgumentException(
            "Error while loading from assigned labels file (line "
                + cnt + "):" + line);
      }
      cnt++;

      String from = entries[0];
      String to = entries[1];
      Double cost = Double.parseDouble(entries[2]);

      MisclassificationCost mcc = new MisclassificationCost(from, to,
          cost);
      labels.add(mcc);
    }
    return labels;
  }

  public Set<CorrectLabel> loadCorrectLabels(String[] lines) {

    Set<CorrectLabel> labels = new HashSet<CorrectLabel>();
    int cnt = 1;
    for (String line : lines) {
      String[] entries = line.split("\t");
      if (entries.length != 2) {
        throw new IllegalArgumentException(
            "Error while loading from correct labels file (line "
                + cnt + "):" + line);
      }
      cnt++;

      String objectname = entries[0];
      String categoryname = entries[1];

      CorrectLabel cl = new CorrectLabel(objectname, categoryname);
      labels.add(cl);
    }
    return labels;
  }

  public Set<CorrectLabel> loadEvaluationLabels(String[] lines) {

    Set<CorrectLabel> labels = new HashSet<CorrectLabel>();
    for (String line : lines) {
      String[] entries = line.split("\t");
      if (entries.length != 2) {
        // evaluation file is optional
        break;
      }

      String objectname = entries[0];
      String categoryname = entries[1];

      CorrectLabel cl = new CorrectLabel(objectname, categoryname);
      labels.add(cl);
    }
    return labels;
  }

  /**
   * @param inputfile
   * @return
   */
  private Set<AssignedLabel> loadWorkerAssignedLabels(String inputfile) {

    // We load the labels assigned by the workers on the different objects
    println("");
    println("Loading file with assigned labels. ");
    String[] lines_input = Helper.readFile(inputfile).split("\n");
    println("File contains " + lines_input.length + " entries.");
    Set<AssignedLabel> labels = loadAssignedLabels(lines_input);
    return labels;
  }

  /**
   * @param costfile
   * @return
   */
  private Set<MisclassificationCost> loadCosts(String costfile) {

    // We load the cost file. The file should have exactly n^2 lines
    // where n is the number of categories.
    println("");
    println("Loading cost file.");
    String[] lines_cost = Helper.readFile(costfile).split("\n");
    // assert (lines_cost.length == categories.size() * categories.size());
    println("File contains " + lines_cost.length + " entries.");
    Set<MisclassificationCost> costs = loadClassificationCost(lines_cost);
    return costs;
  }

  /**
   * @param categoriesfile
   * @return
   */
  private Set<Category> loadCategories(String categoriesfile) {
    println("");
    println("Loading categories file.");
    String[] lines_categories = Helper.readFile(categoriesfile).split("\n");
    println("File contains " + lines_categories.length + " categories.");
    Set<Category> categories = loadCategories(lines_categories);
    return categories;
  }

  public void println(String mask, Object... args) {
    print(mask + "\n", args);
  }

  public void print(String mask, Object... args) {
    if (!ctx.isVerbose())
      return;

    String message;

    if (args.length > 0) {
      message = String.format(mask, args);
    } else {
      // without format arguments, print the mask/string as-is
      message = mask;
    }

    System.out.println(message);
  }
}
TOP

Related Classes of com.ipeirotis.gal.engine.Engine

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.