#include <Rcpp.h>
using namespace Rcpp;

//
// Auxiliary functions related to hypvolgmm
//
// Written by Luca Scrucca
//

//' @title Latin Hypercube Sampling
//'
//' @description
//' Generates a Latin Hypercube Sampling (LHS) design matrix over the hypercube.
//'
//' @param n Integer. The number of samples points of the hypercube.
//' @param d Integer. The the dimension of the hypercube.
//' @param lbound Numeric vector of length \code{d} specifying the 
//'   lower bounds for each dimension of the d-dimensional hypercube.
//' @param ubound Numeric vector of length \code{d} specifying the 
//'   ubound bounds for each dimension of the d-dimensional hypercube.
//'
//' @return A \code{n x d} matrix containing the LHS design. Each 
//'   element is scaled to the range defined by \code{lbound} and 
//'   \code{ubound}.
//'
//' @references McKay, M.D., Beckman, R.J., Conover, W.J. (1979) A comparison of 
//'   three methods for selecting values of input variables in the 
//'   analysis of output from a computer code. Technometrics. 21(2), 
//'   239–-245 (reprinted in 2000: Technometrics 42(1), 55–61).
//' @references Owen, A. B. (1992b) A central limit theorem for Latin hypercube
//'   sampling. JRSS Series B 54, 541-551.
//' @references Stein, M. (1987) Large sample properties of simulations using 
//'   Latin hypercube sampling. Technometrics 29, 143-151.
//'   
//' @examples
//' x = hypcube_lhs(100, 2)
//' plot(x, xlim = c(0,1), ylim = c(0,1))
//' rug(x[,1]); rug(x[,2], side = 2)
//' x = hypcube_lhs(100, 2, lbound = c(-5,1), ubound = c(10,3))
//' plot(x, xlim = c(-5,10), ylim = c(1,3))
//' rug(x[,1]); rug(x[,2], side = 2)
//' @export
// [[Rcpp::export]]

NumericMatrix hypcube_lhs(int n, int d, 
                          NumericVector lbound = NA_REAL, 
                          NumericVector ubound = NA_REAL) 
{
  if (n < 1 || d < 1) 
    stop("n and d must be positive integers");
  if (all(is_na(lbound))) 
    lbound = rep(0.0, d);
  if (all(is_na(ubound))) 
    ubound = rep(1.0, d);
  if (lbound.size() != d || ubound.size() != d)
    stop("lbound and ubound must have length equal to d");

  NumericMatrix X(n, d);
  IntegerVector nseq = Rcpp::seq(0, n-1);
  
  for (int j = 0; j < d; j++) 
  {
    NumericVector samples = as<NumericVector>(Rcpp::sample(nseq, n, false));
    NumericVector u = runif(n);
    X(_, j) = lbound[j] + (ubound[j] - lbound[j]) * (samples + u) / n;
  }
  
  return X;
}

//' @title Simple Monte Carlo Sampling
//'
//' @description
//' This function generates a Simple Monte Carlo (SMC) random design 
//' matrix over the hypercube.
//'
//' @param n Integer. The number of samples points of the hypercube.
//' @param d Integer. The the dimension of the hypercube.
//' @param lbound Numeric vector of length \code{d} specifying the 
//'   lower bounds for each dimension of the d-dimensional hypercube.
//' @param ubound Numeric vector of length \code{d} specifying the 
//'   ubound bounds for each dimension of the d-dimensional hypercube.
//'
//' @return A \code{n x d} matrix containing the SMC design. Each 
//'   element is scaled to the range defined by \code{lbound} and 
//'   \code{ubound}.
//'
//' @examples
//' x = hypcube_smc(100, 2)
//' plot(x, xlim = c(0,1), ylim = c(0,1))
//' rug(x[,1]); rug(x[,2], side = 2)
//' x = hypcube_smc(100, 2, lbound = c(-5,1), ubound = c(10,3))
//' plot(x, xlim = c(-5,10), ylim = c(1,3))
//' rug(x[,1]); rug(x[,2], side = 2)
//'
//' @export
// [[Rcpp::export]]

NumericMatrix hypcube_smc(int n, int d, 
                          NumericVector lbound = NA_REAL, 
                          NumericVector ubound = NA_REAL) 
{
  if (n < 1 || d < 1) 
    stop("n and d must be positive integers");
  if (all(is_na(lbound))) 
    lbound = rep(0.0, d);
  if (all(is_na(ubound))) 
    ubound = rep(1.0, d);
  if (lbound.size() != d || ubound.size() != d)
    stop("lbound and ubound must have length equal to d");

  NumericMatrix X(n, d);

  for (int j = 0; j < d; j++) 
  {
    NumericVector u = runif(n);
    X(_, j) = lbound[j] + (ubound[j] - lbound[j]) * u;
  }
  
  return X;
}

// [[Rcpp::export]]
LogicalVector inside_range(NumericMatrix x, NumericMatrix r) 
{
  int n = x.nrow(), p = x.ncol();
  LogicalVector inside(n, true);
  for (int i = 0; i < n; i++) 
  {
    for (int j = 0; j < p; j++) 
    {
      if (x(i, j) < r(0, j) || x(i, j) > r(1, j)) 
      {
        inside[i] = false;
        break;
      }
    }
  }
  return inside;
}

// [[Rcpp::export]]
NumericVector stable_exp_neg_diff(NumericVector x) 
{
  int n = x.size();
  NumericVector out(n);

  // Find max(x)
  double xmax = Rcpp::max(x);

  // Compute logw = xmax - x, then shift by max(logw)
  // to prevent overflow: exp(logw - max(logw))
  double logw_max = R_NegInf;
  NumericVector logw(n);

  for (int i = 0; i < n; i++) {
    logw[i] = xmax - x[i];
    if (logw[i] > logw_max)
      logw_max = logw[i];
  }

  // Compute safe exponentials
  for (int i = 0; i < n; i++) {
    out[i] = std::exp(logw[i] - logw_max);
  }

  // Optional: rescale back if you want absolute values (may overflow)
  // double scale = std::exp(logw_max);
  // for (int i = 0; i < n; i++) out[i] *= scale;

  return out;
}
