Package uk.ac.cam.ha293.tweetlabel.topics

Source Code of uk.ac.cam.ha293.tweetlabel.topics.LightweightLLDA

package uk.ac.cam.ha293.tweetlabel.topics;

import java.io.FileInputStream;

import javax.jdo.annotations.IdGeneratorStrategy;
import javax.jdo.annotations.Key;
import javax.jdo.annotations.Persistent;
import javax.jdo.annotations.PrimaryKey;

import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

import uk.ac.cam.ha293.tweetlabel.twitter.SimpleProfile;
import uk.ac.cam.ha293.tweetlabel.twitter.SimpleTweet;
import uk.ac.cam.ha293.tweetlabel.types.Document;
import uk.ac.cam.ha293.tweetlabel.util.Tools;
import javax.jdo.annotations.PersistenceCapable;

@PersistenceCapable
public class LightweightLLDA implements Serializable {
 
  private static final long serialVersionUID = -7178136504829779391L;
  @PrimaryKey
    @Persistent(valueStrategy = IdGeneratorStrategy.IDENTITY)
    private Key key;
  transient private Map<String, Integer> wordIDs;
  @Persistent
  private ArrayList<String> topicLookup;
  @Persistent
  private int numTopics;
  @Persistent
  private int numWords;
  @Persistent
  private int[][] wordTopicAssignments;
  @Persistent
  private int[] numWordsAssignedToTopic;
 
  public LightweightLLDA(Map<String,Integer> wordIDs, ArrayList<String> topicLookup, int numTopics, int numWords, int[][] wordTopicAssignments, int[] numWordsAssignedToTopic) {
    this.wordIDs = wordIDs;
    this.topicLookup = topicLookup;
    this.numTopics = numTopics;
    this.numWords = numWords;
    this.wordTopicAssignments = wordTopicAssignments;
    this.numWordsAssignedToTopic = numWordsAssignedToTopic;
  }

  public Map<String,Double> inferTopicDistribution(SimpleProfile sp, int burnIn, int sampling, double alpha, double beta) {
      //Get FV from SP
      Document d = sp.asDocument();
    String[] tokens = d.getDocumentString().split("\\s+");
    ArrayList<Integer> fv = new ArrayList<Integer>();

    for(int i=0; i<tokens.length; i++) {
      if(wordIDs.containsKey(tokens[i])) {
        fv.add(wordIDs.get(tokens[i]));
      }
    }
   
    //Run Gibbs Sampler again
    int[] z = new int[fv.size()];
    int[] zCounts = new int[numTopics];
    for(int n=0; n<z.length; n++) {
      //Random topic assignments
      z[n] = (int)(Math.random()*numTopics);
      zCounts[z[n]]++;
    }
    double[] thetam = new double[numTopics];
    for(int i=0; i<burnIn + sampling; i++) {
      System.out.print(".");
      for(int n=0; n<fv.size(); n++) {
        numWordsAssignedToTopic[z[n]]--;
        wordTopicAssignments[fv.get(n)][z[n]]--;
        //Cumulative multinomial sampling
        double[] p = new double[numTopics];
        for(int k=0; k<numTopics; k++) {
            p[k] = (wordTopicAssignments[fv.get(n)][k] + beta) / (numWordsAssignedToTopic[k] + numWords * beta) * (zCounts[k] + alpha) / (fv.size() + numTopics * alpha);
        }
       
        //Cumulative part - could be combined into above part cleverly
        for(int k=1; k<numTopics; k++) {
          p[k] += p[k-1];
        }
       
        //Sampling part - scaled because we haven't normalised
        double topicThreshold = Math.random() * p[numTopics-1];
        int sampledTopicID = 0;
        for(sampledTopicID=0; sampledTopicID<numTopics; sampledTopicID++) {
          if(topicThreshold < p[sampledTopicID]) {
            break;
          }
        }
       
        //Maybe a fix needed by faulty double arithmetic
        if(sampledTopicID >= numTopics) {
          sampledTopicID = numTopics-1;
        }
       
        zCounts[z[n]]--;
        z[n] = sampledTopicID;
        numWordsAssignedToTopic[z[n]]++;
        wordTopicAssignments[fv.get(n)][z[n]]++;
        zCounts[z[n]]++;
      }
     
      if(i >= burnIn) {
        for(int topicID=0; topicID<numTopics; topicID++) {
          thetam[topicID] += (zCounts[topicID] + alpha) / (fv.size() + numTopics * alpha);
        }
      }
    }
    System.out.println();
    //normalise theta and store
    Map<String,Double> results = new HashMap<String,Double>();
    for(int k=0; k<numTopics; k++) {
      thetam[k] /= sampling;
      results.put(topicLookup.get(k),thetam[k]);
    }
   
    return Tools.sortMapByValueDesc(results);
    }
 
    public void save(String name) {
    try {
      String filename = "models/lightweight/"+name+".model";
      FileOutputStream fileOut = new FileOutputStream(filename);
      ObjectOutputStream objectOut = new ObjectOutputStream(fileOut);
      objectOut.writeObject(this);
      objectOut.close();
    } catch (FileNotFoundException e) {
      System.out.println("Couldn't save LLDA topic model "+name);
      e.printStackTrace();
    } catch (IOException e) {
      System.out.println("Couldn't save LLDA topic model "+name);
      e.printStackTrace();     
    }     
    }
   
