Package edu.isi.karma.cleaning

Source Code of edu.isi.karma.cleaning.ProgSynthesis

package edu.isi.karma.cleaning;

import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Vector;

import org.perf4j.StopWatch;
import org.perf4j.log4j.Log4JStopWatch;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import edu.isi.karma.cleaning.Research.ConfigParameters;
import edu.isi.karma.cleaning.features.Feature;
import edu.isi.karma.cleaning.features.RecordClassifier;
import edu.isi.karma.cleaning.features.RecordFeatureSet;

public class ProgSynthesis {
  public static int time_limit = 20;
  Vector<Vector<TNode>> orgVector = new Vector<Vector<TNode>>();
  Vector<Vector<TNode>> tarVector = new Vector<Vector<TNode>>();
  String bestRuleString = "";
  // for tracking the stats
  public long learnspan = 0;
  public long genspan = 0;
  public long ruleNo = 0;
  public HashMap<String, double[]> string2Vector = new HashMap<String, double[]>();
  public PartitionClassifierType classifier;
  public RecordFeatureSet featureSet = null;
  public Vector<Vector<String[]>> constraints = new Vector<Vector<String[]>>();
  public HashMap<String, Boolean> legalParitions = new HashMap<String, Boolean>();
  public MyLogger logger = new MyLogger();
  private static Logger ulogger = LoggerFactory
      .getLogger(ProgSynthesis.class);
  public Program myprog;
  public OptimizePartition partitioner; // A* search for a consistent program
  public ExampleCluster partiCluster;
  public DataPreProcessor dataPreProcessor;
  public Messager msGer;

  public ProgSynthesis() {
  }

  public void inite(Vector<String[]> examples, String[] vocb) {
    for (int i = 0; i < examples.size(); i++) {
      Ruler r = new Ruler();
      r.setNewInput(examples.get(i)[0]);
      orgVector.add(r.vec);
      Ruler r1 = new Ruler();
      r1.setNewInput(examples.get(i)[1]);
      tarVector.add(r1.vec);
    }
    RecordFeatureSet rSet = new RecordFeatureSet();
    rSet.addVocabulary(vocb);
    this.featureSet = rSet;
    RecordClassifier r2 = new RecordClassifier(rSet);
    this.classifier = r2;
    partitioner = new OptimizePartition(this, r2);
  }
 
  public void inite(Vector<String[]> examples, DataPreProcessor dp, Vector<Vector<String[]>> constraints) {
    for (int i = 0; i < examples.size(); i++) {
      Ruler r = new Ruler();
      r.setNewInput(examples.get(i)[0]);
      orgVector.add(r.vec);
      Ruler r1 = new Ruler();
      r1.setNewInput(examples.get(i)[1]);
      tarVector.add(r1.vec);
    }
    featureSet = dp.rfs;
    RecordClassifier r2 = new RecordClassifier(featureSet);
    this.classifier = r2;
    this.constraints = constraints;
    string2Vector = dp.getStandardData();
  }
  public void inite(Vector<String[]> examples, DataPreProcessor dp, Messager msger) {
    for (int i = 0; i < examples.size(); i++) {
      Ruler r = new Ruler();
      r.setNewInput(examples.get(i)[0]);
      orgVector.add(r.vec);
      Ruler r1 = new Ruler();
      r1.setNewInput(examples.get(i)[1]);
      tarVector.add(r1.vec);
    }
    featureSet = dp.rfs;
    RecordClassifier r2 = new RecordClassifier(featureSet);
    this.classifier = r2;
    string2Vector = dp.getStandardData();
    this.dataPreProcessor = dp;
    this.msGer = msger;
  }
  public Vector<Vector<Integer>> generateCrossIndex(Vector<Integer> poss,
      Vector<Vector<Integer>> p, int index) {
    Vector<Vector<Integer>> qVector = new Vector<Vector<Integer>>();
    if (index >= poss.size()) {
      return p;
    }
    int curleng = poss.get(index);
    if (p.size() == 0) {
      for (int i = 0; i < curleng; i++) {
        Vector<Integer> x = new Vector<Integer>();
        x.add(i);
        qVector.add(x);
      }
    } else {
      for (int j = 0; j < p.size(); j++) {
        for (int i = 0; i < curleng; i++) {
          @SuppressWarnings("unchecked")
          Vector<Integer> x = (Vector<Integer>) p.get(j).clone();
          x.add(i);
          qVector.add(x);
        }
      }
    }
    qVector = generateCrossIndex(poss, qVector, index + 1);
    return qVector;
  }

