Package com.heatonresearch.aifh.examples.capstone.model.milestone1

Source Code of com.heatonresearch.aifh.examples.capstone.model.milestone1.NormalizeTitanic

/*
* Artificial Intelligence for Humans
* Volume 2: Nature Inspired Algorithms
* Java Version
* http://www.aifh.org
* http://www.jeffheaton.com
*
* Code repository:
* https://github.com/jeffheaton/aifh
*
* Copyright 2014 by Jeff Heaton
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package com.heatonresearch.aifh.examples.capstone.model.milestone1;

import au.com.bytecode.opencsv.CSVReader;
import au.com.bytecode.opencsv.CSVWriter;
import com.heatonresearch.aifh.examples.capstone.model.TitanicConfig;
import com.heatonresearch.aifh.examples.util.FormatNumeric;
import com.heatonresearch.aifh.general.data.BasicData;

import java.io.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
* This capstone project shows how to apply some of the techniques in this book to data science.  The data set used
* is the Kaggle titanic data set.  You can find that data set here.
* <p/>
* http://www.kaggle.com/c/titanic-gettingStarted
* <p/>
* There are three parts to this assignment.
* <p/>
* Part 1: Obtain and normalize data, extrapolate features
* Part 2: Cross validate and select model hyperparameters
* Part 3: Build Kaggle submission file
* <p/>
* This is part 1 of the project.
* <p/>
* [age,sex-male,pclass,sibsp,parch,fare,embarked-c,embarked-q,embarked-s,name-mil,name-nobility,name-dr,name-clergy]
*/
public class NormalizeTitanic {

