/*
* This file is part of ALOE.
*
* ALOE is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
* ALOE is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
* You should have received a copy of the GNU General Public License
* along with ALOE. If not, see <http://www.gnu.org/licenses/>.
*
* Copyright (c) 2012 SCCL, University of Washington (http://depts.washington.edu/sccl)
*/
package etc.aloe.data;
import com.csvreader.CsvReader;
import com.csvreader.CsvWriter;
import etc.aloe.processes.Loading;
import etc.aloe.processes.Saving;
import java.io.IOException;
import java.io.InputStream;
import java.io.InvalidObjectException;
import java.io.OutputStream;
import java.nio.charset.Charset;
import java.text.DateFormat;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
/**
* MessageSet contains messages.
*
* @author Michael Brooks <mjbrooks@uw.edu>
*/
public class MessageSet implements Loading, Saving {
private List<Message> messages = new ArrayList<Message>();
private static final int ID_COLUMN = 0;
private static final int TIME_COLUMN = 1;
private static final int PARTICIPANT_COLUMN = 2;
private static final int MESSAGE_COLUMN = 3;
private static final int TRUTH_COLUMN = 4;
private static final int PREDICTION_COLUMN = 5;
private static final int SEGMENT_COLUMN = 6;
private static final int CONFIDENCE_COLUMN = 7;
private static final int MIN_INPUT_COLUMNS = 4;
private static final int NUM_OUTPUT_COLUMNS = 8;
private static final String ID_COLUMN_NAME = "id";
private static final String TIME_COLUMN_NAME = "time";
private static final String PARTICIPANT_COLUMN_NAME = "participant";
private static final String MESSAGE_COLUMN_NAME = "message";
private static final String TRUTH_COLUMN_NAME = "truth";
private static final String PREDICTION_COLUMN_NAME = "predicted";
private static final String SEGMENT_COLUMN_NAME = "segment";
private static final String CONFIDENCE_COLUMN_NAME = "confidence";
private DateFormat dateFormat;
private Charset charset = Charset.forName("UTF-8");
/**
* Add a message to the set.
*
* @param message
*/
public void add(Message message) {
messages.add(message);
}
/**
* Add a bunch of messages to the message set.
* @param messages
*/
public void addAll(List<Message> messages) {
this.messages.addAll(messages);
}
/**
* Get the underlying list of messages.
*
* @return
*/
public List<Message> getMessages() {
return messages;
}
/**
* Makes sure that the csv headers contain the minimum required fields.
*
* @param csvReader
* @throws InvalidObjectException
*/
private void validateCSVHeaders(CsvReader csvReader) throws InvalidObjectException {
String[] headers = null;
try {
if (!csvReader.readHeaders()) {
throw new InvalidObjectException("CSV must contain headers in the first row");
}
headers = csvReader.getHeaders();
} catch (IOException e) {
throw new InvalidObjectException(e.getMessage());
}
if (headers.length < MIN_INPUT_COLUMNS) {
throw new InvalidObjectException("CSV must contain at least " + (MIN_INPUT_COLUMNS) + " columns");
}
if (headers.length > NUM_OUTPUT_COLUMNS) {
throw new InvalidObjectException("CSV must contain no more than " + (NUM_OUTPUT_COLUMNS) + " columns");
}
List<String> headerList = Arrays.asList(headers);
if (!headerList.contains(ID_COLUMN_NAME)) {
throw new InvalidObjectException("'" + ID_COLUMN_NAME + "' column must be present.");
}
if (!headerList.contains(TIME_COLUMN_NAME)) {
throw new InvalidObjectException("'" + TIME_COLUMN_NAME + "' column must be present.");
}
if (!headerList.contains(PARTICIPANT_COLUMN_NAME)) {
throw new InvalidObjectException("'" + PARTICIPANT_COLUMN_NAME + "' column must be present.");
}
if (!headerList.contains(MESSAGE_COLUMN_NAME)) {
throw new InvalidObjectException("'" + MESSAGE_COLUMN_NAME + "' column must be present.");
}
}
@Override
public boolean load(InputStream source) throws InvalidObjectException {
if (dateFormat == null) {
throw new IllegalStateException("No date format provided.");
}
CsvReader csvReader = new CsvReader(source, charset);
try {
validateCSVHeaders(csvReader);
int lineNumber = 1;
int numLabeled = 0;
while (csvReader.readRecord()) {
lineNumber++;
String idText = csvReader.get(ID_COLUMN_NAME);
String messageText = csvReader.get(MESSAGE_COLUMN_NAME);
String participant = csvReader.get(PARTICIPANT_COLUMN_NAME);
String timeText = csvReader.get(TIME_COLUMN_NAME);
String truthText = csvReader.get(TRUTH_COLUMN_NAME).toLowerCase();
String predictionText = csvReader.get(PREDICTION_COLUMN_NAME).toLowerCase();
String segmentIdText = csvReader.get(SEGMENT_COLUMN_NAME).toLowerCase();
int id = -1;
try {
id = Integer.parseInt(idText);
} catch (NumberFormatException e) {
throw new InvalidObjectException("Invalid value '" + idText + "' for '" + ID_COLUMN_NAME + "' on line " + lineNumber);
}
Date time = null;
try {
time = getDateFormat().parse(timeText);
} catch (ParseException e) {
throw new InvalidObjectException("Invalid value '" + timeText + "' for '" + TIME_COLUMN_NAME + "' on line " + lineNumber);
}
Boolean truth = null;
if (truthText.equals("1") || truthText.equals("true")) {
truth = true;
numLabeled++;
} else if (truthText.equals("0") || truthText.equals("false")) {
truth = false;
numLabeled++;
} else if (truthText.equals("")) {
truth = null;
} else {
throw new InvalidObjectException("Invalid value '" + truthText + "' for '" + TRUTH_COLUMN_NAME + "' on line " + lineNumber);
}
Boolean prediction = null;
if (predictionText.equals("1") || predictionText.equals("true")) {
prediction = true;
} else if (predictionText.equals("0") || predictionText.equals("false")) {
prediction = false;
} else if (predictionText.equals("")) {
prediction = null;
} else {
throw new InvalidObjectException("Invalid value '" + predictionText + "' for '" + PREDICTION_COLUMN_NAME + "' on line " + lineNumber);
}
int segment = -1;
if (segmentIdText.equals("")) {
segment = -1;
} else {
try {
segment = Integer.parseInt(segmentIdText);
} catch (NumberFormatException e) {
throw new InvalidObjectException("Invalid value '" + segmentIdText + "' for '" + SEGMENT_COLUMN_NAME + "' on line " + lineNumber);
}
}
Message message = new Message(id, time, participant, messageText, truth, prediction, segment);
this.add(message);
}
System.out.println("Loaded " + this.size() + " raw messages (" + numLabeled + " labeled).");
} catch (IOException ex) {
throw new InvalidObjectException(ex.getMessage());
}
return true;
}
@Override
public boolean save(OutputStream destination) throws IOException {
if (dateFormat == null) {
throw new IllegalStateException("No date format provided.");
}
//Add a BOM for Excel
destination.write(charset.encode("\ufeff").array());
CsvWriter out = new CsvWriter(destination, ',', charset);
String[] row = new String[NUM_OUTPUT_COLUMNS];
row[ID_COLUMN] = ID_COLUMN_NAME;
row[PARTICIPANT_COLUMN] = PARTICIPANT_COLUMN_NAME;
row[TIME_COLUMN] = TIME_COLUMN_NAME;
row[MESSAGE_COLUMN] = MESSAGE_COLUMN_NAME;
row[TRUTH_COLUMN] = TRUTH_COLUMN_NAME;
row[PREDICTION_COLUMN] = PREDICTION_COLUMN_NAME;
row[CONFIDENCE_COLUMN] = CONFIDENCE_COLUMN_NAME;
row[SEGMENT_COLUMN] = SEGMENT_COLUMN_NAME;
out.writeRecord(row);
for (Message message : messages) {
row[ID_COLUMN] = Integer.toString(message.getId());
row[PARTICIPANT_COLUMN] = message.getParticipant();
row[TIME_COLUMN] = dateFormat.format(message.getTimestamp());
row[MESSAGE_COLUMN] = message.getMessage();
row[TRUTH_COLUMN] = message.hasTrueLabel() ? message.getTrueLabel().toString() : null;
row[PREDICTION_COLUMN] = message.hasPredictedLabel() ? message.getPredictedLabel().toString() : null;
row[CONFIDENCE_COLUMN] = message.hasPredictionConfidence() ? message.getPredictionConfidence().toString() : null;
row[SEGMENT_COLUMN] = message.hasSegmentId() ? Integer.toString(message.getSegmentId()) : null;
out.writeRecord(row);
}
out.flush();
return true;
}
public DateFormat getDateFormat() {
return this.dateFormat;
}
/**
* Set the date format used to import/export timestamps.
*
* @param dateFormat
*/
public void setDateFormat(DateFormat dateFormat) {
this.dateFormat = dateFormat;
}
/**
* Get the size of the message set.
*
* @return
*/
public int size() {
return this.messages.size();
}
/**
* Get the ith message.
*
* @param i
* @return
*/
public Message get(int i) {
return this.messages.get(i);
}
}