Package edu.cmu.sphinx.linguist.acoustic.tiedstate.kaldi

Source Code of edu.cmu.sphinx.linguist.acoustic.tiedstate.kaldi.Triple

package edu.cmu.sphinx.linguist.acoustic.tiedstate.kaldi;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import edu.cmu.sphinx.util.LogMath;

final class HmmState {

    private final int id;
    private final int pdfClass;
    private final List<Integer> transitions;

    public HmmState(int id, int pdfClass, Collection<Integer> transitions) {
        this.id = id;
        this.pdfClass = pdfClass;
        this.transitions = new ArrayList<Integer>(transitions);
    }

    public int getId() {
        return id;
    }

    public int getPdfClass() {
        return pdfClass;
    }

    public List<Integer> getTransitions() {
        return transitions;
    }

    public int size() {
        return transitions.size();
    }

    @Override
    public String toString() {
        return String.format("HmmSate {%d, %d, %s}",
                             id, pdfClass, transitions);
    }
}

final class Triple {

    private int phone;
    private int hmmState;
    private int pdf;

    public Triple(int phone, int hmmState, int pdf) {
        this.phone = phone;
        this.hmmState = hmmState;
        this.pdf = pdf;
    }

    @Override
    public boolean equals(Object object) {
        if (!(object instanceof Triple))
            return false;

        Triple other = (Triple) object;

        return phone    == other.phone &&
               hmmState == other.hmmState &&
               pdf      == other.pdf;
    }

    @Override
    public int hashCode() {
        return 31 * (31 * phone + hmmState) + pdf;
    }

    @Override
    public String toString() {
        return String.format("Triple {%d, %d, %d}", phone, hmmState, pdf);
    }
}

/**
* Represents transition model of a Kaldi acoustic model.
*/
public class TransitionModel {

    private Map<Integer, List<HmmState>> phoneStates;
    private Map<Triple, Integer> transitionStates;
    private float[] logProbabilities;

    /**
     * Loads transition model using provided parser.
     *
     * @param parser parser
     */
    public TransitionModel(KaldiTextParser parser) {
        parser.expectToken("<TransitionModel>");
        parseTopology(parser);

        parser.expectToken("<Triples>");
        transitionStates = new HashMap<Triple, Integer>();
        int numTriples = parser.getInt();
        int transitionId = 1;

        for (int i = 0; i < numTriples; ++i) {
            int phone = parser.getInt();
            int hmmState = parser.getInt();
            int pdf = parser.getInt();
            Triple triple = new Triple(phone, hmmState, pdf);
            transitionStates.put(triple, transitionId);
            transitionId +=
                phoneStates.get(phone).get(hmmState).getTransitions().size();
        }

        parser.expectToken("</Triples>");
        parser.expectToken("<LogProbs>");
        logProbabilities = parser.getFloatArray();
        parser.expectToken("</LogProbs>");
        parser.expectToken("</TransitionModel>");

        LogMath logMath = LogMath.getLogMath();
        for (int i = 0; i < logProbabilities.length; ++i)
            logProbabilities[i] = logMath.lnToLog(logProbabilities[i]);
    }

    private void parseTopology(KaldiTextParser parser) {
        parser.expectToken("<Topology>");

        phoneStates = new HashMap<Integer, List<HmmState>>();
        String token;

        while ("<TopologyEntry>".equals(token = parser.getToken())) {
            parser.assertToken("<TopologyEntry>", token);
            parser.expectToken("<ForPhones>");

            List<Integer> phones = new ArrayList<Integer>();
            while (!"</ForPhones>".equals(token = parser.getToken()))
                phones.add(Integer.parseInt(token));

            List<HmmState> states = new ArrayList<HmmState>(3);
            while ("<State>".equals(token = parser.getToken())) {
                // Skip state number.
                int id = parser.getInt();
                token = parser.getToken();

                if ("<PdfClass>".equals(token)) {
                    int pdfClass = parser.getInt();
                    List<Integer> transitions = new ArrayList<Integer>();
                    while ("<Transition>".equals(token = parser.getToken())) {
                        transitions.add(parser.getInt());
                        // Skip initial probability.
                        parser.getToken();
                    }

                    parser.assertToken("</State>", token);
                    states.add(new HmmState(id, pdfClass, transitions));
                }
            }

            for (Integer id : phones)
                phoneStates.put(id, states);
        }

        parser.assertToken("</Topology>", token);
    }

    /**
     * Returns transition matrix for the given context.
     *
     * @param phone central phone in the context
     * @param pdfs  array of pdf identifiers of the context units
     *
     * @return
     * 4 by 4 matrix where cell i,j contains probability in {@link LogMath}
     * domain of transition from state i to state j
     */
    public float[][] getTransitionMatrix(int phone, int[] pdfs) {
        // TODO: use variable size
        float[][] transitionMatrix = new float[4][4];
        Arrays.fill(transitionMatrix[3], LogMath.LOG_ZERO);

        for (HmmState state : phoneStates.get(phone)) {
            int stateId = state.getId();
            Arrays.fill(transitionMatrix[stateId], LogMath.LOG_ZERO);
            Triple triple = new Triple(phone, stateId, pdfs[stateId]);
            int i = transitionStates.get(triple);

            for (Integer j : state.getTransitions())
                transitionMatrix[stateId][j] = logProbabilities[i++];
        }

        return transitionMatrix;
    }
}
TOP

Related Classes of edu.cmu.sphinx.linguist.acoustic.tiedstate.kaldi.Triple

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.