    /**
     * Analyze and generate stats for titanic data.
     *
     * @param stats    The stats for titanic.
     * @param filename The file to analyze.
     * @return The passenger count.
     * @throws IOException Errors reading file.
     */
    public static int analyze(TitanicStats stats, File filename) throws IOException {
        int count = 0;
        Map<String, Integer> headerMap = new HashMap<String, Integer>();

        InputStream istream = new FileInputStream(filename);
        CSVReader reader = new CSVReader(new InputStreamReader(istream));

        String[] header = reader.readNext();
        for (int i = 0; i < header.length; i++) {
            headerMap.put(header[i].toLowerCase(), i);
        }

        int ageIndex = headerMap.get("age");
        int nameIndex = headerMap.get("name");
        int sexIndex = headerMap.get("sex");
        int indexEmbarked = headerMap.get("embarked");
        int indexFare = headerMap.get("fare");
        int indexPclass = headerMap.get("pclass");

        int survivedIndex = -1;

        // test data does not have survived
        if (headerMap.containsKey("survived")) {
            survivedIndex = headerMap.get("survived");
        }

        String[] nextLine;

        while ((nextLine = reader.readNext()) != null) {
            count++;
            String name = nextLine[nameIndex];
            String ageStr = nextLine[ageIndex];
            String sexStr = nextLine[sexIndex];
            String embarkedStr = nextLine[indexEmbarked];

            // test data does not have survived, do not use survived boolean if using test data!
            boolean survived = false;
            if (survivedIndex != -1) {
                String survivedStr = nextLine[survivedIndex];
                survived = survivedStr.equals("1");
            }

            if (indexEmbarked != -1) {
                embarkedStr = nextLine[indexEmbarked];
            }

            // calculate average fare per class
            String strFare = nextLine[indexFare];
            if (strFare.length() > 0) {
                double fare = Double.parseDouble(strFare);
                String pclass = nextLine[indexPclass];
                if (pclass.equals("1")) {
                    stats.getMeanFare1().update(fare);
                } else if (pclass.equals("2")) {
                    stats.getMeanFare2().update(fare);
                } else if (pclass.equals("3")) {
                    stats.getMeanFare3().update(fare);
                }
            }


            boolean isMale = sexStr.equalsIgnoreCase("male");
            double age;

            // Only compute survival stats on training data
            if (survivedIndex != -1) {
                if (embarkedStr.equals("Q")) {
                    stats.getEmbarkedQ().update(isMale, survived);
                } else if (embarkedStr.equals("S")) {
                    stats.getEmbarkedS().update(isMale, survived);
                } else if (embarkedStr.equals("C")) {
                    stats.getEmbarkedC().update(isMale, survived);
                }
            }

            stats.getEmbarkedHisto().update(embarkedStr);

            // Only compute survival stats on training data.
            if (survivedIndex != -1) {
                stats.getSurvivalTotal().update(isMale, survived);
            }

            if (survivedIndex != -1) {
                if (name.contains("Master.")) {
                    stats.getSurvivalMaster().update(isMale, survived);
                } else if (name.contains("Mr.")) {
                    stats.getSurvivalMr().update(isMale, survived);
                } else if (name.contains("Miss.") || name.contains("Mlle.")) {
                    stats.getSurvivalMiss().update(isMale, survived);
                } else if (name.contains("Mrs.") || name.contains("Mme.")) {
                    stats.getSurvivalMrs().update(isMale, survived);
                } else if (name.contains("Col.") || name.contains("Capt.") || name.contains("Major.")) {
                    stats.getSurvivalMilitary().update(isMale, survived);
                } else if (name.contains("Countess.") || name.contains("Lady.") || name.contains("Sir.") || name.contains("Don.") || name.contains("Dona.") || name.contains("Jonkheer.")) {
                    stats.getSurvivalNobility().update(isMale, survived);
                } else if (name.contains("Dr.")) {
                    stats.getSurvivalDr().update(isMale, survived);
                } else if (name.contains("Rev.")) {
                    stats.getSurvivalClergy().update(isMale, survived);
                }
            }

            if (ageStr.length() > 0) {
                age = Double.parseDouble(ageStr);

                // Update general mean age for male/female
                if (isMale) {
                    stats.getMeanMale().update(age);
                } else {
                    stats.getMeanFemale().update(age);
                }

                // Update the total average age
                stats.getMeanTotal().update(age);

                if (name.contains("Master.")) {
                    stats.getMeanMaster().update(age);
                    // Only compute survival stats on training data.
                    if (survivedIndex != -1) {
                        stats.getSurvivalMaster().update(isMale, survived);
                    }
                } else if (name.contains("Mr.")) {
                    stats.getMeanMr().update(age);
                    // Only compute survival stats on training data.
                    if (survivedIndex != -1) {
                        stats.getSurvivalMr().update(isMale, survived);
                    }
                } else if (name.contains("Miss.") || name.contains("Mlle.")) {
                    stats.getMeanMiss().update(age);
                    // Only compute survival stats on training data.
                    if (survivedIndex != -1) {
                        stats.getSurvivalMiss().update(isMale, survived);
                    }
                } else if (name.contains("Mrs.") || name.contains("Mme.")) {
                    stats.getMeanMrs().update(age);
                    // Only compute survival stats on training data.
                    if (survivedIndex != -1) {
                        stats.getSurvivalMrs().update(isMale, survived);
                    }
                } else if (name.contains("Col.") || name.contains("Capt.") || name.contains("Major.")) {
                    stats.getMeanMilitary().update(age);
                    // Only compute survival stats on training data.
                    if (survivedIndex != -1) {
                        stats.getSurvivalMilitary().update(isMale, survived);
                    }
                } else if (name.contains("Countess.") || name.contains("Lady.") || name.contains("Sir.") || name.contains("Don.") || name.contains("Dona.") || name.contains("Jonkheer.")) {
                    stats.getMeanNobility().update(age);
                    // Only compute survival stats on training data.
                    if (survivedIndex != -1) {
                        stats.getSurvivalNobility().update(isMale, survived);
                    }
                } else if (name.contains("Dr.")) {
                    stats.getMeanDr().update(age);
                    // Only compute survival stats on training data.
                    if (survivedIndex != -1) {
                        stats.getSurvivalDr().update(isMale, survived);
                    }
                } else if (name.contains("Rev.")) {
                    stats.getMeanClergy().update(age);
                    // Only compute survival stats on training data.
                    if (survivedIndex != -1) {
                        stats.getSurvivalClergy().update(isMale, survived);
                    }
                }
            }
        }
        return count;
    }

    /**
     * Normalize to a range.
     *
     * @param x              The value to normalize.
     * @param dataLow        The low end of the range of the data.
     * @param dataHigh       The high end of the range of the data.
     * @param normalizedLow  The normalized low end of the range of data.
     * @param normalizedHigh The normalized high end of the range of data.
     * @return The normalized value.
     */
    public static double rangeNormalize(double x, double dataLow, double dataHigh, double normalizedLow, double normalizedHigh) {
        return ((x - dataLow)
                / (dataHigh - dataLow))
                * (normalizedHigh - normalizedLow) + normalizedLow;
    }

