/*
 *  Copyright 2007-2015 The OpenMx Project
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 *   Unless required by applicable law or agreed to in writing, software
 *   distributed under the License is distributed on an "AS IS" BASIS,
 *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

#include <stdio.h>
#include <sys/types.h>
#include <errno.h>

#define R_NO_REMAP
#include <R.h>
#include <Rinternals.h>
#include <R_ext/Rdynload.h>
#include <R_ext/BLAS.h>
#include <R_ext/Lapack.h>

#include "omxDefines.h"
#include "glue.h"
#include "omxState.h"
#include "omxMatrix.h"
#include "omxAlgebra.h"
#include "omxFitFunction.h"
#include "omxExpectation.h"
#include "omxNPSOLSpecific.h"
#include "omxImportFrontendState.h"
#include "omxExportBackendState.h"
#include "Compute.h"
#include "dmvnorm.h"
#include "npsolswitch.h"
#include "omxCsolnp.h"

void markAsDataFrame(SEXP list)
{
	SEXP classes;
	Rf_protect(classes = Rf_allocVector(STRSXP, 1));
	SET_STRING_ELT(classes, 0, Rf_mkChar("data.frame"));
	Rf_setAttrib(list, R_ClassSymbol, classes);
}

static SEXP do_logm_eigen(SEXP x)
{
    SEXP dims, z;
    int n, m;
    double *rx = REAL(x), *rz;

    if (!Rf_isNumeric(x) || !Rf_isMatrix(x)) Rf_error("invalid argument");

    dims = Rf_getAttrib(x, R_DimSymbol);
    n = INTEGER(dims)[0];
    m = INTEGER(dims)[0];
    if (n != m) Rf_error("non-square matrix");
    if (n == 0) return(Rf_allocVector(REALSXP, 0));

    ScopedProtect p1(z, Rf_allocMatrix(REALSXP, n, n));
    rz = REAL(z);

    logm_eigen(n, rx, rz);

    Rf_setAttrib(z, R_DimNamesSymbol, Rf_getAttrib(x, R_DimNamesSymbol));

    return z;
}

static SEXP do_expm_eigen(SEXP x)
{
    SEXP dims, z;
    int n, m;
    double *rx = REAL(x), *rz;

    if (!Rf_isNumeric(x) || !Rf_isMatrix(x)) Rf_error("invalid argument");

    dims = Rf_getAttrib(x, R_DimSymbol);
    n = INTEGER(dims)[0];
    m = INTEGER(dims)[0];
    if (n != m) Rf_error("non-square matrix");
    if (n == 0) return(Rf_allocVector(REALSXP, 0));

    ScopedProtect(z, Rf_allocMatrix(REALSXP, n, n));
    rz = REAL(z);

    expm_eigen(n, rx, rz);

    Rf_setAttrib(z, R_DimNamesSymbol, Rf_getAttrib(x, R_DimNamesSymbol));

    return z;
}

static SEXP has_NPSOL()
{ return Rf_ScalarLogical(HAS_NPSOL); }

static SEXP has_openmp()
{
#if defined(_OPENMP)
	bool opm = true;
#else
	bool opm = false;
#endif
	return Rf_ScalarLogical(opm);
}

static SEXP testMxLog(SEXP Rstr) {
	mxLog("%s", CHAR(Rf_asChar(Rstr)));
	return Rf_ScalarLogical(1);
}

static int untitledCounter = 0;

static SEXP untitledNumberReset() {
	untitledCounter = 0;
	return Rf_ScalarLogical(1);
}

static SEXP untitledNumber() {
	return Rf_ScalarInteger(++untitledCounter);
}

void string_to_try_Rf_error( const std::string& str )
{
	Rf_error("%s", str.c_str());
}

void exception_to_try_Rf_error( const std::exception& ex )
{
	string_to_try_Rf_error(ex.what());
}

SEXP MxRList::asR()
{
	// detect duplicate keys? TODO
	SEXP names, ans;
	int len = size();
	Rf_protect(names = Rf_allocVector(STRSXP, len));
	Rf_protect(ans = Rf_allocVector(VECSXP, len));
	for (int lx=0; lx < len; ++lx) {
		const char *p1 = (*this)[lx].first;
		SEXP p2 = (*this)[lx].second;
		if (!p1 || !p2) Rf_error("Attempt to return NULL pointer to R");
		SET_STRING_ELT(names, lx, Rf_mkChar(p1));
		SET_VECTOR_ELT(ans,   lx, p2);
	}
	Rf_namesgets(ans, names);
	return ans;
}

static void
friendlyStringToLogical(const char *key, const char *str, int *out)
{
	int understood = FALSE;
	int newVal;
	if (matchCaseInsensitive(str, "Yes")) {
		understood = TRUE;
		newVal = 1;
	} else if (matchCaseInsensitive(str, "No")) {
		understood = TRUE;
		newVal = 0;
	} else if (isdigit(str[0]) && (atoi(str) == 1 || atoi(str) == 0)) {
		understood = TRUE;
		newVal = atoi(str);
	}
	if (!understood) {
		Rf_warning("Expecting 'Yes' or 'No' for '%s' but got '%s', ignoring", key, str);
		return;
	}
	if(OMX_DEBUG) { mxLog("%s=%d", key, newVal); }
	*out = newVal;
}

// TODO: make member of omxGlobal class
static void readOpts(SEXP options, int *numThreads, int *analyticGradients)
{
		int numOptions = Rf_length(options);
		SEXP optionNames;
		Rf_protect(optionNames = Rf_getAttrib(options, R_NamesSymbol));
		for(int i = 0; i < numOptions; i++) {
			const char *nextOptionName = CHAR(STRING_ELT(optionNames, i));
			const char *nextOptionValue = CHAR(Rf_asChar(VECTOR_ELT(options, i)));
			if(matchCaseInsensitive(nextOptionName, "Analytic Gradients")) {
				friendlyStringToLogical(nextOptionName, nextOptionValue, analyticGradients);
			} else if(matchCaseInsensitive(nextOptionName, "loglikelihoodScale")) {
				Global->llScale = atof(nextOptionValue);
			} else if(matchCaseInsensitive(nextOptionName, "debug protect stack")) {
				friendlyStringToLogical(nextOptionName, nextOptionValue, &Global->debugProtectStack);
			} else if(matchCaseInsensitive(nextOptionName, "Number of Threads")) {
#ifdef _OPENMP
				*numThreads = atoi(nextOptionValue);
				if (*numThreads < 1) {
					Rf_warning("Computation will be too slow with %d threads; using 1 thread instead", *numThreads);
					*numThreads = 1;
				}
#endif
			} else if(matchCaseInsensitive(nextOptionName, "mvnMaxPointsA")) {
				Global->maxptsa = atof(nextOptionValue);
			} else if(matchCaseInsensitive(nextOptionName, "mvnMaxPointsB")) {
				Global->maxptsb = atof(nextOptionValue);
			} else if(matchCaseInsensitive(nextOptionName, "mvnMaxPointsC")) {
				Global->maxptsc = atof(nextOptionValue);
			} else if(matchCaseInsensitive(nextOptionName, "mvnAbsEps")) {
				Global->absEps = atof(nextOptionValue);
			} else if(matchCaseInsensitive(nextOptionName, "mvnRelEps")) {
				Global->relEps = atof(nextOptionValue);
			} else if(matchCaseInsensitive(nextOptionName, "maxStackDepth")) {
				Global->maxStackDepth = atoi(nextOptionValue);
			} else if (matchCaseInsensitive(nextOptionName, "Feasibility tolerance")) {
				Global->feasibilityTolerance = atof(nextOptionValue);
			} else if (matchCaseInsensitive(nextOptionName, "Optimality tolerance")) {
				Global->optimalityTolerance = atof(nextOptionValue);
			} else if (matchCaseInsensitive(nextOptionName, "Major iterations")) {
				Global->majorIterations = atoi(nextOptionValue);
			} else if (matchCaseInsensitive(nextOptionName, "Intervals")) {
				Global->intervals = Rf_asLogical(VECTOR_ELT(options, i));
			} else if (matchCaseInsensitive(nextOptionName, "Major iteration_CSOLNP")) {
				CSOLNPOpt_majIter(nextOptionValue);
			} else if (matchCaseInsensitive(nextOptionName, "Minor iteration_CSOLNP")) {
				CSOLNPOpt_minIter(nextOptionValue);
			} else if (matchCaseInsensitive(nextOptionName, "Function precision_CSOLNP")) {
				CSOLNPOpt_FuncPrecision(nextOptionValue);
			} else {
				// ignore
			}
		}
}

/* Main functions */
SEXP omxCallAlgebra2(SEXP matList, SEXP algNum, SEXP options) {

	omxManageProtectInsanity protectManager;

	if(OMX_DEBUG) { mxLog("-----------------------------------------------------------------------");}
	if(OMX_DEBUG) { mxLog("Explicit call to algebra %d.", INTEGER(algNum)[0]);}

	int j,k,l;
	omxMatrix* algebra;
	int algebraNum = INTEGER(algNum)[0];
	SEXP ans, nextMat;

	FitContext::setRFitFunction(NULL);
	Global = new omxGlobal;

	omxState *globalState = new omxState;

	readOpts(options, &Global->numThreads, &Global->analyticGradients);

	/* Retrieve All Matrices From the MatList */

	if(OMX_DEBUG) { mxLog("Processing %d matrix(ces).", Rf_length(matList));}

	std::vector<omxMatrix *> args(Rf_length(matList));
	for(k = 0; k < Rf_length(matList); k++) {
		Rf_protect(nextMat = VECTOR_ELT(matList, k));	// This is the matrix + populations
		args[k] = omxNewMatrixFromRPrimitive(nextMat, globalState, 1, - k - 1);
		globalState->matrixList.push_back(args[k]);
		if(OMX_DEBUG) {
			mxLog("Matrix[%d] initialized (%d x %d)",
				k, globalState->matrixList[k]->rows, globalState->matrixList[k]->cols);
		}
	}

	algebra = omxNewAlgebraFromOperatorAndArgs(algebraNum, args.data(), Rf_length(matList), globalState);

	if(algebra==NULL) {
		Rf_error("Failed to build algebra");
	}

	if(OMX_DEBUG) {mxLog("Completed Algebras and Matrices.  Beginning Initial Compute.");}

	omxRecompute(algebra, NULL);

	Rf_protect(ans = Rf_allocMatrix(REALSXP, algebra->rows, algebra->cols));
	for(l = 0; l < algebra->rows; l++)
		for(j = 0; j < algebra->cols; j++)
			REAL(ans)[j * algebra->rows + l] =
				omxMatrixElement(algebra, l, j);

	if(OMX_DEBUG) { mxLog("All Algebras complete."); }

	const char *bads = Global->getBads();

	omxFreeMatrix(algebra);
	delete globalState;
	delete Global;

	if (bads) Rf_error(bads);

	return ans;
}

