Package joshua.decoder.ff.lm.distributed_lm

Source Code of joshua.decoder.ff.lm.distributed_lm.LMServer

/* This file is part of the Joshua Machine Translation System.
*
* Joshua is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1
* of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*/
package joshua.decoder.ff.lm.distributed_lm;

import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SrilmSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.ff.lm.NGramLanguageModel;
import joshua.decoder.ff.lm.buildin_lm.LMGrammarJAVA;
import joshua.decoder.ff.lm.srilm.LMGrammarSRILM;
import joshua.util.io.LineReader;
import joshua.util.Regex;

import java.io.IOException;
//import java.net.InetAddress;
import java.net.ServerSocket;
import java.net.Socket;
//import java.net.UnknownHostException;
import java.util.HashMap;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;

/**
* this class implement
* (1) load lm file
* (2) listen to connection request
* (3) serve request for LM probablity
*
* @author Zhifei Li, <zhifei.work@gmail.com>
* @version $LastChangedDate: 2009-05-19 20:43:54 -0500 (Tue, 19 May 2009) $
*/
public class LMServer {
 
  private static final Logger logger = Logger.getLogger(LMServer.class.getName());
 
  //common options
  public static int port = 9800;
  static boolean use_srilm = true;
  public static boolean use_left_euqivalent_state = false;
  public static boolean use_right_euqivalent_state = false;
  static int g_lm_order = 3;
  static double lm_ceiling_cost = 100;//TODO: make sure LMGrammar is using this number
  static String remote_symbol_tbl = null;
 
  //lm specific
  static String lm_file              = null;
  static Double interpolation_weight = null;//the interpolation weight of this lm
  static String g_host_name          = null;
 
  //pointer
  static NGramLanguageModel p_lm;
  static HashMap<String,String> request_cache = new HashMap<String,String>();//cmd with result
  static int cache_size_limit = 3000000;
 
  //  stat
  static int g_n_request   = 0;
  static int g_n_cache_hit = 0;
 
  static SymbolTable p_symbolTable;
 
 
  public static void main(String[] args) throws IOException {
    if (args.length != 1) {
      System.err.println("Usage: java LMServer config_file");
     
      if (logger.isLoggable(Level.FINE)) {
        logger.fine("num of args is "+ args.length);
        for (int i = 0; i < args.length; i++) {
          logger.fine("arg is: " + args[i]);
        }
      }
     
      System.exit(1);
    }
    String config_file = args[0].trim();
    read_config_file(config_file);
   
    ServerSocket serverSocket = null;
    LMServer server = new LMServer();
   
    //p_lm.write_vocab_map_srilm(remote_symbol_tbl);
    //####write host infomation
    //String hostname=LMServer.findHostName();//this one is not stable, sometimes throw exception
    //String hostname="unknown";
   
   
    //### begin loop
    try {
      serverSocket = new ServerSocket(port);
      if (null == serverSocket) {
        throw new IOException("server socket is null");
      }
      init_lm_grammar();
     
      logger.info("finished lm reading, wait for connection");
     
      // serverSocket = new ServerSocket(0);//0 means any free port
      // port = serverSocket.getLocalPort();
      while (true) {
        Socket socket = serverSocket.accept();
        logger.info("accept a connection from client");
        ClientHandler handler = new ClientHandler(socket,server);
        handler.start();
      }
    } catch (IOException ioe) {
      logger.severe("cannot create serversocket at port or connection fail");
      ioe.printStackTrace();
    } finally {
      try {
        if (null != serverSocket) serverSocket.close();
      } catch(IOException ioe) {
        ioe.printStackTrace();
      }
    }
  }
 
 
  // BUG: duplicates initializeLanguageModel and initializeSymbolTable in JoshuaDecoder, needs unifying
  public static void init_lm_grammar() throws IOException {
    if (use_srilm) {
      if (use_left_euqivalent_state || use_right_euqivalent_state) {
        throw new IllegalArgumentException("when using local srilm, we cannot use suffix stuff");
      }
      p_symbolTable = new SrilmSymbol(remote_symbol_tbl, g_lm_order);
      p_lm = new LMGrammarSRILM((SrilmSymbol)p_symbolTable, g_lm_order, lm_file);
     
    } else {
      //p_lm = new LMGrammar_JAVA(g_lm_order, lm_file, use_left_euqivalent_state);
      //big bug: should load the consistent symbol files
      p_symbolTable = new BuildinSymbol(remote_symbol_tbl);
      p_lm = new LMGrammarJAVA((BuildinSymbol)p_symbolTable, g_lm_order, lm_file, use_left_euqivalent_state, use_right_euqivalent_state);
    }
  }
 
 
 