    public static List<BasicData> normalize(TitanicStats stats, File filename, List<String> ids,
                                            double inputLow, double inputHigh,
                                            double predictSurvive, double predictPerish) throws IOException {
        List<BasicData> result = new ArrayList<BasicData>();

        Map<String, Integer> headerMap = new HashMap<String, Integer>();

        InputStream istream = new FileInputStream(filename);
        CSVReader reader = new CSVReader(new InputStreamReader(istream));

        String[] header = reader.readNext();
        for (int i = 0; i < header.length; i++) {
            headerMap.put(header[i].toLowerCase(), i);
        }

        int ageIndex = headerMap.get("age");
        int nameIndex = headerMap.get("name");
        int sexIndex = headerMap.get("sex");
        int indexEmbarked = headerMap.get("embarked");
        int indexPclass = headerMap.get("pclass");
        int indexSibsp = headerMap.get("sibsp");
        int indexParch = headerMap.get("parch");
        int indexFare = headerMap.get("fare");
        int indexId = headerMap.get("passengerid");
        int survivedIndex = -1;

        // test data does not have survived
        if (headerMap.containsKey("survived")) {
            survivedIndex = headerMap.get("survived");
        }

        String[] nextLine;

        while ((nextLine = reader.readNext()) != null) {
            BasicData data = new BasicData(TitanicConfig.InputFeatureCount, 1);

            String name = nextLine[nameIndex];
            String sex = nextLine[sexIndex];
            String embarked = nextLine[indexEmbarked];
            String id = nextLine[indexId];

            // Add record the passenger id, if requested
            if (ids != null) {
                ids.add(id);
            }

            boolean isMale = sex.equalsIgnoreCase("male");


            // age
            double age;

            // do we have an age for this person?
            if (nextLine[ageIndex].length() == 0) {
                // age is missing, interpolate using name
                if (name.contains("Master.")) {
                    age = stats.getMeanMaster().calculate();
                } else if (name.contains("Mr.")) {
                    age = stats.getMeanMr().calculate();
                } else if (name.contains("Miss.") || name.contains("Mlle.")) {
                    age = stats.getMeanMiss().calculate();
                } else if (name.contains("Mrs.") || name.contains("Mme.")) {
                    age = stats.getMeanMrs().calculate();
                } else if (name.contains("Col.") || name.contains("Capt.") || name.contains("Major.")) {
                    age = stats.getMeanMiss().calculate();
                } else if (name.contains("Countess.") || name.contains("Lady.") || name.contains("Sir.") || name.contains("Don.") || name.contains("Dona.") || name.contains("Jonkheer.")) {
                    age = stats.getMeanNobility().calculate();
                } else if (name.contains("Dr.")) {
                    age = stats.getMeanDr().calculate();
                } else if (name.contains("Rev.")) {
                    age = stats.getMeanClergy().calculate();
                } else {
                    if (isMale) {
                        age = stats.getMeanMale().calculate();
                    } else {
                        age = stats.getMeanFemale().calculate();
                    }
                }
            } else {
                age = Double.parseDouble(nextLine[ageIndex]);

            }
            data.getInput()[0] = rangeNormalize(age, 0, 100, inputLow, inputHigh);

            // sex-male
            data.getInput()[1] = isMale ? inputHigh : inputLow;

            // pclass
            double pclass = Double.parseDouble(nextLine[indexPclass]);
            data.getInput()[2] = rangeNormalize(pclass, 1, 3, inputLow, inputHigh);

            // sibsp
            double sibsp = Double.parseDouble(nextLine[indexSibsp]);
            data.getInput()[3] = rangeNormalize(sibsp, 0, 10, inputLow, inputHigh);

            // parch
            double parch = Double.parseDouble(nextLine[indexParch]);
            data.getInput()[4] = rangeNormalize(parch, 0, 10, inputLow, inputHigh);

            // fare
            String strFare = nextLine[indexFare];
            double fare;

            if (strFare.length() == 0) {
                if (((int) pclass) == 1) {
                    fare = stats.getMeanFare1().calculate();
                } else if (((int) pclass) == 2) {
                    fare = stats.getMeanFare2().calculate();
                } else if (((int) pclass) == 3) {
                    fare = stats.getMeanFare3().calculate();
                } else {
                    // should not happen, we would have a class other than 1,2,3.
                    // however, if that DID happen, use the median class (2).
                    fare = stats.getMeanFare2().calculate();
                }
            } else {
                fare = Double.parseDouble(nextLine[indexFare]);
            }
            data.getInput()[5] = rangeNormalize(fare, 0, 500, inputLow, inputHigh);

            // embarked-c
            data.getInput()[6] = embarked.trim().equalsIgnoreCase("c") ? inputHigh : inputLow;

            // embarked-q
            data.getInput()[7] = embarked.trim().equalsIgnoreCase("q") ? inputHigh : inputLow;

            // embarked-s
            data.getInput()[8] = embarked.trim().equalsIgnoreCase("s") ? inputHigh : inputLow;

            // name-mil
            data.getInput()[9] = (name.contains("Col.") || name.contains("Capt.") || name.contains("Major.")) ? inputHigh : inputLow;

            // name-nobility
            data.getInput()[10] = (name.contains("Countess.") || name.contains("Lady.") || name.contains("Sir.") || name.contains("Don.") || name.contains("Dona.") || name.contains("Jonkheer.")) ? inputHigh : inputLow;

            // name-dr
            data.getInput()[11] = (name.contains("Dr.")) ? inputHigh : inputLow;


            // name-clergy
            data.getInput()[12] = (name.contains("Rev.")) ? inputHigh : inputLow;

            // add the new row
            result.add(data);

            // add survived, if it exists
            if (survivedIndex != -1) {
                int survived = Integer.parseInt(nextLine[survivedIndex]);
                data.getIdeal()[0] = (survived == 1) ? predictSurvive : predictPerish;
            }

        }

        return result;
    }

