# ------------------------------------------------------------------------------
# Generated by 'pre-generate/generate-steps.R': do not edit by hand.
# ------------------------------------------------------------------------------

#' @title Tent Template Functions Vectorization of Persistent Homology
#' 
#' @description The function `step_vpd_tent_template_functions()` creates
#'   a _specification_ of a recipe step that will convert
#'   a list-column of 3-column matrices of persistence data
#'   to a list-column of 1-row matrices of vectorizations.
#' 

#' 
#' @template step-vpd-details
#' 
#' @section Engine:
#' 
#' The tent template functions vectorization deploys
#' [TDAvec::computeTemplateFunction()].
#' See there for definitions and references.
#' 
#' @section Tuning Parameters:
#' 
#' This step has 4 tuning parameters:
#' \itemize{
#'   \item `hom_degree`: Homological degree (type: integer, default: `0L`)
#'   \item `tent_size`: Discretization grid increment (type: double, default: `NULL`)
#'   \item `num_bins`: Discretization grid bins (type: integer, default: `10L`)
#'   \item `tent_shift`: Discretization grid shift (type: double, default: `NULL`)
#' }
#' 
#' @param hom_degree
#'   The homological degree of the features to be transformed.
#' @param tent_size
#'   The length of the increment used to discretize tent template functions.
#' @param num_bins
#'   The number of bins along each axis in the discretization grid.
#' @param tent_shift
#'   The vertical shift applied to the discretization grid.

#' @import recipes
#' @inheritParams recipes::step_pca
#' @inherit recipes::step_pca return
#' @example inst/examples/zzz-ex-step-vpd-tent-template-functions.R

#' @export
step_vpd_tent_template_functions <- function(
    recipe,
    ...,
    role = "predictor",
    trained = FALSE,
    hom_degree = 0L,
    tent_size = NULL,
    num_bins = 10L,
    tent_shift = NULL,
    columns = NULL,
    keep_original_cols = TRUE,
    skip = FALSE,
    id = rand_id("vpd_tent_template_functions")
) {
  recipes_pkg_check(required_pkgs.step_vpd_tent_template_functions())
  
  add_step(
    recipe,
    step_vpd_tent_template_functions_new(
      terms = rlang::enquos(...),
      trained = trained,
      role = role,
      hom_degree = hom_degree,
      tent_size = tent_size,
      num_bins = num_bins,
      tent_shift = tent_shift,
      columns = columns,
      keep_original_cols = keep_original_cols,
      skip = skip,
      id = id
    )
  )
}

step_vpd_tent_template_functions_new <- function(
    terms,
    role, trained,
    hom_degree,
    tent_size,
    num_bins,
    tent_shift,
    columns, keep_original_cols,
    skip, id
) {
  step(
    subclass = "vpd_tent_template_functions",
    terms = terms,
    role = role,
    trained = trained,
    hom_degree = hom_degree,
    tent_size = tent_size,
    num_bins = num_bins,
    tent_shift = tent_shift,
    columns = columns,
    keep_original_cols = keep_original_cols,
    skip = skip,
    id = id
  )
}

#' @export
prep.step_vpd_tent_template_functions <- function(x, training, info = NULL, ...) {
  col_names <- recipes_eval_select(x$terms, training, info)
  check_pd_list(training[, col_names, drop = FALSE])
  for (col_name in col_names) class(training[[col_name]]) <- "list"
  
  
  if (is.null(x$tent_shift) | is.null(x$tent_size)) {
      x_pers_ranges <- sapply(training[, col_names, drop = FALSE], 
          function(l) {
              val <- sapply(l, pers_range, hom_degree = x$hom_degree, 
                simplify = TRUE)
              range(val[is.finite(val)])
          }, simplify = TRUE)
  }
  if (is.null(x$tent_size)) {
      x_birth_ranges <- sapply(training[, col_names, drop = FALSE], 
          function(l) {
              val <- sapply(l, birth_range, hom_degree = x$hom_degree, 
                simplify = TRUE)
              range(val[is.finite(val)])
          }, simplify = TRUE)
  }
  if (is.null(x$tent_shift)) 
      x$tent_shift <- x_pers_ranges[1L, ]/2
  if (is.null(x$tent_size)) 
      x$tent_size <- pmax(x_birth_ranges[2L, ], x_pers_ranges[2L, 
          ] - x$tent_shift)/x$num_bins

  step_vpd_tent_template_functions_new(
    terms = col_names,
    role = x$role,
    trained = TRUE,
    hom_degree = x$hom_degree,
    tent_size = x$tent_size,
    num_bins = x$num_bins,
    tent_shift = x$tent_shift,
    columns = col_names,
    keep_original_cols = get_keep_original_cols(x),
    skip = x$skip,
    id = x$id
  )
}

