Package edu.cmu.sphinx.decoder.adaptation

Source Code of edu.cmu.sphinx.decoder.adaptation.Stats

package edu.cmu.sphinx.decoder.adaptation;

import edu.cmu.sphinx.api.SpeechResult;
import edu.cmu.sphinx.decoder.search.Token;
import edu.cmu.sphinx.frontend.FloatData;
import edu.cmu.sphinx.linguist.HMMSearchState;
import edu.cmu.sphinx.linguist.SearchState;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.Loader;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.Sphinx3Loader;
import edu.cmu.sphinx.util.LogMath;

/**
* This class is used for estimating a MLLR transform for each cluster of data.
* The clustering must be previously performed using
* ClusteredDensityFileData.java
*
* @author Bogdan Petcu
*/
public class Stats {

  private ClusteredDensityFileData means;
  private double[][][][][] regLs;
  private double[][][][] regRs;
  private int nrOfClusters;
  private Sphinx3Loader loader;
  private float varFlor;
  private LogMath logMath = LogMath.getLogMath();;

  public Stats(Loader loader, ClusteredDensityFileData means) {
    this.loader = (Sphinx3Loader) loader;
    this.nrOfClusters = means.getNumberOfClusters();
    this.means = means;
    this.varFlor = (float) 1e-5;
    this.invertVariances();
    this.init();
  }

  private void init() {
    int len = loader.getVectorLength()[0];
    this.regLs = new double[nrOfClusters][][][][];
    this.regRs = new double[nrOfClusters][][][];

    for (int i = 0; i < nrOfClusters; i++) {
      this.regLs[i] = new double[loader.getNumStreams()][][][];
      this.regRs[i] = new double[loader.getNumStreams()][][];

      for (int j = 0; j < loader.getNumStreams(); j++) {
        len = loader.getVectorLength()[j];
        this.regLs[i][j] = new double[len][len + 1][len + 1];
        this.regRs[i][j] = new double[len][len + 1];
      }
    }
  }

  public ClusteredDensityFileData getClusteredData() {
    return this.means;
  }

  public double[][][][][] getRegLs() {
    return regLs;
  }

  public double[][][][] getRegRs() {
    return regRs;
  }

  /**
   * Used for inverting variances.
   */
  private void invertVariances() {

    for (int i = 0; i < loader.getNumStates(); i++) {
      for (int k = 0; k < loader.getNumGaussiansPerState(); k++) {
        for (int l = 0; l < loader.getVectorLength()[0]; l++) {
          if (loader.getVariancePool().get(
              i * loader.getNumGaussiansPerState() + k)[l] <= 0.) {
            this.loader.getVariancePool().get(i
                * loader.getNumGaussiansPerState() + k)[l] = (float) 0.5;
          } else if (loader.getVariancePool().get(
              i * loader.getNumGaussiansPerState() + k)[l] < varFlor) {
            this.loader.getVariancePool().get(i
                * loader.getNumGaussiansPerState() + k)[l] = (float) (1. / varFlor);
          } else {
            this.loader.getVariancePool().get(i
                * loader.getNumGaussiansPerState() + k)[l] = (float) (1. / loader
                .getVariancePool().get(
                    i * loader.getNumGaussiansPerState()
                        + k)[l]);
          }
        }
      }
    }
  }

  /**
   * Computes posterior values for the each component.
   *
   * @param componentScores
   *            from which the posterior values are computed.
   * @return posterior values for all components.
   */
  private float[] computePosterios(float[] componentScores) {
    float max;
    float[] posteriors = componentScores;

    max = posteriors[0];

    for (int i = 1; i < componentScores.length; i++) {
      if (posteriors[i] > max) {
        max = posteriors[i];
      }
    }

    for (int i = 0; i < componentScores.length; i++) {
      posteriors[i] = (float) logMath.logToLinear(posteriors[i] - max);
    }

    return posteriors;
  }

  /**
   * This method is used for directly collect and use counts. The counts are
   * collected and stored separately for each cluster.
   *
   * @param result
   *            Result object to collect counts from.
   */
  public void collect(SpeechResult result) throws Exception {
    Token token = result.getResult().getBestToken();
    float[] componentScore, featureVector, posteriors, tmean;
    float dnom, wtMeanVar, wtDcountVar, wtDcountVarMean, mean;
    int mId, len, cluster;

    if (token == null)
      throw new Exception("Best token not found!");

    do {
      FloatData feature = (FloatData) token.getData();
      SearchState ss = token.getSearchState();

      if (!(ss instanceof HMMSearchState && ss.isEmitting())) {
        token = token.getPredecessor();
        continue;
      }

      componentScore = token.calculateComponentScore(feature);
      featureVector = FloatData.toFloatData(feature).getValues();
      mId = (int) ((HMMSearchState)token.getSearchState()).getHMMState().getMixtureId();
      posteriors = this.computePosterios(componentScore);
      len = loader.getVectorLength()[0];

      for (int i = 0; i < componentScore.length; i++) {
        cluster = means.getClassIndex(mId
            * loader.getNumGaussiansPerState() + i);
        dnom = posteriors[i];
        if (dnom > 0.) {
          tmean = loader.getMeansPool().get(
              mId * loader.getNumGaussiansPerState() + i);

          for (int j = 0; j < featureVector.length; j++) {
            mean = posteriors[i] * featureVector[j];
            wtMeanVar = mean
                * loader.getVariancePool().get(mId
                    * loader.getNumGaussiansPerState()
                    + i)[j];
            wtDcountVar = dnom
                * loader.getVariancePool().get(mId
                    * loader.getNumGaussiansPerState()
                    + i)[j];

            for (int p = 0; p < featureVector.length; p++) {
              wtDcountVarMean = wtDcountVar * tmean[p];

              for (int q = p; q < featureVector.length; q++) {
                regLs[cluster][0][j][p][q] += wtDcountVarMean
                    * tmean[q];
              }
              regLs[cluster][0][j][p][len] += wtDcountVarMean;
              regRs[cluster][0][j][p] += wtMeanVar * tmean[p];
            }
            regLs[cluster][0][j][len][len] += wtDcountVar;
            regRs[cluster][0][j][len] += wtMeanVar;

          }
        }
      }

      token = token.getPredecessor();
    } while (token != null);
  }

  /**
   * Fill lower part of Legetter's set of G matrices.
   */
  public void fillRegLowerPart() {
    for (int i = 0; i < this.nrOfClusters; i++) {
      for (int j = 0; j < loader.getNumStreams(); j++) {
        for (int l = 0; l < loader.getVectorLength()[j]; l++) {
          for (int p = 0; p <= loader.getVectorLength()[j]; p++) {
            for (int q = p + 1; q <= loader.getVectorLength()[j]; q++) {
              regLs[i][j][l][q][p] = regLs[i][j][l][p][q];
            }
          }
        }
      }
    }
  }

    public Transform createTransform() {
        Transform transform = new Transform(loader, nrOfClusters);
        transform.update(this);
        return transform;
    }

}
TOP

Related Classes of edu.cmu.sphinx.decoder.adaptation.Stats

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.