#' SVEM Significance Test with Mixture Support
#'
#' Performs a whole-model significance test using the SVEM framework and allows
#' the user to specify mixture factor groups.  Mixture factors are sets of
#' continuous variables that are constrained to sum to a constant (the
#' mixture total) and have optional lower and upper bounds.  When mixture
#' groups are supplied, the grid of evaluation points is generated by
#' sampling Dirichlet variates over the mixture simplex rather than by
#' independently sampling each continuous predictor.  Non-mixture
#' continuous predictors are sampled via a maximin Latin hypercube over
#' their observed ranges, and categorical predictors are sampled from
#' their observed levels.  The remainder of the algorithm follows
#' `svem_significance_test()`, computing standardized predictions on the grid,
#' refitting SVEM on permutations of the response, and calculating a
#' Mahalanobis distance for the original and permutation fits.
#'
#' @param formula A formula specifying the model to be tested.
#' @param data A data frame containing the variables in the model.
#' @param mixture_groups Optional list describing one or more mixture factor
#'   groups.  Each element of the list should be a list with components
#'   `vars` (character vector of column names), `lower` (numeric vector of
#'   lower bounds of the same length as `vars`), `upper` (numeric vector
#'   of upper bounds of the same length), and `total` (scalar specifying the
#'   sum of the mixture variables).  All mixture variables must be
#'   included in `vars`, and no variable can appear in more than one
#'   mixture group.  Defaults to `NULL` (no mixtures).
#' @param nPoint Number of random points in the factor space (default: 2000).
#' @param nSVEM Number of SVEM fits on the original data (default: 5).
#' @param nPerm Number of SVEM fits on permuted responses for the reference
#'   distribution (default: 125).
#' @param percent Percentage of variance to capture in the SVD (default: 85).
#' @param nBoot Number of bootstrap iterations within each SVEM fit (default: 200).
#' @param glmnet_alpha The alpha parameter(s) for glmnet (default: `c(1)`).
#' @param weight_scheme Weighting scheme for SVEM (default: "SVEM").
#' @param objective Objective function for SVEM ("wAIC" or "wSSE", default: "wAIC").
#' @param verbose Logical; if `TRUE`, displays progress messages (default: `TRUE`).
#' @param ... Additional arguments passed to `SVEMnet()` and then to `glmnet()`.
#' @return A list of class `svem_significance_test` containing the test results.
#' @details
#' This function extends `svem_significance_test()` by allowing the user to
#' specify mixture factor groups.  In a mixture group, the specified
#' variables are jointly sampled from a Dirichlet distribution so that
#' their values sum to the specified `total`.  Lower and upper bounds can
#' be supplied to shift and scale the mixture simplex.  Feasibility is
#' checked (`sum(lower) <= total <= sum(upper)`), and samples are generated
#' as `lower + (total - sum(lower)) * w` for Dirichlet weights `w`, with
#' rejection of any draws violating the upper bounds.  This guarantees the
#' correct total while respecting all bounds.
#'
#' If no mixture groups are supplied, this function behaves identically
#' to `svem_significance_test()`.
#'
#'
#' @seealso `svem_significance_test()`
#' @importFrom stats rgamma
#' @examples
#' \donttest{
#'   # Construct a small data set with a three-component mixture (A, B, C)
#'   # Each has distinct lower/upper bounds and they sum to 1
#'   set.seed(123)
#'   n <- 30
#'
#'   # Helper used only for generating training data in this example
#'   sample_trunc_dirichlet <- function(n, lower, upper, total) {
#'     k <- length(lower)
#'     min_sum <- sum(lower); max_sum <- sum(upper)
#'     stopifnot(total >= min_sum, total <= max_sum)
#'     avail <- total - min_sum
#'     out <- matrix(NA_real_, n, k)
#'     i <- 1L
#'     while (i <= n) {
#'       g <- rgamma(k, 1, 1)
#'       w <- g / sum(g)
#'       x <- lower + avail * w
#'       if (all(x <= upper + 1e-12)) {
#'         out[i, ] <- x
#'         i <- i + 1L
#'       }
#'     }
#'     out
#'   }
#'
#'   # Three mixture components with distinct bounds; sum to 1
#'   lower <- c(0.10, 0.20, 0.05)  # for A, B, C
#'   upper <- c(0.60, 0.70, 0.50)
#'   total <- 1.0
#'   ABC   <- sample_trunc_dirichlet(n, lower, upper, total)
#'   A <- ABC[, 1]; B <- ABC[, 2]; C <- ABC[, 3]
#'
#'   # Additional predictors
#'   X <- runif(n)
#'   F <- factor(sample(c("red", "blue"), n, replace = TRUE))
#'
#'   # Response
#'   y <- 2 + 3*A + 1.5*B + 1.2*C + 0.5*X + 1*(F == "red") + rnorm(n, sd = 0.3)
#'   dat <- data.frame(y = y, A = A, B = B, C = C, X = X, F = F)
#'
#'   # Specify the mixture group for A, B, C
#'   mix_spec <- list(
#'     list(
#'       vars  = c("A", "B", "C"),
#'       lower = c(0.10, 0.20, 0.05),
#'       upper = c(0.60, 0.70, 0.50),
#'       total = 1.0
#'     )
#'   )
#'
#'   # Run the whole-model significance test on this mixture model
#'   test_res <- svem_significance_test(
#'     y ~ A + B + C + X + F,
#'     data           = dat,
#'     mixture_groups = mix_spec,
#'     nPoint         = 200,
#'     nSVEM          = 3,
#'     nPerm          = 50,
#'     nBoot          = 100,
#'     glmnet_alpha   = c(1),
#'     weight_scheme  = "SVEM",
#'     objective      = "wAIC",
#'     verbose        = FALSE
#'   )
#'
#'   print(test_res)
#'   plot(test_res)
#' }
#' @export
svem_significance_test <- function(formula, data, mixture_groups = NULL,
                                                nPoint = 2000, nSVEM = 5, nPerm = 125,
                                                percent = 85, nBoot = 200,
                                                glmnet_alpha = c(1),
                                                weight_scheme = c("SVEM"),
                                                objective = c("wAIC", "wSSE"),
                                                verbose = TRUE, ...) {
  objective <- match.arg(objective)
  weight_scheme <- match.arg(weight_scheme)
  data <- as.data.frame(data)

  mf <- model.frame(formula, data)
  y <- model.response(mf)
  X <- model.matrix(formula, data)
  intercept_col <- which(colnames(X) == "(Intercept)")
  if (length(intercept_col) > 0) X <- X[, -intercept_col, drop = FALSE]

  predictor_vars <- all.vars(delete.response(terms(formula, data = data)))
  predictor_types <- sapply(data[predictor_vars], class)
  continuous_vars <- predictor_vars[!predictor_types %in% c("factor", "character")]
  categorical_vars <- predictor_vars[predictor_types %in% c("factor", "character")]

  mixture_vars <- character(0)
  if (!is.null(mixture_groups)) {
    for (grp in mixture_groups) mixture_vars <- c(mixture_vars, grp$vars)
    if (any(duplicated(mixture_vars))) {
      dups <- unique(mixture_vars[duplicated(mixture_vars)])
      stop("Mixture variables appear in multiple groups: ", paste(dups, collapse = ", "))
    }
  }
  nonmix_continuous_vars <- setdiff(continuous_vars, mixture_vars)

  if (length(nonmix_continuous_vars) > 0) {
    ranges <- sapply(data[nonmix_continuous_vars], function(col) range(col, na.rm = TRUE))
    T_continuous_raw <- as.matrix(lhs::maximinLHS(nPoint, length(nonmix_continuous_vars)))
    T_continuous <- matrix(NA_real_, nrow = nPoint, ncol = length(nonmix_continuous_vars))
    colnames(T_continuous) <- nonmix_continuous_vars
    for (i in seq_along(nonmix_continuous_vars)) {
      T_continuous[, i] <- T_continuous_raw[, i] * (ranges[2, i] - ranges[1, i]) + ranges[1, i]
    }
    T_continuous <- as.data.frame(T_continuous)
  } else {
    T_continuous <- NULL
  }

  .sample_trunc_dirichlet <- function(n, lower, upper, total,
                                      alpha = NULL, oversample = 4L, max_tries = 10000L) {
    k <- length(lower)
    if (length(upper) != k) stop("upper must have the same length as lower.")
    if (is.null(alpha)) alpha <- rep(1, k)

    min_sum <- sum(lower); max_sum <- sum(upper)
    if (total < min_sum - 1e-12 || total > max_sum + 1e-12) {
      stop("Infeasible mixture constraints: need sum(lower) <= total <= sum(upper).")
    }

    avail <- total - min_sum
    if (avail <= 1e-12) {
      return(matrix(rep(lower, each = n), nrow = n))
    }

    res <- matrix(NA_real_, nrow = n, ncol = k)
    filled <- 0L; tries <- 0L

    while (filled < n && tries < max_tries) {
      m <- max(oversample * (n - filled), 1L)
      g <- matrix(stats::rgamma(m * k, shape = alpha, rate = 1), ncol = k, byrow = TRUE)
      W <- g / rowSums(g)
      cand <- matrix(lower, nrow = m, ncol = k, byrow = TRUE) + avail * W
      ok <- cand <= matrix(upper, nrow = m, ncol = k, byrow = TRUE)
      ok <- rowSums(ok) == k
      if (any(ok)) {
        keep <- which(ok)
        take <- min(length(keep), n - filled)
        res[(filled + 1):(filled + take), ] <- cand[keep[seq_len(take)], , drop = FALSE]
        filled <- filled + take
      }
      tries <- tries + 1L
    }

    if (filled < n) {
      stop("Could not sample enough feasible mixture points within max_tries. ",
           "Try relaxing upper bounds or increasing 'oversample'/'max_tries'.")
    }
    res
  }

  T_mixture <- NULL
  if (!is.null(mixture_groups)) {
    mix_all_vars <- unlist(lapply(mixture_groups, `[[`, "vars"))
    T_mixture <- matrix(NA_real_, nrow = nPoint, ncol = length(mix_all_vars))
    colnames(T_mixture) <- mix_all_vars

    for (grp in mixture_groups) {
      vars  <- grp$vars
      k     <- length(vars)
      lower <- if (!is.null(grp$lower)) grp$lower else rep(0, k)
      upper <- if (!is.null(grp$upper)) grp$upper else rep(1, k)
      total <- if (!is.null(grp$total)) grp$total else 1

      if (length(lower) != k || length(upper) != k) {
        stop("lower and upper must each have length equal to the number of mixture variables (",
             paste(vars, collapse = ","), ").")
      }

      vals <- .sample_trunc_dirichlet(nPoint, lower, upper, total)
      colnames(vals) <- vars
      T_mixture[, vars] <- vals
    }
    T_mixture <- as.data.frame(T_mixture)
  }

  # Categorical sampling (use observed levels; keep training levels attribute for factors)
  T_categorical <- NULL
  if (length(categorical_vars) > 0) {
    T_categorical <- vector("list", length(categorical_vars))
    names(T_categorical) <- categorical_vars
    for (v in categorical_vars) {
      x <- data[[v]]
      if (is.factor(x)) {
        obs_lev <- levels(base::droplevels(x))  # <-- FIX: use base::droplevels
        T_categorical[[v]] <- factor(
          sample(obs_lev, nPoint, replace = TRUE),
          levels = levels(x)                     # keep original full level set
        )
      } else {
        obs_lev <- sort(unique(as.character(x)))
        T_categorical[[v]] <- factor(
          sample(obs_lev, nPoint, replace = TRUE),
          levels = obs_lev
        )
      }
    }
    T_categorical <- as.data.frame(T_categorical, stringsAsFactors = FALSE)
  }


  parts <- list(T_continuous, T_mixture, T_categorical)
  parts <- parts[!vapply(parts, is.null, logical(1))]
  if (length(parts) == 0) stop("No predictors provided.")
  T_data <- do.call(cbind, parts)

  y_mean <- mean(y)
  M_Y <- matrix(NA_real_, nrow = nSVEM, ncol = nPoint)
  if (verbose) message("Fitting SVEM models to original data with mixture handling...")
  for (i in seq_len(nSVEM)) {
    svem_model <- tryCatch({
      SVEMnet(formula, data = data, nBoot = nBoot, glmnet_alpha = glmnet_alpha,
              weight_scheme = weight_scheme, objective = objective, ...)
    }, error = function(e) {
      message("Error in SVEMnet during SVEM fitting: ", e$message)
      NULL
    })
    if (is.null(svem_model)) next
    pred_res <- predict(svem_model, newdata = T_data, debias = FALSE, se.fit = TRUE)
    f_hat_Y_T <- pred_res$fit
    s_hat_Y_T <- pred_res$se.fit
    s_hat_Y_T[s_hat_Y_T == 0] <- 1e-6
    h_Y <- (f_hat_Y_T - y_mean) / s_hat_Y_T
    M_Y[i, ] <- h_Y
  }

  M_pi_Y <- matrix(NA_real_, nrow = nPerm, ncol = nPoint)
  if (verbose) message("Starting permutation testing...")
  start_time_perm <- Sys.time()
  for (j in seq_len(nPerm)) {
    y_perm <- sample(y, replace = FALSE)
    data_perm <- data
    data_perm[[as.character(formula[[2]])]] <- y_perm
    svem_model_perm <- tryCatch({
      SVEMnet(formula, data = data_perm, nBoot = nBoot, glmnet_alpha = glmnet_alpha,
              weight_scheme = weight_scheme, objective = objective, ...)
    }, error = function(e) {
      message("Error in SVEMnet during permutation fitting: ", e$message)
      NULL
    })
    if (is.null(svem_model_perm)) next
    pred_res <- predict(svem_model_perm, newdata = T_data, debias = FALSE, se.fit = TRUE)
    f_hat_piY_T <- pred_res$fit
    s_hat_piY_T <- pred_res$se.fit
    s_hat_piY_T[s_hat_piY_T == 0] <- 1e-6
    h_piY <- (f_hat_piY_T - y_mean) / s_hat_piY_T
    M_pi_Y[j, ] <- h_piY

    if (verbose && (j %% 10 == 0 || j == nPerm)) {
      elapsed_time <- Sys.time() - start_time_perm
      elapsed_secs <- as.numeric(elapsed_time, units = "secs")
      estimated_total_secs <- (elapsed_secs / j) * nPerm
      remaining_secs <- estimated_total_secs - elapsed_secs
      remaining_time_formatted <- sprintf("%02d:%02d:%02d",
                                          floor(remaining_secs / 3600),
                                          floor((remaining_secs %% 3600) / 60),
                                          floor(remaining_secs %% 60))
      message(sprintf("Permutation %d/%d completed. Estimated time remaining: %s",
                      j, nPerm, remaining_time_formatted))
    }
  }

  M_Y <- M_Y[complete.cases(M_Y), , drop = FALSE]
  M_pi_Y <- M_pi_Y[complete.cases(M_pi_Y), , drop = FALSE]
  if (nrow(M_Y) == 0) stop("All SVEM fits on the original data failed.")
  if (nrow(M_pi_Y) == 0) stop("All SVEM fits on permuted data failed.")

  col_means_M_pi_Y <- colMeans(M_pi_Y)
  col_sds_M_pi_Y <- apply(M_pi_Y, 2, sd)
  col_sds_M_pi_Y[col_sds_M_pi_Y == 0] <- 1e-6

  tilde_M_pi_Y <- scale(M_pi_Y, center = col_means_M_pi_Y, scale = col_sds_M_pi_Y)
  M_Y_centered <- sweep(M_Y, 2, col_means_M_pi_Y, "-")
  tilde_M_Y <- sweep(M_Y_centered, 2, col_sds_M_pi_Y, "/")

  svd_res <- svd(tilde_M_pi_Y)
  U <- svd_res$u; s <- svd_res$d; V <- svd_res$v
  evalues_temp <- s^2
  evalues_temp <- evalues_temp / sum(evalues_temp) * ncol(tilde_M_pi_Y)
  cumsum_evalues <- cumsum(evalues_temp) / sum(evalues_temp) * 100
  k_idx <- which(cumsum_evalues >= percent)[1]
  if (is.na(k_idx)) k_idx <- length(evalues_temp)
  evalues <- evalues_temp[1:k_idx]
  evectors <- V[, 1:k_idx]

  T2_perm <- rowSums((tilde_M_pi_Y %*% evectors %*% diag(1 / evalues)) * (tilde_M_pi_Y %*% evectors))
  d_pi_Y <- sqrt(T2_perm)

  T2_Y <- rowSums((tilde_M_Y %*% evectors %*% diag(1 / evalues)) * (tilde_M_Y %*% evectors))
  d_Y <- sqrt(T2_Y)

  if (length(d_pi_Y) == 0) stop("No valid permutation distances to fit a distribution.")

  suppressMessages({
    distribution_fit <- tryCatch({
      gamlss::gamlss(
        d_pi_Y ~ 1,
        family = gamlss.dist::SHASHo(mu.link = "identity", sigma.link = "log",
                                     nu.link = "identity", tau.link = "log"),
        control = gamlss::gamlss.control(n.cyc = 1000, trace = FALSE)
      )
    }, error = function(e) {
      message("Error in fitting SHASHo distribution: ", e$message)
      NULL
    })
  })
  if (is.null(distribution_fit)) stop("Failed to fit SHASHo distribution.")

  mu <- as.numeric(coef(distribution_fit, what = "mu"))
  sigma <- exp(as.numeric(coef(distribution_fit, what = "sigma")))
  nu <- as.numeric(coef(distribution_fit, what = "nu"))
  tau <- exp(as.numeric(coef(distribution_fit, what = "tau")))

  p_values <- 1 - gamlss.dist::pSHASHo(d_Y, mu = mu, sigma = sigma, nu = nu, tau = tau)
  p_value <- median(p_values)

  response_name <- as.character(formula[[2]])
  data_d <- data.frame(
    D = c(d_Y, d_pi_Y),
    Source_Type = c(rep("Original", length(d_Y)), rep("Permutation", length(d_pi_Y))),
    Response = response_name
  )

  results_list <- list(
    p_value = p_value,
    p_values = p_values,
    d_Y = d_Y,
    d_pi_Y = d_pi_Y,
    distribution_fit = distribution_fit,
    data_d = data_d
  )
  class(results_list) <- "svem_significance_test"
  return(results_list)
}

