#include <ttr_environment.hpp>
#include <ttr_support_functions.hpp>
#include <algorithm>
#include <cmath>

//for definition see below
namespace ttr{
template<TTRVariant V>
void site(double *out,
          size_t const site,
          const arma::cube& timeseries_in,
          const arma::mat& timeinvariant_in,
          const ParameterVector<AlphaParameter>& alpha_par,
          const std::vector<ParameterVector<BetaParameter>>& beta_pars,
          const environment& env);
  //  std::tuple<ParameterVector<AlphaParameter>,std::vector<ParameterVector<BetaParameter>>> parameter_copy(const List& parlist, const environment& env);
}

// [[Rcpp::export]]
Rcpp::NumericVector interval_c(
  Rcpp::NumericVector v,
  Rcpp::NumericMatrix b){
  size_t s = v.length();
  auto ret = Rcpp::NumericVector(s);
  for(size_t ind = 0; ind < s; ind++){
    ret[ind] = ttr::reflect(v[ind],b(ind,0),b(ind,1));
  }
  return ret;
}

// [[Rcpp::export]]
Rcpp::NumericVector run_ttr_cpp(
    Rcpp::List ttr_parameters,
    const arma::cube timeseries,
    const arma::mat timeinvariant,
    const Rcpp::List options,
    const Rcpp::NumericVector globals){
  ttr::environment env = ttr::environment(timeseries,
                                          //timeinvariant,
                                          options,
                                          globals);
  auto maybe_a = Rcpp::as<Rcpp::Nullable<Rcpp::NumericVector>>(ttr_parameters["alpha"]);
  auto a = ParameterVector<AlphaParameter>();
  if(maybe_a.isNotNull()){
    a = ParameterVector<AlphaParameter>(Rcpp::as<Rcpp::NumericVector>(ttr_parameters["alpha"]));
  }
  auto b_mat = Rcpp::as<Rcpp::NumericMatrix>(ttr_parameters["beta"]);
  auto b = ParameterVector<BetaParameter>::fromMatrix(b_mat);

  auto res = Rcpp::NumericVector(env.nsites*env.nout*env.nspecies*env.ntimeout);

  switch(env.var){
  case TTRVariant::std:
    for(size_t site = 0; site < env.nsites; site++){
      ttr::site<TTRVariant::std>(res.begin(), site, timeseries, timeinvariant, a, b, env);
    }
    break;
  case TTRVariant::fqr:
    for(size_t site = 0; site < env.nsites; site++){
      ttr::site<TTRVariant::fqr>(res.begin(), site, timeseries, timeinvariant, a, b, env);
    }
    break;
  case TTRVariant::red:
    for(size_t site = 0; site < env.nsites; site++){
      ttr::site<TTRVariant::red>(res.begin(), site, timeseries, timeinvariant, a, b, env);
    }
    break;
  case TTRVariant::oak:
    for(size_t site = 0; site < env.nsites; site++){
      ttr::site<TTRVariant::oak>(res.begin(), site, timeseries, timeinvariant, a, b, env);
    }
    break;
  default:
    Rcpp::Rcerr << "Wrong TTR Variant" << std::endl;
  }

  return res;
}

// [[Rcpp::export]]
Rcpp::StringVector varnames(std::string x){
  if(x.compare("TimeSeries") == 0){
    return varnames<TimeSeries>();
  }
  if(x.compare("Global") == 0){
    return varnames<Global>();
  }
  if(x.compare("TimeInvariant") == 0){
    return varnames<TimeInvariant>();
  }
  if(x.compare("Biotic") == 0){
    return varnames<Biotic>();
  }
  if(x.compare("AlphaParameter") == 0){
    return varnames<AlphaParameter>();
  }
  if(x.compare("BetaParameter") == 0){
    return varnames<BetaParameter>();
  }
  if(x.compare("TTRVariant") == 0){
    return varnames<TTRVariant>();
  }
  if(x.compare("PhotoType") == 0){
    return varnames<PhotoType>();
  }
  Rcpp::stop("Wrong variable type given");
  return "";
}

