Package edu.cmu.sphinx.fst.operations

Source Code of edu.cmu.sphinx.fst.operations.NShortestPaths

/**
*
* Copyright 1999-2012 Carnegie Mellon University. 
* Portions Copyright 2002 Sun Microsystems, Inc. 
* Portions Copyright 2002 Mitsubishi Electric Research Laboratories.
* All Rights Reserved.  Use is subject to license terms.
*
* See the file "license.terms" for information on usage and
* redistribution of this file, and for a DISCLAIMER OF ALL
* WARRANTIES.
*
*/

package edu.cmu.sphinx.fst.operations;

import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.PriorityQueue;

import edu.cmu.sphinx.fst.Arc;
import edu.cmu.sphinx.fst.Fst;
import edu.cmu.sphinx.fst.State;
import edu.cmu.sphinx.fst.utils.Pair;
import edu.cmu.sphinx.fst.semiring.Semiring;

/**
* N-shortest paths operation.
*
* See: M. Mohri, M. Riley,
* "An Efficient Algorithm for the n-best-strings problem", Proceedings of the
* International Conference on Spoken Language Processing 2002 (ICSLP '02).
*
* See: M. Mohri,
* "Semiring Framework and Algorithms for Shortest-Distance Problems", Journal
* of Automata, Languages and Combinatorics, 7(3), pp. 321-350, 2002.
*
* @author John Salatas <jsalatas@users.sourceforge.net>
*
*/
public class NShortestPaths {
    private NShortestPaths() {
    }

    /**
     * Calculates the shortest distances from each state to the final
     *
     * @param fst
     *            the fst to calculate the shortest distances
     * @return the array containing the shortest distances
     */
    public static float[] shortestDistance(Fst fst) {

        Fst reversed = Reverse.get(fst);

        float[] d = new float[reversed.getNumStates()];
        float[] r = new float[reversed.getNumStates()];

        Semiring semiring = reversed.getSemiring();

        Arrays.fill(d, semiring.zero());
        Arrays.fill(r, semiring.zero());

        LinkedHashSet<State> queue = new LinkedHashSet<State>();

        queue.add(reversed.getStart());

        d[reversed.getStart().getId()] = semiring.one();
        r[reversed.getStart().getId()] = semiring.one();

        while (!queue.isEmpty()) {
            State q = queue.iterator().next();
            queue.remove(q);

            float rnew = r[q.getId()];
            r[q.getId()] = semiring.zero();

            for (int i = 0; i < q.getNumArcs(); i++) {
                Arc a = q.getArc(i);
                State nextState = a.getNextState();
                float dnext = d[a.getNextState().getId()];
                float dnextnew = semiring.plus(dnext,
                        semiring.times(rnew, a.getWeight()));
                if (dnext != dnextnew) {
                    d[a.getNextState().getId()] = dnextnew;
                    r[a.getNextState().getId()] = semiring.plus(r[a
                            .getNextState().getId()], semiring.times(rnew,
                            a.getWeight()));
                    if (!queue.contains(nextState.getId())) {
                        queue.add(nextState);
                    }
                }
            }
        }
        return d;
    }

    /**
     * Calculates the n-best shortest path from the initial to the final state.
     *
     * @param fst
     *            the fst to calculate the nbest shortest paths
     * @param n
     *            number of best paths to return
     * @param determinize
     *            if true the input fst will bwe determinized prior the
     *            operation
     * @return an fst containing the n-best shortest paths
     */
    public static Fst get(Fst fst, int n, boolean determinize) {
        if (fst == null) {
            return null;
        }

        if (fst.getSemiring() == null) {
            return null;
        }
        Fst fstdet = fst;
        if (determinize) {
            fstdet = Determinize.get(fst);
        }
        final Semiring semiring = fstdet.getSemiring();
        Fst res = new Fst(semiring);
        res.setIsyms(fstdet.getIsyms());
        res.setOsyms(fstdet.getOsyms());

        final float[] d = shortestDistance(fstdet);

        ExtendFinal.apply(fstdet);

        int[] r = new int[fstdet.getNumStates()];

        PriorityQueue<Pair<State, Float>> queue = new PriorityQueue<Pair<State, Float>>(
                10, new Comparator<Pair<State, Float>>() {

                    public int compare(Pair<State, Float> o1,
                            Pair<State, Float> o2) {
                        float previous = o1.getRight();
                        float d1 = d[o1.getLeft().getId()];

                        float next = o2.getRight();
                        float d2 = d[o2.getLeft().getId()];

                        float a1 = semiring.times(next, d2);
                        float a2 = semiring.times(previous, d1);

                        if (semiring.naturalLess(a1, a2))
                            return 1;

                        if (a1 == a2)
                            return 0;

                        return -1;
                    }
                });

        HashMap<Pair<State, Float>, Pair<State, Float>> previous = new HashMap<Pair<State, Float>, Pair<State, Float>>(
                fst.getNumStates());
        HashMap<Pair<State, Float>, State> stateMap = new HashMap<Pair<State, Float>, State>(
                fst.getNumStates());

        State start = fstdet.getStart();
        Pair<State, Float> item = new Pair<State, Float>(start, semiring.one());
        queue.add(item);
        previous.put(item, null);

        while (!queue.isEmpty()) {
            Pair<State, Float> pair = queue.remove();
            State p = pair.getLeft();
            Float c = pair.getRight();

            State s = new State(p.getFinalWeight());
            res.addState(s);
            stateMap.put(pair, s);
            if (previous.get(pair) == null) {
                // this is the start state
                res.setStart(s);
            } else {
                // add the incoming arc from previous to current
                State previouState = stateMap.get(previous.get(pair));
                State previousOldState = previous.get(pair).getLeft();
                for (int j = 0; j < previousOldState.getNumArcs(); j++) {
                    Arc a = previousOldState.getArc(j);
                    if (a.getNextState().equals(p)) {
                        previouState.addArc(new Arc(a.getIlabel(), a
                                .getOlabel(), a.getWeight(), s));
                    }
                }
            }

            Integer stateIndex = p.getId();
            r[stateIndex]++;

            if ((r[stateIndex] == n) && (p.getFinalWeight() != semiring.zero())) {
                break;
            }

            if (r[stateIndex] <= n) {
                for (int j = 0; j < p.getNumArcs(); j++) {
                    Arc a = p.getArc(j);
                    float cnew = semiring.times(c, a.getWeight());
                    Pair<State, Float> next = new Pair<State, Float>(
                            a.getNextState(), cnew);
                    previous.put(next, pair);
                    queue.add(next);
                }
            }
        }

        return res;
    }
}
TOP

Related Classes of edu.cmu.sphinx.fst.operations.NShortestPaths

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.