Package edu.cmu.sphinx.linguist.acoustic.tiedstate.trainer

Source Code of edu.cmu.sphinx.linguist.acoustic.tiedstate.trainer.Sphinx4Loader

/*
* Copyright 1999-2002 Carnegie Mellon University.
* Portions Copyright 2002 Sun Microsystems, Inc.
* Portions Copyright 2002 Mitsubishi Electric Research Laboratories.
* All Rights Reserved.  Use is subject to license terms.
*
* See the file "license.terms" for information on usage and
* redistribution of this file, and for a DISCLAIMER OF ALL
* WARRANTIES.
*
*/

package edu.cmu.sphinx.linguist.acoustic.tiedstate.trainer;

import edu.cmu.sphinx.linguist.acoustic.*;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.*;
import static edu.cmu.sphinx.linguist.acoustic.tiedstate.Pool.Feature.*;
import edu.cmu.sphinx.util.ExtendedStreamTokenizer;
import edu.cmu.sphinx.util.LogMath;
import edu.cmu.sphinx.util.StreamFactory;
import edu.cmu.sphinx.util.props.PropertyException;
import edu.cmu.sphinx.util.props.PropertySheet;
import edu.cmu.sphinx.util.props.S4Integer;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Map;
import java.util.logging.Level;

/**
* an acoustic model loader that loads sphinx3 ascii data
* <p/>
* Mixture weights and transition probabilities are maintained in logMath log base,
*/
class Sphinx4Loader extends Sphinx3Loader {

    protected final static String TMAT_FILE_VERSION = "4.0";

    @S4Integer(defaultValue = 10)
    public final static String MAX_MODEL_SIZE = "maxStatePerModel";
    private int maxModelSize;


    @Override
    public void newProperties(PropertySheet ps) throws PropertyException {
        super.newProperties(ps);

        maxModelSize = ps.getInt(MAX_MODEL_SIZE);
    }