  public double getCompScore(int i, int j, Vector<Partition> pars) {
    Partition p = pars.get(i).mergewith(pars.get(j));
    if (p == null || !isLegalPartition(p)) {
      return -Double.MAX_VALUE;
    }
    int validCnt = 0;
    for (int x = 0; x < pars.size(); x++) {
      if (x == i || x == j) {
        continue;
      }
      Partition q = p.mergewith(pars.get(x));
      if (q != null && isLegalPartition(p)) {
        validCnt++;
      }
    }
    return validCnt;
  }

  public Vector<Partition> initePartitions() {
    Vector<Partition> pars = new Vector<Partition>();
    // inite partition for each example
    for (int i = 0; i < orgVector.size(); i++) {
      Vector<Vector<TNode>> ovt = new Vector<Vector<TNode>>();
      Vector<Vector<TNode>> tvt = new Vector<Vector<TNode>>();
      ovt.add(this.orgVector.get(i));
      tvt.add(this.tarVector.get(i));
      Partition pt = new Partition(ovt, tvt);
      pars.add(pt);
    }
    return pars;
  }

  public boolean isLegalPartition(Partition p) {
    String key = p.getHashKey();
    if (legalParitions.containsKey(key)) {
      return legalParitions.get(key);
    }
    // test whether its subset fails
    for (String k : legalParitions.keySet()) {
      if (!legalParitions.get(k)) // false
      {
        if (key.indexOf(k) != -1) {
          return false;
        }
      }
    }
    Vector<Partition> xPar = new Vector<Partition>();
    xPar.add(p);
    Collection<ProgramRule> cpr = this.producePrograms(xPar);
    if (cpr == null || cpr.size() == 0) {
      legalParitions.put(key, false);
      return false;
    } else {
      legalParitions.put(key, true);
      return true;
    }
  }

  public void mergePartitions(Vector<Partition> pars) {
    double maxScore = 0;
    int[] pos = { -1, -1 };
    for (int i = 0; i < pars.size(); i++) {
      for (int j = i + 1; j < pars.size(); j++) {
        double s = getCompScore(i, j, pars);
        if (s < 0) {
          continue;
        }
        if (s >= maxScore) {
          pos[0] = i;
          pos[1] = j;
          maxScore = s;
        }
      }
    }
    if (pos[0] != -1 && pos[1] != -1)
      UpdatePartitions(pos[0], pos[1], pars);
  }

  public void UpdatePartitions(int i, int j, Vector<Partition> pars) {
    Partition p = pars.get(i).mergewith(pars.get(j));
    pars.set(i, p);
    pars.remove(j);
  }

  public String getBestRule() {
    return this.bestRuleString;
  }

  public double[] getFeatureArray(String s) {
    Collection<Feature> cfeat = featureSet.computeFeatures(s, "");
    Feature[] x = cfeat.toArray(new Feature[cfeat.size()]);
    double[] res = new double[x.length];
    for (int i = 0; i < x.length; i++) {
      res[i] = x[i].getScore();
    }
    return res;
  }

  public Vector<Partition> ProducePartitions(boolean condense) {
    Vector<Partition> pars = this.initePartitions();
    partiCluster = new ExampleCluster(this, pars, string2Vector);
    partiCluster.updateConstraints(msGer.getConstraints());
    partiCluster.updateWeights(msGer.weights);
    if (condense) {
      pars = partiCluster.cluster_weigthEuclidean(pars);
    }
    this.ruleNo = pars.size();
    /*
     * int size = pars.size(); while (condense) {
     * this.mergePartitions(pars); if (size == pars.size()) { break; } else
     * { size = pars.size(); } }
     */
    return pars;
  }
 
