/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.preorder;

import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.preorder.ConditionalVarianceAndTransform2;
import dr.evomodel.treedatalikelihood.preorder.ModelExtensionProvider;
import dr.inference.model.CompoundParameter;
import dr.math.MathUtils;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.CholeskyDecomposition;
import dr.math.matrixAlgebra.WrappedVector;
import org.ejml.data.DenseMatrix64F;

public abstract class ContinuousExtensionDelegate {
    protected final TreeTrait treeTrait;
    protected final Tree tree;
    private final ContinuousDataLikelihoodDelegate likelihoodDelegate;
    protected final ModelExtensionProvider dataModel;
    protected final int dimTrait;
    protected final int nTaxa;
    private boolean forceResample = true;

    public ContinuousExtensionDelegate(ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, ModelExtensionProvider modelExtensionProvider, TreeTrait treeTrait, Tree tree) {
        this.treeTrait = treeTrait;
        this.tree = tree;
        this.likelihoodDelegate = continuousDataLikelihoodDelegate;
        this.dataModel = modelExtensionProvider;
        this.dimTrait = modelExtensionProvider.getDataDimension();
        this.nTaxa = tree.getExternalNodeCount();
    }

    public ModelExtensionProvider getDataModel() {
        return this.dataModel;
    }

    public double[] getTreeTraits() {
        if (this.forceResample) {
            this.likelihoodDelegate.fireModelChanged();
        }
        return (double[])this.treeTrait.getTrait(this.tree, null);
    }

    public double[] getExtendedValues() {
        double[] dArray = this.getTransformedTraits();
        return this.getExtendedValues(dArray);
    }

    public double[] getTransformedTraits() {
        double[] dArray = this.getTreeTraits();
        return this.dataModel.transformTreeTraits(dArray);
    }

    public double[] getExtendedValues(double[] dArray) {
        double[] dArray2 = new double[this.nTaxa * this.dimTrait];
        CompoundParameter compoundParameter = this.dataModel.getParameter();
        boolean[] blArray = this.dataModel.getDataMissingIndicators();
        for (int i = 0; i < this.nTaxa; ++i) {
            IndexPartition indexPartition = new IndexPartition(blArray, i);
            int n = i * this.dimTrait;
            for (int n2 : indexPartition.obsInds) {
                int n3 = n2 + n;
                dArray2[n3] = compoundParameter.getParameterValue(n3);
            }
            this.sampleMissingValues(dArray2, dArray, indexPartition, i);
        }
        return dArray2;
    }

    protected abstract void sampleMissingValues(double[] var1, double[] var2, IndexPartition var3, int var4);

    public TreeTrait getTreeTrait() {
        return this.treeTrait;
    }

    public Tree getTree() {
        return this.tree;
    }

    protected class IndexPartition {
        private final int[] obsInds;
        private final int[] misInds;
        private int nMissing;
        private int nObserved;

        private IndexPartition(boolean[] blArray, int n) {
            int n2;
            int n3 = n * ContinuousExtensionDelegate.this.dimTrait;
            this.nMissing = 0;
            for (n2 = n3; n2 < n3 + ContinuousExtensionDelegate.this.dimTrait; ++n2) {
                if (!blArray[n2]) continue;
                ++this.nMissing;
            }
            this.nObserved = ContinuousExtensionDelegate.this.dimTrait - this.nMissing;
            this.misInds = new int[this.nMissing];
            this.obsInds = new int[ContinuousExtensionDelegate.this.dimTrait - this.nMissing];
            n2 = 0;
            int n4 = 0;
            for (int i = n3; i < n3 + ContinuousExtensionDelegate.this.dimTrait; ++i) {
                if (blArray[i]) {
                    this.misInds[n2] = i - n3;
                    ++n2;
                    continue;
                }
                this.obsInds[n4] = i - n3;
                ++n4;
            }
        }
    }