    public static LightweightLLDA load(String path) {
    try {
      String filename = path;
      FileInputStream fileIn = new FileInputStream(filename);
      ObjectInputStream objectIn = new ObjectInputStream(fileIn);
      LightweightLLDA model = (LightweightLLDA)objectIn.readObject();
      objectIn.close();
      System.out.println("Loaded LLDA topic model "+path);
      return model;
    } catch (FileNotFoundException e) {
      System.out.println("Couldn't load LLDA topic model "+path);
      e.printStackTrace();
    } catch (IOException e) {
      System.out.println("Couldn't load LLDA topic model "+path);
      e.printStackTrace();     
    } catch (ClassNotFoundException e) {
      System.out.println("Couldn't load LLDA topic model "+path);
      e.printStackTrace();     
    }
    return null;
    }
   
    public void loadWordIDs(String path) {
    try {
      String filename = path;
      FileInputStream fileIn = new FileInputStream(filename);
      ObjectInputStream objectIn = new ObjectInputStream(fileIn);
      wordIDs = (Map<String,Integer>)objectIn.readObject();
      objectIn.close();
      System.out.println("Loaded wordIDs");
    } catch (FileNotFoundException e) {
      System.out.println("Couldn't load wordIDS");
      e.printStackTrace();
    } catch (IOException e) {
      System.out.println("Couldn't load wordIDS");
      e.printStackTrace();     
    } catch (ClassNotFoundException e) {
      System.out.println("Couldn't load wordIDS");
      e.printStackTrace();     
    }
    }
   
    public void printStats() {
      System.out.println("num words = "+wordIDs.size()+" and "+numWords);
    System.out.println("num topics = "+topicLookup.size()+" and "+numTopics);
    }
   
    public void saveWordIDs(String name) {
    try {
      String filename = "models/lightweight/"+name+".wordids";
      FileOutputStream fileOut = new FileOutputStream(filename);
      ObjectOutputStream objectOut = new ObjectOutputStream(fileOut);
      objectOut.writeObject(wordIDs);
      objectOut.close();
    } catch (FileNotFoundException e) {
      System.out.println("Couldn't save wordIDs");
      e.printStackTrace();
    } catch (IOException e) {
      System.out.println("Couldn't save wordIDs");
      e.printStackTrace();     
    }     
    }
   
    public void helpMe() {
      System.out.print("{");
      for(int n=0; n<wordTopicAssignments.length; n++) {
        System.out.print("{");
        for(int k=0; k<wordTopicAssignments[n].length; k++) {
          System.out.print(wordTopicAssignments[n][k]);
          if(k < wordTopicAssignments[n].length-1) System.out.print(",");
        }
        System.out.print("}");
        if(n < wordTopicAssignments.length - 1) System.out.print(",");
      }
    }
   
    public void saveCSVComponents() {
      //save wordTopicAssignments
    try {
        //save wordTopicAssignments
      {String filename = "calaisWordTopicAssignments.csv";
      FileOutputStream fileOut = new FileOutputStream(filename);
      PrintWriter writeOut = new PrintWriter(fileOut);
      for(int v = 0; v<wordTopicAssignments.length; v++) {
        for(int k=0; k<wordTopicAssignments[v].length; k++) {
          writeOut.print(wordTopicAssignments[v][k]);
          if(k<wordTopicAssignments[v].length-1) writeOut.print(",");
        }
        writeOut.print("\n");
      }
      writeOut.close();
      fileOut.close();}
     
      //save wordIDs
      wordIDs = Tools.sortMapByValue(wordIDs);
      {String filename = "alchemyWordIDs.csv";
      FileOutputStream fileOut = new FileOutputStream(filename);
      PrintWriter writeOut = new PrintWriter(fileOut);
      for(String word : wordIDs.keySet()) {
        writeOut.print(wordIDs.get(word)+","+word+"\n");
      }
      writeOut.close();
      fileOut.close();}
     
      //save wordIDs
      {String filename = "alchemyTopics.csv";
      FileOutputStream fileOut = new FileOutputStream(filename);
      PrintWriter writeOut = new PrintWriter(fileOut);
      for(int i=0; i<topicLookup.size(); i++) {
        writeOut.print(i+","+topicLookup.get(i)+"\n");
      }
      writeOut.close();
      fileOut.close();}
     
      //save numWordsAssignedToTopic
      {String filename = "alchemyNumWordsAssignedToTopic.csv";
      FileOutputStream fileOut = new FileOutputStream(filename);
      PrintWriter writeOut = new PrintWriter(fileOut);
      for(int i=0; i<numWordsAssignedToTopic.length; i++) {
        writeOut.print(i+","+numWordsAssignedToTopic[i]+"\n");
      }
      writeOut.close();
      fileOut.close();}   
     
    } catch (IOException e) {
      e.printStackTrace();
    }
    }
   
    public void readOut() {
      System.out.println(numWords+","+numTopics);
    }

}
TOP

Related Classes of uk.ac.cam.ha293.tweetlabel.topics.LightweightLLDA

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.