  public Collection<ProgramRule> producePrograms(Vector<Partition> pars) {
    Program prog = new Program(pars, this.classifier,this.dataPreProcessor);
    this.myprog = prog;
    HashSet<ProgramRule> rules = new HashSet<ProgramRule>();
    int prog_cnt = 1;
    int i = 0;
    while (i < prog_cnt) {
      ProgramRule r = prog.toProgram1();
      if (r == null)
        return null;
      String xString = "";
      boolean findRule = true;
      while ((xString = this.validRule(r, pars)) != "GOOD" && findRule) {
        if (xString.compareTo("NO_CLASIF") == 0) {
          return null; // indistinguishable classes.
        }
        for (Partition p : prog.partitions) {
          if (p.label.compareTo(xString) == 0) {
            String newRule = p.toProgram();
            if (ConfigParameters.debug == 1)
              System.out.println("updated Rule: " + p.label
                  + ": " + newRule);
            if (newRule.contains("null")) {
              findRule = false;
              break;
            }
            r.updateClassworker(xString, newRule);
          }
        }
      }
      if (findRule)
      {
        rules.add(r);
      }
      i++;
    }
    return rules;
  }
  public Collection<ProgramRule> adaptive_main()
  {
    StopWatch stopWatch0 = new Log4JStopWatch("adaptive_main");
    long t1 = System.currentTimeMillis();
    StopWatch stopWatch = new Log4JStopWatch("adaptive_producePartition");
    Vector<Partition> par = this.adaptive_producePartition();
    stopWatch.stop();
    StopWatch stopWatch1 = new Log4JStopWatch("adaptive_produceProgram");
    Collection<ProgramRule> cpr = this.adaptive_produceProgram(par);
    stopWatch1.stop();
    Traces.AllSegs.clear();
    //record the learning time
    this.learnspan = System.currentTimeMillis()-t1;
    stopWatch0.stop();
    return cpr;
  }
  public Collection<ProgramRule> adaptive_produceProgram(Vector<Partition> pars)
  {
   
    Program prog = new Program(pars, this.classifier,this.dataPreProcessor);
    this.myprog = prog;
    HashSet<ProgramRule> rules = new HashSet<ProgramRule>();
    int prog_cnt = 1;
    int i = 0;
    while (i < prog_cnt) {
      ProgramRule r = prog.toProgram2(msGer);
      if (r == null)
        return null;
      String xString = "";
      boolean findRule = true;
      while ((xString = this.validRule(r, pars)) != "GOOD" && findRule) {
        if (xString.compareTo("NO_CLASIF") == 0) {
          return null; // indistinguishable classes.
        }
        for (Partition p : prog.partitions) {
          if (p.label.compareTo(xString) == 0) {
            String newRule = p.toProgram();
            if (ConfigParameters.debug == 1)
              System.out.println("updated Rule: " + p.label
                  + ": " + newRule);
            if (newRule.contains("null")) {
              findRule = false;
              break;
            }
            r.updateClassworker(xString, newRule);
          }
        }
      }
      if (findRule)
      {
        rules.add(r);
      }
      i++;
    }
    return rules;
  }
  public Vector<Partition> adaptive_producePartition()
  { 
    Vector<Partition> pars = this.initePartitions();
    partiCluster = new ExampleCluster(this, pars, string2Vector);
    partiCluster.updateConstraints(msGer.getConstraints());
    partiCluster.updateWeights(msGer.weights);
    //update the program space and hypothesis space
    pars = partiCluster.adaptive_cluster_weightEuclidean(pars);
    this.ruleNo = pars.size();
    return pars;
  }
  public Collection<ProgramRule> run_main() {
    long t1 = System.currentTimeMillis();
    StopWatch stopWatch0 = new Log4JStopWatch("main");
    StopWatch stopWatch = new Log4JStopWatch("producePartition");
    Vector<Partition> vp = this.ProducePartitions(true);
    stopWatch.stop();
    StopWatch stopWatch1 = new Log4JStopWatch("producePrograms");
    Collection<ProgramRule> cpr = this.producePrograms(vp);
    stopWatch1.stop();
    Traces.AllSegs.clear();
    stopWatch0.stop();
    this.learnspan = System.currentTimeMillis()-t1;
    return cpr;
  }

  public String validRule(ProgramRule p, Vector<Partition> vp) {
    for (Partition px : vp) {
      for (int i = 0; i < px.orgNodes.size(); i++) {
        String s1 = UtilTools.print(px.orgNodes.get(i));
        String labelString = p.getClassForValue(s1);
        if (labelString.compareTo(px.label) != 0) {
          //logger.error(vp.toString());
          ulogger.error("classification error on examples " + s1
              + "as: " + labelString);
          // return "NO_CLASIF";
        }
        InterpreterType worker = p.getWorkerForClass(px.label);
        String s2 = "";
        try {
          s2 = new String(worker.execute(s1).getBytes(), "UTF-8");
        } catch (Exception e) {
          return "DECODE_ERROR";
        }
        String s3 = UtilTools.print(px.tarNodes.get(i));
        if (s3.compareTo(s2) != 0) {
          return labelString;
        }
      }
    }
    return "GOOD";
  }
}
TOP

Related Classes of edu.isi.karma.cleaning.ProgSynthesis

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.