/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.branchratemodel;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.AbstractBranchRateModel;
import dr.evomodel.tree.TreeModel;
import dr.inference.loggers.Loggable;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.xml.Reportable;
import java.util.Arrays;
import java.util.Comparator;

public class GridBasedBranchRateModel
extends AbstractBranchRateModel
implements Reportable,
Loggable {
    private final TreeModel tree;
    private final Parameter gridPoints;
    private final Parameter rateFunction;
    private double[] branchesIntersections;
    private double[] branchRates;
    private boolean ratesKnown;
    private boolean nodesOrderKnown;
    private boolean sufficientStatisticKnown;
    private Integer[] orderedNodesIndexes;

    public GridBasedBranchRateModel(TreeModel treeModel, Parameter parameter, Parameter parameter2) {
        super("gridBasedBranchRateModel");
        this.tree = treeModel;
        this.gridPoints = parameter;
        this.rateFunction = parameter2;
        this.ratesKnown = false;
        this.nodesOrderKnown = false;
        this.sufficientStatisticKnown = false;
        this.branchRates = new double[treeModel.getNodeCount()];
        this.branchesIntersections = new double[(parameter.getDimension() + 1) * treeModel.getNodeCount()];
        this.orderedNodesIndexes = new Integer[treeModel.getNodeCount()];
        this.addModel(treeModel);
        this.addVariable(parameter);
        this.addVariable(parameter2);
        this.getBranchRates();
    }

    private void getBranchRates() {
        if (!this.ratesKnown) {
            this.computeBranchRates();
            this.ratesKnown = true;
        }
    }

    private void computeBranchRates() {
        this.getIntersectionsMatrix();
        for (int i = 0; i < this.tree.getNodeCount(); ++i) {
            if (this.tree.isRoot(this.tree.getNode(i))) continue;
            double d = 0.0;
            for (int j = 0; j < this.gridPoints.getDimension() + 1; ++j) {
                d += this.branchesIntersections[j + (this.gridPoints.getDimension() + 1) * i] * this.rateFunction.getParameterValue(j);
            }
            this.branchRates[i] = d;
        }
    }

    private void getIntersectionsMatrix() {
        if (!this.sufficientStatisticKnown) {
            this.computeIntersectionsMatrix();
            this.sufficientStatisticKnown = true;
        }
    }

    private void computeIntersectionsMatrix() {
        int n = 0;
        this.orderNodesByHeight();
        for (int i = 0; i < this.tree.getNodeCount() - 1; ++i) {
            int n2 = this.orderedNodesIndexes[i];
            if (this.tree.isRoot(this.tree.getNode(n2))) continue;
            double d = this.tree.getNodeHeight(this.tree.getNode(n2));
            double d2 = this.tree.getNodeHeight(this.tree.getParent(this.tree.getNode(n2)));
            double d3 = d;
            while (n < this.gridPoints.getDimension() && this.gridPoints.getParameterValue(n) < d) {
                ++n;
            }
            int n3 = n;
            if (n3 < this.gridPoints.getDimension() && this.gridPoints.getParameterValue(n3) < d2) {
                while (n3 < this.gridPoints.getDimension() && this.gridPoints.getParameterValue(n3) < d2) {
                    this.branchesIntersections[n3 + (this.gridPoints.getDimension() + 1) * n2] = this.gridPoints.getParameterValue(n3) - d3;
                    d3 = this.gridPoints.getParameterValue(n3);
                    ++n3;
                }
                this.branchesIntersections[n3 + (this.gridPoints.getDimension() + 1) * n2] = d2 - d3;
                continue;
            }
            this.branchesIntersections[n3 + (this.gridPoints.getDimension() + 1) * n2] = d2 - d;
        }
    }

    private void orderNodesByHeight() {
        if (!this.nodesOrderKnown) {
            for (int i = 0; i < this.tree.getNodeCount(); ++i) {
                this.orderedNodesIndexes[i] = i;
            }
            Arrays.sort(this.orderedNodesIndexes, Comparator.comparingDouble(n -> this.tree.getNodeHeight(this.tree.getNode((int)n))));
            this.nodesOrderKnown = true;
        }
    }

    @Override
    public double getBranchRate(Tree tree, NodeRef nodeRef) {
        this.getBranchRates();
        return this.branchRates[nodeRef.getNumber()];
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        if (model == this.tree) {
            this.sufficientStatisticKnown = false;
            this.ratesKnown = false;
            this.nodesOrderKnown = false;
        }
        this.fireModelChanged();
    }

    @Override
    protected void storeState() {
    }

    @Override
    protected void restoreState() {
        this.ratesKnown = false;
    }

    @Override
    protected void acceptState() {
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        this.sufficientStatisticKnown = false;
        this.ratesKnown = false;
        this.fireModelChanged();
    }

    @Override
    public String getReport() {
        return "Branches intersections matrix: " + Arrays.toString(this.branchesIntersections) + "\nBranch rates: " + Arrays.toString(this.branchRates);
    }

    protected Parameter getGridPoints() {
        return this.gridPoints;
    }

    protected double getGridPoint(int n) {
        return this.gridPoints.getParameterValue(n);
    }

    protected TreeModel getTree() {
        return this.tree;
    }

    protected int getOrderedNodesIndexes(int n) {
        return this.orderedNodesIndexes[n];
    }

    protected double getSufficientStatistic(int n) {
        this.getIntersectionsMatrix();
        return this.branchesIntersections[n];
    }
}

