/*  -*- c++ -*-  (for Emacs)
 *
 *  multiclass_svm.h
 *  Digest
 * 
 *  Created by Adrian Bickerstaffe on Wed Jan 18 2006.
 *  Copyright (c) 2005-2006 Optimisation and Constraint Solving Group,
 *  Monash University. All rights reserved.
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 */

#ifndef MULTICLASS_SVM_H
#define MULTICLASS_SVM_H

#include <vector>
#include <map>
#include <list>
#include <set>
#include <utility>

#include "svm.h"

using namespace std;

// Class used to deal with feature scaling.
class feature_scaler
{
public:
	feature_scaler(void);
	void get_scale_range(double &, double &) const;	

protected:
	int num_features;
	double scale_range_low, scale_range_high;
	vector<int> targets;
	vector<double> lower_limits, upper_limits;

	void scale_single_feature(double &, double, double) const;

	void write_feature_details(ofstream &) const;
	void read_feature_details(ifstream &);
	void print_scaling_details(void) const;
};

//////////////////////////////

// Class to represent a dataset for an SVM.
class svm_dataset : public feature_scaler
{
public:
	svm_dataset(void);
	svm_dataset(int);
	~svm_dataset(void);
	
	bool add_sample(int, const vector<double> &);
	bool add_samples(int, const vector<vector<double> > &);
	vector<vector<double> > get_samples(int) const;

	bool set_num_features(int);
	int get_num_features(void) const;
	
	void set_scale_range(double, double);
	void scale(void);
	void get_feature_limits(vector<double> &, vector<double> &) const;
	
	int size(void) const;
	int num_unique_targets(void) const;

	vector<int> get_targets(void) const;
	struct svm_problem to_libsvm_prob(double*) const;
	
	bool split(float, pair<svm_dataset, svm_dataset> &);
	bool save(const string &);
	bool load(const string &);
	
	void print(void) const;
	void clear(void);
	
private:
	long int num_samples;
	map<int, vector<vector<double> > > data;
	
	void copy_feature_details(const svm_dataset &);
	void find_feature_limits(vector<double> &, vector<double> &);
	vector<double> collect_featurevals(int);
};

////////////////////////////////

// Class to represent a generic SVM.  This class is mostly
// a wrapper for libsvm functions.
class generic_svm
{
public:
	generic_svm(void);
	generic_svm(const svm_dataset &);
	~generic_svm(void);
	
	void set_problem(const svm_dataset &);	
	bool train(void);
	
	bool load(const char*);
	bool save(const char*) const;
	void clear(void);

protected:
	struct svm_parameter param;		// set by constructor
	struct svm_problem prob;		// set by set_problem()
	struct svm_model *model;
	bool model_setup, problem_setup;
		
private:
	void setup_defaults(void);

	void search_kernel_params(void);
	void setup_parameter_grid(map<double, double> &, map<double, double> &);
	
	void find_best_parameters(const map<double, double> &,
							  const map<double, double> &,
							  double &, double &);

	vector<double> find_new_powers(const map<double, double> &, const double &);
	void generate_new_points(map<double, double> &, const vector<double> &);

	void update_grid(map<double, double> &, map<double, double> &,
					 const double &, const double &);
	void print_grid(const map<double, double> &, const map<double, double> &);
								   
	double cross_validate(double, double);
};

// Class represents a generic SVM classifier which performs
// nonprobabilistic prediction.
class nonprobabilistic_svm : public generic_svm
{
public:
	nonprobabilistic_svm(void);
	nonprobabilistic_svm(const svm_dataset &);
	
	double classify_sample(const vector<double> &) const;

};

// Class represents a generic SVM classifier which performs
// probabilistic prediction. 
class probabilistic_svm : public generic_svm
{
public:
	probabilistic_svm(void);
	probabilistic_svm(const svm_dataset &);
	
	double classify_sample(const vector<double> &,
						   double &) const;

};

////////////

// Base class for multi-class SVM classifiers.  This class contains the feature
// scaling data structures and member functions common to all multi-class
// classifiers.
class multiclass_classifier : public feature_scaler
{
public:
	enum sample_type{scaled, unscaled};
	
	multiclass_classifier(void);
	virtual ~multiclass_classifier(void) {}
	
	// a simple interface to which SVM classifiers must adhere
	virtual bool train(const svm_dataset &) = 0;	
	virtual double classify_sample(const vector<double> &,
								   const sample_type &) = 0;
	
	// use leaf distributions inferred during training
	double classify_with_leaf(const vector<double> &,
							  const sample_type &,
							  map<int, double> &);
	// make scaling function public so that it can also
	// be used to get scaled features for debugging purposes
	vector<double> scale_sample(const vector<double> &) const;
	
	virtual bool save(const string &) = 0;
	virtual bool load(const string &) = 0;

	virtual void print(void) = 0;
	virtual void clear(void) = 0;
	
protected:
	bool model_setup;
	map<int, map<int, double> > leaf_probabilities;	

	void setup_leaf_probabilities(const svm_dataset &);
	void write_leaf_dists(ofstream &) const;
	void read_leaf_dists(ifstream &);	
	void print_leaf_dists(void) const;

	svm_dataset get_data_subset(const svm_dataset &, 
								const set<int> &, const set<int> &) const;
		
	bool create_archive(string, string);
	bool extract_archive(string, string);
};

