/*  -*- c++ -*-  (for Emacs)
 *
 *  linearrecognisertrainer.cpp
 *  Digest
 * 
 *  Created by Aidan Lane on Mon Jul 11 2005.
 *  Copyright (c) 2005-2006 Optimisation and Constraint Solving Group,
 *  Monash University. All rights reserved.
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 */

#include "linearrecognisertrainer.h"

#include <QFile>

#include <QDebug> // TODO: remove me!
#include <iostream> // TODO: remove me!

#include "linearrecogniser.h"
#include "invertmatrix.h"

//using namespace std; // TODO: remove me


typedef double FeatureCalcT;  /*! used for internal calculations, not results. */


// TODO: remove the following!
/*
template <class Type>
void dumpMatrixValues( const Type* matrix, int order )  {
  Q_ASSERT( matrix != 0 );
  for (int i=0; i < order; i++)
    {
      cout << "i=" << i << ": ";
    for (int j=0; j<order; j++)
      {
        cout << QString::number( matrix[i*order+j] ).toAscii().constData() << " ";
      }
    cout << endl;
    }
};
*/



LinearRecogniserTrainer::LinearRecogniserTrainer( JavaVM* jvm,
						  DigestDbModel* digestDbModel,
						  QObject* parent )
  : AbstractRecogniserTrainer(jvm, digestDbModel, parent),
    c_numFeatures(0)
{
}


bool LinearRecogniserTrainer::prepareForTraining()
{
  m_classSampleResults.clear();
  c_numFeatures = featureKeys().size();
  m_classWeights.clear();
  return true;
}


/*!
 */
bool LinearRecogniserTrainer::examineSample( const DGestureRecord& sample,
					      const QVector<FeatureResultT>& featureVec )
{
  foreach ( int classId, sample.classes )
    {
      // Ensure that the given class ID has a record in m_classSampleResults
      if ( ! m_classSampleResults.contains(classId) )
	m_classSampleResults[classId] = QList< QVector<FeatureResultT> >();

      Q_ASSERT( featureVec.size() == c_numFeatures ); // needs to be consistent
      
      m_classSampleResults[classId].append( featureVec );
    }

  return true;
}


/*!
 * See: Rubine, D. Specifying Gestures by Sample. Computer Graphics 25, 4
 *      (July 1991), 329-337.
 *
 * Assumes that there is at least one sample for each class.
 */
