manniru
7/18/2017 - 10:24 AM

WekaDemo.java

/*
WEKA Demo

Demonstrates some of WEKA's main functions.

Copyright (C) 2015 Alina Maria Ciobanu

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

package weka.demo;

import java.awt.Container;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;

import javax.swing.JFrame;

import weka.attributeSelection.AttributeSelection;
import weka.attributeSelection.InfoGainAttributeEval;
import weka.attributeSelection.Ranker;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.trees.RandomTree;
import weka.clusterers.ClusterEvaluation;
import weka.clusterers.Clusterer;
import weka.clusterers.SimpleKMeans;
import weka.core.Attribute;
import weka.core.Drawable;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SerializationHelper;
import weka.core.converters.ArffSaver;
import weka.core.converters.CSVLoader;
import weka.core.converters.ConverterUtils.DataSource;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.MathExpression;
import weka.filters.unsupervised.attribute.Remove;
import weka.filters.unsupervised.attribute.StringToNominal;
import weka.gui.treevisualizer.PlaceNode2;
import weka.gui.treevisualizer.TreeVisualizer;

public class WekaDemo {

	private boolean DEBUG = false;

	public static void main(String[] args) {

		WekaDemo demo = new WekaDemo();
		demo.runFlow();
	}

	/**
	 * Runs demo flow for WEKA's functions.
	 */
	public void runFlow() {

		// write ARFF file from CSV
		writeArffFileFromCsv("data/weather.csv", "data/weather.arff");

		// load ARFF file
		Instances instances = loadInstancesFromArff("data/weather.arff");

		// convert attribute from string to nominal
		instances = convertStringToNominal(instances, "1");

		// transform numeric attributes to percentages
		instances = applyMathExpression(instances, "1,4-5", "A/100");

		// load ARFF file from CSV for prediction
		Instances dataset = loadInstancesFromCsv("data/dataset.csv");

		// instantiate random tree classifier
		RandomTree classifier = new RandomTree();
		classifier.setKValue(2);

		// train random tree classifier
		train(instances, classifier);
		
		// serialize model
		serialize(classifier, "tree.model");
		
		// deserialize model
		classifier = (RandomTree)deserialize("tree.model");

		// predict class for new instances
		predict(instances, dataset, classifier);

		// visualize decision tree
		visualizeResults(classifier);

		// run cross-validation
		runCrossValidation(instances, classifier, 5);

		// rank attributes
		rankAttributes(instances);

		// remove class attribute for clustering
		instances = removeAttributes(instances, String.valueOf(instances.classIndex() + 1));

		// instantiate k-means clusterer
		SimpleKMeans clusterer = new SimpleKMeans();
		try {
			clusterer.setNumClusters(2);
		} catch (Exception e) {
			log(e, "Error while instantiating clusterer.");
		}

		// cluster instances
		cluster(instances, clusterer);
		
		// remove class attribute for new instances
		dataset = removeAttributes(dataset, String.valueOf(dataset.classIndex() + 1));
		

		// transform numeric attributes to percentages for new dataset
		dataset = applyMathExpression(dataset, "1,4", "A/100");
		
		// assign clusters to new instances
		assign(dataset, clusterer);
	}

	/**
	 * Loads instances from a CSV file. The first line of the file is the
	 * header.
	 * 
	 * @param path the location of the input file
	 * @return the loaded dataset
	 */
	public Instances loadInstancesFromCsv(String path) {

		Instances instances = null;
		try {
			CSVLoader loader = new CSVLoader();
			loader.setSource(new File(path));
			instances = loader.getDataSet();

			if (instances.classIndex() == -1)
				instances.setClassIndex(instances.numAttributes() - 1);
		} catch (IOException e) {
			log(e, "Error while reading csv file.");
		}

		return instances;
	}

	/**
	 * Loads instances "manually" from a CSV file.
	 * 
	 * @param path the location of the input file
	 * @return the loaded dataset
	 */
	public Instances loadInstancesManuallyFromCsv(String path) {

		FastVector attributes = new FastVector();

		List<String> windy = Arrays.asList("TRUE", "FALSE");
		List<String> play = Arrays.asList("yes", "no");

		attributes.addElement(new Attribute("outlook", (FastVector) null));
		attributes.addElement(new Attribute("temperature"));
		attributes.addElement(new Attribute("humidity"));
		attributes.addElement(new Attribute("windy", getFastVector(windy)));
		attributes.addElement(new Attribute("play", getFastVector(play)));

		Instances dataset = new Instances("weather", attributes, 14);

		List<String> lines = getLines(path);

		for (String line : lines) {
			String[] split = line.split(",");
			double[] values = { dataset.attribute(0).addStringValue(split[0]),
					Double.valueOf(split[1]), Double.valueOf(split[2]),
					windy.indexOf(split[3]), play.indexOf(split[4]) };

			dataset.add(new Instance(1.0, values));
		}

		dataset.setClassIndex(dataset.numAttributes() - 1);

		return dataset;
	}

	/**
	 * Persists instances in ARFF file format.
	 * 
	 * @param instances the dataset to be persisted
	 * @param outPath the location of the output file
	 */
	public void writeArffFile(Instances instances, String outPath) {

		try {
			ArffSaver saver = new ArffSaver();
			saver.setInstances(instances);
			saver.setFile(new File(outPath));
			saver.writeBatch();
		} catch (IOException e) {
			log(e, "Error while writing data to file.");
		} catch (Exception e) {
			log(e, "Error while transforming data from non-sparse to sparse.");
		}
	}

	/**
	 * Reads data from a CSV files and transforms it to ARFF format.
	 * 
	 * @param inPath the location of the input CSV file
	 * @param outPath the location of the output ARFF file
	 */
	public void writeArffFileFromCsv(String inPath, String outPath) {

		writeArffFile(loadInstancesManuallyFromCsv(inPath), outPath);
	}

	/**
	 * Loads instances from an ARFF file.
	 * 
	 * @param path the location of the input file
	 * @return the loaded dataset
	 */
	public Instances loadInstancesFromArff(String path) {

		Instances instances = null;
		try {
			DataSource source = new DataSource(path);
			instances = source.getDataSet();

			if (instances.classIndex() == -1)
				instances.setClassIndex(instances.numAttributes() - 1);
		} catch (Exception e) {
			log(e, "Error while loading ARFF file.");
		}
		return instances;
	}

	/**
	 * Converts string attributes to nominal attributes.
	 * 
	 * @param instances the dataset to be transformed
	 * @param range the range of considered attributes
	 * @return the updated dataset
	 */
	public Instances convertStringToNominal(Instances instances, String range) {

		try {
			StringToNominal filter = new StringToNominal();
			filter.setAttributeRange(range);
			filter.setInputFormat(instances);
			instances = Filter.useFilter(instances, filter);
		} catch (Exception e) {
			log(e, "Error while converting attribute from string to nominal.");
		}
		return instances;
	}

	/**
	 * Applies a math transformation to numeric attributes.
	 * 
	 * @param instances the dataset to be transformed
	 * @param ignored the range of ignored attribute indices
	 * @return the updated dataset
	 */
	public Instances applyMathExpression(Instances instances, String ignored,
			String expression) {

		try {
			MathExpression filter = new MathExpression();
			filter.setIgnoreRange(ignored);
			filter.setExpression(expression);
			filter.setInputFormat(instances);
			instances = Filter.useFilter(instances, filter);
		} catch (Exception e) {
			log(e, "Error while converting attribute from string to nominal.");
		}
		return instances;
	}

	/**
	 * Trains a classifier on the given instances.
	 * 
	 * @param instances the dataset used for training
	 * @param classifier the model to be trained
	 */
	public void train(Instances instances, Classifier classifier) {

		try {
			classifier.buildClassifier(instances);
		} catch (Exception e) {
			log(e, "Error while training classifier");
		}
	}

	/**
	 * Predicts class for new instances
	 * 
	 * @param instances training instances (used here for class attribute info)
	 * @param dataset the instances for which the class is predicted
	 * @param classifier the model used for predicting the class
	 */
	public void predict(Instances instances, Instances dataset,
			Classifier classifier) {

		try {
			for (int i = 0; i < dataset.numInstances(); i++) {
				int prediction = (int) classifier.classifyInstance(dataset
						.instance(i));
				String label = instances.classAttribute().value((prediction));
				System.out.print(label + " ");
			}
		} catch (Exception e) {
			log(e, "Error while running classifier.");
		}
	}

	/**
	 * Runs cross-validation on the given instances.
	 * 
	 * @param instances dataset used for evaluating the classifier
	 * @param classifier the model that is evaluated
	 * @param folds number of cross-validation folds
	 */
	public void runCrossValidation(Instances instances, Classifier classifier,
			int folds) {

		try {
			Evaluation eval = new Evaluation(instances);
			eval.crossValidateModel(classifier, instances, folds, new Random(1));

			System.out.println(eval.toSummaryString());
			System.out.println(eval.toClassDetailsString());
			System.out.println(eval.toMatrixString());

		} catch (Exception e) {
			log(e, "Erorr while running cross-validation.");
		}
	}

	/**
	 * Runs a clustering algorithm on the given dataset.
	 * 
	 * @param instances dataset for clustering
	 * @param clusterer the cluster used for grouping instances
	 */
	public void cluster(Instances instances, Clusterer clusterer) {
		
		try {
			clusterer.buildClusterer(instances);

			ClusterEvaluation evaluation = new ClusterEvaluation();
			evaluation.setClusterer(clusterer);
			evaluation.evaluateClusterer(instances);
			System.out.println(evaluation.clusterResultsToString());
		} catch (Exception e) {
			log(e, "Error while running clusterer.");
		}
	}
	
	/**
	 * Assigns new instances to clusters.
	 * 
	 * @param dataset the instances to be clustered
	 * @param clusterer the model used for clustering instances
	 */
	public void assign(Instances dataset, Clusterer clusterer) {
		try {
			for (int i = 0; i < dataset.numInstances(); i++) {
				int cluster = clusterer.clusterInstance(dataset.instance(i));
				System.out.print((i + "(" +  cluster + ") "));
			}
		} catch (Exception e) {
			log(e, "Error while assigning clusters.");
		}
	}

	/**
	 * Displays the decision tree classifier.
	 * 
	 * @param classifier the decision tree represented graphically
	 */
	public void visualizeResults(Drawable classifier) {
		
		try {
			TreeVisualizer visualizer = new TreeVisualizer(null,
					classifier.graph(), new PlaceNode2());

			JFrame frame = new JFrame("Weka Demo: RandomTree");
			Container contentPane = frame.getContentPane();
			
			contentPane.add(visualizer);
			
			frame.setSize(800, 500);
			frame.setVisible(true);
			frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);

			visualizer.fitToScreen();
		} catch (Exception e) {
			log(e, "Error while displaying classification tree.");
		}
	}

	/**
	 * Ranks attributes based on information gain
	 * 
	 * @param instances dataset used for ranking the attributes
	 */
	public void rankAttributes(Instances instances) {

		AttributeSelection selection = new AttributeSelection();
		InfoGainAttributeEval eval = new InfoGainAttributeEval();
		Ranker ranker = new Ranker();
		try {
			selection.setEvaluator(eval);
			selection.setSearch(ranker);
			selection.SelectAttributes(instances);
			double[][] rank = selection.rankedAttributes();
			for (int i = 0; i < rank.length; i++)
				System.out.println(instances.attribute((int) rank[i][0]).name()
						+ ": " + rank[i][1]);
		} catch (Exception e) {
			log(e, "Error while ranking attributes.");
		}
	}

	/**
	 * Removes attributes from a dataset.
	 * 
	 * @param instances dataset from which attributes are removed.
	 * @return updated dataset
	 */
	public Instances removeAttributes(Instances instances, String indices) {

		Instances filtered = null;
		try {
			Remove filter = new Remove();
			filter.setAttributeIndices(indices);
			filter.setInputFormat(instances);
			filtered = Filter.useFilter(instances, filter);
		} catch (Exception e) {
			log(e, "Error while removing class attribute.");
		}

		return filtered;
	}
	
	/**
	 * Serializes a model.
	 * 
	 * @param classifier the trained classifier to be serialized
	 * @param path the location of the output file
	 */
	public void serialize(Classifier classifier, String path) {
		
		try {
			SerializationHelper.write(path, classifier);
		} catch (Exception e) {
			log(e, "Error while serializing classifier.");
		}
	}
	
	/**
	 * Deserializes a model.
	 * 
	 * @param path the location of the output file
	 * @return the trained classifier
	 */
	public Classifier deserialize(String path) {
		
		Classifier classifier = null;
		try {
			classifier = (Classifier)SerializationHelper.read(path);
		} catch (Exception e) {
			log(e, "Error while deserializing classifier");
		}
		
		return classifier;
	}

	/**
	 * Creates vector with possible values for a nominal attribute.
	 * 
	 * @param elements the possible values for an attribute
	 * @return vector with possible values for a nominal attribute
	 */
	private FastVector getFastVector(List<String> elements) {

		FastVector vector = new FastVector(elements.size());
		for (String element : elements)
			vector.addElement(element);
		return vector;
	}

	/**
	 * Returns the lines of a file.
	 * 
	 * @param path the location of the file
	 * @return a list with the lines of the file
	 */
	private List<String> getLines(String path) {

		List<String> lines = new LinkedList<>();

		try {
			BufferedReader in = new BufferedReader(new InputStreamReader(
					new FileInputStream(new File(path)),
					Charset.forName("UTF8")));

			String line;
			while ((line = in.readLine()) != null)
				lines.add(line);

			in.close();
		} catch (Exception e) {
			log(e, "Error while reading data from file.");
		}

		return lines;
	}

	/**
	 * Displays error messages and stack traces.
	 * 
	 * @param exception the exception that occurred
	 * @param message the error message
	 */
	private void log(Exception exception, String message) {
		
		System.out.println(message);
		if (DEBUG)
			exception.printStackTrace();
	}
}