/**
* Copyright 2013-2015 Pierre Merienne
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.github.pmerienne.trident.ml.testing.data;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import com.github.pmerienne.trident.ml.core.Instance;
import com.github.pmerienne.trident.ml.core.TextInstance;
import com.github.pmerienne.trident.ml.preprocessing.EnglishTokenizer;
import com.github.pmerienne.trident.ml.preprocessing.TwitterTokenizer;
public class Datasets {
private final static File USPS_FILE = new File("src/test/resources/usps.csv");
private final static File SPAM_FILE = new File("src/test/resources/spam.csv");
private final static File BIRTHS_FILE = new File("src/test/resources/births.csv");
private final static File REUTEURS_FILE = new File("src/test/resources/reuters.csv");
private final static File CLUSTERING_FILE = new File("src/test/resources/seeds.csv");
private final static File TWITTER_FILE = new File("src/test/resources/twitter-sentiment.csv");
private static List<Instance<Boolean>> SPAM_SAMPLES;
private static List<Instance<Integer>> USPS_SAMPLES;
private static List<Instance<Double>> BIRTHS_SAMPLES;
private static List<TextInstance<Integer>> REUTERS_SAMPLES;
private static List<TextInstance<Boolean>> TWITTER_SAMPLES;
private static List<Instance<Integer>> CUSTERING_SAMPLES;
public static List<Instance<Boolean>> getSpamSamples() {
if (SPAM_SAMPLES == null) {
try {
loadSPAMData();
} catch (IOException e) {
e.printStackTrace();
}
}
return SPAM_SAMPLES;
}
public static List<Instance<Integer>> getUSPSSamples() {
if (USPS_SAMPLES == null) {
try {
loadUSPSData();
} catch (IOException e) {
e.printStackTrace();
}
}
return USPS_SAMPLES;
}
public static List<Instance<Double>> getBIRTHSSamples() {
if (BIRTHS_SAMPLES == null) {
try {
loadBirthsData();
} catch (IOException e) {
e.printStackTrace();
}
}
return BIRTHS_SAMPLES;
}
public static List<TextInstance<Integer>> getReutersSamples() {
if (REUTERS_SAMPLES == null) {
try {
loadReutersData();
} catch (IOException e) {
e.printStackTrace();
}
}
return REUTERS_SAMPLES;
}
public static List<TextInstance<Boolean>> getTwitterSamples() {
if (TWITTER_SAMPLES == null) {
try {
loadTwitterData();
} catch (IOException e) {
e.printStackTrace();
}
}
return TWITTER_SAMPLES;
}
public static List<Instance<Integer>> getClusteringSamples() {
if (CUSTERING_SAMPLES == null) {
try {
loadClusteringData();
} catch (IOException e) {
e.printStackTrace();
}
}
return CUSTERING_SAMPLES;
}
private static void loadUSPSData() throws IOException {
USPS_SAMPLES = new ArrayList<Instance<Integer>>();
FileInputStream is = new FileInputStream(USPS_FILE);
BufferedReader br = new BufferedReader(new InputStreamReader(is));
try {
String line;
while ((line = br.readLine()) != null) {
try {
String[] values = line.split(" ");
Integer label = Integer.parseInt(values[0]) - 1;
double[] features = new double[values.length - 1];
for (int i = 1; i < values.length; i++) {
features[i - 1] = Double.parseDouble(values[i].split(":")[1]);
}
USPS_SAMPLES.add(new Instance<Integer>(label, features));
} catch (Exception ex) {
System.err.println("Skipped USPS sample : " + line);
}
}
Collections.shuffle(USPS_SAMPLES);
} finally {
is.close();
br.close();
}
}
private static void loadSPAMData() throws IOException {
SPAM_SAMPLES = new ArrayList<Instance<Boolean>>();
FileInputStream is = new FileInputStream(SPAM_FILE);
BufferedReader br = new BufferedReader(new InputStreamReader(is));
try {
String line;
while ((line = br.readLine()) != null) {
try {
String[] values = line.split(";");
Boolean label = "1".equals(values[values.length - 1]);
double[] features = new double[values.length - 1];
for (int i = 0; i < values.length - 1; i++) {
features[i] = Double.parseDouble(values[i]);
}
SPAM_SAMPLES.add(new Instance<Boolean>(label, features));
} catch (Exception ex) {
System.err.println("Skipped SPAM sample : " + line);
}
}
Collections.shuffle(SPAM_SAMPLES);
} finally {
is.close();
br.close();
}
}
private static void loadBirthsData() throws IOException {
BIRTHS_SAMPLES = new ArrayList<Instance<Double>>();
FileInputStream is = new FileInputStream(BIRTHS_FILE);
BufferedReader br = new BufferedReader(new InputStreamReader(is));
try {
String line;
while ((line = br.readLine()) != null) {
try {
String[] values = line.split(";");
Double label = Double.parseDouble(values[values.length - 1]);
double[] features = new double[values.length - 1];
for (int i = 1; i < values.length - 1; i++) {
features[i - 1] = Double.parseDouble(values[i]);
}
BIRTHS_SAMPLES.add(new Instance<Double>(label, features));
} catch (Exception ex) {
System.out.println("Skipped BIRTHS sample : " + line);
}
}
Collections.shuffle(BIRTHS_SAMPLES);
} finally {
is.close();
br.close();
}
}
protected static void loadReutersData() throws IOException {
REUTERS_SAMPLES = new ArrayList<TextInstance<Integer>>();
EnglishTokenizer tokenizer = new EnglishTokenizer();
Map<String, Integer> topics = new HashMap<String, Integer>();
FileInputStream is = new FileInputStream(REUTEURS_FILE);
BufferedReader br = new BufferedReader(new InputStreamReader(is));
try {
String line;
while ((line = br.readLine()) != null) {
try {
// Get class index
String topic = line.split(",")[0];
if (!topics.containsKey(topic)) {
topics.put(topic, topics.size());
}
Integer classIndex = topics.get(topic);
// Get text
int startIndex = line.indexOf(" - ");
String text = line.substring(startIndex, line.length() - 1);
REUTERS_SAMPLES.add(new TextInstance<Integer>(classIndex, tokenizer.tokenize(text)));
} catch (Exception ex) {
System.err.println("Skipped Reuters sample because it can't be parsed : " + line);
}
}
Collections.shuffle(REUTERS_SAMPLES);
} finally {
is.close();
br.close();
}
}
protected static void loadTwitterData() throws IOException {
TWITTER_SAMPLES = new ArrayList<TextInstance<Boolean>>();
TwitterTokenizer tokenizer = new TwitterTokenizer(2, 2);
FileInputStream is = new FileInputStream(TWITTER_FILE);
BufferedReader br = new BufferedReader(new InputStreamReader(is));
try {
String line;
while ((line = br.readLine()) != null) {
try {
String[] values = line.split(",");
Boolean label = !values[0].equals("0");
String text = line.substring(line.indexOf(",") + 1);
TWITTER_SAMPLES.add(new TextInstance<Boolean>(label, tokenizer.tokenize(text)));
} catch (Exception ex) {
System.err.println("Skipped twitter sample because it can't be parsed : " + line);
}
}
Collections.shuffle(TWITTER_SAMPLES);
} finally {
is.close();
br.close();
}
}
protected static void loadClusteringData() throws IOException {
CUSTERING_SAMPLES = new ArrayList<Instance<Integer>>();
FileInputStream is = new FileInputStream(CLUSTERING_FILE);
BufferedReader br = new BufferedReader(new InputStreamReader(is));
try {
String line;
while ((line = br.readLine()) != null) {
try {
String[] values = line.split(";");
Integer label = Integer.parseInt(values[7]);
double[] features = new double[values.length - 1];
for (int i = 0; i < values.length - 1; i++) {
features[i] = Double.parseDouble(values[i]);
}
CUSTERING_SAMPLES.add(new Instance<Integer>(label, features));
} catch (Exception ex) {
ex.printStackTrace();
}
Collections.shuffle(CUSTERING_SAMPLES);
}
} finally {
is.close();
br.close();
}
}
public static List<Instance<Integer>> generateDataForClusterization(int nbCluster, int nbInstances) {
Random random = new Random();
List<Instance<Integer>> samples = new ArrayList<Instance<Integer>>();
for (int i = 0; i < nbInstances; i++) {
Integer label = random.nextInt(nbCluster);
double[] features = new double[] { label + random.nextDouble() * 1.25, -label + random.nextDouble() * 1.25, random.nextDouble() };
Instance<Integer> sample = new Instance<Integer>(label, features);
samples.add(sample);
}
return samples;
}
public static List<Instance<Boolean>> generatedNandInstances(int nb) {
Random random = new Random();
List<Instance<Boolean>> samples = new ArrayList<Instance<Boolean>>();
for (int i = 0; i < nb; i++) {
List<Boolean> nandInputs = Arrays.asList(random.nextBoolean(), random.nextBoolean());
Boolean label = !(nandInputs.get(0) && nandInputs.get(1));
double[] features = new double[] { 1.0, nandInputs.get(0) ? 1.0 : -1.0, nandInputs.get(1) ? 1.0 : -1.0 };
samples.add(new Instance<Boolean>(label, features));
}
return samples;
}
public static List<Instance<Boolean>> generateDataForClassification(int size, int featureSize) {
Random random = new Random();
List<Instance<Boolean>> samples = new ArrayList<Instance<Boolean>>();
for (int i = 0; i < size; i++) {
Double label = random.nextDouble() > 0.5 ? 1.0 : -1.0;
double[] features = new double[featureSize + 1];
for (int j = 0; j < featureSize; j++) {
features[j] = (j % 2 == 0 ? 1.0 : -1.0) * label + random.nextDouble() - 0.5;
}
features[featureSize] = 1.0;
samples.add(new Instance<Boolean>(label > 0, features));
}
return samples;
}
public static List<Instance<Boolean>> generateNonSeparatableDataForClassification(int size) {
Random random = new Random();
List<Instance<Boolean>> samples = new ArrayList<Instance<Boolean>>();
for (int i = 0; i < size; i++) {
Boolean label = random.nextDouble() > 0.5;
double[] features = new double[3];
features[0] = 1.0;
features[1] = (label ? -1.0 : 1.0) * random.nextDouble() + random.nextGaussian() / 2;
features[2] = (label ? -1.0 : 1.0) * random.nextDouble() + random.nextGaussian() / 2;
samples.add(new Instance<Boolean>(label, features));
}
return samples;
}
public static List<Instance<Integer>> generateDataForMultiLabelClassification(int size, int featureSize, int nbClasses) {
Random random = new Random();
List<Instance<Integer>> samples = new ArrayList<Instance<Integer>>();
for (int i = 0; i < size; i++) {
Integer label = random.nextInt(nbClasses);
double[] features = new double[featureSize];
for (int j = 0; j < featureSize; j++) {
features[j] = (j % (label + 1) == 0 ? 1.0 : -1.0) + random.nextDouble() - 0.5;
}
samples.add(new Instance<Integer>(label, features));
}
return samples;
}
public static List<Instance<Double>> generateDataForRegression(int size, int featureSize) {
List<Instance<Double>> samples = new ArrayList<Instance<Double>>();
Random random = new Random();
List<Double> factors = new ArrayList<Double>(featureSize);
for (int i = 0; i < featureSize; i++) {
factors.add(random.nextDouble() * (1 + random.nextInt(2)));
}
for (int i = 0; i < size; i++) {
double label = 0.0;
double[] features = new double[featureSize];
for (int j = 0; j < featureSize; j++) {
double feature = (j % 2 == 0 ? 1.0 : -1.0) * random.nextDouble();
features[j] = feature;
label += factors.get(j) * feature;
}
samples.add(new Instance<Double>(label, features));
}
return samples;
}
}