/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.ReversibleHMCProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inference.operators.AdaptationMode;
import dr.inference.operators.GeneralOperator;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.hmc.HamiltonianMonteCarloOperator;
import dr.inference.operators.hmc.MassPreconditioner;
import dr.inference.operators.hmc.StepSize;
import dr.math.MathUtils;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.Transform;
import java.util.Arrays;

public class OldNoUTurnOperator
extends HamiltonianMonteCarloOperator
implements GeneralOperator,
GibbsOperator {
    private final int dim;
    private final Options options;
    private StepSize stepSizeInformation;
    private ReversibleHMCProvider reversibleHMCProvider;
    private boolean autoStepsize;

    public OldNoUTurnOperator(AdaptationMode adaptationMode, double d, GradientWrtParameterProvider gradientWrtParameterProvider, Parameter parameter, Transform transform, Parameter parameter2, HamiltonianMonteCarloOperator.Options options, MassPreconditioner.Type type, ReversibleHMCProvider reversibleHMCProvider) {
        super(adaptationMode, d, gradientWrtParameterProvider, parameter, transform, parameter2, options, type);
        this.dim = this.gradientProvider.getDimension();
        this.options = new Options();
        this.autoStepsize = false;
        this.reversibleHMCProvider = reversibleHMCProvider;
    }

    @Override
    protected HamiltonianMonteCarloOperator.InstabilityHandler getDefaultInstabilityHandler() {
        return HamiltonianMonteCarloOperator.InstabilityHandler.IGNORE;
    }

    @Override
    public String getOperatorName() {
        return "No-UTurn-Sampler operator";
    }

    @Override
    public double doOperation(Likelihood likelihood) {
        if (this.shouldCheckGradient()) {
            this.checkGradient(likelihood);
        }
        double[] dArray = this.reversibleHMCProvider.getInitialPosition();
        if (this.stepSizeInformation == null) {
            this.stepSizeInformation = this.findReasonableStepSize(dArray, this.stepSize);
        }
        double[] dArray2 = this.takeOneStep(this.getCount() + 1L, dArray);
        this.reversibleHMCProvider.setParameter(dArray2);
        return 0.0;
    }

    private double[] takeOneStep(long l, double[] dArray) {
        double[] dArray2 = Arrays.copyOf(dArray, dArray.length);
        WrappedVector wrappedVector = OldNoUTurnOperator.mask(this.reversibleHMCProvider.drawMomentum(), this.mask);
        double d = this.getJointProbability(this.gradientProvider, wrappedVector);
        double d2 = Math.log(MathUtils.nextDouble()) + d;
        TreeState treeState = new TreeState(dArray, wrappedVector.getBuffer(), 1, true);
        int n = 0;
        while (treeState.flagContinue) {
            double[] dArray3 = this.updateTrajectoryTree(treeState, n, d2, d);
            if (dArray3 != null) {
                dArray2 = dArray3;
            }
            if (++n <= this.options.maxHeight) continue;
            treeState.flagContinue = false;
        }
        if (this.autoStepsize) {
            this.stepSizeInformation.update(l, treeState.cumAcceptProb, treeState.numAcceptProbStates);
        }
        return dArray2;
    }

    private double[] updateTrajectoryTree(TreeState treeState, int n, double d, double d2) {
        double d3;
        double d4;
        double[] dArray = null;
        double d5 = MathUtils.nextDouble();
        int n2 = d5 < 0.5 ? -1 : 1;
        TreeState treeState2 = this.buildTree(treeState.getPosition(n2), treeState.getMomentum(n2), n2, d, n, this.stepSizeInformation.getStepSize(), d2);
        if (treeState2.flagContinue && (d4 = MathUtils.nextDouble()) < (d3 = (double)treeState2.numNodes / (double)treeState.numNodes)) {
            dArray = treeState2.getSample();
        }
        treeState.mergeNextTree(treeState2, n2);
        return dArray;
    }

    private TreeState buildTree(double[] dArray, double[] dArray2, int n, double d, int n2, double d2, double d3) {
        if (n2 == 0) {
            return this.buildBaseCase(dArray, dArray2, n, d, d2, d3);
        }
        return this.buildRecursiveCase(dArray, dArray2, n, d, n2, d2, d3);
    }

    private TreeState buildBaseCase(double[] dArray, double[] dArray2, int n, double d, double d2, double d3) {
        WrappedVector.Raw raw = new WrappedVector.Raw(Arrays.copyOf(dArray, dArray.length));
        WrappedVector.Raw raw2 = new WrappedVector.Raw(Arrays.copyOf(dArray2, dArray2.length));
        this.reversibleHMCProvider.setParameter(raw.getBuffer());
        this.reversibleHMCProvider.reversiblePositionMomentumUpdate(raw, raw2, n, d2);
        double d4 = this.getJointProbability(this.gradientProvider, raw2);
        int n2 = d <= d4 ? 1 : 0;
        boolean bl = d < this.options.logProbErrorTol + d4;
        double d5 = Math.min(1.0, Math.exp(d4 - d3));
        this.reversibleHMCProvider.setParameter(dArray);
        return new TreeState(raw.getBuffer(), raw2.getBuffer(), n2, bl, d5, 1);
    }

    private TreeState buildRecursiveCase(double[] dArray, double[] dArray2, int n, double d, int n2, double d2, double d3) {
        TreeState treeState = this.buildTree(dArray, dArray2, n, d, n2 - 1, d2, d3);
        if (treeState.flagContinue) {
            TreeState treeState2 = this.buildTree(treeState.getPosition(n), treeState.getMomentum(n), n, d, n2 - 1, this.stepSizeInformation.getStepSize(), d3);
            treeState.mergeNextTree(treeState2, n);
        }
        return treeState;
    }

    private StepSize findReasonableStepSize(double[] dArray, double d) {
        if (d != 0.0) {
            return new StepSize(d);
        }
        double d2 = 0.1;
        WrappedVector wrappedVector = this.preconditioning.drawInitialMomentum();
        int n = 1;
        WrappedVector.Raw raw = new WrappedVector.Raw(Arrays.copyOf(dArray, this.dim));
        double d3 = this.getJointProbability(this.gradientProvider, wrappedVector);
        this.reversibleHMCProvider.reversiblePositionMomentumUpdate(raw, wrappedVector, 1, d2);
        double d4 = this.getJointProbability(this.gradientProvider, wrappedVector);
        double d5 = d4 - d3 > Math.log(0.5) ? 1 : -1;
        double d6 = Math.exp(d4 - d3);
        while (Math.pow(d6, d5) > Math.pow(2.0, -d5)) {
            d3 = d4;
            this.reversibleHMCProvider.reversiblePositionMomentumUpdate(raw, wrappedVector, 1, d2);
            d4 = this.getJointProbability(this.gradientProvider, wrappedVector);
            d6 = Math.exp(d4 - d3);
            d2 = Math.pow(2.0, d5) * d2;
            if (++n <= this.options.findMax) continue;
            throw new RuntimeException("Cannot find a reasonable step-size in " + this.options.findMax + " iterations");
        }
        this.reversibleHMCProvider.setParameter(dArray);
        return new StepSize(d2);
    }

    private static boolean computeStopCriterion(boolean bl, TreeState treeState) {
        return OldNoUTurnOperator.computeStopCriterion(bl, treeState.getPosition(1), treeState.getPosition(-1), treeState.getMomentum(1), treeState.getMomentum(-1));
    }

    private static boolean computeStopCriterion(boolean bl, double[] dArray, double[] dArray2, double[] dArray3, double[] dArray4) {
        double[] dArray5 = OldNoUTurnOperator.subtractArray(dArray, dArray2);
        return bl && OldNoUTurnOperator.getDotProduct(dArray5, dArray4) >= 0.0 && OldNoUTurnOperator.getDotProduct(dArray5, dArray3) >= 0.0;
    }

    private static double getDotProduct(double[] dArray, double[] dArray2) {
        assert (dArray.length == dArray2.length);
        int n = dArray.length;
        double d = 0.0;
        for (int i = 0; i < n; ++i) {
            d += dArray[i] * dArray2[i];
        }
        return d;
    }

    private static double[] subtractArray(double[] dArray, double[] dArray2) {
        assert (dArray.length == dArray2.length);
        int n = dArray.length;
        double[] dArray3 = new double[n];
        for (int i = 0; i < n; ++i) {
            dArray3[i] = dArray[i] - dArray2[i];
        }
        return dArray3;
    }

    private double getJointProbability(GradientWrtParameterProvider gradientWrtParameterProvider, WrappedVector wrappedVector) {
        assert (gradientWrtParameterProvider != null);
        assert (wrappedVector != null);
        return gradientWrtParameterProvider.getLikelihood().getLogLikelihood() - this.reversibleHMCProvider.getKineticEnergy(wrappedVector) - this.reversibleHMCProvider.getParameterLogJacobian();
    }

    private class TreeState {
        private final double[][] position = new double[3][];
        private final double[][] momentum = new double[3][];
        private int numNodes;
        private boolean flagContinue;
        private double cumAcceptProb;
        private int numAcceptProbStates;

        private TreeState(double[] dArray, double[] dArray2, int n, boolean bl) {
            this(dArray, dArray2, n, bl, 0.0, 0);
        }

        private TreeState(double[] dArray, double[] dArray2, int n, boolean bl, double d, int n2) {
            for (int i = 0; i < 3; ++i) {
                this.position[i] = dArray;
                this.momentum[i] = dArray2;
            }
            this.numNodes = n;
            this.flagContinue = bl;
            this.cumAcceptProb = d;
            this.numAcceptProbStates = n2;
        }

        private double[] getPosition(int n) {
            return this.position[this.getIndex(n)];
        }

        private double[] getMomentum(int n) {
            return this.momentum[this.getIndex(n)];
        }

        private double[] getSample() {
            return this.position[this.getIndex(0)];
        }

        private void setPosition(int n, double[] dArray) {
            this.position[this.getIndex((int)n)] = dArray;
        }

        private void setMomentum(int n, double[] dArray) {
            this.momentum[this.getIndex((int)n)] = dArray;
        }

        private void setSample(double[] dArray) {
            this.setPosition(0, dArray);
        }

        private int getIndex(int n) {
            assert (n >= -1 && n <= 1);
            return n + 1;
        }

        private void mergeNextTree(TreeState treeState, int n) {
            this.setPosition(n, treeState.getPosition(n));
            this.setMomentum(n, treeState.getMomentum(n));
            this.updateSample(treeState);
            this.numNodes += treeState.numNodes;
            this.flagContinue = OldNoUTurnOperator.computeStopCriterion(treeState.flagContinue, this);
            this.cumAcceptProb += treeState.cumAcceptProb;
            this.numAcceptProbStates += treeState.numAcceptProbStates;
        }

        private void updateSample(TreeState treeState) {
            double d = MathUtils.nextDouble();
            if (treeState.numNodes > 0 && d < (double)treeState.numNodes / (double)(this.numNodes + treeState.numNodes)) {
                this.setSample(treeState.getSample());
            }
        }
    }

    class Options {
        private double logProbErrorTol = 100.0;
        private int findMax = 100;
        private int maxHeight = 10;

        Options() {
        }
    }
}