    /**
     * The main method.
     *
     * @param args The arguments.
     */
    public static void main(String[] args) {
        String filename;

        if (args.length != 1) {
            filename = System.getProperty("FILENAME");
            if( filename==null ) {
                System.out.println("Please call this program with a single parameter that specifies your data directory.\n" +
                        "If you are calling with gradle, consider:\n" +
                "gradle runCapstoneTitanic1 -Pdata_path=[path to your data directory]\n");
                System.exit(0);
            }
        } else {
            filename = args[0];
        }

        File dataPath = new File(filename);
        File trainingPath = new File(dataPath, TitanicConfig.TrainingFilename);
        File testPath = new File(dataPath, TitanicConfig.TestFilename);
        File normalizePath = new File(dataPath, TitanicConfig.NormDumpFilename);


        try {
            TitanicStats stats = new TitanicStats();
            analyze(stats, trainingPath);
            analyze(stats, testPath);
            stats.dump();

            List<String> ids = new ArrayList<String>();
            List<BasicData> training = normalize(stats, trainingPath, ids,
                    TitanicConfig.InputNormalizeLow,
                    TitanicConfig.InputNormalizeHigh,
                    TitanicConfig.PredictSurvive,
                    TitanicConfig.PredictPerish);

            // Write out the normalized file, mainly so that you can examine it.
            // This file is not actually used by the program.
            FileOutputStream fos = new FileOutputStream(normalizePath);
            CSVWriter csv = new CSVWriter(new OutputStreamWriter(fos));

            csv.writeNext(new String[]{
                    "id",
                    "age", "sex-male", "pclass", "sibsp", "parch", "fare",
                    "embarked-c", "embarked-q", "embarked-s", "name-mil", "name-nobility", "name-dr", "name-clergy"
            });

            int idx = 0;
            for (BasicData data : training) {
                String[] line = {
                        ids.get(idx++),
                        FormatNumeric.formatDouble(data.getInput()[0], 5),
                        FormatNumeric.formatDouble(data.getInput()[1], 5),
                        FormatNumeric.formatDouble(data.getInput()[2], 5),
                        FormatNumeric.formatDouble(data.getInput()[3], 5),
                        FormatNumeric.formatDouble(data.getInput()[4], 5),
                        FormatNumeric.formatDouble(data.getInput()[5], 5),
                        FormatNumeric.formatDouble(data.getInput()[6], 5),
                        FormatNumeric.formatDouble(data.getInput()[7], 5),
                        FormatNumeric.formatDouble(data.getInput()[8], 5),
                        FormatNumeric.formatDouble(data.getInput()[9], 5),
                        FormatNumeric.formatDouble(data.getInput()[10], 5),
                        FormatNumeric.formatDouble(data.getInput()[11], 5),
                        FormatNumeric.formatDouble(data.getInput()[12], 5),
                        FormatNumeric.formatDouble(data.getIdeal()[0], 5)

                };

                csv.writeNext(line);
            }

            csv.close();
            fos.close();

        } catch (IOException ex) {
            ex.printStackTrace();
        }


    }
}
TOP

Related Classes of com.heatonresearch.aifh.examples.capstone.model.milestone1.NormalizeTitanic

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.