/*
WEKA Demo
Demonstrates some of WEKA's main functions.
Copyright (C) 2015 Alina Maria Ciobanu
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package weka.demo;
import java.awt.Container;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import javax.swing.JFrame;
import weka.attributeSelection.AttributeSelection;
import weka.attributeSelection.InfoGainAttributeEval;
import weka.attributeSelection.Ranker;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.trees.RandomTree;
import weka.clusterers.ClusterEvaluation;
import weka.clusterers.Clusterer;
import weka.clusterers.SimpleKMeans;
import weka.core.Attribute;
import weka.core.Drawable;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SerializationHelper;
import weka.core.converters.ArffSaver;
import weka.core.converters.CSVLoader;
import weka.core.converters.ConverterUtils.DataSource;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.MathExpression;
import weka.filters.unsupervised.attribute.Remove;
import weka.filters.unsupervised.attribute.StringToNominal;
import weka.gui.treevisualizer.PlaceNode2;
import weka.gui.treevisualizer.TreeVisualizer;
public class WekaDemo {
private boolean DEBUG = false;
public static void main(String[] args) {
WekaDemo demo = new WekaDemo();
demo.runFlow();
}
/**
* Runs demo flow for WEKA's functions.
*/
public void runFlow() {
// write ARFF file from CSV
writeArffFileFromCsv("data/weather.csv", "data/weather.arff");
// load ARFF file
Instances instances = loadInstancesFromArff("data/weather.arff");
// convert attribute from string to nominal
instances = convertStringToNominal(instances, "1");
// transform numeric attributes to percentages
instances = applyMathExpression(instances, "1,4-5", "A/100");
// load ARFF file from CSV for prediction
Instances dataset = loadInstancesFromCsv("data/dataset.csv");
// instantiate random tree classifier
RandomTree classifier = new RandomTree();
classifier.setKValue(2);
// train random tree classifier
train(instances, classifier);
// serialize model
serialize(classifier, "tree.model");
// deserialize model
classifier = (RandomTree)deserialize("tree.model");
// predict class for new instances
predict(instances, dataset, classifier);
// visualize decision tree
visualizeResults(classifier);
// run cross-validation
runCrossValidation(instances, classifier, 5);
// rank attributes
rankAttributes(instances);
// remove class attribute for clustering
instances = removeAttributes(instances, String.valueOf(instances.classIndex() + 1));
// instantiate k-means clusterer
SimpleKMeans clusterer = new SimpleKMeans();
try {
clusterer.setNumClusters(2);
} catch (Exception e) {
log(e, "Error while instantiating clusterer.");
}
// cluster instances
cluster(instances, clusterer);
// remove class attribute for new instances
dataset = removeAttributes(dataset, String.valueOf(dataset.classIndex() + 1));
// transform numeric attributes to percentages for new dataset
dataset = applyMathExpression(dataset, "1,4", "A/100");
// assign clusters to new instances
assign(dataset, clusterer);
}
/**
* Loads instances from a CSV file. The first line of the file is the
* header.
*
* @param path the location of the input file
* @return the loaded dataset
*/
public Instances loadInstancesFromCsv(String path) {
Instances instances = null;
try {
CSVLoader loader = new CSVLoader();
loader.setSource(new File(path));
instances = loader.getDataSet();
if (instances.classIndex() == -1)
instances.setClassIndex(instances.numAttributes() - 1);
} catch (IOException e) {
log(e, "Error while reading csv file.");
}
return instances;
}
/**
* Loads instances "manually" from a CSV file.
*
* @param path the location of the input file
* @return the loaded dataset
*/
public Instances loadInstancesManuallyFromCsv(String path) {
FastVector attributes = new FastVector();
List<String> windy = Arrays.asList("TRUE", "FALSE");
List<String> play = Arrays.asList("yes", "no");
attributes.addElement(new Attribute("outlook", (FastVector) null));
attributes.addElement(new Attribute("temperature"));
attributes.addElement(new Attribute("humidity"));
attributes.addElement(new Attribute("windy", getFastVector(windy)));
attributes.addElement(new Attribute("play", getFastVector(play)));
Instances dataset = new Instances("weather", attributes, 14);
List<String> lines = getLines(path);
for (String line : lines) {
String[] split = line.split(",");
double[] values = { dataset.attribute(0).addStringValue(split[0]),
Double.valueOf(split[1]), Double.valueOf(split[2]),
windy.indexOf(split[3]), play.indexOf(split[4]) };
dataset.add(new Instance(1.0, values));
}
dataset.setClassIndex(dataset.numAttributes() - 1);
return dataset;
}
/**
* Persists instances in ARFF file format.
*
* @param instances the dataset to be persisted
* @param outPath the location of the output file
*/
public void writeArffFile(Instances instances, String outPath) {
try {
ArffSaver saver = new ArffSaver();
saver.setInstances(instances);
saver.setFile(new File(outPath));
saver.writeBatch();
} catch (IOException e) {
log(e, "Error while writing data to file.");
} catch (Exception e) {
log(e, "Error while transforming data from non-sparse to sparse.");
}
}
/**
* Reads data from a CSV files and transforms it to ARFF format.
*
* @param inPath the location of the input CSV file
* @param outPath the location of the output ARFF file
*/
public void writeArffFileFromCsv(String inPath, String outPath) {
writeArffFile(loadInstancesManuallyFromCsv(inPath), outPath);
}
/**
* Loads instances from an ARFF file.
*
* @param path the location of the input file
* @return the loaded dataset
*/
public Instances loadInstancesFromArff(String path) {
Instances instances = null;
try {
DataSource source = new DataSource(path);
instances = source.getDataSet();
if (instances.classIndex() == -1)
instances.setClassIndex(instances.numAttributes() - 1);
} catch (Exception e) {
log(e, "Error while loading ARFF file.");
}
return instances;
}
/**
* Converts string attributes to nominal attributes.
*
* @param instances the dataset to be transformed
* @param range the range of considered attributes
* @return the updated dataset
*/
public Instances convertStringToNominal(Instances instances, String range) {
try {
StringToNominal filter = new StringToNominal();
filter.setAttributeRange(range);
filter.setInputFormat(instances);
instances = Filter.useFilter(instances, filter);
} catch (Exception e) {
log(e, "Error while converting attribute from string to nominal.");
}
return instances;
}
/**
* Applies a math transformation to numeric attributes.
*
* @param instances the dataset to be transformed
* @param ignored the range of ignored attribute indices
* @return the updated dataset
*/
public Instances applyMathExpression(Instances instances, String ignored,
String expression) {
try {
MathExpression filter = new MathExpression();
filter.setIgnoreRange(ignored);
filter.setExpression(expression);
filter.setInputFormat(instances);
instances = Filter.useFilter(instances, filter);
} catch (Exception e) {
log(e, "Error while converting attribute from string to nominal.");
}
return instances;
}
/**
* Trains a classifier on the given instances.
*
* @param instances the dataset used for training
* @param classifier the model to be trained
*/
public void train(Instances instances, Classifier classifier) {
try {
classifier.buildClassifier(instances);
} catch (Exception e) {
log(e, "Error while training classifier");
}
}
/**
* Predicts class for new instances
*
* @param instances training instances (used here for class attribute info)
* @param dataset the instances for which the class is predicted
* @param classifier the model used for predicting the class
*/
public void predict(Instances instances, Instances dataset,
Classifier classifier) {
try {
for (int i = 0; i < dataset.numInstances(); i++) {
int prediction = (int) classifier.classifyInstance(dataset
.instance(i));
String label = instances.classAttribute().value((prediction));
System.out.print(label + " ");
}
} catch (Exception e) {
log(e, "Error while running classifier.");
}
}
/**
* Runs cross-validation on the given instances.
*
* @param instances dataset used for evaluating the classifier
* @param classifier the model that is evaluated
* @param folds number of cross-validation folds
*/
public void runCrossValidation(Instances instances, Classifier classifier,
int folds) {
try {
Evaluation eval = new Evaluation(instances);
eval.crossValidateModel(classifier, instances, folds, new Random(1));
System.out.println(eval.toSummaryString());
System.out.println(eval.toClassDetailsString());
System.out.println(eval.toMatrixString());
} catch (Exception e) {
log(e, "Erorr while running cross-validation.");
}
}
/**
* Runs a clustering algorithm on the given dataset.
*
* @param instances dataset for clustering
* @param clusterer the cluster used for grouping instances
*/
public void cluster(Instances instances, Clusterer clusterer) {
try {
clusterer.buildClusterer(instances);
ClusterEvaluation evaluation = new ClusterEvaluation();
evaluation.setClusterer(clusterer);
evaluation.evaluateClusterer(instances);
System.out.println(evaluation.clusterResultsToString());
} catch (Exception e) {
log(e, "Error while running clusterer.");
}
}
/**
* Assigns new instances to clusters.
*
* @param dataset the instances to be clustered
* @param clusterer the model used for clustering instances
*/
public void assign(Instances dataset, Clusterer clusterer) {
try {
for (int i = 0; i < dataset.numInstances(); i++) {
int cluster = clusterer.clusterInstance(dataset.instance(i));
System.out.print((i + "(" + cluster + ") "));
}
} catch (Exception e) {
log(e, "Error while assigning clusters.");
}
}
/**
* Displays the decision tree classifier.
*
* @param classifier the decision tree represented graphically
*/
public void visualizeResults(Drawable classifier) {
try {
TreeVisualizer visualizer = new TreeVisualizer(null,
classifier.graph(), new PlaceNode2());
JFrame frame = new JFrame("Weka Demo: RandomTree");
Container contentPane = frame.getContentPane();
contentPane.add(visualizer);
frame.setSize(800, 500);
frame.setVisible(true);
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
visualizer.fitToScreen();
} catch (Exception e) {
log(e, "Error while displaying classification tree.");
}
}
/**
* Ranks attributes based on information gain
*
* @param instances dataset used for ranking the attributes
*/
public void rankAttributes(Instances instances) {
AttributeSelection selection = new AttributeSelection();
InfoGainAttributeEval eval = new InfoGainAttributeEval();
Ranker ranker = new Ranker();
try {
selection.setEvaluator(eval);
selection.setSearch(ranker);
selection.SelectAttributes(instances);
double[][] rank = selection.rankedAttributes();
for (int i = 0; i < rank.length; i++)
System.out.println(instances.attribute((int) rank[i][0]).name()
+ ": " + rank[i][1]);
} catch (Exception e) {
log(e, "Error while ranking attributes.");
}
}
/**
* Removes attributes from a dataset.
*
* @param instances dataset from which attributes are removed.
* @return updated dataset
*/
public Instances removeAttributes(Instances instances, String indices) {
Instances filtered = null;
try {
Remove filter = new Remove();
filter.setAttributeIndices(indices);
filter.setInputFormat(instances);
filtered = Filter.useFilter(instances, filter);
} catch (Exception e) {
log(e, "Error while removing class attribute.");
}
return filtered;
}
/**
* Serializes a model.
*
* @param classifier the trained classifier to be serialized
* @param path the location of the output file
*/
public void serialize(Classifier classifier, String path) {
try {
SerializationHelper.write(path, classifier);
} catch (Exception e) {
log(e, "Error while serializing classifier.");
}
}
/**
* Deserializes a model.
*
* @param path the location of the output file
* @return the trained classifier
*/
public Classifier deserialize(String path) {
Classifier classifier = null;
try {
classifier = (Classifier)SerializationHelper.read(path);
} catch (Exception e) {
log(e, "Error while deserializing classifier");
}
return classifier;
}
/**
* Creates vector with possible values for a nominal attribute.
*
* @param elements the possible values for an attribute
* @return vector with possible values for a nominal attribute
*/
private FastVector getFastVector(List<String> elements) {
FastVector vector = new FastVector(elements.size());
for (String element : elements)
vector.addElement(element);
return vector;
}
/**
* Returns the lines of a file.
*
* @param path the location of the file
* @return a list with the lines of the file
*/
private List<String> getLines(String path) {
List<String> lines = new LinkedList<>();
try {
BufferedReader in = new BufferedReader(new InputStreamReader(
new FileInputStream(new File(path)),
Charset.forName("UTF8")));
String line;
while ((line = in.readLine()) != null)
lines.add(line);
in.close();
} catch (Exception e) {
log(e, "Error while reading data from file.");
}
return lines;
}
/**
* Displays error messages and stack traces.
*
* @param exception the exception that occurred
* @param message the error message
*/
private void log(Exception exception, String message) {
System.out.println(message);
if (DEBUG)
exception.printStackTrace();
}
}