SEXP omxCallAlgebra(SEXP matList, SEXP algNum, SEXP options)
{
	try {
		return omxCallAlgebra2(matList, algNum, options);
	} catch( std::exception& __ex__ ) {
		exception_to_try_Rf_error( __ex__ );
	} catch(...) {
		string_to_try_Rf_error( "c++ exception (unknown reason)" );
	}
}

static double internalToUserBound(double val, double inf)
{
	if (val == inf) return NA_REAL;
	return val;
}

SEXP omxBackend2(SEXP constraints, SEXP matList,
		 SEXP varList, SEXP algList, SEXP expectList, SEXP computeList,
		 SEXP data, SEXP intervalList, SEXP checkpointList, SEXP options,
		 SEXP defvars)
{
	SEXP nextLoc;

	/* Sanity Check and Parse Inputs */
	/* TODO: Need to find a way to account for nullness in these.  For now, all checking is done on the front-end. */
//	if(!isVector(matList)) Rf_error ("matList must be a list");
//	if(!isVector(algList)) Rf_error ("algList must be a list");

	omxManageProtectInsanity protectManager;

	FitContext::setRFitFunction(NULL);
	Global = new omxGlobal;

	/* Create new omxState for current state storage and initialize it. */
	omxState *globalState = new omxState;

	readOpts(options, &Global->numThreads, &Global->analyticGradients);
#if HAS_NPSOL
	omxSetNPSOLOpts(options);
#endif

	if (Global->debugProtectStack) mxLog("Protect depth at line %d: %d", __LINE__, protectManager.getDepth());
	globalState->omxProcessMxDataEntities(data, defvars);
    
	if (Global->debugProtectStack) mxLog("Protect depth at line %d: %d", __LINE__, protectManager.getDepth());
	globalState->omxProcessMxExpectationEntities(expectList);

	if (Global->debugProtectStack) mxLog("Protect depth at line %d: %d", __LINE__, protectManager.getDepth());
	globalState->omxProcessMxMatrixEntities(matList);

	if (Global->debugProtectStack) mxLog("Protect depth at line %d: %d", __LINE__, protectManager.getDepth());
	std::vector<double> startingValues;
	omxProcessFreeVarList(varList, &startingValues);
	FitContext *fc = new FitContext(globalState, startingValues);
	Global->fc = fc;
	fc->copyParamToModelClean();

	if (Global->debugProtectStack) mxLog("Protect depth at line %d: %d", __LINE__, protectManager.getDepth());
	globalState->omxProcessMxAlgebraEntities(algList);

	/* Process Matrix and Algebra Population Function */
	/*
	  Each matrix is a list containing a matrix and the other matrices/algebras that are
	  populated into it at each iteration.  The first element is already processed, above.
	  The rest of the list will be processed here.
	*/
	if (Global->debugProtectStack) mxLog("Protect depth at line %d: %d", __LINE__, protectManager.getDepth());
	for(int j = 0; j < Rf_length(matList); j++) {
		Rf_protect(nextLoc = VECTOR_ELT(matList, j));		// This is the matrix + populations
		globalState->matrixList[j]->omxProcessMatrixPopulationList(nextLoc);
	}

	if (Global->debugProtectStack) mxLog("Protect depth at line %d: %d", __LINE__, protectManager.getDepth());
	omxInitialMatrixAlgebraCompute(globalState, NULL);

	if (Global->debugProtectStack) mxLog("Protect depth at line %d: %d", __LINE__, protectManager.getDepth());
	globalState->omxCompleteMxExpectationEntities();

	for (int dx=0; dx < (int) globalState->dataList.size(); ++dx) {
		globalState->dataList[dx]->connectDynamicData();
	}

	if (Global->debugProtectStack) mxLog("Protect depth at line %d: %d", __LINE__, protectManager.getDepth());
	globalState->omxCompleteMxFitFunction(algList);

	if (Global->debugProtectStack) mxLog("Protect depth at line %d: %d", __LINE__, protectManager.getDepth());
	Global->omxProcessMxComputeEntities(computeList, globalState);

	// Nothing depend on constraints so we can process them last.
	if (Global->debugProtectStack) mxLog("Protect depth at line %d: %d", __LINE__, protectManager.getDepth());
	globalState->omxProcessConstraints(constraints, fc);

	if (isErrorRaised()) {
		Rf_error(Global->getBads());
	}

	globalState->loadDefinitionVariables(true);

	globalState->setWantStage(FF_COMPUTE_FIT);

	omxCompute *topCompute = NULL;
	if (Global->computeList.size()) topCompute = Global->computeList[0];

	if (Global->debugProtectStack) mxLog("Protect depth at line %d: %d", __LINE__, protectManager.getDepth());
	Global->omxProcessConfidenceIntervals(intervalList, globalState);

	omxProcessCheckpointOptions(checkpointList);

	Global->cacheDependencies(globalState);

	if (Global->debugProtectStack) mxLog("Protect depth at line %d: %d", __LINE__, protectManager.getDepth());
	if (protectManager.getDepth() > Global->maxStackDepth) {
		Rf_error("Protection stack too large; report this problem to the OpenMx forum");
	}

	if (topCompute && !isErrorRaised()) {
		topCompute->compute(fc);

		if ((fc->wanted & FF_COMPUTE_FIT) && !std::isfinite(fc->fit) &&
		    fc->inform != INFORM_STARTING_VALUES_INFEASIBLE) {
			std::string diag = fc->getIterationError();
			omxRaiseErrorf("fit is not finite (%s)", diag.c_str());
		}
	}

	SEXP evaluations;
	Rf_protect(evaluations = Rf_allocVector(REALSXP,1));

	REAL(evaluations)[0] = Global->computeCount;

	MxRList result;

	if (Global->debugProtectStack) mxLog("Protect depth at line %d: %d", __LINE__, protectManager.getDepth());
	globalState->omxExportResults(&result);

	if (topCompute && !isErrorRaised()) {
		LocalComputeResult cResult;
		topCompute->collectResults(fc, &cResult, &result);

		if (cResult.size()) {
			SEXP computes;
			Rf_protect(computes = Rf_allocVector(VECSXP, cResult.size() * 2));
			for (size_t cx=0; cx < cResult.size(); ++cx) {
				std::pair<int, MxRList*> &c1 = cResult[cx];
				SET_VECTOR_ELT(computes, cx*2, Rf_ScalarInteger(c1.first));
				SET_VECTOR_ELT(computes, cx*2+1, c1.second->asR());
				delete c1.second;
			}
			result.add("computes", computes);
		}

		if (fc->wanted & FF_COMPUTE_FIT) {
			result.add("fit", Rf_ScalarReal(fc->fit));
			if (fc->fitUnits) {
				SEXP units;
				Rf_protect(units = Rf_allocVector(STRSXP, 1));
				SET_STRING_ELT(units, 0, Rf_mkChar(fitUnitsToName(fc->fitUnits)));
				result.add("fitUnits", units);
			}
			result.add("Minus2LogLikelihood", Rf_ScalarReal(fc->fit));
		}
		if (fc->wanted & FF_COMPUTE_BESTFIT) {
			result.add("minimum", Rf_ScalarReal(fc->fit));
		}

		FreeVarGroup *varGroup = Global->findVarGroup(FREEVARGROUP_ALL);
		int numFree = int(varGroup->vars.size());
		if (numFree) {
			SEXP estimate;
			Rf_protect(estimate = Rf_allocVector(REALSXP, numFree));
			memcpy(REAL(estimate), fc->est, sizeof(double)*numFree);
			result.add("estimate", estimate);

			if (Global->boundsUpdated) {
				MxRList bret;
				SEXP Rlb = Rf_allocVector(REALSXP, numFree);
				bret.add("l", Rlb);
				SEXP Rub = Rf_allocVector(REALSXP, numFree);
				bret.add("u", Rub);
				double *lb = REAL(Rlb);
				double *ub = REAL(Rub);
				for(int px = 0; px < numFree; px++) {
					lb[px] = internalToUserBound(varGroup->vars[px]->lbound, NEG_INF);
					ub[px] = internalToUserBound(varGroup->vars[px]->ubound, INF);
				}
				result.add("bounds", bret.asR());
			}
			if (fc->stderrs) {
				SEXP stdErrors;
				Rf_protect(stdErrors = Rf_allocMatrix(REALSXP, numFree, 1));
				memcpy(REAL(stdErrors), fc->stderrs, sizeof(double) * numFree);
				result.add("standardErrors", stdErrors);
			}
			if (fc->wanted & (FF_COMPUTE_HESSIAN | FF_COMPUTE_IHESSIAN)) {
				result.add("infoDefinite", Rf_ScalarLogical(fc->infoDefinite));
				result.add("conditionNumber", Rf_ScalarReal(fc->infoCondNum));
			}
		}
	}

	if (Global->debugProtectStack) mxLog("Protect depth at line %d: %d", __LINE__, protectManager.getDepth());
	MxRList backwardCompatStatus;
	backwardCompatStatus.add("code", Rf_ScalarInteger(fc->inform));
	backwardCompatStatus.add("status", Rf_ScalarInteger(-isErrorRaised()));

	if (isErrorRaised()) {
		SEXP msg;
		Rf_protect(msg = Rf_allocVector(STRSXP, 1));
		SET_STRING_ELT(msg, 0, Rf_mkChar(Global->getBads()));
		result.add("error", msg);
		backwardCompatStatus.add("statusMsg", msg);
	}

	result.add("status", backwardCompatStatus.asR());
	result.add("iterations", Rf_ScalarInteger(fc->iterations));
	result.add("evaluations", evaluations);

	// Data are not modified and not copied. The same memory
	// is shared across all instances of state.
	// NOTE: This may need to change for MxDataDynamic
	for(size_t dx = 0; dx < globalState->dataList.size(); dx++) {
		omxFreeData(globalState->dataList[dx]);
	}

	if (Global->debugProtectStack) mxLog("Protect depth at line %d: %d", __LINE__, protectManager.getDepth());
	delete Global;

	return result.asR();
}

