% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/bijectors.R
\name{tfb_ffjord}
\alias{tfb_ffjord}
\title{Implements a continuous normalizing flow X->Y defined via an ODE.}
\usage{
tfb_ffjord(state_time_derivative_fn, ode_solve_fn = NULL,
  trace_augmentation_fn = tfp$bijectors$ffjord$trace_jacobian_hutchinson,
  initial_time = 0, final_time = 1, validate_args = FALSE,
  dtype = tf$float32, name = "ffjord")
}
\arguments{
\item{state_time_derivative_fn}{\code{function} taking arguments \code{time}
(a scalar representing time) and \code{state} (a Tensor representing the
state at given \code{time}) returning the time derivative of the \code{state} at
given \code{time}.}

\item{ode_solve_fn}{\code{function} taking arguments \code{ode_fn} (same as
\code{state_time_derivative_fn} above), \code{initial_time} (a scalar representing
the initial time of integration), \code{initial_state} (a Tensor of floating
dtype represents the initial state) and \code{solution_times} (1D Tensor of
floating dtype representing time at which to obtain the solution)
returning a Tensor of shape \code{[time_axis, initial_state$shape]}. Will take
\code{[final_time]} as the \code{solution_times} argument and
\code{state_time_derivative_fn} as \code{ode_fn} argument.
If \code{NULL} a DormandPrince solver from \code{tfp$math$ode} is used.
Default value: NULL}

\item{trace_augmentation_fn}{\code{function} taking arguments \code{ode_fn} (
\code{function} same as \code{state_time_derivative_fn} above),
\code{state_shape} (TensorShape of a the state), \code{dtype} (same as dtype of
the state) and returning a \code{function} taking arguments \code{time}
(a scalar representing the time at which the function is evaluted),
\code{state} (a Tensor representing the state at given \code{time}) that computes
a tuple (\code{ode_fn(time, state)}, \code{jacobian_trace_estimation}).
\code{jacobian_trace_estimation} should represent trace of the jacobian of
\code{ode_fn} with respect to \code{state}. \code{state_time_derivative_fn} will be
passed as \code{ode_fn} argument.
Default value: tfp$bijectors$ffjord$trace_jacobian_hutchinson}

\item{initial_time}{Scalar float representing time to which the \code{x} value of the
bijector corresponds to. Passed as \code{initial_time} to \code{ode_solve_fn}.
For default solver can be \code{float} or floating scalar \code{Tensor}.
Default value: 0.}

\item{final_time}{Scalar float representing time to which the \code{y} value of the
bijector corresponds to. Passed as \code{solution_times} to \code{ode_solve_fn}.
For default solver can be \code{float} or floating scalar \code{Tensor}.
Default value: 1.}

\item{validate_args}{Logical, default FALSE. Whether to validate input with asserts. If validate_args is
FALSE, and the inputs are invalid, correct behavior is not guaranteed.}

\item{dtype}{\code{tf$DType} to prefer when converting args to \code{Tensor}s. Else, we
fall back to a common dtype inferred from the args, finally falling
back to float32.}

\item{name}{name prefixed to Ops created by this class.}
}
\value{
a bijector instance.
}
\description{
This bijector implements a continuous dynamics transformation
parameterized by a differential equation, where initial and terminal
conditions correspond to domain (X) and image (Y) i.e.
}
\details{
\preformatted{d/dt[state(t)] = state_time_derivative_fn(t, state(t))
state(initial_time) = X
state(final_time) = Y
}

For this transformation the value of \code{log_det_jacobian} follows another
differential equation, reducing it to computation of the trace of the jacobian
along the trajectory\preformatted{state_time_derivative = state_time_derivative_fn(t, state(t))
d/dt[log_det_jac(t)] = Tr(jacobian(state_time_derivative, state(t)))
}

FFJORD constructor takes two functions \code{ode_solve_fn} and
\code{trace_augmentation_fn} arguments that customize integration of the
differential equation and trace estimation.

Differential equation integration is performed by a call to \code{ode_solve_fn}.

Custom \code{ode_solve_fn} must accept the following arguments:
\itemize{
\item ode_fn(time, state): Differential equation to be solved.
\item initial_time: Scalar float or floating Tensor representing the initial time.
\item initial_state: Floating Tensor representing the initial state.
\item solution_times: 1D floating Tensor of solution times.
}

And return a Tensor of shape \code{[solution_times$shape, initial_state$shape]}
representing state values evaluated at \code{solution_times}. In addition
\code{ode_solve_fn} must support nested structures. For more details see the
interface of \code{tfp$math$ode$Solver$solve()}.

Trace estimation is computed simultaneously with \code{state_time_derivative}
using \code{augmented_state_time_derivative_fn} that is generated by
\code{trace_augmentation_fn}. \code{trace_augmentation_fn} takes
\code{state_time_derivative_fn}, \code{state.shape} and \code{state.dtype} arguments and
returns a \code{augmented_state_time_derivative_fn} callable that computes both
\code{state_time_derivative} and unreduced \code{trace_estimation}.

Custom \code{ode_solve_fn} and \code{trace_augmentation_fn} examples:\preformatted{# custom_solver_fn: `function(f, t_initial, t_solutions, y_initial, ...)`
# ... : Additional arguments to pass to custom_solver_fn.
ode_solve_fn <- function(ode_fn, initial_time, initial_state, solution_times) {
  custom_solver_fn(ode_fn, initial_time, solution_times, initial_state, ...)
}
ffjord <- tfb_ffjord(state_time_derivative_fn, ode_solve_fn = ode_solve_fn)
}\preformatted{# state_time_derivative_fn: `function(time, state)`
# trace_jac_fn: `function(time, state)` unreduced jacobian trace function
trace_augmentation_fn <- function(ode_fn, state_shape, state_dtype) {
  augmented_ode_fn <- function(time, state) {
    list(ode_fn(time, state), trace_jac_fn(time, state))
  }
augmented_ode_fn
}
ffjord <- tfb_ffjord(state_time_derivative_fn, trace_augmentation_fn = trace_augmentation_fn)
}

For more details on FFJORD and continous normalizing flows see Chen et al. (2018), Grathwol et al. (2018).
}
\section{References}{

\itemize{
\item Chen, T. Q., Rubanova, Y., Bettencourt, J., & Duvenaud, D. K. (2018). Neural ordinary differential equations. In Advances in neural information processing systems (pp. 6571-6583)
\item \href{http://arxiv.org/abs/1810.01367}{Grathwohl, W., Chen, R. T., Betterncourt, J., Sutskever, I., & Duvenaud, D. (2018). Ffjord: Free-form continuous dynamics for scalable reversible generative models. arXiv preprint arXiv:1810.01367.}
}
}

\seealso{
For usage examples see \code{\link[=tfb_forward]{tfb_forward()}}, \code{\link[=tfb_inverse]{tfb_inverse()}}, \code{\link[=tfb_inverse_log_det_jacobian]{tfb_inverse_log_det_jacobian()}}.

Other bijectors: \code{\link{tfb_absolute_value}},
  \code{\link{tfb_affine_linear_operator}},
  \code{\link{tfb_affine_scalar}},
  \code{\link{tfb_affine}},
  \code{\link{tfb_batch_normalization}},
  \code{\link{tfb_blockwise}}, \code{\link{tfb_chain}},
  \code{\link{tfb_cholesky_outer_product}},
  \code{\link{tfb_cholesky_to_inv_cholesky}},
  \code{\link{tfb_correlation_cholesky}},
  \code{\link{tfb_cumsum}},
  \code{\link{tfb_discrete_cosine_transform}},
  \code{\link{tfb_expm1}}, \code{\link{tfb_exp}},
  \code{\link{tfb_fill_scale_tri_l}},
  \code{\link{tfb_fill_triangular}},
  \code{\link{tfb_gumbel_cdf}}, \code{\link{tfb_gumbel}},
  \code{\link{tfb_identity}}, \code{\link{tfb_inline}},
  \code{\link{tfb_invert}},
  \code{\link{tfb_iterated_sigmoid_centered}},
  \code{\link{tfb_kumaraswamy_cdf}},
  \code{\link{tfb_kumaraswamy}},
  \code{\link{tfb_masked_autoregressive_default_template}},
  \code{\link{tfb_masked_autoregressive_flow}},
  \code{\link{tfb_masked_dense}},
  \code{\link{tfb_matrix_inverse_tri_l}},
  \code{\link{tfb_matvec_lu}},
  \code{\link{tfb_normal_cdf}}, \code{\link{tfb_ordered}},
  \code{\link{tfb_pad}}, \code{\link{tfb_permute}},
  \code{\link{tfb_power_transform}},
  \code{\link{tfb_rational_quadratic_spline}},
  \code{\link{tfb_real_nvp_default_template}},
  \code{\link{tfb_real_nvp}}, \code{\link{tfb_reciprocal}},
  \code{\link{tfb_reshape}},
  \code{\link{tfb_scale_matvec_diag}},
  \code{\link{tfb_scale_matvec_linear_operator}},
  \code{\link{tfb_scale_matvec_lu}},
  \code{\link{tfb_scale_matvec_tri_l}},
  \code{\link{tfb_scale_tri_l}}, \code{\link{tfb_scale}},
  \code{\link{tfb_shift}}, \code{\link{tfb_sigmoid}},
  \code{\link{tfb_sinh_arcsinh}},
  \code{\link{tfb_softmax_centered}},
  \code{\link{tfb_softplus}}, \code{\link{tfb_softsign}},
  \code{\link{tfb_square}}, \code{\link{tfb_tanh}},
  \code{\link{tfb_transform_diagonal}},
  \code{\link{tfb_transpose}},
  \code{\link{tfb_weibull_cdf}}, \code{\link{tfb_weibull}}
}
\concept{bijectors}
