/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.continuous;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.StrictClockBranchRates;
import dr.evomodel.continuous.MultivariateDiffusionModel;
import dr.evomodel.treedatalikelihood.continuous.AbstractDiffusionModelDelegate;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator;
import dr.inference.model.Model;
import dr.math.KroneckerOperation;
import java.util.List;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

public abstract class AbstractDriftDiffusionModelDelegate
extends AbstractDiffusionModelDelegate {
    private final List<BranchRateModel> branchRateModels;

    AbstractDriftDiffusionModelDelegate(Tree tree, MultivariateDiffusionModel multivariateDiffusionModel, List<BranchRateModel> list, int n) {
        super(tree, multivariateDiffusionModel, n);
        this.branchRateModels = list;
        if (list != null) {
            for (BranchRateModel branchRateModel : list) {
                this.addModel(branchRateModel);
            }
            if (list.size() != this.dim) {
                throw new IllegalArgumentException("Invalid dimensions");
            }
        }
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        if (this.branchRateModels.contains(model)) {
            this.fireModelChanged(model);
        } else {
            super.handleModelChangedEvent(model, object, n);
        }
    }

    @Override
    public boolean hasDrift() {
        return true;
    }

    @Override
    protected double[] getDriftRates(int[] nArray, int n) {
        double[] dArray = new double[n * this.dim];
        if (this.branchRateModels != null) {
            int n2 = 0;
            for (int i = 0; i < n; ++i) {
                NodeRef nodeRef = this.tree.getNode(nArray[i]);
                for (int j = 0; j < this.dim; ++j) {
                    dArray[n2] = this.branchRateModels.get(j).getBranchRate(this.tree, nodeRef);
                    ++n2;
                }
            }
        }
        return dArray;
    }

    double[] getDriftRate(NodeRef nodeRef) {
        double[] dArray = new double[this.dim];
        if (this.branchRateModels != null) {
            for (int i = 0; i < this.dim; ++i) {
                dArray[i] = this.branchRateModels.get(i).getBranchRate(this.tree, nodeRef);
            }
        }
        return dArray;
    }

    public boolean isConstantDrift() {
        if (this.branchRateModels == null) {
            return false;
        }
        for (int i = 0; i < this.dim; ++i) {
            if (this.branchRateModels.get(i) instanceof StrictClockBranchRates) continue;
            return false;
        }
        return true;
    }

    DenseMatrix64F getGradientDisplacementWrtDrift(NodeRef nodeRef, ContinuousDiffusionIntegrator continuousDiffusionIntegrator, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, DenseMatrix64F denseMatrix64F) {
        return this.scaleGradient(nodeRef, continuousDiffusionIntegrator, continuousDataLikelihoodDelegate, denseMatrix64F);
    }

    @Override
    public double[] getAccumulativeDrift(NodeRef nodeRef, double[] dArray, ContinuousDiffusionIntegrator continuousDiffusionIntegrator, int n) {
        double[] dArray2 = new double[n];
        System.arraycopy(dArray, 0, dArray2, 0, dArray.length);
        double[] dArray3 = new double[n];
        double[] dArray4 = null;
        if (this.hasActualization()) {
            dArray4 = this.hasDiagonalActualization() ? new double[n] : new double[n * n];
        }
        this.recursivelyAccumulateDrift(nodeRef, dArray2, continuousDiffusionIntegrator, dArray3, dArray4, n);
        return dArray2;
    }

    private void recursivelyAccumulateDrift(NodeRef nodeRef, double[] dArray, ContinuousDiffusionIntegrator continuousDiffusionIntegrator, double[] dArray2, double[] dArray3, int n) {
        if (!this.tree.isRoot(nodeRef)) {
            this.recursivelyAccumulateDrift(this.tree.getParent(nodeRef), dArray, continuousDiffusionIntegrator, dArray2, dArray3, n);
            continuousDiffusionIntegrator.getBranchDisplacement(this.getMatrixBufferOffsetIndex(nodeRef.getNumber()), dArray2);
            if (this.hasActualization()) {
                continuousDiffusionIntegrator.getBranchActualization(this.getMatrixBufferOffsetIndex(nodeRef.getNumber()), dArray3);
            }
            double[] dArray4 = new double[n];
            continuousDiffusionIntegrator.getBranchExpectation(dArray3, dArray, dArray2, dArray4);
            System.arraycopy(dArray4, 0, dArray, 0, n);
        }
    }

    @Override
    public double[][] getJointVariance(double d, double[][] dArray, double[][] dArray2, double[][] dArray3) {
        return KroneckerOperation.product(dArray, dArray3);
    }

    @Override
    public void getMeanTipVariances(double d, double[] dArray, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
        double d2 = 0.0;
        for (double d3 : dArray) {
            d2 += d3;
        }
        CommonOps.scale(d2 /= (double)dArray.length, denseMatrix64F, denseMatrix64F2);
    }
}

