/*
* Copyright 2010 JBoss Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.drools.planner.benchmark.statistic;
import java.awt.image.BufferedImage;
import java.io.OutputStream;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
import java.util.ArrayList;
import java.util.Collections;
import java.io.File;
import java.io.Writer;
import java.io.OutputStreamWriter;
import java.io.FileOutputStream;
import java.io.IOException;
import javax.imageio.ImageIO;
import org.drools.planner.core.Solver;
import org.drools.planner.core.score.Score;
import org.apache.commons.io.IOUtils;
import org.drools.planner.core.score.definition.ScoreDefinition;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.xy.XYItemRenderer;
import org.jfree.chart.renderer.xy.XYStepRenderer;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
public class BestScoreStatistic implements SolverStatistic {
private List<String> configNameList = new ArrayList<String>();
// key is the configName
private Map<String, BestScoreStatisticListener> bestScoreStatisticListenerMap
= new HashMap<String, BestScoreStatisticListener>();
private ScoreDefinition scoreDefinition = null;
public void addListener(Solver solver, String configName) {
if (configNameList.contains(configName)) {
throw new IllegalArgumentException("Cannot add a listener with the same configName (" + configName
+ ") twice.");
}
configNameList.add(configName);
BestScoreStatisticListener bestScoreStatisticListener = new BestScoreStatisticListener();
solver.addEventListener(bestScoreStatisticListener);
bestScoreStatisticListenerMap.put(configName, bestScoreStatisticListener);
if (scoreDefinition == null) {
scoreDefinition = solver.getScoreDefinition();
} else {
if (!scoreDefinition.getClass().equals(solver.getScoreDefinition().getClass())) {
throw new IllegalStateException("The scoreDefinition (" + solver.getScoreDefinition()
+ ") should be of the same class as the other scoreDefinition (" + scoreDefinition + ")");
}
}
}
public void removeListener(Solver solver, String configName) {
BestScoreStatisticListener bestScoreStatisticListener = bestScoreStatisticListenerMap.get(configName);
solver.removeEventListener(bestScoreStatisticListener);
}
public CharSequence writeStatistic(File solverStatisticFilesDirectory, String baseName) {
StringBuilder htmlFragment = new StringBuilder();
htmlFragment.append(writeCsvStatistic(solverStatisticFilesDirectory, baseName));
htmlFragment.append(writeGraphStatistic(solverStatisticFilesDirectory, baseName));
return htmlFragment;
}
private List<TimeToBestScoresLine> extractTimeToBestScoresLineList() {
Map<Long, TimeToBestScoresLine> timeToBestScoresLineMap = new HashMap<Long, TimeToBestScoresLine>();
for (Map.Entry<String, BestScoreStatisticListener> listenerEntry : bestScoreStatisticListenerMap.entrySet()) {
String configName = listenerEntry.getKey();
List<BestScoreStatisticPoint> statisticPointList = listenerEntry.getValue()
.getBestScoreStatisticPointList();
for (BestScoreStatisticPoint statisticPoint : statisticPointList) {
long timeMillisSpend = statisticPoint.getTimeMillisSpend();
TimeToBestScoresLine line = timeToBestScoresLineMap.get(timeMillisSpend);
if (line == null) {
line = new TimeToBestScoresLine(timeMillisSpend);
timeToBestScoresLineMap.put(timeMillisSpend, line);
}
line.getConfigNameToScoreMap().put(configName, statisticPoint.getScore());
}
}
List<TimeToBestScoresLine> timeToBestScoresLineList
= new ArrayList<TimeToBestScoresLine>(timeToBestScoresLineMap.values());
Collections.sort(timeToBestScoresLineList);
return timeToBestScoresLineList;
}
protected class TimeToBestScoresLine implements Comparable<TimeToBestScoresLine> {
private long timeMillisSpend;
private Map<String, Score> configNameToScoreMap;
public TimeToBestScoresLine(long timeMillisSpend) {
this.timeMillisSpend = timeMillisSpend;
configNameToScoreMap = new HashMap<String, Score>();
}
public long getTimeMillisSpend() {
return timeMillisSpend;
}
public Map<String, Score> getConfigNameToScoreMap() {
return configNameToScoreMap;
}
public int compareTo(TimeToBestScoresLine other) {
return timeMillisSpend < other.timeMillisSpend ? -1 : (timeMillisSpend > other.timeMillisSpend ? 1 : 0);
}
}
private CharSequence writeCsvStatistic(File solverStatisticFilesDirectory, String baseName) {
List<TimeToBestScoresLine> timeToBestScoresLineList = extractTimeToBestScoresLineList();
File csvStatisticFile = new File(solverStatisticFilesDirectory, baseName + "Statistic.csv");
Writer writer = null;
try {
writer = new OutputStreamWriter(new FileOutputStream(csvStatisticFile), "utf-8");
writer.append("\"TimeMillisSpend\"");
for (String configName : configNameList) {
writer.append(",\"").append(configName.replaceAll("\\\"", "\\\"")).append("\"");
}
writer.append("\n");
for (TimeToBestScoresLine timeToBestScoresLine : timeToBestScoresLineList) {
writer.write(Long.toString(timeToBestScoresLine.getTimeMillisSpend()));
for (String configName : configNameList) {
writer.append(",");
Score score = timeToBestScoresLine.getConfigNameToScoreMap().get(configName);
if (score != null) {
Double scoreGraphValue = scoreDefinition.translateScoreToGraphValue(score);
if (scoreGraphValue != null) {
writer.append(scoreGraphValue.toString());
}
}
}
writer.append("\n");
}
} catch (IOException e) {
throw new IllegalArgumentException("Problem writing csvStatisticFile: " + csvStatisticFile, e);
} finally {
IOUtils.closeQuietly(writer);
}
return " <p><a href=\"" + csvStatisticFile.getName() + "\">CVS file</a></p>\n";
}
private CharSequence writeGraphStatistic(File solverStatisticFilesDirectory, String baseName) {
XYSeriesCollection seriesCollection = new XYSeriesCollection();
for (Map.Entry<String, BestScoreStatisticListener> listenerEntry : bestScoreStatisticListenerMap.entrySet()) {
String configName = listenerEntry.getKey();
XYSeries configSeries = new XYSeries(configName);
List<BestScoreStatisticPoint> statisticPointList = listenerEntry.getValue()
.getBestScoreStatisticPointList();
for (BestScoreStatisticPoint statisticPoint : statisticPointList) {
long timeMillisSpend = statisticPoint.getTimeMillisSpend();
Score score = statisticPoint.getScore();
Double scoreGraphValue = scoreDefinition.translateScoreToGraphValue(score);
if (scoreGraphValue != null) {
configSeries.add(timeMillisSpend, scoreGraphValue);
}
}
seriesCollection.addSeries(configSeries);
}
NumberAxis xAxis = new NumberAxis("Time millis spend");
xAxis.setNumberFormatOverride(new MillisecondsSpendNumberFormat());
NumberAxis yAxis = new NumberAxis("Score");
yAxis.setAutoRangeIncludesZero(false);
XYItemRenderer renderer = new XYStepRenderer();
XYPlot plot = new XYPlot(seriesCollection, xAxis, yAxis, renderer);
plot.setOrientation(PlotOrientation.VERTICAL);
JFreeChart chart = new JFreeChart(baseName + " best score statistic",
JFreeChart.DEFAULT_TITLE_FONT, plot, true);
BufferedImage chartImage = chart.createBufferedImage(1024, 768);
File graphStatisticFile = new File(solverStatisticFilesDirectory, baseName + "Statistic.png");
OutputStream out = null;
try {
out = new FileOutputStream(graphStatisticFile);
ImageIO.write(chartImage, "png", out);
} catch (IOException e) {
throw new IllegalArgumentException("Problem writing graphStatisticFile: " + graphStatisticFile, e);
} finally {
IOUtils.closeQuietly(out);
}
return " <img src=\"" + graphStatisticFile.getName() + "\"/>\n";
}
}