#include <string>
#include <cassert>

#include "recognizerProxy.h"
#include "nnsl_mcst_recogniser.h"

using namespace std;

/////////// JNI utility functions //////////////

// Function reports a fatal error and terminates the program.
void fatalError(const char * msg)
{
	fprintf(stderr, "Fatal Error: %s\n", msg);
	exit(1);
}

// Function returns the class ID of a named Java class.
jclass tryFindClass(JNIEnv * env, const char * name)
{
	assert(env);
	jclass cls = env->FindClass(name);
	if (!cls)
		fatalError(("Couldn't find the class " + string(name)).c_str());
	return cls;
}

// Function returns the method ID of a Java method given a class ID, the method name,
// and the method signature.
jmethodID tryGetMethodID(JNIEnv * env, jclass cls, const char * name, const char * sig)
{
	assert(env);
	jmethodID mid = env->GetMethodID(cls, name, sig);
	if (!mid)
		fatalError(("Couldn't find method \"" + string(name) + "\" with sig \"" + string(sig) + "\"").c_str());
	return mid;
}

// Function returns the field ID of a Java class field (attribute/member) given a Java class ID,
// field name, and field signature.
jfieldID tryGetFieldID(JNIEnv * env, jclass cls, const char * name, const char * sig)
{
	assert(env);
	jfieldID mid = env->GetFieldID(cls, name, sig);
	if (!mid)
		fatalError(("Couldn't find field \"" + string(name) + "\" with sig \"" + string(sig) + "\"").c_str());
	return mid;
}

// Function returns a pointer to the recognizer.  The pointer is extracted from the Java proxy class where
// it is stored as a long int.
nnsl_mcst_recogniser* getRecognizerPointer(JNIEnv * env, jobject obj)
{
	assert(env);

	// get the class and field IDs
	jclass recognizerProxy_cls = tryFindClass(env, "librecognizer/recognizerProxy");
	jfieldID recognizerProxy_pointer_fid = tryGetFieldID(env, recognizerProxy_cls, "recognizerPtr", "J");
	// extract the pointer
	jlong pointer = env->GetLongField(obj, recognizerProxy_pointer_fid);
	
	if(!pointer)
		fatalError("Couldn't extract recognizer pointer!");

	return (nnsl_mcst_recogniser*)pointer;
}

/////////// Native functions ///////////////

// Function creates a new recognizer and stores a pointer to this recognizer
// in the Java proxy class.  The pointer is stored on the Java side so that
// the recognizer can be used later on.
JNIEXPORT void JNICALL Java_librecognizer_recognizerProxy_initCppSide(JNIEnv *env, jobject obj)
{
	assert(env);
	// create the new recognizer
	nnsl_mcst_recogniser* recognizer = new nnsl_mcst_recogniser;
//	cout << "Recognizer created at location: " << recognizer << endl;
	
	// store the pointer on the Java side
	jclass recognizerProxy_cls = tryFindClass(env, "librecognizer/recognizerProxy");
	jfieldID recognizerProxy_pointer_fid = tryGetFieldID(env, recognizerProxy_cls, "recognizerPtr", "J");
	env->SetLongField(obj, recognizerProxy_pointer_fid, (jlong)recognizer);
	
//	cout << "Pointer stored in Java proxy object: " <<  getRecognizerPointer(env, obj) << endl;	
}

// Function loads a trained recognizer from file.
JNIEXPORT jboolean JNICALL Java_librecognizer_recognizerProxy_readModelFile(JNIEnv * env, 
																			jobject obj,
																			jstring filename)
{
	assert(env);
	// retrieve a pointer to the recognizer
	nnsl_mcst_recogniser * recognizer = getRecognizerPointer(env, obj);
	// create a C++ string from the characters
	const char * local_chars = env->GetStringUTFChars(filename, 0); 
	string filenameStr(local_chars);
	cout << "Loading model file: " << filenameStr << endl;
	// load the model from file and clean up
	jboolean success = static_cast<jboolean>(recognizer->readModelFile(filenameStr));	
	env->ReleaseStringUTFChars(filename, local_chars);
	
	return success;
}

// Function returns a most likely prediction given a serialized version of a stroke.
JNIEXPORT jint JNICALL Java_librecognizer_recognizerProxy_classifySample(JNIEnv * env,
																		 jobject obj,
																		 jfloatArray serializedStroke)
{
	assert(env);
	// retrieve a pointer to the recognizer	
	nnsl_mcst_recogniser * recognizer = getRecognizerPointer(env, obj);
	jboolean * isCopy = NULL;
	// find the length of the serialized data so that it can be reconstructed into an StlStroke
	jsize serializedStrokeLength = env->GetArrayLength(serializedStroke);
	jfloat * strokeData = env->GetFloatArrayElements(serializedStroke, isCopy);
	// check that the array contents are available
	if(!strokeData)
		fatalError("Couldn't extract the array of serialized stroke data!");
	// classify the stroke
	jint classification = static_cast<jint>(recognizer->classifyWithArray(
											static_cast<float*>(strokeData),
											static_cast<int>(serializedStrokeLength)));	
	// release array elements on the native side
	env->ReleaseFloatArrayElements(serializedStroke, strokeData, JNI_ABORT);

	return classification;
}

// Function probabilistically predicts the stroke given a serialized version of a stroke. A
// serialized mapping of class IDs to prediction probabilities is returned.
JNIEXPORT jfloatArray JNICALL Java_librecognizer_recognizerProxy_classifySampleProb(JNIEnv * env,
																					jobject obj,
																					jfloatArray serializedStroke)
{
	assert(env);
	// retrieve a pointer to the recognizer	
	nnsl_mcst_recogniser * recognizer = getRecognizerPointer(env, obj);
	jboolean * isCopy = NULL;
	// find the length of the serialized stroke data so that it can be reconstructed into
	// an StlStroke
	jsize serializedStrokeLength = env->GetArrayLength(serializedStroke);
	jfloat * strokeData = env->GetFloatArrayElements(serializedStroke, isCopy);
	// check that the array contents are available
	if(!strokeData)
		fatalError("Couldn't extract the array of serialized stroke data!");
	// storage for the length of the serialized map
	int serializedMapLength = 0;
	// classify the stroke
	float * probArray = recognizer->classifyProbWithArray(static_cast<float*>(strokeData),
														  static_cast<int>(serializedStrokeLength),
														  &serializedMapLength);

	// allocate a new Java array to hold the serialized map
	jfloatArray serializedMap = env->NewFloatArray(static_cast<jsize>(serializedMapLength));
	// if the Java array was allocated, copy the serialized data into the array
	if(serializedMap)
		env->SetFloatArrayRegion(serializedMap, 0, serializedMapLength,
								 static_cast<jfloat*>(probArray));
	
	// release serialized stroke data on the native side
	env->ReleaseFloatArrayElements(serializedStroke, strokeData, JNI_ABORT);
	
	return serializedMap;
}
