#include "scalablebayesm.h"
#include <RcppArmadillo.h>

using namespace Rcpp;
using namespace arma;

double lllinear(arma::vec const& beta, arma::vec const& y, arma::mat const& X, double const& sigmasq);


//[[Rcpp::export]]

Rcpp::List rheteroLinearIndepMetrop_rcpp_loop(Rcpp::List const& Data, arma::mat const& betadraws,
                                               Rcpp::List const& Mcmc, int nu, arma::vec ssq)
{
  int R, k, N,keep, mkeep;
  double s;
  keep = Mcmc["keep"];
  R = betadraws.n_rows;
  k = betadraws.n_cols;

  Rcpp::List regdata = Data["regdata"];
  N = regdata.length();
  Rcpp::List regdatai;

  arma::cube betadraw(N, k, floor((R)/keep));
  arma::mat taudraw(N, floor((R)/keep));
  arma::rowvec betainit(k);
  double taup;
  betainit = mean(betadraws, 0);
  arma::rowvec betap(k);
  double ratio, unif;

  arma::mat currentbetadraw(N, k);
  arma::vec currenttaudraw(N);
  arma::vec logc(N), logp(N);

  arma::mat betac = repmat(betainit, N, 1);

  arma::vec tauc(N);

  for (int i=0; i<N; i++){

    regdatai = regdata[i];

    arma::vec y = regdatai["y"];
    arma::mat X = regdatai["X"];

    s = sum(square(y-X*betainit.t()));
    tauc(i) = (s + nu*ssq(i))/rchisq(1,nu+y.size())[0];
    logc(i) = lllinear(trans(betainit), y, X, tauc(i));
  }

  for (int rep=0; rep<R; rep++) {

    betap = betadraws.row(rep);

    for (int i=0; i<N; i++) {

      regdatai = regdata[i];
      arma::vec y = regdatai["y"];
      arma::mat X = regdatai["X"];

      s = sum(square(y-X*betap.t()));
      taup = (s + nu*ssq(i))/rchisq(1,nu+y.size())[0];
      logp(i) = lllinear(trans(betap), y, X, taup);

      if (logp(i) - logc(i) > 0){

        ratio = 1;}

      else{
      ratio = exp(logp(i) - logc(i));}
      unif = Rcpp::runif(1)[0];
      if (unif < ratio) {

        currentbetadraw.row(i) = betap;
        betac.row(i) = betap;
        currenttaudraw(i) = taup;
        tauc(i) = taup;
        logc(i) = logp(i);

      } else {
        currentbetadraw.row(i) = betac.row(i);
        currenttaudraw(i) = tauc(i);

      }
    }

    if (((rep + 1) > 0) && (((rep + 1) % keep) == 0)) {

      mkeep = (rep+1)/keep;
      
      betadraw.slice(mkeep-1) = currentbetadraw.rows(0,N-1);

      taudraw.col(mkeep-1) = currenttaudraw;

    }
  }

  return(Rcpp::List::create(
         Rcpp::Named("betadraw") = betadraw));

}
