Package edu.cmu.sphinx.linguist.g2p

Source Code of edu.cmu.sphinx.linguist.g2p.G2PConverter

/**
*
* Copyright 1999-2012 Carnegie Mellon University. 
* Portions Copyright 2002 Sun Microsystems, Inc. 
* Portions Copyright 2002 Mitsubishi Electric Research Laboratories.
* All Rights Reserved.  Use is subject to license terms.
*
* See the file "license.terms" for information on usage and
* redistribution of this file, and for a DISCLAIMER OF ALL
* WARRANTIES.
*
*/

package edu.cmu.sphinx.linguist.g2p;

import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Queue;
import java.util.regex.Pattern;

import edu.cmu.sphinx.fst.Arc;
import edu.cmu.sphinx.fst.Fst;
import edu.cmu.sphinx.fst.ImmutableFst;
import edu.cmu.sphinx.fst.State;
import edu.cmu.sphinx.fst.operations.ArcSort;
import edu.cmu.sphinx.fst.operations.Compose;
import edu.cmu.sphinx.fst.operations.ILabelCompare;
import edu.cmu.sphinx.fst.operations.NShortestPaths;
import edu.cmu.sphinx.fst.operations.OLabelCompare;
import edu.cmu.sphinx.fst.operations.Project;
import edu.cmu.sphinx.fst.operations.ProjectType;
import edu.cmu.sphinx.fst.operations.RmEpsilon;
import edu.cmu.sphinx.fst.semiring.Semiring;
import edu.cmu.sphinx.fst.semiring.TropicalSemiring;
import edu.cmu.sphinx.fst.utils.Utils;

/**
* The grapheme-to-phoneme (g2p) decoder
*
* @author John Salatas <jsalatas@users.sourceforge.net>
*/
public class G2PConverter {

    // epsilon symbol
    String eps = "<eps>";

    // end sequence symbol
    String se = "</s>";

    // begin sequence symbol
    String sb = "<s>";

    // skip symbol
    String skip = "_";

    // separator symbol
    String tie = "|";

    // set containing sequences to ignore
    HashSet<String> skipSeqs = new HashSet<String>();

    // clusters
    ArrayList<String>[] clusters = null;

    // the g2p model
    ImmutableFst g2pmodel;

    // fst containing the epsilon filter for the compose operation
    Fst epsilonFilter;

    /**
     * Create a decoder by loading the serialized model from a specified URL
     *
     * @param g2pModelUrl
     *            the URL of the serialized model
     * @throws IOException
     * @throws ClassNotFoundException
     */
    public G2PConverter(URL g2pModelUrl) throws IOException {
        try {
            g2pmodel = ImmutableFst.loadModel(g2pModelUrl.openStream());
        } catch (ClassNotFoundException e) {
            throw new IOException("Failed to load the model from " + g2pModelUrl, e);
        }
        init();
    }

    /**
     * Create a decoder by loading the serialized model from a specified
     * filename
     *
     * @param g2pmodel_file
     *            the filename of the serialized model
     */
    public G2PConverter(String g2pmodel_file) {
        g2pmodel = ImmutableFst.loadModel(g2pmodel_file);
        init();
    }

    /**
     * Initialize the decoder
     */
    private void init() {
        skipSeqs.add(eps);
        skipSeqs.add(sb);
        skipSeqs.add(se);
        skipSeqs.add(skip);
        skipSeqs.add("-");
        // keep an augmented copy (for compose)
        Compose.augment(0, g2pmodel, g2pmodel.getSemiring());
        ArcSort.apply(g2pmodel, new ILabelCompare());

        String[] isyms = g2pmodel.getIsyms();

        loadClusters(isyms);

        // get epsilon filter for composition
        epsilonFilter = Compose.getFilter(g2pmodel.getIsyms(),
                g2pmodel.getSemiring());
        ArcSort.apply(epsilonFilter, new ILabelCompare());
    }

    /**
     * Phoneticize a word
     *
     * @param entry
     *            the word to phoneticize transformed to an ArrayList of Strings
     *            (each element hold a single character)
     * @param nbest
     *            the number of distinct pronunciations to return
     * @return the pronunciation(s) of the input word
     */
    public ArrayList<Path> phoneticize(ArrayList<String> entry, int nbest) {
        Fst efst = entryToFSA(entry);
        Semiring s = efst.getSemiring();
        Compose.augment(1, efst, s);
        ArcSort.apply(efst, new OLabelCompare());
        Fst result = Compose.compose(efst, epsilonFilter, s, true);
        ArcSort.apply(result, new OLabelCompare());
        result = Compose.compose(result, g2pmodel, s, true);
        Project.apply(result, ProjectType.OUTPUT);
        if (nbest == 1) {
            result = NShortestPaths.get(result, 1, false);
        } else {
            // Requesting 10 times more best paths than what was asking
            // as there might be several paths resolving to same pronunciation
            // due to epsilon transitions.
            // I really hate cosmological constants :)
            result = NShortestPaths.get(result, nbest * 10, false);
        }
        // result = NShortestPaths.get(result, nbest, false);
        result = RmEpsilon.get(result);
        ArrayList<Path> paths = findAllPaths(result, nbest, skipSeqs,
                tie);

        return paths;
    }

    /**
     * Phoneticize a word
     *
     * @param entry
     *            the word to phoneticize
     * @param nbest
     *            the number of distinct pronunciations to return
     * @return the pronunciation(s) of the input word
     */
    public ArrayList<Path> phoneticize(String word, int nbest) {
        ArrayList<String> entry = new ArrayList<String>(word.length());
        for (int i = 0; i < word.length(); i++) {
            String ch = word.substring(i, i + 1);
            if (Utils.getIndex(g2pmodel.getIsyms(), ch) >= 0) {
                entry.add(ch);
            }
        }
        return phoneticize(entry, nbest);
    }

