Package uk.ac.cam.ha293.tweetlabel.eval

Source Code of uk.ac.cam.ha293.tweetlabel.eval.Diversity

package uk.ac.cam.ha293.tweetlabel.eval;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.util.Set;
import java.util.HashSet;
import java.util.Map;
import java.util.HashMap;

import uk.ac.cam.ha293.tweetlabel.classify.FullAlchemyClassification;
import uk.ac.cam.ha293.tweetlabel.classify.FullCalaisClassification;
import uk.ac.cam.ha293.tweetlabel.topics.FullLDAClassification;
import uk.ac.cam.ha293.tweetlabel.topics.FullLLDAClassification;
import uk.ac.cam.ha293.tweetlabel.util.Tools;

public class Diversity {
 
  private static Set<Double> diversitySet(String topicType, long uid) {
    Set<Double> valueSet = new HashSet<Double>();
    if(topicType.equals("alchemy")) {
      FullAlchemyClassification c = new FullAlchemyClassification(uid);
      for(String cat : c.getCategorySet()) {
        valueSet.add(c.getScore(cat));
      }
    } else if(topicType.equals("calais")) {
      FullCalaisClassification c = new FullCalaisClassification(uid);
      for(String cat : c.getCategorySet()) {
        valueSet.add(c.getScore(cat));
      }
    } else if(topicType.equals("textwise")) {
      FullCalaisClassification c = new FullCalaisClassification(uid);
      for(String cat : c.getCategorySet()) {
        valueSet.add(c.getScore(cat));
      }
    }
    return valueSet;
  }
 
  public static Set<Double> diversitySet(String topicType, double alpha, long uid) {
    Set<Double> valueSet = new HashSet<Double>();
    if(topicType.equals("lda")) {
      FullLDAClassification c = new FullLDAClassification(uid,1000,100,0,alpha);
      for(String cat : c.getCategorySet()) {
        valueSet.add(c.getScore(cat));
      }
    } else if(topicType.equals("alchemy")) {
      FullLLDAClassification c = new FullLLDAClassification("alchemy",alpha,uid);
      for(String cat : c.getCategorySet()) {
        valueSet.add(c.getScore(cat));
      }
    } else if(topicType.equals("calais")) {
      FullLLDAClassification c = new FullLLDAClassification("calais",alpha,uid);
      for(String cat : c.getCategorySet()) {
        valueSet.add(c.getScore(cat));
      }
    } else if(topicType.equals("textwise")) {
      FullLLDAClassification c = new FullLLDAClassification("textwise",alpha,uid);
      for(String cat : c.getCategorySet()) {
        valueSet.add(c.getScore(cat));
      }
    }
    return valueSet;
  }
 
  public static double simpson(String topicType, long uid) {
    return simpson(diversitySet(topicType, uid));
  }
 
  public static double simpson(String topicType, double alpha, long uid) {
    return simpson(diversitySet(topicType,alpha,uid));
  }
 
  public static double simpson(Set<Double> values) {
    double result = 0.0;
    for(Double value : values) {
      result += value*value;
    }
    return result;
  }
 
  public static double shannon(String topicType, long uid) {
    return shannon(diversitySet(topicType, uid));
  }
 
  public static double shannon(String topicType, double alpha, long uid) {
    return shannon(diversitySet(topicType,alpha,uid));
  }
 
  public static double shannon(Set<Double> values) {
    double sum = 0.0;
    for(Double value : values) {
      sum += value * Math.log(value);
    }
    return -1.0*sum;
  }
 
  //saves baseline and inferred API - no LDA
  public static void saveDiversities() {
    String[] topicTypes = {"alchemy","calais","textwise"};
    double[] alphas = {0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0};
    try {
      for(String topicType : topicTypes) {
        System.out.println("Diversities for "+topicType);
        //get the baselines
        String fileName = "diversities/baseline-"+topicType;
        FileOutputStream simpsonFileOut = new FileOutputStream(fileName+"-simpson.csv");
        FileOutputStream shannonFileOut = new FileOutputStream(fileName+"-shannon.csv");
        PrintWriter simpsonWriteOut = new PrintWriter(simpsonFileOut);
        PrintWriter shannonWriteOut = new PrintWriter(shannonFileOut);
        simpsonWriteOut.println("\"uid\",\"diversity\"");
        for(long uid : Tools.getCSVUserIDs()) {
          simpsonWriteOut.println(uid+","+Diversity.simpson(topicType, uid));
        }
        simpsonWriteOut.close();
        shannonWriteOut.println("\"uid\",\"diversity\"");
        for(long uid : Tools.getCSVUserIDs()) {
          shannonWriteOut.println(uid+","+Diversity.shannon(topicType, uid));
        }
        shannonWriteOut.close();
        for(double alpha : alphas) {
          System.out.println("Simpson Diversities for alpha "+alpha);
          fileName = "diversities/llda-"+topicType+"-"+alpha;
          FileOutputStream lldaSimpsonFileOut = new FileOutputStream(fileName+"-simpson.csv");
          PrintWriter lldaSimpsonWriteOut = new PrintWriter(lldaSimpsonFileOut);
          lldaSimpsonWriteOut.println("\"uid\",\"diversity\"");
          for(long uid : Tools.getCSVUserIDs()) {
            lldaSimpsonWriteOut.println(uid+","+Diversity.simpson(topicType, alpha, uid));
          }
          lldaSimpsonWriteOut.close();
        }
        for(double alpha : alphas) {
          System.out.println("Shannon Diversities for alpha "+alpha);
          fileName = "diversities/llda-"+topicType+"-"+alpha;
          FileOutputStream lldaShannonFileOut = new FileOutputStream(fileName+"-shannon.csv");
          PrintWriter lldaShannonWriteOut = new PrintWriter(lldaShannonFileOut);
          lldaShannonWriteOut.println("\"uid\",\"diversity\"");
          for(long uid : Tools.getCSVUserIDs()) {
            lldaShannonWriteOut.println(uid+","+Diversity.shannon(topicType, alpha, uid));
          }
          lldaShannonWriteOut.close();
        }
      }
    } catch (IOException e) {
      e.printStackTrace();
    }
  }
 
  public static Map<Long,Double> loadDiversities(String file) {
    Map<Long,Double> result = new HashMap<Long,Double>();
    String nextLine = "";
    String[] split = new String[2];
    try {
      FileInputStream fileIn = new FileInputStream(file+".csv");
      BufferedReader buffer = new BufferedReader(new InputStreamReader(fileIn));
      buffer.readLine(); //skip past the CSV descriptor
      while(true) {
        nextLine = buffer.readLine();
        //If nextLine is null, we still have to save the final profile!
        if(nextLine == null) {
          break;
        }
        split = nextLine.split(",");
        result.put(Long.parseLong(split[0]), Double.parseDouble(split[1]));
      }
    } catch (Exception e) {
      System.out.println(nextLine);
      System.out.println(split[0]);
      e.printStackTrace();
    }
    return result;
  }
 
  public static Map<Long,Double> loadDiversities(String topicType, String diversityType) {
    String file = "diversities/baseline-"+topicType+"-"+diversityType;
    return loadDiversities(file);
  }
 
  public static Map<Long,Double> loadDiversities(String topicType, double alpha, String diversityType) {
    String file = "diversities/llda-"+topicType+"-"+alpha+"-"+diversityType;
    return loadDiversities(file);
  }

}
TOP

Related Classes of uk.ac.cam.ha293.tweetlabel.eval.Diversity

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.