  // BUG: this is duplicating code in JoshuaConfiguration, needs unifying
  public static void read_config_file(String config_file)
  throws IOException {
   
    LineReader configReader = new LineReader(config_file);
    try { for (String line : configReader) {
      //line = line.trim().toLowerCase();
      line = line.trim();
      if (Regex.commentOrEmptyLine.matches(line)) continue;
     
      if (line.indexOf("=") != -1) { //parameters
        String[] fds = Regex.equalsWithSpaces.split(line);
        if (fds.length != 2) {
          throw new IllegalArgumentException("Wrong config line: " + line);
        }
        if ("lm_file".equals(fds[0])) {
          lm_file = fds[1].trim();
          if (logger.isLoggable(Level.FINE))
            logger.fine(String.format("lm file: %s", lm_file));
         
        } else if ("use_srilm".equals(fds[0])) {
          use_srilm = Boolean.valueOf(fds[1]);
          if (logger.isLoggable(Level.FINE))
            logger.fine(String.format("use_srilm: %s", use_srilm));
         
        } else if ("lm_ceiling_cost".equals(fds[0])) {
          lm_ceiling_cost = Double.parseDouble(fds[1]);
          if (logger.isLoggable(Level.FINE))
            logger.fine(String.format("lm_ceiling_cost: %s", lm_ceiling_cost));
         
        } else if ("use_left_euqivalent_state".equals(fds[0])) {
          use_left_euqivalent_state = Boolean.valueOf(fds[1]);
          if (logger.isLoggable(Level.FINE))
            logger.fine(String.format("use_left_euqivalent_state: %s", use_left_euqivalent_state));
         
        } else if ("use_right_euqivalent_state".equals(fds[0])) {
          use_right_euqivalent_state = Boolean.valueOf(fds[1]);
          if (logger.isLoggable(Level.FINE))
            logger.fine(String.format("use_right_euqivalent_state: %s", use_right_euqivalent_state));
         
        } else if ("order".equals(fds[0])) {
          g_lm_order = Integer.parseInt(fds[1]);
          if (logger.isLoggable(Level.FINE))
            logger.fine(String.format("g_lm_order: %s", g_lm_order));
         
        } else if ("remote_lm_server_port".equals(fds[0])) {
          port = Integer.parseInt(fds[1]);
          if (logger.isLoggable(Level.FINE))
            logger.fine(String.format("remote_lm_server_port: %s", port));
         
        } else if ("remote_symbol_tbl".equals(fds[0])) {
          remote_symbol_tbl = fds[1];
          if (logger.isLoggable(Level.FINE))
            logger.fine(String.format("remote_symbol_tbl: %s", remote_symbol_tbl));
         
        } else if ("hostname".equals(fds[0])) {
          g_host_name = fds[1].trim();
          if (logger.isLoggable(Level.FINE))
            logger.fine(String.format("host name is: %s", g_host_name));
         
        } else if ("interpolation_weight".equals(fds[0])) {
          interpolation_weight = Double.parseDouble(fds[1]);
          if (logger.isLoggable(Level.FINE))
            logger.fine(String.format("interpolation_weightt: %s", interpolation_weight));
         
        } else {
          logger.warning("LMServer doesn't use config line: " + line);
          //System.exit(1);
        }
      }
    } } finally { configReader.close(); }
  }
 
 
  // used by server to process diffent Client
  public static class ClientHandler extends Thread {
    public static class DecodedStructure {
      String cmd;
      int    num;
      int[]  wrds;
    }
   