// Class to represent a Directed Acyclic Graph (DAG) arrangement of SVMs for
// multi-class (k > 2) classification.
class dag_svm : public multiclass_classifier
{
public:
	dag_svm(void);
	~dag_svm(void);
	
	bool train(const svm_dataset &);
	double classify_sample(const vector<double> &,
						   const sample_type &);	   	
							 
	bool save(const string &);
	bool load(const string &);

	void print(void);
	void clear(void);
	
private:
	map<vector<int>, nonprobabilistic_svm*> dag_nodes;
	vector<vector<int> > node_labels;

	void setup_node_labels(void);
	void clear_dag_nodes(void);
};

/////////////////////////

// Class to represent an edge in a Minimal Cost Spanning Tree
// (MCST).  Each edge has a pair of end-points and an associated
// weight.
class mcst_edge
{
public:
	mcst_edge(void);
	mcst_edge(int, int);
	mcst_edge(int, int, long double);
			
	void set_vertices(int, int);
	void get_vertices(int &, int &) const;
	
	void set_weight(long double);
	long double get_weight(void) const;
	
	void clear(void);
	void print(void) const;
	
private:
	int first_vertex, second_vertex;
	long double weight;
};

// Operator to sort MCST edges according to their weights.
bool operator<(const mcst_edge &, const mcst_edge &);

// Class to represent a single node in an MCST.  Each node contains
// a binary SVM classifier to classify between two sets of target classes.
class classifier_node
{
public:
	enum direction{left, right};

	classifier_node(void);
	~classifier_node(void);
	
	void add_class_set(const set<int> &);
	vector<set<int> > get_class_sets(void) const;
	set<int> get_flattened_classes(void) const;

	int num_class_sets(void) const;
	bool is_singleton(void) const;
	
	void set_child(direction, classifier_node*);
	bool has_child(direction) const;
	classifier_node* get_child(direction) const;
	
	bool train(const svm_dataset &);

	classifier_node* classify_sample(const vector<double> &,
									 double &);
	
	bool write(ofstream &, const string &, list<string> &, const int &);
	bool read(ifstream &, const string &, list<string> &);
	
	void clear(void);
	void print(void) const;
	
private:
	vector<set<int> > class_sets;
	classifier_node *left_child, *right_child;
	nonprobabilistic_svm *classifier;
	
	string form_filename(const int &);
};

///////////////////////////

// Class to represent a Minimal Cost Spanning Tree (MCST) of SVMs for
// multi-class classification.
class hierarchical_classifier : public multiclass_classifier
{
public:
	enum distance_type{centroid, median};
	
	hierarchical_classifier(void);	
	hierarchical_classifier(const distance_type &);
	~hierarchical_classifier(void);
	
	virtual bool train(const svm_dataset &) = 0;
	double classify_sample(const vector<double> &,
						   const sample_type &);
	
	bool save(const string &);
	bool load(const string &);

	void print(void);
	void clear(void);
	
protected:
	classifier_node *root;
	distance_type representative;

	double class_distance(vector<vector<double> > &, vector<vector<double> > &);
	double euclidean_distance(const vector<double> &, const vector<double> &);
	vector<double> calc_centroid(const vector<vector<double> > &);
	vector<double> calc_median(const vector<vector<double> > &);

	void merge_sets(vector<set<int> > &, int, int);

	classifier_node* find_child(const vector<classifier_node*> &, const set<int> &);

	void print_set(const set<int> &);
	void print_sets(const vector<set<int> > &);
	
private:
	void create_tree_linkage(vector<classifier_node*> &);
	void free_tree_nodes(classifier_node*);
	void print_tree_nodes(classifier_node*);
	void flatten_tree(vector<classifier_node*> &, classifier_node*);
};

///////////////////////////

class nnsl_classifier : public hierarchical_classifier
{
public:
	nnsl_classifier(void);
	nnsl_classifier(const distance_type &);
	
	bool train(const svm_dataset &);
	
private:
	bool joins_subtrees(const mcst_edge &, const vector<set<int> > &,  int &, int &);
	bool sets_disjoint(const set<int> &, const set<int> &);
	
	vector<mcst_edge> calc_graph_edges(const svm_dataset &);
};

///////////////////////////

class non_mcst_classifier : public hierarchical_classifier
{
public:
	non_mcst_classifier(void);	
	non_mcst_classifier(const distance_type &);	
	
	bool train(const svm_dataset &);
	
protected:
	virtual void find_best_merger(const vector<set<int> > &,
								  const svm_dataset &, int &, int &) = 0;
								 
	vector<vector<double> > set_to_samples(const svm_dataset &,
										   const set<int> &);
};

///////////////////////////

class nnal_classifier : public non_mcst_classifier
{
public:
	nnal_classifier(void);
	nnal_classifier(const distance_type &);
	
private:
	void find_best_merger(const vector<set<int> > &,
						  const svm_dataset &, int &, int &);
};

///////////////////////////

class nncl_classifier : public non_mcst_classifier
{
public:
	nncl_classifier(void);
	nncl_classifier(const distance_type &);
	
private:
	void find_best_merger(const vector<set<int> > &,
						  const svm_dataset &, int &, int &);
	
	double maximum_distance(const svm_dataset &,
							const set<int> &, const set<int> &);
};

#endif	// !MULTICLASS_SVM_H
