/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.continuous.hmc;

import dr.evolution.tree.TreeTrait;
import dr.evomodel.continuous.hmc.TreePrecisionTraitProductProvider;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.preorder.WrappedNormalSufficientStatistics;
import dr.evomodel.treedatalikelihood.preorder.WrappedTipFullConditionalDistributionDelegate;
import dr.inference.model.Parameter;
import dr.math.MathUtils;
import dr.math.MaximumEigenvalue;
import dr.math.matrixAlgebra.ReadableMatrix;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.WrappedMatrix;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.TaskPool;
import java.util.List;

public class LinearOrderTreePrecisionTraitProductProvider
extends TreePrecisionTraitProductProvider {
    private final TreeTrait<List<WrappedNormalSufficientStatistics>> fullConditionalDensity;
    private static final boolean DEBUG = false;
    private static final boolean NEW_DATA = false;
    private final TaskPool taxonTaskPool;
    private final double[][] delta;
    private final double roughTimeGuess;
    private final int eigenvalueReplicates;
    private final double optimalTravelTimeScalar;
    private final MaximumEigenvalue eigenvalue;

    public LinearOrderTreePrecisionTraitProductProvider(TreeDataLikelihood treeDataLikelihood, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, String string, int n, double d, double d2, int n2) {
        super(treeDataLikelihood, continuousDataLikelihoodDelegate);
        String string2 = WrappedTipFullConditionalDistributionDelegate.getName(string);
        if (treeDataLikelihood.getTreeTrait(string2) == null) {
            continuousDataLikelihoodDelegate.addWrappedFullConditionalDensityTrait(string);
        }
        this.fullConditionalDensity = LinearOrderTreePrecisionTraitProductProvider.castTreeTrait(treeDataLikelihood.getTreeTrait(string2));
        this.delta = new double[this.tree.getExternalNodeCount()][this.dimTrait];
        this.roughTimeGuess = d;
        this.optimalTravelTimeScalar = d2;
        this.eigenvalueReplicates = n2;
        this.taxonTaskPool = new TaskPool(this.tree.getExternalNodeCount(), n);
        this.eigenvalue = new MaximumEigenvalue.PowerMethod(50, 0.01);
    }

    @Override
    public double[] getProduct(Parameter parameter) {
        if (parameter != this.dataParameter) {
            throw new IllegalArgumentException("May only compute for trait data vector");
        }
        double[] dArray = new double[parameter.getDimension()];
        if (this.taxonTaskPool.getNumThreads() == 1) {
            for (int i = 0; i < this.tree.getExternalNodeCount(); ++i) {
                List<WrappedNormalSufficientStatistics> list = this.fullConditionalDensity.getTrait(this.tree, this.tree.getExternalNode(i));
                assert (list.size() == 1);
                WrappedNormalSufficientStatistics wrappedNormalSufficientStatistics = list.get(0);
                this.computeProductForOneTaxon(i, wrappedNormalSufficientStatistics, dArray);
            }
        } else {
            List<WrappedNormalSufficientStatistics> list = this.fullConditionalDensity.getTrait(this.tree, null);
            assert (list.size() == this.tree.getExternalNodeCount());
            this.taxonTaskPool.fork((n, n2) -> this.computeProductForOneTaxon(n, (WrappedNormalSufficientStatistics)list.get(n), dArray));
        }
        return dArray;
    }

    private void computeProductForOneTaxon(int n, WrappedNormalSufficientStatistics wrappedNormalSufficientStatistics, double[] dArray) {
        WrappedVector wrappedVector = wrappedNormalSufficientStatistics.getMean();
        WrappedMatrix wrappedMatrix = wrappedNormalSufficientStatistics.getPrecision();
        double d = wrappedNormalSufficientStatistics.getPrecisionScalar();
        int n2 = n * this.dimTrait;
        LinearOrderTreePrecisionTraitProductProvider.computeDelta(n, this.delta[n], this.dataParameter, wrappedVector);
        LinearOrderTreePrecisionTraitProductProvider.computePrecisionDeltaProduct(dArray, n2, wrappedMatrix, this.delta[n], d);
    }

    private static void computeDelta(int n, double[] dArray, Parameter parameter, ReadableVector readableVector) {
        int n2 = dArray.length;
        for (int i = 0; i < n2; ++i) {
            dArray[i] = parameter.getParameterValue(n * n2 + i) - readableVector.get(i);
        }
    }

    private static void computePrecisionDeltaProduct(double[] dArray, int n, ReadableMatrix readableMatrix, double[] dArray2, double d) {
        int n2 = dArray2.length;
        for (int i = 0; i < n2; ++i) {
            double d2 = 0.0;
            for (int j = 0; j < n2; ++j) {
                d2 += readableMatrix.get(i, j) * dArray2[j];
            }
            dArray[n] = d2 * d;
            ++n;
        }
    }

    @Override
    public double[] getMassVector() {
        return null;
    }

    @Override
    public double getTimeScale() {
        if (this.roughTimeGuess > 0.0) {
            return this.roughTimeGuess;
        }
        return this.getMaxEigenvalueAsTravelTime();
    }

    @Override
    public double getTimeScaleEigen() {
        return this.eigenvalue.find(this.likelihoodDelegate.getTraitVariance());
    }

    private double getMaxEigenvalueAsTravelTime() {
        double d = this.eigenvalue.find(this.likelihoodDelegate.getTreeVariance());
        double d2 = this.eigenvalue.find(this.likelihoodDelegate.getTraitVariance());
        return this.optimalTravelTimeScalar * Math.sqrt(d * d2);
    }

    private double getRoughLowerBoundForTravelTime() {
        WrappedVector.Raw raw = new WrappedVector.Raw(this.dataParameter.getParameterValues());
        double d = 0.0;
        for (int i = 0; i < this.eigenvalueReplicates; ++i) {
            WrappedVector wrappedVector = LinearOrderTreePrecisionTraitProductProvider.drawUniformSphere(this.dataParameter.getDimension());
            ReadableVector.Utils.setParameter((ReadableVector)wrappedVector, this.dataParameter);
            WrappedVector.Raw raw2 = new WrappedVector.Raw(this.getProduct(this.dataParameter));
            d += ReadableVector.Utils.innerProduct((ReadableVector)wrappedVector, (ReadableVector)raw2);
        }
        ReadableVector.Utils.setParameter((ReadableVector)raw, this.dataParameter);
        return Math.sqrt(1.0 / (d /= (double)this.eigenvalueReplicates));
    }

    private static WrappedVector drawUniformSphere(int n) {
        double[] dArray = new double[n];
        double d = 0.0;
        for (int i = 0; i < n; ++i) {
            dArray[i] = MathUtils.nextGaussian();
            d += dArray[i] * dArray[i];
        }
        double d2 = Math.sqrt(d);
        for (int i = 0; i < n; ++i) {
            dArray[i] = dArray[i] / d2;
        }
        return new WrappedVector.Raw(dArray);
    }

    static TreeTrait<List<WrappedNormalSufficientStatistics>> castTreeTrait(TreeTrait treeTrait) {
        return treeTrait;
    }
}