#' @export
bake.step_vpd_tent_template_functions <- function(object, new_data, ...) {
  col_names <- names(object$columns)
  check_new_data(col_names, object, new_data)
  for (col_name in col_names) class(new_data[[col_name]]) <- "list"
  
  if (nrow(new_data) == 0L || length(col_names) == 0L) return(new_data)
  
  vph_data <- tibble::tibble(.rows = nrow(new_data))
  for (col_name in col_names) {
    col_vpd <- purrr::map(
      new_data[[col_name]],
      function(d) {
        v <- TDAvec::computeTemplateFunction(
          as.matrix(d),
          homDim = object$hom_degree,
          delta = object$tent_size,
          d = object$num_bins,
          epsilon = object$tent_shift
        )
        vn <- vpd_suffix(v)
        v <- as.vector(v)
        names(v) <- vn
        v
      }
    )
    col_vpd <- purrr::map(
      col_vpd,
      function(v) as.data.frame(matrix(
        v, nrow = 1L, dimnames = list(NULL, names(v))
      ))
    )
    vph_data[[paste(col_name, "tf", sep = "_")]] <- col_vpd
  }
  vph_col_names <- if (length(col_names) == 0L) col_names else
    paste(col_names, "tf", sep = "_")
  vph_data <- tidyr::unnest(
    vph_data,
    cols = tidyr::all_of(vph_col_names),
    names_sep = "_"
  )
  
  check_name(vph_data, new_data, object)
  new_data <- vctrs::vec_cbind(new_data, vph_data)
  new_data <- remove_original_cols(new_data, object, col_names)
  new_data
}

#' @export
print.step_vpd_tent_template_functions <- function(
    x, width = max(20, options()$width - 35), ...
) {
  title <- "tent template functions of "
  
  print_step(
    untr_obj = x$terms,
    tr_obj = NULL,
    trained = x$trained,
    title = title,
    width = width
  )
  invisible(x)
}

#' @rdname required_pkgs.tdarec
#' @export
required_pkgs.step_vpd_tent_template_functions <- function(x, ...) {
  c("TDAvec", "tdarec")
}

#' @rdname step_vpd_tent_template_functions
#' @usage NULL
#' @export
tidy.step_vpd_tent_template_functions <- function(x, ...) {
  if (is_trained(x)) {
    res <- tibble::tibble(
      terms = unname(x$columns),
      value = rep(NA_real_, length(x$columns))
    )
  } else {
    term_names <- sel2char(x$terms)
    res <- tibble::tibble(
      terms = term_names,
      value = rep(NA_real_, length(term_names))
    )
  }
  res$id <- x$id
  res
}

#' @rdname tunable_tdavec
#' @export
tunable.step_vpd_tent_template_functions <- function(x, ...) {
  tibble::tibble(
    name = c("hom_degree", "tent_size", "num_bins", "tent_shift"),
    call_info = list(
      list(pkg = "tdarec", fun = "hom_degree", range = c(0L, unknown())),
      list(pkg = "tdarec", fun = "tent_size", range = c(unknown(), unknown())),
      list(pkg = "tdarec", fun = "num_bins", range = c(2L, 20L)),
      list(pkg = "tdarec", fun = "tent_shift", range = c(unknown(), unknown()))
    ),
    source = "recipe",
    component = "step_vpd_tent_template_functions",
    component_id = x$id
  )
}