    LMServer               parent;
    private Socket         socket;
    private BufferedReader in;
    private PrintWriter    out;
   
   
    public ClientHandler(Socket sock, LMServer pa) throws IOException {
      parent = pa;
      socket = sock;
      in = new BufferedReader(
        new InputStreamReader(socket.getInputStream()));
      out = new PrintWriter(
        new OutputStreamWriter(socket.getOutputStream()));
    }
   
   
    public void run() {
      String line_in;
      String line_out;
      try {
        while ((line_in = in.readLine()) != null) {
          //TODO block read
          //System.out.println("coming in: " + line);
          //line_out = process_request(line_in);
          line_out = process_request_no_cache(line_in);
         
          out.println(line_out);
          out.flush();
        }
      } catch(IOException ioe) {
        ioe.printStackTrace();
      } finally {
        try {
          in.close();
          out.close();
          socket.close();
        } catch(IOException ioe) {
          ioe.printStackTrace();
        }
      }
    }
   
   
    private String process_request_no_cache(String packet) {
      //search cache
      g_n_request++;
      String cmd_res = process_request_helper(packet);
      if (logger.isLoggable(Level.FINE) && g_n_request % 50000 == 0) {
        logger.fine("n_requests: " + g_n_request);
      }
      return cmd_res;
    }
   
   
   
    //This is the funciton that application specific
    private String process_request_helper(String line) {
      DecodedStructure ds = decode_packet(line);
     
      if ("prob".equals(ds.cmd)) {
        return get_prob(ds);
      } else if ("prob_bow".equals(ds.cmd)) {
        return get_prob_backoff_state(ds);
      } else if ("equiv_left".equals(ds.cmd)) {
        return get_left_equiv_state(ds);
      } else if ("equiv_right".equals(ds.cmd)) {
        return get_right_equiv_state(ds);
      } else {
        logger.severe("error : Wrong request line: " + line);
        //System.exit(1);
        return "";
      }
    }
   
   
    // format: prob order wrds
    private String get_prob(DecodedStructure ds) {
      return Double.toString(p_lm.ngramLogProbability(ds.wrds, ds.num));
    }
   
   
    // format: prob order wrds
    private String get_prob_backoff_state(DecodedStructure ds) {
      throw new RuntimeException("call get_prob_backoff_state in lmserver, must exit");
     
      /*Double res = p_lm.get_prob_backoff_state(ds.wrds, ds.num, ds.num);
      return res.toString();*/
    }
   
   
    // format: prob order wrds
    private String get_left_equiv_state(DecodedStructure ds) {
      throw new RuntimeException("call get_left_equiv_state in lmserver, must exit");
    }
   
   
    // format: prob order wrds
    private String get_right_equiv_state(DecodedStructure ds) {
      throw new RuntimeException("call get_right_equiv_state in lmserver, must exit");
    }
   
   
    private DecodedStructure decode_packet(String packet) {
      String[] fds         = Regex.spaces.split(packet);
      DecodedStructure res = new DecodedStructure();
      res.cmd              = fds[0].trim();
      res.num              = Integer.parseInt(fds[1]);
      int[] wrds           = new int[fds.length-2];
     
      for (int i = 2; i < fds.length; i++) {
        wrds[i-2] = Integer.parseInt(fds[i]);
      }
      res.wrds = wrds;
      return res;
    }
  }
}
TOP

Related Classes of joshua.decoder.ff.lm.distributed_lm.LMServer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.