static SEXP omxBackend(SEXP constraints, SEXP matList,
		SEXP varList, SEXP algList, SEXP expectList, SEXP computeList,
		SEXP data, SEXP intervalList, SEXP checkpointList, SEXP options,
		SEXP defvars)
{
	try {
		return omxBackend2(constraints, matList,
				   varList, algList, expectList, computeList,
				   data, intervalList, checkpointList, options, defvars);
	} catch( std::exception& __ex__ ) {
		exception_to_try_Rf_error( __ex__ );
	} catch(...) {
		string_to_try_Rf_error( "c++ exception (unknown reason)" );
	}
}

static R_CallMethodDef callMethods[] = {
	{"backend", (DL_FUNC) omxBackend, 11},
	{"callAlgebra", (DL_FUNC) omxCallAlgebra, 3},
	{"findIdenticalRowsData", (DL_FUNC) findIdenticalRowsData, 5},
	{"Dmvnorm_wrapper", (DL_FUNC) dmvnorm_wrapper, 3},
	{"hasNPSOL_wrapper", (DL_FUNC) has_NPSOL, 0},
	{"sparseInvert_wrapper", (DL_FUNC) sparseInvert_wrapper, 1},
	{"hasOpenMP_wrapper", (DL_FUNC) has_openmp, 0},
	{"do_logm_eigen", (DL_FUNC) &do_logm_eigen, 1},
	{"do_expm_eigen", (DL_FUNC) &do_expm_eigen, 1},
	{"Log_wrapper", (DL_FUNC) &testMxLog, 1},
	{"untitledNumberReset", (DL_FUNC) &untitledNumberReset, 0},
	{"untitledNumber", (DL_FUNC) &untitledNumber, 0},
	{NULL, NULL, 0}
};

#ifdef  __cplusplus
extern "C" {
#endif

void R_init_OpenMx(DllInfo *info) {
	R_registerRoutines(info, NULL, callMethods, NULL, NULL);

	// There is no code that will change behavior whether openmp
	// is set for nested or not. I'm just keeping this in case it
	// makes a difference with older versions of openmp. 2012-12-24 JNP
#if defined(_OPENMP) && _OPENMP <= 200505
	omp_set_nested(0);
#endif
}

void R_unload_OpenMx(DllInfo *) {
	// keep this stub in case we need it
}

#ifdef  __cplusplus
}
#endif