    /**
     * Transforms an input spelling/pronunciation into an equivalent FSA, adding
     * extra arcs as needed to accommodate clusters.
     *
     * @param entry
     *            the input vector
     * @return the created fst
     */
    private Fst entryToFSA(ArrayList<String> entry) {
        TropicalSemiring ts = new TropicalSemiring();
        Fst efst = new Fst(ts);

        State s = new State(ts.zero());
        efst.addState(s);
        efst.setStart(s);

        // Build the basic FSA
        for (int i = 0; i < entry.size() + 1; i++) {
            s = new State(ts.zero());
            efst.addState(s);
            if (i >= 1) {
                int symIndex = Utils.getIndex(g2pmodel.getIsyms(),
                        entry.get(i - 1));
                efst.getState(i).addArc(new Arc(symIndex, symIndex, 0.f, s));
            } else if (i == 0) {
                int symIndex = Utils.getIndex(g2pmodel.getIsyms(), sb);
                efst.getStart().addArc(new Arc(symIndex, symIndex, 0.f, s));
            }

            if (i == entry.size()) {
                State s1 = new State(ts.zero());
                efst.addState(s1);
                int symIndex = Utils.getIndex(g2pmodel.getIsyms(), se);
                s.addArc(new Arc(symIndex, symIndex, 0.f, s1));
                s1.setFinalWeight(0.f);
            }
        }

        // Add any cluster arcs
        for (int value = 0; value < clusters.length; value++) {
            ArrayList<String> cluster = clusters[value];
            if (cluster != null) {
                int start = 0;
                int k = 0;
                while (k != -1) {
                    k = Utils.search(entry, cluster, start);
                    if (k != -1) {
                        State from = efst.getState(start + k + 1);
                        from.addArc(new Arc(value, value, 0.f, efst
                                .getState(start + k + cluster.size() + 1)));
                        start = start + k + cluster.size();
                    }
                }
            }
        }

        efst.setIsyms(g2pmodel.getIsyms());
        efst.setOsyms(g2pmodel.getIsyms());

        return efst;
    }

    /**
     * Finds nbest paths in an Fst returned by NShortestPaths operation
     *
     * @param fst
     *            the input fst
     * @param nbest
     *            the number of paths to return
     * @param skipSeqs
     *            the sequences to ignore
     * @param tie
     *            the separator symbol
     * @return the paths
     */
    @SuppressWarnings("unchecked")
    private ArrayList<Path> findAllPaths(Fst fst, int nbest,
            HashSet<String> skipSeqs, String tie) {
        Semiring semiring = fst.getSemiring();

        // ArrayList<Path> finalPaths = new ArrayList<Path>();
        HashMap<String, Path> finalPaths = new HashMap<String, Path>();
        HashMap<State, Path> paths = new HashMap<State, Path>();
        Queue<State> queue = new LinkedList<State>();
        Path p = new Path(fst.getSemiring());
        p.setCost(semiring.one());
        paths.put(fst.getStart(), p);

        queue.add(fst.getStart());

        String[] osyms = fst.getOsyms();
        while (!queue.isEmpty()) {
            State s = queue.remove();
            Path currentPath = paths.get(s);

            if (s.getFinalWeight() != semiring.zero()) {
                String pathString = currentPath.getPath().toString();
                if (finalPaths.containsKey(pathString)) {
                    // path already exist. update its cost
                    Path old = finalPaths.get(pathString);
                    if (old.getCost() > currentPath.getCost()) {
                        finalPaths.put(pathString, currentPath);
                    }
                } else {
                    finalPaths.put(pathString, currentPath);
                }
            }

            int numArcs = s.getNumArcs();
            for (int j = 0; j < numArcs; j++) {
                Arc a = s.getArc(j);
                p = new Path(fst.getSemiring());
                Path cur = paths.get(s);
                p.setCost(cur.getCost());
                p.setPath((ArrayList<String>) cur.getPath().clone());

                String sym = osyms[a.getOlabel()];

                String[] symsArray = sym.split("\\" + tie);

                for (int i = 0; i < symsArray.length; i++) {
                    String phone = symsArray[i];
                    if (!skipSeqs.contains(phone)) {
                        p.getPath().add(phone);
                    }
                }
                p.setCost(semiring.times(p.getCost(), a.getWeight()));
                State nextState = a.getNextState();
                paths.put(nextState, p);
                if (!queue.contains(nextState)) {
                    queue.add(nextState);
                }
            }
        }

        ArrayList<Path> res = new ArrayList<Path>();
        for (Path path : finalPaths.values()) {
            res.add(path);
        }

        Collections.sort(res, new PathComparator());
        int numPaths = res.size();
        for (int i = nbest; i < numPaths; i++) {
            res.remove(res.size() - 1);
        }

        return res;
    }

    /**
     * Initialize clusters
     */
    @SuppressWarnings("unchecked")
    private void loadClusters(String[] syms) {
        clusters = new ArrayList[syms.length];
        for (int i = 0; i < syms.length; i++) {
            clusters[i] = null;
        }
        for (int i = 2; i < syms.length; i++) {
            String sym = syms[i];
            if (sym.contains(tie)) {
                String split[] = sym.split(Pattern.quote(tie));
                ArrayList<String> cluster = new ArrayList<String>(
                        Arrays.asList(split));
                clusters[i] = cluster;
            }
        }
    }
}
TOP

Related Classes of edu.cmu.sphinx.linguist.g2p.G2PConverter

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.