/*
* Carrot2 project.
*
* Copyright (C) 2002-2014, Dawid Weiss, Stanisław Osiński.
* All rights reserved.
*
* Refer to the full license file "carrot2.LICENSE"
* in the root folder of the repository checkout or at:
* http://www.carrot2.org/carrot2.LICENSE
*/
package org.carrot2.clustering.lingo;
import org.carrot2.text.vsm.TermDocumentMatrixReducer;
import org.carrot2.text.vsm.TfTermWeighting;
import org.junit.Before;
import org.junit.Test;
/**
* Test cases for label building in {@link ClusterBuilder}.
*/
public class ClusterLabelBuilderTest extends LingoProcessingComponentTestBase
{
/** Label builder under tests */
private ClusterBuilder clusterBuilder;
@Before
public void setUpClusterLabelBuilder()
{
clusterBuilder = new ClusterBuilder();
clusterBuilder.labelAssigner = new SimpleLabelAssigner();
reducer = new TermDocumentMatrixReducer();
}
@Test
public void testEmpty()
{
buildModelAndCheck(new int [0]);
}
@Test
public void testNoPhrases()
{
createDocuments("", "aa . bb", "", "bb . cc", "", "cc . aa");
final int [] expectedFeatureIndex = new int []
{
0, 1, 2
};
desiredClusterCountBase = 30;
buildModelAndCheck(expectedFeatureIndex);
}
@Test
public void testSinglePhraseNoSingleWords()
{
createDocuments("aa bb", "aa bb", "aa bb", "aa bb");
final int [] expectedFeatureIndex = new int []
{
2
};
desiredClusterCountBase = 10;
buildModelAndCheck(expectedFeatureIndex);
}
@Test
public void testSinglePhraseSingleWords()
{
createDocuments("aa bb", "aa bb", "cc", "cc", "aa bb", "aa bb . cc");
clusterBuilder.phraseLabelBoost = 0.5;
final int [] expectedFeatureIndex = new int []
{
2, 3
};
desiredClusterCountBase = 15;
buildModelAndCheck(expectedFeatureIndex);
}
@Test
public void testQueryWordsRemoval()
{
createDocuments("query word . aa", "query word . aa", "query . word",
"query . word . aa");
clusterBuilder.phraseLabelBoost = 0.5;
final int [] expectedFeatureIndex = new int []
{
0
};
desiredClusterCountBase = 10;
createPreprocessingContext("query word");
buildModelAndCheck(expectedFeatureIndex);
}
@Test
public void testExternalFeatureScores()
{
createDocuments("aa bb", "aa bb", "cc", "cc", "cc", "cc", "aa bb", "aa bb", "dd",
"dd", "dd", "dd", "ee ff", "ee ff", "ee ff", "ee ff");
clusterBuilder.phraseLabelBoost = 0.5;
desiredClusterCountBase = 15;
final int [] expectedFeatureIndex = new int []
{
6, 7, 2, 3
};
buildModelAndCheck(expectedFeatureIndex);
// Make a copy of feature indices
final int [] featureIndex = lingoContext.preprocessingContext.allLabels.featureIndex;
for (int i = 0; i < featureIndex.length; i++)
{
clusterBuilder.featureScorer = new OneLabelFeatureScorer(i, 2);
check(new int []
{
featureIndex[i], featureIndex[i], featureIndex[i], featureIndex[i]
});
}
}
private static class OneLabelFeatureScorer implements IFeatureScorer
{
private int labelIndex;
private double score;
OneLabelFeatureScorer(int labelIndex, double score)
{
this.labelIndex = labelIndex;
this.score = score;
}
public double [] getFeatureScores(LingoProcessingContext lingoContext)
{
final double [] scores = new double [lingoContext.preprocessingContext.allLabels.featureIndex.length];
scores[labelIndex] = score;
return scores;
}
}
private void buildModelAndCheck(int [] expectedFeatureIndex)
{
buildLingoModel();
check(expectedFeatureIndex);
}
private void check(int [] expectedFeatureIndex)
{
clusterBuilder.buildLabels(lingoContext, new TfTermWeighting());
assertThat(lingoContext.clusterLabelFeatureIndex).as("clusterLabelFeatureIndex")
.containsOnly(expectedFeatureIndex);
}
}