    public static class IndependentNormalExtensionDelegate
    extends ContinuousExtensionDelegate {
        private final ModelExtensionProvider.NormalExtensionProvider dataModel;
        private final double[] stdev;

        public IndependentNormalExtensionDelegate(ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, TreeTrait treeTrait, ModelExtensionProvider.NormalExtensionProvider normalExtensionProvider, Tree tree) {
            super(continuousDataLikelihoodDelegate, normalExtensionProvider, treeTrait, tree);
            this.dataModel = normalExtensionProvider;
            this.stdev = new double[this.dimTrait];
        }

        @Override
        public double[] getExtendedValues(double[] dArray) {
            DenseMatrix64F denseMatrix64F = this.dataModel.getExtensionVariance();
            for (int i = 0; i < this.dimTrait; ++i) {
                this.stdev[i] = Math.sqrt(denseMatrix64F.get(i, i));
            }
            return super.getExtendedValues(dArray);
        }

        @Override
        protected void sampleMissingValues(double[] dArray, double[] dArray2, IndexPartition indexPartition, int n) {
            int n2 = this.dimTrait * n;
            for (int n3 : indexPartition.misInds) {
                int n4 = n3 + n2;
                dArray[n4] = MathUtils.nextGaussian() * this.stdev[n3] + dArray2[n4];
            }
        }
    }

    public static class MultivariateNormalExtensionDelegate
    extends ContinuousExtensionDelegate {
        private final CompoundParameter dataParameter;
        private boolean choleskyKnown = false;
        private double[][] cholesky;
        private DenseMatrix64F extensionVariance;
        private final ModelExtensionProvider.NormalExtensionProvider dataModel;

        public MultivariateNormalExtensionDelegate(ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, TreeTrait treeTrait, ModelExtensionProvider.NormalExtensionProvider normalExtensionProvider, Tree tree) {
            super(continuousDataLikelihoodDelegate, normalExtensionProvider, treeTrait, tree);
            this.dataModel = normalExtensionProvider;
            this.dataParameter = normalExtensionProvider.getParameter();
        }

        @Override
        public double[] getExtendedValues(double[] dArray) {
            this.choleskyKnown = false;
            this.extensionVariance = this.dataModel.getExtensionVariance();
            return super.getExtendedValues(dArray);
        }

        @Override
        protected void sampleMissingValues(double[] dArray, double[] dArray2, IndexPartition indexPartition, int n) {
            block5: {
                int n2;
                block4: {
                    n2 = n * this.dimTrait;
                    if (indexPartition.nMissing != this.dimTrait) break block4;
                    double[] dArray3 = new double[this.dimTrait];
                    System.arraycopy(dArray2, n2, dArray3, 0, this.dimTrait);
                    if (!this.choleskyKnown) {
                        this.cholesky = CholeskyDecomposition.execute(this.extensionVariance.getData(), 0, this.dimTrait);
                        this.choleskyKnown = true;
                    }
                    double[] dArray4 = MultivariateNormalDistribution.nextMultivariateNormalCholesky(dArray3, this.cholesky);
                    for (int i = n2; i < n2 + this.dimTrait; ++i) {
                        dArray[i] = dArray4[i - n2];
                    }
                    break block5;
                }
                if (indexPartition.nMissing <= 0) break block5;
                ConditionalVarianceAndTransform2 conditionalVarianceAndTransform2 = new ConditionalVarianceAndTransform2(this.extensionVariance, indexPartition.misInds, indexPartition.obsInds);
                WrappedVector wrappedVector = conditionalVarianceAndTransform2.getConditionalMean(this.dataParameter.getParameter(n).getParameterValues(), 0, dArray2, n2);
                double[][] dArray5 = conditionalVarianceAndTransform2.getConditionalCholesky();
                double[] dArray6 = MultivariateNormalDistribution.nextMultivariateNormalCholesky(wrappedVector.getBuffer(), dArray5);
                for (int n3 : indexPartition.obsInds) {
                    dArray[n3 + n2] = this.dataParameter.getParameterValue(n3 + n2);
                }
                for (int i = 0; i < indexPartition.nMissing; ++i) {
                    dArray[((IndexPartition)indexPartition).misInds[i] + n2] = dArray6[i];
                }
            }
        }
    }

    public static class NullExtensionDelegate
    extends ContinuousExtensionDelegate {
        public NullExtensionDelegate(ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, ModelExtensionProvider modelExtensionProvider, TreeTrait treeTrait, Tree tree) {
            super(continuousDataLikelihoodDelegate, modelExtensionProvider, treeTrait, tree);
        }

        @Override
        public double[] getExtendedValues() {
            return this.getTreeTraits();
        }

        @Override
        public double[] getExtendedValues(double[] dArray) {
            return dArray;
        }

        @Override
        protected void sampleMissingValues(double[] dArray, double[] dArray2, IndexPartition indexPartition, int n) {
        }
    }
}