    /**
     * Loads the sphinx4 density file, a set of density arrays are created and placed in the given pool.
     *
     * {@inheritDoc}
     */
    @Override
    protected void loadHMMPool(boolean useCDUnits, InputStream inputStream, String path)
            throws IOException {
        ExtendedStreamTokenizer est = new ExtendedStreamTokenizer(inputStream, '#', false);

        logger.info("Loading HMM file from: ");
        logger.info(path);

        est.expectString(MODEL_VERSION);

        int numBase = est.getInt("numBase");
        est.expectString("n_base");

        int numTri = est.getInt("numTri");
        est.expectString("n_tri");

        est.getInt("numStateMap");
        est.expectString("n_state_map");

        int numTiedState = est.getInt("numTiedState");
        est.expectString("n_tied_state");

        int numContextIndependentTiedState = est.getInt("numContextIndependentTiedState");
        est.expectString("n_tied_ci_state");

        int numTiedTransitionMatrices = est.getInt("numTiedTransitionMatrices");
        est.expectString("n_tied_tmat");

        int[] maxStid = new int[maxModelSize];

        HMMManager hmmManager = super.getHmmManager();
        Pool<float[][]> matrixPool = super.getMatrixPool();
        Pool<float[]> mixtureWeightsPool = super.getMixtureWeightsPool();
        Map<String, Unit> contextIndependentUnits = super.getContextIndependentUnits();

        assert numTiedState == mixtureWeightsPool.getFeature(NUM_SENONES, 0);
        assert numTiedTransitionMatrices == matrixPool.size();

        // Load the base phones
        for (int i = 0; i < numBase; i++) {
            String left = est.getString();
            String right = est.getString();
            String position = est.getString();
            int tmat = est.getInt("tmat");

            // Read all state ID in the line...
            for (int j = 0; ; j++) {
                String str = est.getString();

                // ... until we reach a "N"

                if (!str.equals("N")) {
                    int id = Integer.parseInt(str);
                    try {
                        maxStid[j] = id;
                    } catch (ArrayIndexOutOfBoundsException aie) {
                        throw new Error("Use a larger value for " +
                                "maxStatePerModel");
                    }
                    assert maxStid[j] >= 0 &&
                            maxStid[j] < numContextIndependentTiedState;
                } else {
                    break;
                }
            }

            assert left.equals("-");
            assert right.equals("-");
            assert position.equals("-");
            assert tmat < numTiedTransitionMatrices;

//          int[] stid = Arrays.copyOf(maxStid, numStates);
//          Unit unit = Unit.getUnit(name, attribute.equals(FILLER));
//          contextIndependentUnits.put(unit.getName(), unit);
//
//          if (logger.isLoggable(Level.FINE)) {
//              logger.fine("Loaded " + unit);
//          }
//           The first filler
//          if (unit.isFiller() && unit.getName().equals(SILENCE_CIPHONE)) {
//                unit = Unit.SILENCE;
//          }
//          float[][] transitionMatrix = matrixPool.get(tmat);
//          SenoneSequence ss = getSenoneSequence(stid);
//          HMM hmm = new SenoneHMM(unit, ss, transitionMatrix, HMMPosition.lookup(position));
//          hmmManager.put(hmm);
        }

        // Load the context dependent phones.  If the useCDUnits
        // property is false, the CD phones will not be created, but
        // the values still need to be read in from the file.

        String lastUnitName = "";
        Unit lastUnit = null;
        int[] lastStid = null;
        SenoneSequence lastSenoneSequence = null;

        for (int i = 0; i < numTri; i++) {
            String name = est.getString();
            String left = est.getString();
            String right = est.getString();
            String position = est.getString();
            String attribute = est.getString();
            int tmat = est.getInt("tmat");

            int numStates = 0;

            // Read all state ID in the line...
            for (int j = 0; ; j++) {
                String str = est.getString();

                // ... until we reach a "N"

                if (!str.equals("N")) {
                    int id = Integer.parseInt(str);
                    maxStid[j] = id;
                    assert maxStid[j] >= numContextIndependentTiedState &&
                            maxStid[j] < numTiedState;
                } else {
                    numStates = j;
                    break;
                }
            }
            int[] stid = Arrays.copyOf(maxStid, numStates);

            assert !left.equals("-");
            assert !right.equals("-");
            assert !position.equals("-");
            assert attribute.equals("n/a");
            assert tmat < numTiedTransitionMatrices;

            if (useCDUnits) {
                Unit unit = null;
                String unitName = (name + ' ' + left + ' ' + right);

                if (unitName.equals(lastUnitName)) {
                    unit = lastUnit;
                } else {
                    Unit[] leftContext = new Unit[1];
                    leftContext[0] = contextIndependentUnits.get(left);

                    Unit[] rightContext = new Unit[1];
                    rightContext[0] = contextIndependentUnits.get(right);

                    //Context context = LeftRightContext.get(leftContext, rightContext);
                    //unit = Unit.getUnit(name, false, context);
                }
                lastUnitName = unitName;
                lastUnit = unit;

                if (logger.isLoggable(Level.FINE)) {
                    logger.fine("Loaded " + unit);
                }

                float[][] transitionMatrix = matrixPool.get(tmat);

                SenoneSequence ss = lastSenoneSequence;
                if (ss == null || !sameSenoneSequence(stid, lastStid)) {
                    ss = getSenoneSequence(stid);
                }
                lastSenoneSequence = ss;
                lastStid = stid;

                HMM hmm = new SenoneHMM(unit, ss, transitionMatrix, HMMPosition.lookup(position));
                hmmManager.put(hmm);
            }
        }

        est.close();
    }

    @Override
    protected Pool<float[][]> loadTransitionMatrices(String path)
            throws FileNotFoundException, IOException {

        String location = "";
        InputStream inputStream = StreamFactory.getInputStream(location, path);

        LogMath logMath = LogMath.getLogMath();
        logger.info("Loading transition matrices from: ");
        logger.info(path);

        Pool<float[][]> pool = new Pool<float[][]>(path);
        ExtendedStreamTokenizer est = new ExtendedStreamTokenizer(inputStream, '#', false);

        est.expectString("tmat");
        int numMatrices = est.getInt("numMatrices");
        est.expectString("X");
        // numStates = est.getInt("numStates");

        for (int i = 0; i < numMatrices; i++) {
            est.expectString("tmat");
            est.expectString("[" + i + ']');
            est.expectString("nstate");
            // Number of emitting states + 1, final non-emitting state
            int numStates = est.getInt("numStates") + 1;
            float[][] tmat = new float[numStates][numStates];

            for (int j = 0; j < numStates; j++) {
                for (int k = 0; k < numStates; k++) {

                    // the last row is just zeros, so we just do
                    // the first (numStates - 1) rows

                    if (j < numStates - 1) {
                         if (k == j || k == j + 1) {
                             tmat[j][k] = est.getFloat("tmat value");
                         }
                    }

                    tmat[j][k] = logMath.linearToLog(tmat[j][k]);

                    if (logger.isLoggable(Level.FINE)) {
                        logger.fine("tmat j " + j + " k "
                                + k + " tm " + tmat[j][k]);
                    }
                }
            }
            pool.put(i, tmat);
        }
        est.close();
        return pool;
    }
}
TOP

Related Classes of edu.cmu.sphinx.linguist.acoustic.tiedstate.trainer.Sphinx4Loader

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.