Package etc.aloe.data

Source Code of etc.aloe.data.MessageSet

/*
* This file is part of ALOE.
*
* ALOE is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.

* ALOE is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.

* You should have received a copy of the GNU General Public License
* along with ALOE.  If not, see <http://www.gnu.org/licenses/>.
*
* Copyright (c) 2012 SCCL, University of Washington (http://depts.washington.edu/sccl)
*/
package etc.aloe.data;

import com.csvreader.CsvReader;
import com.csvreader.CsvWriter;
import etc.aloe.processes.Loading;
import etc.aloe.processes.Saving;
import java.io.IOException;
import java.io.InputStream;
import java.io.InvalidObjectException;
import java.io.OutputStream;
import java.nio.charset.Charset;
import java.text.DateFormat;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.List;

/**
* MessageSet contains messages.
*
* @author Michael Brooks <mjbrooks@uw.edu>
*/
public class MessageSet implements Loading, Saving {

    private List<Message> messages = new ArrayList<Message>();
    private static final int ID_COLUMN = 0;
    private static final int TIME_COLUMN = 1;
    private static final int PARTICIPANT_COLUMN = 2;
    private static final int MESSAGE_COLUMN = 3;
    private static final int TRUTH_COLUMN = 4;
    private static final int PREDICTION_COLUMN = 5;
    private static final int SEGMENT_COLUMN = 6;
    private static final int CONFIDENCE_COLUMN = 7;
    private static final int MIN_INPUT_COLUMNS = 4;
    private static final int NUM_OUTPUT_COLUMNS = 8;
    private static final String ID_COLUMN_NAME = "id";
    private static final String TIME_COLUMN_NAME = "time";
    private static final String PARTICIPANT_COLUMN_NAME = "participant";
    private static final String MESSAGE_COLUMN_NAME = "message";
    private static final String TRUTH_COLUMN_NAME = "truth";
    private static final String PREDICTION_COLUMN_NAME = "predicted";
    private static final String SEGMENT_COLUMN_NAME = "segment";
    private static final String CONFIDENCE_COLUMN_NAME = "confidence";
   
    private DateFormat dateFormat;
    private Charset charset = Charset.forName("UTF-8");

    /**
     * Add a message to the set.
     *
     * @param message
     */
    public void add(Message message) {
        messages.add(message);
    }
   
    /**
     * Add a bunch of messages to the message set.
     * @param messages
     */
    public void addAll(List<Message> messages) {
        this.messages.addAll(messages);
    }

    /**
     * Get the underlying list of messages.
     *
     * @return
     */
    public List<Message> getMessages() {
        return messages;
    }

    /**
     * Makes sure that the csv headers contain the minimum required fields.
     *
     * @param csvReader
     * @throws InvalidObjectException
     */
    private void validateCSVHeaders(CsvReader csvReader) throws InvalidObjectException {

        String[] headers = null;
        try {
            if (!csvReader.readHeaders()) {
                throw new InvalidObjectException("CSV must contain headers in the first row");
            }
            headers = csvReader.getHeaders();
        } catch (IOException e) {
            throw new InvalidObjectException(e.getMessage());
        }

        if (headers.length < MIN_INPUT_COLUMNS) {
            throw new InvalidObjectException("CSV must contain at least " + (MIN_INPUT_COLUMNS) + " columns");
        }

        if (headers.length > NUM_OUTPUT_COLUMNS) {
            throw new InvalidObjectException("CSV must contain no more than " + (NUM_OUTPUT_COLUMNS) + " columns");
        }

        List<String> headerList = Arrays.asList(headers);
        if (!headerList.contains(ID_COLUMN_NAME)) {
            throw new InvalidObjectException("'" + ID_COLUMN_NAME + "' column must be present.");
        }

        if (!headerList.contains(TIME_COLUMN_NAME)) {
            throw new InvalidObjectException("'" + TIME_COLUMN_NAME + "' column must be present.");
        }

        if (!headerList.contains(PARTICIPANT_COLUMN_NAME)) {
            throw new InvalidObjectException("'" + PARTICIPANT_COLUMN_NAME + "' column must be present.");
        }

        if (!headerList.contains(MESSAGE_COLUMN_NAME)) {
            throw new InvalidObjectException("'" + MESSAGE_COLUMN_NAME + "' column must be present.");
        }
    }