bool LinearRecogniserTrainer::finalizeTraining()
{
  // TODO: OPTIMISE ME!!!
  QHash< int, QVector<FeatureCalcT> > classFeatureMeans;

  // Step 1. Build classFeatureMeans.
  QHashIterator< int, QList< QVector<FeatureResultT> > > c(m_classSampleResults);
  while ( c.hasNext() )
    {
      c.next();
      // For each feature, get the mean of their results across all samples.
      const int numClassSamples = c.value().size(); // size != c_numSamples, as it's per class
      Q_ASSERT( numClassSamples > 0 ); // as we divide by it.
      QVector<FeatureCalcT> means( c_numFeatures, 0.0 );
      for ( int i=0; i < c_numFeatures; ++i ) {
	for ( int e=0; e < numClassSamples; ++e ) { 
	  Q_ASSERT( c.value().at(e).size() == c_numFeatures );
	  means[i] += c.value().at(e).at(i); // note: lookup with at() - faster
	}
	means[i] /= numClassSamples; // POTENTIAL FOR DIVISION-BY-ZERO
      }
      classFeatureMeans.insert( c.key(), means );
  }

  // Step 2. Compute sample estimate of the covariance matrix for each class.
  //         NOTE: following Rubine, we delay the 1 / (|class samples| - 1) until step 3.
  // Step 3. Compute estimate of the COMMON covariance matrix.
  // TODO: FIND A BETTER AND CLEANER MATRIX LIBRRAY!!!
  FeatureCalcT* commonCovarianceMatrix = new FeatureCalcT[c_numFeatures * c_numFeatures];
  for ( int m=0; m < c_numFeatures*c_numFeatures; ++m )
    commonCovarianceMatrix[m] = 0.0;
  int numClasses = m_classSampleResults.size();
  c.toFront(); // go back to the front of the hash
  while ( c.hasNext() )
    {
      c.next();
      int classId = c.key();

      // Fill-out UPPER-RIGHT half of the common covariance matrix with the sums of the class covariance matrices
      Q_ASSERT( m_classSampleResults.contains(classId) );
      Q_ASSERT( classFeatureMeans.contains(classId) );
      const QList< QVector<FeatureResultT> >& sampleResults = m_classSampleResults.value(classId);
      const QVector<FeatureCalcT>& featureMeans = classFeatureMeans.value(classId);
      const int numClassSamples = sampleResults.size(); // size != c_numSamples, as it's per class
      for ( int i=0; i < c_numFeatures; ++i ) {
	const FeatureCalcT iMean = featureMeans.at(i);
	for ( int j=i; j < c_numFeatures; ++j ) { // start at 'i', as matrix is symmetrical along diag
	  const FeatureCalcT jMean = featureMeans.at(j);
	  FeatureCalcT v = 0.0;
	  for ( int e=0; e < numClassSamples; ++e )
	    v += ( (sampleResults.at(e).at(i) - iMean)
		   * (sampleResults.at(e).at(j) - jMean) );
	  commonCovarianceMatrix[i*c_numFeatures + j] += v;
	}
      }
    }

  // Scale all values of the matrix by the constant factor
  // *AND* fill-out bottom-left half of matrix
  // Note: We perform a single division now, so that we only have to multiply later
  FeatureCalcT factor = 1.0 / (FeatureCalcT)(-numClasses + gestureIds().size());
  for ( int i=0; i < c_numFeatures; ++i ) {
    for ( int j=i; j < c_numFeatures; ++j ) { // start at 'i', as matrix is symmetrical along diag
      /* Note: Don't check for multiple writes when i==j.
       *       It would most likely take longer to test if i==j
       *       |features|x|features| times than it would to update
       *       the matrix once more |features| times.
       */
      const FeatureCalcT v = commonCovarianceMatrix[i*c_numFeatures + j] * factor;
      commonCovarianceMatrix[i*c_numFeatures + j] = v; // top-right half
      commonCovarianceMatrix[j*c_numFeatures + i] = v; // bottom-left half
    }
  }

  // Step 4. Invert matrix
  FeatureCalcT* invertedMatrix = new FeatureCalcT[c_numFeatures * c_numFeatures];
  if ( !invert<FeatureCalcT>(commonCovarianceMatrix, invertedMatrix, c_numFeatures) ) {
    Q_ASSERT( !"invert error" ); // TODO: remove me!
    return false;
  }

  // We've finished with commonCovarianceMatrix, we're using invertedMatrix now
  delete[] commonCovarianceMatrix;


  // Step 5. Build m_classWeights : compute weight estimates
  c.toFront(); // go back to the front of the hash
  while ( c.hasNext() )
    {
      c.next();
      int classId = c.key();
      const QVector<FeatureCalcT>& featureMeans = classFeatureMeans.value(classId);

      QVector<WeightT> w( c_numFeatures+1, 0.0 ); // YES, F+1 !
      // w_1 to w_F
      for ( int j=0; j < c_numFeatures; ++j )
	for ( int i=0; i < c_numFeatures; ++i )
	  w[j+1] // offset by 1, as w[0] is special. see further down.
	    += ( invertedMatrix[i*c_numFeatures + j]
		 * featureMeans.at(i) );
      // w_0 (dependant w_1 to w_F)
      for ( int i=0; i < c_numFeatures; ++i )
	w[0] += w.at(i+1) * featureMeans.at(i); // using w_1 to w_F -> w.at(i+1)
      w[0] *= -0.5;

      m_classWeights.insert( classId, w );
    }


  // CLEANUP
  delete[] invertedMatrix;
  m_classSampleResults.clear();
  c_numFeatures = 0;
  

  return true;
}


bool LinearRecogniserTrainer::writeModelFile( const QString& fileName )
{
  // TODO: remove me:
  qDebug() << outputFilePath();
  qDebug() << outputPath();
  // TODO: OPTIMISE ME! (for starters, remove the redundant ","s and "\n"s)
  QFile file( fileName );
  if ( ! file.open(QIODevice::WriteOnly | QIODevice::Text) )
    return false;

  QTextStream out( &file );
  QHashIterator< int, QVector<WeightT> > ci(m_classWeights);
  while ( ci.hasNext() )
    {
      ci.next();
      out << ci.key() << "=";
      QVectorIterator<WeightT> wi(ci.value());
      while ( wi.hasNext() )
	out << wi.next() << ",";
      out << "\n";
    }

  file.close();

  return true;
}
