/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.branchratemodel;

import dr.evolution.alignment.PatternList;
import dr.evolution.datatype.DataType;
import dr.evolution.parsimony.FitchParsimony;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evolution.util.TaxonList;
import dr.evomodel.branchratemodel.AbstractBranchRateModel;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;

public class DiscreteTraitBranchRateModel
extends AbstractBranchRateModel {
    private static final boolean CACHING_RATES = true;
    public static final String DISCRETE_TRAIT_BRANCH_RATE_MODEL = "discreteTraitRateModel";
    protected TreeTrait trait = null;
    private Parameter rateParameter;
    private Parameter relativeRatesParameter;
    private Parameter indicatorParameter;
    protected int traitIndex;
    private double[] rates;
    private double[] storedRates;
    private boolean[] rateKnown;
    private TreeTrait[] traits;
    private FitchParsimony fitchParsimony;
    private boolean treeChanged = true;
    private Mode mode;
    private DataType dataType;

    public DiscreteTraitBranchRateModel(TreeModel treeModel, PatternList patternList, int n, Parameter parameter) {
        this(treeModel, n, parameter, null, null);
        if (!TaxonList.Utils.getTaxonListIdSet(treeModel).equals(TaxonList.Utils.getTaxonListIdSet(patternList))) {
            throw new IllegalArgumentException("Tree model and pattern list must have the same list of taxa!");
        }
        parameter.setDimension(patternList.getDataType().getStateCount());
        this.fitchParsimony = new FitchParsimony(patternList, false);
        this.mode = Mode.PARSIMONY;
    }

    public DiscreteTraitBranchRateModel(TreeTraitProvider treeTraitProvider, DataType dataType, TreeModel treeModel, TreeTrait treeTrait, int n, Parameter parameter, Parameter parameter2, Parameter parameter3) {
        this(treeModel, n, parameter, parameter2, parameter3);
        this.trait = treeTrait;
        this.dataType = dataType;
        this.mode = treeTrait.getTraitName().equals("states") ? Mode.NODE_STATES : Mode.MARKOV_JUMP_PROCESS;
        parameter2.setDimension(dataType.getStateCount());
        if (treeTraitProvider instanceof Model) {
            this.addModel((Model)((Object)treeTraitProvider));
        }
        if (treeTrait instanceof Model) {
            this.addModel((Model)((Object)treeTrait));
        }
    }

    public DiscreteTraitBranchRateModel(TreeTraitProvider treeTraitProvider, DataType dataType, TreeModel treeModel, TreeTrait treeTrait, int n, Parameter parameter) {
        this(treeModel, n, parameter, null, null);
        this.trait = treeTrait;
        this.dataType = dataType;
        this.mode = treeTrait.getTraitName().equals("states") ? Mode.NODE_STATES : Mode.MARKOV_JUMP_PROCESS;
        parameter.setDimension(dataType.getStateCount());
        if (treeTraitProvider instanceof Model) {
            this.addModel((Model)((Object)treeTraitProvider));
        }
        if (treeTrait instanceof Model) {
            this.addModel((Model)((Object)treeTrait));
        }
    }

    public DiscreteTraitBranchRateModel(TreeTraitProvider treeTraitProvider, TreeTrait[] treeTraitArray, TreeModel treeModel, Parameter parameter) {
        this(treeModel, 0, parameter, null, null);
        this.traits = treeTraitArray;
        this.mode = Mode.MARKOV_JUMP_PROCESS;
        parameter.setDimension(treeTraitArray.length);
        if (treeTraitProvider instanceof Model) {
            this.addModel((Model)((Object)treeTraitProvider));
        }
    }

    private DiscreteTraitBranchRateModel(TreeModel treeModel, int n, Parameter parameter, Parameter parameter2, Parameter parameter3) {
        super(DISCRETE_TRAIT_BRANCH_RATE_MODEL);
        this.addModel(treeModel);
        this.traitIndex = n;
        this.rateParameter = parameter;
        this.addVariable(parameter);
        this.relativeRatesParameter = parameter2;
        if (parameter2 != null) {
            this.addVariable(parameter2);
        }
        this.indicatorParameter = parameter3;
        if (parameter3 != null) {
            this.addVariable(parameter3);
        }
        this.rates = new double[treeModel.getNodeCount()];
        this.storedRates = new double[treeModel.getNodeCount()];
        this.rateKnown = new boolean[treeModel.getNodeCount()];
    }

    @Override
    public void handleModelChangedEvent(Model model, Object object, int n) {
        for (int i = 0; i < this.rateKnown.length; ++i) {
            this.rateKnown[i] = false;
        }
        this.treeChanged = true;
        this.fireModelChanged();
    }

    @Override
    protected final void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        for (int i = 0; i < this.rateKnown.length; ++i) {
            this.rateKnown[i] = false;
        }
        this.fireModelChanged();
    }

    @Override
    protected void storeState() {
        System.arraycopy(this.rates, 0, this.storedRates, 0, this.rates.length);
    }

    @Override
    protected void restoreState() {
        double[] dArray = this.rates;
        this.rates = this.storedRates;
        this.storedRates = dArray;
        for (int i = 0; i < this.rateKnown.length; ++i) {
            this.rateKnown[i] = true;
        }
    }

    @Override
    protected void acceptState() {
    }

    protected int getStateCount() {
        int n = 0;
        if (this.mode == Mode.NODE_STATES || this.mode == Mode.MARKOV_JUMP_PROCESS) {
            n = this.dataType.getStateCount();
        } else if (this.mode == Mode.PARSIMONY) {
            n = this.fitchParsimony.getPatterns().getStateCount();
        }
        return n;
    }

    @Override
    public double getBranchRate(Tree tree, NodeRef nodeRef) {
        if (!this.rateKnown[nodeRef.getNumber()]) {
            this.rates[nodeRef.getNumber()] = this.getRawBranchRate(tree, nodeRef);
            this.rateKnown[nodeRef.getNumber()] = true;
        }
        return this.rates[nodeRef.getNumber()];
    }

    protected double getRawBranchRate(Tree tree, NodeRef nodeRef) {
        int n;
        double d = 0.0;
        int n2 = this.getStateCount();
        double[] dArray = this.getProcessValues(tree, nodeRef);
        double[] dArray2 = new double[n2];
        double d2 = 0.0;
        for (n = 0; n < n2; ++n) {
            int n3 = n;
            dArray2[n3] = dArray2[n3] + dArray[n];
            d2 += dArray[n];
        }
        n = 0;
        while (n < n2) {
            int n4 = n++;
            dArray2[n4] = dArray2[n4] / d2;
        }
        if (this.relativeRatesParameter != null && this.indicatorParameter == null) {
            double d3 = this.rateParameter.getParameterValue(0);
            for (int i = 0; i < n2; ++i) {
                d += d3 * this.relativeRatesParameter.getParameterValue(i) * dArray2[i];
            }
        } else if (this.relativeRatesParameter != null && this.indicatorParameter != null) {
            double d4 = this.rateParameter.getParameterValue(0);
            for (int i = 0; i < n2; ++i) {
                d += d4 * this.relativeRatesParameter.getParameterValue(i) * dArray2[i] * this.indicatorParameter.getParameterValue(i);
            }
        } else {
            for (n = 0; n < n2; ++n) {
                d += this.rateParameter.getParameterValue(n) * dArray[n];
                d2 += dArray[n];
            }
        }
        return d;
    }

    private double[] getProcessValues(Tree tree, NodeRef nodeRef) {
        double[] dArray = null;
        int n = this.getStateCount();
        double d = tree.getBranchLength(nodeRef);
        if (this.mode == Mode.MARKOV_JUMP_PROCESS) {
            dArray = new double[n];
            for (int i = 0; i < n; ++i) {
                dArray[i] = ((double[])((TreeTrait.DA)this.traits[i]).getTrait(tree, nodeRef))[0];
            }
        } else if (this.mode == Mode.PARSIMONY) {
            int n2;
            int n3;
            if (this.treeChanged) {
                this.fitchParsimony.initialize(tree);
                this.treeChanged = false;
            }
            int[] nArray = this.fitchParsimony.getStates(tree, nodeRef);
            int[] nArray2 = this.fitchParsimony.getStates(tree, tree.getParent(nodeRef));
            dArray = new double[this.fitchParsimony.getPatterns().getStateCount()];
            int[] nArray3 = nArray;
            int n4 = nArray3.length;
            for (n3 = 0; n3 < n4; ++n3) {
                int n5 = n2 = nArray3[n3];
                dArray[n5] = dArray[n5] + d / 2.0;
            }
            nArray3 = nArray2;
            n4 = nArray3.length;
            for (n3 = 0; n3 < n4; ++n3) {
                int n6 = n2 = nArray3[n3];
                dArray[n6] = dArray[n6] + d / 2.0;
            }
            int n7 = 0;
            while (n7 < dArray.length) {
                int n8 = n7++;
                dArray[n8] = dArray[n8] / (double)((nArray.length + nArray2.length) / 2);
            }
        } else if (this.mode == Mode.NODE_STATES) {
            int n9;
            int n10;
            dArray = new double[n];
            int n11 = n10 = ((int[])this.trait.getTrait(tree, nodeRef))[this.traitIndex];
            dArray[n11] = dArray[n11] + d / 2.0;
            NodeRef nodeRef2 = tree.getParent(nodeRef);
            int n12 = n9 = ((int[])this.trait.getTrait(tree, nodeRef2))[this.traitIndex];
            dArray[n12] = dArray[n12] + d / 2.0;
        }
        return dArray;
    }

    static enum Mode {
        NODE_STATES,
        MARKOV_JUMP_PROCESS,
        MARKOV_JUMP_COUNT,
        PARSIMONY;

    }
}