    @Override
    public boolean load(InputStream source) throws InvalidObjectException {
        if (dateFormat == null) {
            throw new IllegalStateException("No date format provided.");
        }

        CsvReader csvReader = new CsvReader(source, charset);

        try {
            validateCSVHeaders(csvReader);

            int lineNumber = 1;
            int numLabeled = 0;
            while (csvReader.readRecord()) {
                lineNumber++;

                String idText = csvReader.get(ID_COLUMN_NAME);
                String messageText = csvReader.get(MESSAGE_COLUMN_NAME);
                String participant = csvReader.get(PARTICIPANT_COLUMN_NAME);
                String timeText = csvReader.get(TIME_COLUMN_NAME);
                String truthText = csvReader.get(TRUTH_COLUMN_NAME).toLowerCase();
                String predictionText = csvReader.get(PREDICTION_COLUMN_NAME).toLowerCase();
                String segmentIdText = csvReader.get(SEGMENT_COLUMN_NAME).toLowerCase();

                int id = -1;
                try {
                    id = Integer.parseInt(idText);
                } catch (NumberFormatException e) {
                    throw new InvalidObjectException("Invalid value '" + idText + "' for '" + ID_COLUMN_NAME + "' on line " + lineNumber);
                }

                Date time = null;
                try {
                    time = getDateFormat().parse(timeText);
                } catch (ParseException e) {
                    throw new InvalidObjectException("Invalid value '" + timeText + "' for '" + TIME_COLUMN_NAME + "' on line " + lineNumber);
                }

                Boolean truth = null;
                if (truthText.equals("1") || truthText.equals("true")) {
                    truth = true;
                    numLabeled++;
                } else if (truthText.equals("0") || truthText.equals("false")) {
                    truth = false;
                    numLabeled++;
                } else if (truthText.equals("")) {
                    truth = null;
                } else {
                    throw new InvalidObjectException("Invalid value '" + truthText + "' for '" + TRUTH_COLUMN_NAME + "' on line " + lineNumber);
                }

                Boolean prediction = null;
                if (predictionText.equals("1") || predictionText.equals("true")) {
                    prediction = true;
                } else if (predictionText.equals("0") || predictionText.equals("false")) {
                    prediction = false;
                } else if (predictionText.equals("")) {
                    prediction = null;
                } else {
                    throw new InvalidObjectException("Invalid value '" + predictionText + "' for '" + PREDICTION_COLUMN_NAME + "' on line " + lineNumber);
                }

                int segment = -1;
                if (segmentIdText.equals("")) {
                    segment = -1;
                } else {
                    try {
                        segment = Integer.parseInt(segmentIdText);
                    } catch (NumberFormatException e) {
                        throw new InvalidObjectException("Invalid value '" + segmentIdText + "' for '" + SEGMENT_COLUMN_NAME + "' on line " + lineNumber);
                    }
                }

                Message message = new Message(id, time, participant, messageText, truth, prediction, segment);
                this.add(message);
            }

            System.out.println("Loaded " + this.size() + " raw messages (" + numLabeled + " labeled).");

        } catch (IOException ex) {
            throw new InvalidObjectException(ex.getMessage());
        }

        return true;
    }

    @Override
    public boolean save(OutputStream destination) throws IOException {
        if (dateFormat == null) {
            throw new IllegalStateException("No date format provided.");
        }
       
        //Add a BOM for Excel
        destination.write(charset.encode("\ufeff").array());
       
        CsvWriter out = new CsvWriter(destination, ',', charset);
       
        String[] row = new String[NUM_OUTPUT_COLUMNS];
        row[ID_COLUMN] = ID_COLUMN_NAME;
        row[PARTICIPANT_COLUMN] = PARTICIPANT_COLUMN_NAME;
        row[TIME_COLUMN] = TIME_COLUMN_NAME;
        row[MESSAGE_COLUMN] = MESSAGE_COLUMN_NAME;
        row[TRUTH_COLUMN] = TRUTH_COLUMN_NAME;
        row[PREDICTION_COLUMN] = PREDICTION_COLUMN_NAME;
        row[CONFIDENCE_COLUMN] = CONFIDENCE_COLUMN_NAME;
        row[SEGMENT_COLUMN] = SEGMENT_COLUMN_NAME;

        out.writeRecord(row);

        for (Message message : messages) {
            row[ID_COLUMN] = Integer.toString(message.getId());
            row[PARTICIPANT_COLUMN] = message.getParticipant();
            row[TIME_COLUMN] = dateFormat.format(message.getTimestamp());
            row[MESSAGE_COLUMN] = message.getMessage();
            row[TRUTH_COLUMN] = message.hasTrueLabel() ? message.getTrueLabel().toString() : null;
            row[PREDICTION_COLUMN] = message.hasPredictedLabel() ? message.getPredictedLabel().toString() : null;
            row[CONFIDENCE_COLUMN] = message.hasPredictionConfidence() ? message.getPredictionConfidence().toString() : null;
            row[SEGMENT_COLUMN] = message.hasSegmentId() ? Integer.toString(message.getSegmentId()) : null;

            out.writeRecord(row);
        }

        out.flush();
        return true;
    }

    public DateFormat getDateFormat() {
        return this.dateFormat;
    }

    /**
     * Set the date format used to import/export timestamps.
     *
     * @param dateFormat
     */
    public void setDateFormat(DateFormat dateFormat) {
        this.dateFormat = dateFormat;
    }

    /**
     * Get the size of the message set.
     *
     * @return
     */
    public int size() {
        return this.messages.size();
    }

    /**
     * Get the ith message.
     *
     * @param i
     * @return
     */
    public Message get(int i) {
        return this.messages.get(i);
    }
}
TOP

Related Classes of etc.aloe.data.MessageSet

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.