namespace ttr{
template <TTRVariant V>
void site(double* out,
          size_t const site,
          const arma::cube& timeseries_in,
          const arma::mat& timeinvariant_in,
          const ParameterVector<AlphaParameter>& alpha_par,
          const std::vector<ParameterVector<BetaParameter>>& beta_pars,
          const environment& env){

  //the global constants
  auto const& globals = env.globals;

  double * output = out + enum_count<Biotic>()*env.nspecies*env.ntimeout*site;

  //iteration variables over step and species
  size_t step = 0;
  size_t species = 0;

  //auto ind_out = env.steps.begin();
  int next_step = 0;
  size_t ind_out = static_cast<size_t>(env.steps[next_step]) - 1;

  //arrays for temporary data saving
  Mat<double> frame_a(enum_count<State>(), env.nspecies, fill::zeros);
  Mat<double> frame_b(enum_count<State>(), env.nspecies, fill::zeros);
  auto present = &frame_a;
  auto future = &frame_b;

  const auto out_step_size = enum_count<Biotic>() * env.nspecies;

  size_t ndata = env.ndata;

  auto present_b =
    [&present, &species](const State x) -> double& {
        return present->at(enum_value<State>(x),species);};
  auto future_b =
    [&future, &species](const State x) -> double& {
        return future->at(enum_value<State>(x), species);};
  auto present_ts =
    [&step, site, &timeseries_in, ndata](const TimeSeries x) -> double {
        return timeseries_in.at(enum_value<TimeSeries>(x),step % ndata,site);};
  auto write =
    [&output, &species](const Biotic b, double v) {
      output[species * enum_count<Biotic>() + enum_value(b)] = v;
    };

  const double nsoil = timeinvariant_in(enum_value(TimeInvariant::nsoil),site);

  //adds up biomass of all species
  double bsum = 0;

  for(species = 0; species < env.nspecies; species++){

    //get parameter vector for species
    ParameterVector<BetaParameter> const& beta_par = beta_pars[species];

    //get initial size
    double init_size;
    if(beta_par[BetaParameter::iM] != 0.0){
      init_size = beta_par[BetaParameter::iM];
    }
    else{
      init_size = env.initial_mass;
    }
    //the variables to set initially and the initial value
    const std::pair<State,double> initial_set[] = {
      {State::Ms, 1 * init_size},
      {State::Mr, 1 * init_size},
      {State::Cs, 0.05 * init_size},
      {State::Cr, 0.05 * init_size},
      {State::Ns, 0.01 * init_size},
      {State::Nr, 0.01 * init_size}
    };
    //set initial size
    for(auto const& var : initial_set){
      present_b(var.first) = var.second;
    }
  }
  //globals
  const double KM = globals[Global::KM];
  const double KA = globals[Global::KA];
  const double Jc = globals[Global::Jc];
  const double Jn = globals[Global::Jn];
  const double RHOc = globals[Global::RHOc];
  const double RHOn = globals[Global::RHOn];
  const double Fc = globals[Global::Fc];
  const double Fn = globals[Global::Fn];

  //calculation of the model
  for(step = 0; step < env.nsteps; step++){
    const double swc = present_ts(TimeSeries::swc);
    const double tcur = present_ts(TimeSeries::tcur);
    const double tnur = present_ts(TimeSeries::tnur);
    const double tgrowth = present_ts(TimeSeries::tgrowth);
    const double tloss = present_ts(TimeSeries::tloss);
    const double ppfd = present_ts(TimeSeries::ppfd);

    for(species = 0; species < env.nspecies; species++){
      TimeSeries photo = TimeSeries::C3p;
      switch(env.photos[species]){
      case PhotoType::c3 :
        photo = TimeSeries::C3p;
        break;
      case PhotoType::c4 :
        photo = TimeSeries::C4p;
        break;
      default:
        stop("Wrong photo type");
      }
      double pht = present_ts(photo);

      auto const& beta_par = beta_pars[species];

      double Ms = present_b(State::Ms);
      double Mr = present_b(State::Mr);
      double Ns = present_b(State::Ns);
      double Nr = present_b(State::Nr);
      double Cs = present_b(State::Cs);
      double Cr = present_b(State::Cr);

      bsum += present_b(State::Ms) + present_b(State::Mr);

      double cur;
      double nur;
      double g;
      double loss;

      // process error
      double pe_Ms = 0;
      double pe_Mr = 0;
      double pe_Ns = 0;
      double pe_Nr = 0;
      double pe_Cs = 0;
      double pe_Cr = 0;

      if(env.pe){
        pe_Ms = f_pe(beta_par[BetaParameter::pe_Ms],env.pe_scale);
        pe_Mr = f_pe(beta_par[BetaParameter::pe_Mr],env.pe_scale);
        pe_Ns = f_pe(beta_par[BetaParameter::pe_Ns],env.pe_scale);
        pe_Nr = f_pe(beta_par[BetaParameter::pe_Nr],env.pe_scale);
        pe_Cs = f_pe(beta_par[BetaParameter::pe_Cs],env.pe_scale);
        pe_Cr = f_pe(beta_par[BetaParameter::pe_Cr],env.pe_scale);
      }

      if constexpr(TTRVariant::oak == V){
        loss = get_m_oak(tloss, swc, globals,beta_par);
        cur = get_CUR_RED(swc,Ns,Ms,pht,globals,beta_par);
        nur = get_NUR_RED(tnur,nsoil,swc,globals,beta_par);
        g = get_g_RED(tgrowth,swc,globals,beta_par);
      }
      else if constexpr(TTRVariant::std == V){
        loss = get_m_std(tloss, globals, beta_par);
        cur = get_CUR_STD(tcur,ppfd,swc,Ns,Ms,globals,beta_par);
        nur = get_NUR_STD(tnur,nsoil,swc,globals,beta_par);
        g = get_g_STD(tgrowth,swc,globals,beta_par);
      }
      else if constexpr(TTRVariant::red == V){
        loss = get_m_std(tloss, globals, beta_par);
        cur = get_CUR_RED(swc,Ns,Ms,pht,globals,beta_par);
        nur = get_NUR_RED(tnur,nsoil,swc,globals,beta_par);
        g = get_g_RED(tgrowth,swc,globals,beta_par);
      }
      else if constexpr(TTRVariant::fqr == V){
        loss = get_m_std(tloss, globals, beta_par);
        cur = get_CUR_FQR(swc,Ns,Ms,pht,globals,beta_par);
        nur = get_NUR_STD(tnur,nsoil,swc,globals,beta_par);
        g = get_g_STD(tgrowth,swc,globals,beta_par);
      }
      else {
        not_implemented<V>();
      }

      const double Mr_loss = (loss * Mr) / (1.0 + KM / Mr);
      const double Ms_loss = (loss * Ms) / (1.0 + KM / Ms);

      // transport resistance and loss
      //SIMPLIFIED: eliminated Q
      const double RN = Ms * Mr / (RHOn * (Mr + Ms));
      const double RC = Ms * Mr / (RHOc * (Mr + Ms));
      const double TAUc = F_TAUc(Cs,Cr,Ms,Mr,RC);
      const double TAUn = F_TAUn(Ns,Nr,Ms,Mr,RN);

      // uptake and growth parameters
      const double Uc = F_Uc(cur,Ms,KA,Cs,Jc);
      const double Un = F_Un(nur,Mr,KA,Nr,Jn);
      const double Gs = F_Gs(g,Ms,Cs,Ns);
      const double Gr = F_Gr(g,Mr,Cr,Nr);

      if(step >= ind_out){
        write(Biotic::Ms, Ms);
        write(Biotic::Mr, Mr);
        write(Biotic::Cs, Cs);
        write(Biotic::Cr, Cr);
        write(Biotic::Ns, Ns);
        write(Biotic::Nr, Nr);
        write(Biotic::TAUc, TAUc);
        write(Biotic::TAUn, TAUn);
        write(Biotic::Uc, Uc);
        write(Biotic::Un, Un);
        write(Biotic::Gs, Gs);
        write(Biotic::Gr, Gr);
        write(Biotic::Mr_loss, Mr_loss);
        write(Biotic::Ms_loss, Ms_loss);
        write(Biotic::CUR, cur);
        write(Biotic::NUR, nur);
        write(Biotic::g, g);
        write(Biotic::loss, loss);
      }

      // update state variables
      future_b(State::Ms) = std::fmax(0.0, Ms + Gs - Ms_loss + pe_Ms);
      future_b(State::Mr) = std::fmax(0.0, Mr + Gr - Mr_loss + pe_Mr);
      future_b(State::Cs) = std::fmax(0.0, Cs + F_dCs_dt(Uc,Fc,Gs,TAUc) + pe_Cs);
      future_b(State::Cr) = std::fmax(0.0, Cr + F_dCr_dt(Fc,Gr,TAUc) + pe_Cr);
      future_b(State::Ns) = std::fmax(0.0, Ns + F_dNs_dt(Fn,Gs,TAUn) + pe_Ns);
      future_b(State::Nr) = std::fmax(0.0, Nr + F_dNr_dt(Un,Fn,Gr,TAUn) + pe_Nr);
    }

    if(step >= ind_out){
      next_step ++ ;
      if(next_step >= env.steps.length())
        break;
      ind_out = static_cast<size_t>(env.steps[next_step]) - 1;
      output += out_step_size;
    }

    //swap the buffers
    std::swap(present, future);
    future->zeros();
  }
}

}
