/*
 * Decompiled with CFR 0.152.
 */
package dr.evolution.colouring;

import dr.evolution.alignment.Alignment;
import dr.evolution.coalescent.structure.MetaPopulation;
import dr.evolution.colouring.BranchColouring;
import dr.evolution.colouring.ColourChangeMatrix;
import dr.evolution.colouring.ColourSampler;
import dr.evolution.colouring.DefaultBranchColouring;
import dr.evolution.colouring.DefaultTreeColouring;
import dr.evolution.colouring.TreeColouring;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.util.TaxonList;
import dr.math.MathUtils;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.TreeMap;

public class StructuredColourSampler
implements ColourSampler {
    static final int maxIterations = 1000;
    static final double tinyTime = 1.0E-6;
    static final boolean debugMessages = false;
    static final boolean debugMeanColours = false;
    static final boolean debugNodePartials = false;
    static final boolean debugSampleLikelihoods = false;
    static final boolean debugRejectionSampler = false;
    static final boolean debugProposalProbabilityCalculator = false;
    double _totalIntegratedRate;
    static final DecimalFormat df = new DecimalFormat("###.####");
    static final double propAffected = 0.0;
    private boolean useNodeBias = false;
    private boolean useBranchBias = false;
    private boolean useSecondColourIteration = true;
    private final int colourCount;
    private final int[] nodeColours;
    private final int[] leafColourCounts;
    private double[][] meanColourCounts;
    private int[][] nodeColoursEM;
    private double[][] nodePartials;
    private double[][][] nodePartialsEM;
    private double logNodePartialsRescaling;
    private double[] equilibriumColours;
    private int[] node2Interval;
    private double[] interval2Height;
    private double[] avgN0;
    private double[] avgN1;
    private int numIntervals;

    public StructuredColourSampler(Alignment alignment, Tree tree, boolean bl, boolean bl2, boolean bl3) {
        if (alignment.getSiteCount() != 1) {
            throw new IllegalArgumentException("Tip colour alignment must consist of a single column!");
        }
        this.nodeColours = new int[tree.getNodeCount()];
        this.colourCount = alignment.getDataType().getStateCount();
        this.leafColourCounts = new int[this.colourCount];
        for (int i = 0; i < tree.getExternalNodeCount(); ++i) {
            int n;
            NodeRef nodeRef = tree.getExternalNode(i);
            this.nodeColours[nodeRef.getNumber()] = n = alignment.getState(alignment.getTaxonIndex(tree.getTaxonId(i)), 0);
            int n2 = n;
            this.leafColourCounts[n2] = this.leafColourCounts[n2] + 1;
        }
        this.useNodeBias = bl;
        this.useBranchBias = bl2;
        this.useSecondColourIteration = bl3;
        this.initialize(tree);
    }

    public StructuredColourSampler(TaxonList[] taxonListArray, Tree tree, boolean bl, boolean bl2, boolean bl3) {
        this.nodeColours = new int[tree.getNodeCount()];
        this.colourCount = taxonListArray.length + 1;
        this.leafColourCounts = new int[this.colourCount];
        for (int i = 0; i < tree.getExternalNodeCount(); ++i) {
            NodeRef nodeRef = tree.getExternalNode(i);
            int n = 0;
            for (int j = 0; j < taxonListArray.length; ++j) {
                if (taxonListArray[j].getTaxonIndex(tree.getTaxonId(i)) == -1) continue;
                n = j + 1;
            }
            this.nodeColours[nodeRef.getNumber()] = n;
            int n2 = n;
            this.leafColourCounts[n2] = this.leafColourCounts[n2] + 1;
        }
        this.useNodeBias = bl;
        this.useBranchBias = bl2;
        this.useSecondColourIteration = bl3;
        this.initialize(tree);
    }

    @Override
    public int[] getLeafColourCounts() {
        return this.leafColourCounts;
    }

    private void initialize(Tree tree) {
        this.nodePartials = new double[tree.getNodeCount()][this.colourCount];
        this.meanColourCounts = new double[tree.getNodeCount()][this.colourCount];
        this.nodeColoursEM = new int[tree.getNodeCount()][];
        this.nodePartialsEM = new double[tree.getNodeCount()][][];
        this.equilibriumColours = new double[this.colourCount];
    }

    private void computeIntervals(Tree tree, MetaPopulation metaPopulation) {
        List list;
        Double d;
        TreeMap treeMap = new TreeMap();
        int n = tree.getNodeCount();
        for (int i = 0; i < n; ++i) {
            NodeRef nodeRef = tree.getNode(i);
            d = new Double(tree.getNodeHeight(nodeRef));
            if (treeMap.containsKey(d)) {
                ((ArrayList)treeMap.get(d)).add(nodeRef);
                continue;
            }
            list = new ArrayList<NodeRef>(1);
            ((ArrayList)list).add(nodeRef);
            treeMap.put(d, list);
        }
        this.node2Interval = new int[n];
        this.interval2Height = new double[treeMap.size()];
        this.avgN0 = new double[treeMap.size()];
        this.avgN1 = new double[treeMap.size()];
        Iterator iterator = treeMap.keySet().iterator();
        int n2 = 0;
        while (iterator.hasNext()) {
            d = (Double)iterator.next();
            this.interval2Height[n2] = d;
            list = (List)treeMap.get(d);
            for (int i = 0; i < list.size(); ++i) {
                this.node2Interval[((NodeRef)list.get((int)i)).getNumber()] = n2;
            }
            if (n2 > 0) {
                double d2 = this.interval2Height[n2 - 1];
                double d3 = d;
                this.avgN0[n2 - 1] = (d3 - d2) / metaPopulation.getIntegral(d2, d3, 0);
                this.avgN1[n2 - 1] = (d3 - d2) / metaPopulation.getIntegral(d2, d3, 1);
            }
            ++n2;
        }
        this.numIntervals = n2;
    }

    @Override
    public DefaultTreeColouring sampleTreeColouring(Tree tree, ColourChangeMatrix colourChangeMatrix, MetaPopulation metaPopulation) {
        this.populateEquilibriumColourArray(colourChangeMatrix);
        this.computeIntervals(tree, metaPopulation);
        DefaultTreeColouring defaultTreeColouring = new DefaultTreeColouring(2, tree);
        this.logNodePartialsRescaling = 0.0;
        this.prune(tree, tree.getRoot(), colourChangeMatrix);
        this.calculateMeanColourCounts(tree, colourChangeMatrix);
        this.logNodePartialsRescaling = 0.0;
        double[] dArray = this.pruneEM(tree, tree.getRoot(), colourChangeMatrix, metaPopulation);
        if (this.useSecondColourIteration) {
            this.calculateMeanColourCountsEM(tree, tree.getRoot(), colourChangeMatrix);
            this.logNodePartialsRescaling = 0.0;
            dArray = this.pruneEM(tree, tree.getRoot(), colourChangeMatrix, metaPopulation);
        }
        double d = 0.0;
        for (int i = 0; i < this.colourCount; ++i) {
            d += this.equilibriumColours[i] * dArray[i];
        }
        double d2 = Math.log(d) + this.logNodePartialsRescaling;
        double d3 = this.sampleEM(tree, tree.getRoot(), colourChangeMatrix, metaPopulation, defaultTreeColouring) - d2;
        defaultTreeColouring.setLogProbabilityDensity(d3);
        return defaultTreeColouring;
    }

    @Override
    public double getProposalProbability(TreeColouring treeColouring, Tree tree, ColourChangeMatrix colourChangeMatrix, MetaPopulation metaPopulation) {
        this.populateEquilibriumColourArray(colourChangeMatrix);
        this.computeIntervals(tree, metaPopulation);
        this.logNodePartialsRescaling = 0.0;
        this.prune(tree, tree.getRoot(), colourChangeMatrix);
        this.calculateMeanColourCounts(tree, colourChangeMatrix);
        this.logNodePartialsRescaling = 0.0;
        double[] dArray = this.pruneEM(tree, tree.getRoot(), colourChangeMatrix, metaPopulation);
        if (this.useSecondColourIteration) {
            this.calculateMeanColourCountsEM(tree, tree.getRoot(), colourChangeMatrix);
            this.logNodePartialsRescaling = 0.0;
            dArray = this.pruneEM(tree, tree.getRoot(), colourChangeMatrix, metaPopulation);
        }
        double d = 0.0;
        for (int i = 0; i < this.colourCount; ++i) {
            d += this.equilibriumColours[i] * dArray[i];
        }
        double d2 = this.calculateEMProposal(tree, tree.getRoot(), colourChangeMatrix, metaPopulation, treeColouring);
        return d2 - Math.log(d) - this.logNodePartialsRescaling;
    }

    private int getColour(NodeRef nodeRef) {
        return this.nodeColours[nodeRef.getNumber()];
    }

    private void setColour(NodeRef nodeRef, int n) {
        if (n < 0 || n >= this.colourCount) {
            throw new IllegalArgumentException("colour value " + n + " + is outside of range of colours, [0, " + Integer.toString(this.colourCount - 1) + "]");
        }
        this.nodeColours[nodeRef.getNumber()] = n;
    }

    void populateEquilibriumColourArray(ColourChangeMatrix colourChangeMatrix) {
        for (int i = 0; i < this.colourCount; ++i) {
            this.equilibriumColours[i] = colourChangeMatrix.getEquilibrium(i);
        }
    }

    double[] getMeanColours(int n, ColourChangeMatrix colourChangeMatrix) {
        int n2;
        double[] dArray = new double[this.colourCount];
        double d = 0.0;
        for (n2 = 0; n2 < this.colourCount; ++n2) {
            dArray[n2] = this.nodePartials[n][n2] * this.equilibriumColours[n2];
            d += dArray[n2];
        }
        n2 = 0;
        while (n2 < this.colourCount) {
            int n3 = n2++;
            dArray[n3] = dArray[n3] / d;
        }
        return dArray;
    }

    double[] getMeanColoursEM(int n, int n2, ColourChangeMatrix colourChangeMatrix) {
        int n3;
        double[] dArray = new double[this.colourCount];
        double d = 0.0;
        for (n3 = 0; n3 < this.colourCount; ++n3) {
            dArray[n3] = this.nodePartialsEM[n][n2][n3] * this.equilibriumColours[n3];
            d += dArray[n3];
        }
        n3 = 0;
        while (n3 < this.colourCount) {
            int n4 = n3++;
            dArray[n4] = dArray[n4] / d;
        }
        return dArray;
    }

    void fillMeanColourCounts(Tree tree, NodeRef nodeRef, ColourChangeMatrix colourChangeMatrix) {
        NodeRef nodeRef2;
        if (!tree.isRoot(nodeRef)) {
            nodeRef2 = tree.getParent(nodeRef);
            int n = nodeRef2.getNumber();
            int n2 = nodeRef.getNumber();
            double[] dArray = this.getMeanColours(n, colourChangeMatrix);
            double[] dArray2 = this.getMeanColours(n2, colourChangeMatrix);
            int n3 = 0;
            while (n3 < this.colourCount) {
                double d = (dArray[n3] + dArray2[n3]) / 2.0;
                double[] dArray3 = this.meanColourCounts[this.node2Interval[n2]];
                int n4 = n3;
                dArray3[n4] = dArray3[n4] + d;
                double[] dArray4 = this.meanColourCounts[this.node2Interval[n]];
                int n5 = n3++;
                dArray4[n5] = dArray4[n5] - d;
            }
        }
        if (!tree.isExternal(nodeRef)) {
            nodeRef2 = tree.getChild(nodeRef, 0);
            NodeRef nodeRef3 = tree.getChild(nodeRef, 1);
            this.fillMeanColourCounts(tree, nodeRef2, colourChangeMatrix);
            this.fillMeanColourCounts(tree, nodeRef3, colourChangeMatrix);
        }
    }

    void calculateMeanColourCountsEM(Tree tree, NodeRef nodeRef, ColourChangeMatrix colourChangeMatrix) {
        int n;
        if (tree.isRoot(nodeRef)) {
            for (int i = 0; i < this.colourCount; ++i) {
                for (n = 0; n < this.meanColourCounts.length; ++n) {
                    this.meanColourCounts[n][i] = 0.0;
                }
            }
        } else {
            NodeRef nodeRef2 = tree.getParent(nodeRef);
            n = nodeRef2.getNumber();
            int n2 = nodeRef.getNumber();
            int n3 = this.node2Interval[n2];
            int n4 = this.node2Interval[n];
            double[] dArray = this.getMeanColoursEM(n2, 0, colourChangeMatrix);
            for (int i = 0; i < n4 - n3; ++i) {
                int n5 = i + 1;
                double[] dArray2 = n5 + n3 < n4 ? this.getMeanColoursEM(n2, n5, colourChangeMatrix) : (!tree.isRoot(nodeRef2) ? this.getMeanColoursEM(n, 0, colourChangeMatrix) : dArray);
                int n6 = 0;
                while (n6 < this.colourCount) {
                    double d = (dArray2[n6] + dArray[n6]) / 2.0;
                    double[] dArray3 = this.meanColourCounts[i + n3];
                    int n7 = n6++;
                    dArray3[n7] = dArray3[n7] + d;
                }
                dArray = dArray2;
            }
        }
        if (!tree.isExternal(nodeRef)) {
            NodeRef nodeRef3 = tree.getChild(nodeRef, 0);
            NodeRef nodeRef4 = tree.getChild(nodeRef, 1);
            this.calculateMeanColourCountsEM(tree, nodeRef3, colourChangeMatrix);
            this.calculateMeanColourCountsEM(tree, nodeRef4, colourChangeMatrix);
        }
        if (tree.isRoot(nodeRef)) {
            // empty if block
        }
    }

    void calculateMeanColourCounts(Tree tree, ColourChangeMatrix colourChangeMatrix) {
        int n;
        for (n = 0; n < this.colourCount; ++n) {
            for (int i = 0; i < this.meanColourCounts.length; ++i) {
                this.meanColourCounts[i][n] = 0.0;
            }
        }
        this.fillMeanColourCounts(tree, tree.getRoot(), colourChangeMatrix);
        for (n = 0; n < this.colourCount; ++n) {
            double d = 0.0;
            for (int i = 0; i < this.meanColourCounts.length; ++i) {
                this.meanColourCounts[i][n] = d += this.meanColourCounts[i][n];
            }
        }
    }

    private double[] prune(Tree tree, NodeRef nodeRef, ColourChangeMatrix colourChangeMatrix) {
        double[] dArray = new double[this.colourCount];
        if (tree.isExternal(nodeRef)) {
            dArray[this.getColour((NodeRef)nodeRef)] = 1.0;
        } else {
            int n;
            NodeRef nodeRef2 = tree.getChild(nodeRef, 0);
            NodeRef nodeRef3 = tree.getChild(nodeRef, 1);
            double[] dArray2 = this.prune(tree, nodeRef2, colourChangeMatrix);
            double[] dArray3 = this.prune(tree, nodeRef3, colourChangeMatrix);
            double d = tree.getNodeHeight(nodeRef);
            double d2 = d - tree.getNodeHeight(tree.getChild(nodeRef, 0));
            double d3 = d - tree.getNodeHeight(tree.getChild(nodeRef, 1));
            double d4 = 0.0;
            for (n = 0; n < this.colourCount; ++n) {
                double d5 = 0.0;
                double d6 = 0.0;
                for (int i = 0; i < this.colourCount; ++i) {
                    d5 += colourChangeMatrix.forwardTimeEvolution(n, i, d2) * dArray2[i];
                    d6 += colourChangeMatrix.forwardTimeEvolution(n, i, d3) * dArray3[i];
                }
                dArray[n] = d5 * d6;
                if (!(dArray[n] > d4)) continue;
                d4 = dArray[n];
            }
            if (d4 < 1.0E-100) {
                n = 0;
                while (n < this.colourCount) {
                    int n2 = n++;
                    dArray[n2] = dArray[n2] * 1.0E100;
                }
                this.logNodePartialsRescaling -= Math.log(1.0E100);
            }
        }
        this.nodePartials[nodeRef.getNumber()] = dArray;
        return dArray;
    }

    static double[] matrixEvolve(double[] dArray, int n) {
        double d = dArray[0];
        double d2 = dArray[3];
        double d3 = dArray[1];
        double d4 = dArray[2];
        double d5 = Math.sqrt((d - d2) * (d - d2) + 4.0 * d3 * d4);
        if (d5 < 1.0E-5) {
            if (n == 0) {
                d = Math.exp(-d);
                return new double[]{d, d3 * d};
            }
            d2 = Math.exp(-d2);
            return new double[]{d4 * d2, d2};
        }
        double d6 = Math.exp(-(d + d2 + d5) / 2.0);
        double d7 = Math.exp(-(d + d2 - d5) / 2.0);
        if (n == 0) {
            return new double[]{((d2 - d + d5) * d7 - (d2 - d - d5) * d6) / (2.0 * d5), d3 * (d7 - d6) / d5};
        }
        return new double[]{d4 * (d7 - d6) / d5, ((d - d2 + d5) * d7 - (d - d2 - d5) * d6) / (2.0 * d5)};
    }

    static void matrixPullBack(double[] dArray, double[] dArray2) {
        double d;
        double d2;
        double d3;
        double d4;
        double d5;
        double d6 = dArray[0];
        double d7 = dArray[3];
        double d8 = dArray[1];
        double d9 = dArray[2];
        double d10 = Math.sqrt((d6 - d7) * (d6 - d7) + 4.0 * d8 * d9);
        if (d10 < 1.0E-5) {
            d5 = Math.exp(-d6);
            d4 = Math.exp(-d7);
            d3 = d8 * d5;
            d2 = d9 * d4;
        } else {
            d = Math.exp(-(d6 + d7 + d10) / 2.0);
            double d11 = Math.exp(-(d6 + d7 - d10) / 2.0);
            d5 = ((d7 - d6 + d10) * d11 - (d7 - d6 - d10) * d) / (2.0 * d10);
            d3 = d8 * (d11 - d) / d10;
            d2 = d9 * (d11 - d) / d10;
            d4 = ((d6 - d7 + d10) * d11 - (d6 - d7 - d10) * d) / (2.0 * d10);
        }
        d = dArray2[0] * d5 + dArray2[1] * d3;
        dArray2[1] = dArray2[0] * d2 + dArray2[1] * d4;
        dArray2[0] = d;
    }

    double[] calculateMatrixElts(int n, NodeRef nodeRef, Tree tree, double d, double d2, double d3, ColourChangeMatrix colourChangeMatrix) {
        double d4;
        double d5 = this.meanColourCounts[n][0];
        double d6 = this.meanColourCounts[n][1];
        double d7 = d5 + d6;
        double d8 = (0.0 * (d7 - 1.0) + 1.0 * (d5 - 1.0)) / (2.0 * d2) * d;
        if (d8 < 0.0) {
            d8 = 0.0;
        }
        if ((d4 = (0.0 * (d7 - 1.0) + 1.0 * (d6 - 1.0)) / (2.0 * d3) * d) < 0.0) {
            d4 = 0.0;
        }
        if (!this.useBranchBias) {
            d8 = 0.0;
            d4 = 0.0;
        }
        double d9 = Math.min(d8, d4);
        double d10 = colourChangeMatrix.getForwardRate(0, 1) * d;
        double d11 = colourChangeMatrix.getForwardRate(1, 0) * d;
        return new double[]{d10 + (d8 -= d9), d10, d11, d11 + (d4 -= d9)};
    }

    double[] pruneBranchEM(ColourChangeMatrix colourChangeMatrix, double[] dArray, NodeRef nodeRef, NodeRef nodeRef2, Tree tree, MetaPopulation metaPopulation) {
        int n = this.node2Interval[nodeRef.getNumber()];
        int n2 = this.node2Interval[nodeRef2.getNumber()];
        double[][] dArray2 = new double[n - n2][2];
        double[] dArray3 = (double[])dArray.clone();
        for (int i = n2; i != n; ++i) {
            dArray2[i - n2][0] = dArray3[0];
            dArray2[i - n2][1] = dArray3[1];
            double d = this.interval2Height[i + 1] - this.interval2Height[i];
            double[] dArray4 = this.calculateMatrixElts(i, nodeRef2, tree, d, this.avgN0[i], this.avgN1[i], colourChangeMatrix);
            StructuredColourSampler.matrixPullBack(dArray4, dArray3);
        }
        this.nodePartialsEM[nodeRef2.getNumber()] = dArray2;
        return dArray3;
    }

    private double[] pruneEM(Tree tree, NodeRef nodeRef, ColourChangeMatrix colourChangeMatrix, MetaPopulation metaPopulation) {
        double[] dArray = new double[this.colourCount];
        if (tree.isExternal(nodeRef)) {
            dArray[this.getColour((NodeRef)nodeRef)] = 1.0;
        } else {
            int n;
            NodeRef nodeRef2 = tree.getChild(nodeRef, 0);
            NodeRef nodeRef3 = tree.getChild(nodeRef, 1);
            double[] dArray2 = this.pruneEM(tree, nodeRef2, colourChangeMatrix, metaPopulation);
            double[] dArray3 = this.pruneEM(tree, nodeRef3, colourChangeMatrix, metaPopulation);
            double[] dArray4 = this.pruneBranchEM(colourChangeMatrix, dArray2, nodeRef, nodeRef2, tree, metaPopulation);
            double[] dArray5 = this.pruneBranchEM(colourChangeMatrix, dArray3, nodeRef, nodeRef3, tree, metaPopulation);
            double d = 0.0;
            for (n = 0; n < this.colourCount; ++n) {
                dArray[n] = dArray4[n] * dArray5[n];
                if (this.useNodeBias) {
                    int n2 = n;
                    dArray[n2] = dArray[n2] * (colourChangeMatrix.getEquilibrium(n) / metaPopulation.getDemographic(tree.getNodeHeight(nodeRef) - 1.0E-6, n));
                }
                if (!(dArray[n] > d)) continue;
                d = dArray[n];
            }
            if (d < 1.0E-100) {
                n = 0;
                while (n < this.colourCount) {
                    int n3 = n++;
                    dArray[n3] = dArray[n3] * 1.0E100;
                }
                this.logNodePartialsRescaling -= Math.log(1.0E100);
            }
        }
        this.nodePartials[nodeRef.getNumber()] = dArray;
        return dArray;
    }

    private double sampleEM(Tree tree, NodeRef nodeRef, ColourChangeMatrix colourChangeMatrix, MetaPopulation metaPopulation, DefaultTreeColouring defaultTreeColouring) {
        int n;
        double[] dArray;
        double[] dArray2;
        DefaultBranchColouring defaultBranchColouring = null;
        double d = 0.0;
        if (tree.isRoot(nodeRef)) {
            this._totalIntegratedRate = 0.0;
            dArray2 = colourChangeMatrix.getEquilibrium();
            double[] dArray3 = this.nodePartials[nodeRef.getNumber()];
            dArray = new double[this.colourCount];
            double d2 = -1.0;
            double d3 = 1.0;
            for (int i = 0; i < dArray2.length; ++i) {
                dArray[i] = dArray2[i] * dArray3[i];
                d2 = Math.max(d2, dArray[i]);
                d3 = Math.min(d3, dArray[i]);
            }
            n = MathUtils.randomChoicePDF(dArray);
            d += Math.log(dArray2[n]);
        } else {
            int n2 = nodeRef.getNumber();
            double[][] dArray4 = this.nodePartialsEM[n2];
            int n3 = this.node2Interval[n2];
            this.nodeColoursEM[n2] = new int[dArray4.length];
            n = this.getColour(tree.getParent(nodeRef));
            defaultBranchColouring = new DefaultBranchColouring(n, n);
            dArray = new double[this.colourCount];
            for (int i = dArray4.length - 1; i >= 0; --i) {
                int n4;
                int n5 = i + n3;
                double d4 = this.interval2Height[n5];
                double d5 = this.interval2Height[n5 + 1];
                double d6 = d5 - d4;
                double[] dArray5 = this.calculateMatrixElts(n5, nodeRef, tree, d6, this.avgN0[n5], this.avgN1[n5], colourChangeMatrix);
                dArray2 = StructuredColourSampler.matrixEvolve(dArray5, n);
                for (n4 = 0; n4 < this.colourCount; ++n4) {
                    dArray[n4] = dArray2[n4] * dArray4[i][n4];
                }
                this.nodeColoursEM[n2][i] = n4 = MathUtils.randomChoicePDF(dArray);
                d += this.sampleConditionalBranchColouringEM(nodeRef, n, n4, d6, d4, dArray5, defaultBranchColouring);
                n = n4;
            }
            defaultTreeColouring.setBranchColouring(nodeRef, defaultBranchColouring);
        }
        this.setColour(nodeRef, n);
        if (!tree.isExternal(nodeRef) && this.useNodeBias) {
            double d7 = tree.getNodeHeight(nodeRef);
            d += Math.log(colourChangeMatrix.getEquilibrium(n) / metaPopulation.getDemographic(d7 - 1.0E-6, n));
        }
        for (int i = 0; i < tree.getChildCount(nodeRef); ++i) {
            NodeRef nodeRef2 = tree.getChild(nodeRef, i);
            d += this.sampleEM(tree, nodeRef2, colourChangeMatrix, metaPopulation, defaultTreeColouring);
        }
        return d;
    }

    private double sampleConditionalBranchColouringEM(NodeRef nodeRef, int n, int n2, double d, double d2, double[] dArray, DefaultBranchColouring defaultBranchColouring) {
        double d3;
        double d4;
        double d5;
        double d6;
        double d7;
        int n3;
        boolean bl;
        DefaultBranchColouring defaultBranchColouring2 = new DefaultBranchColouring(n, n2);
        int n4 = 0;
        double d8 = 0.0;
        String string = "";
        do {
            defaultBranchColouring2.clear();
            n3 = n;
            d7 = d;
            d6 = 0.0;
            bl = false;
            boolean bl2 = true;
            do {
                double d9;
                if (n3 == 0) {
                    d5 = dArray[0] / d;
                    d4 = dArray[1] / d;
                } else {
                    d5 = dArray[3] / d;
                    d4 = dArray[2] / d;
                }
                while ((d9 = MathUtils.nextDouble()) == 0.0) {
                }
                if (bl2 && n != n2) {
                    double d10 = Math.exp(-d5 * d);
                    d9 = d10 + d9 * (1.0 - d10);
                }
                if ((d7 -= (d3 = -Math.log(d9) / d5)) > 0.0) {
                    if (bl2 || d4 == d5 || MathUtils.nextDouble() < d4 / d5) {
                        n3 = 1 - n3;
                        defaultBranchColouring2.addEvent(n3, d7 + d2);
                        d6 += -d5 * d3 + Math.log(d4);
                    } else {
                        bl = true;
                    }
                } else {
                    d6 += -d5 * (d7 + d3);
                }
                bl2 = false;
            } while (!bl && d7 > 0.0);
            ++n4;
            if (n3 == n2) continue;
            bl = true;
        } while (bl && n4 < 1000);
        if (bl && n3 != n2) {
            d5 = d7 + d3;
            d4 = 0.01 * d5;
            defaultBranchColouring2.addEvent(n2, d4 + d2);
        }
        defaultBranchColouring.addHistory(defaultBranchColouring2);
        return d6;
    }

    private double calculateEMProposal(Tree tree, NodeRef nodeRef, ColourChangeMatrix colourChangeMatrix, MetaPopulation metaPopulation, TreeColouring treeColouring) {
        Object object;
        int n;
        int n2;
        BranchColouring branchColouring = null;
        double d = 0.0;
        if (tree.isRoot(nodeRef)) {
            double[] dArray = colourChangeMatrix.getEquilibrium();
            n2 = treeColouring.getNodeColour(nodeRef);
            d += Math.log(dArray[n2]);
        } else {
            n = nodeRef.getNumber();
            object = this.nodePartialsEM[n];
            int n3 = this.node2Interval[n];
            branchColouring = treeColouring.getBranchColouring(nodeRef);
            for (int i = ((double[][])object).length - 1; i >= 0; --i) {
                int n4 = i + n3;
                double d2 = this.interval2Height[n4];
                double d3 = this.interval2Height[n4 + 1];
                double d4 = d3 - d2;
                double[] dArray = this.calculateMatrixElts(n4, nodeRef, tree, d4, this.avgN0[n4], this.avgN1[n4], colourChangeMatrix);
                d += this.calculateConditionalBranchColouringEM(nodeRef, d4, d2, dArray, branchColouring);
            }
            n2 = treeColouring.getNodeColour(nodeRef);
        }
        if (!tree.isExternal(nodeRef) && this.useNodeBias) {
            double d5 = tree.getNodeHeight(nodeRef);
            d += Math.log(colourChangeMatrix.getEquilibrium(n2) / metaPopulation.getDemographic(d5 - 1.0E-6, n2));
        }
        for (n = 0; n < tree.getChildCount(nodeRef); ++n) {
            object = tree.getChild(nodeRef, n);
            d += this.calculateEMProposal(tree, (NodeRef)object, colourChangeMatrix, metaPopulation, treeColouring);
        }
        return d;
    }

    private double calculateConditionalBranchColouringEM(NodeRef nodeRef, double d, double d2, double[] dArray, BranchColouring branchColouring) {
        double d3 = d + d2;
        int n = branchColouring.getNextForwardEvent(d3);
        int n2 = branchColouring.getForwardColourBelow(n - 1);
        double d4 = 0.0;
        int n3 = 0;
        while (d3 > d2) {
            double d5;
            double d6;
            double d7 = n == branchColouring.getNumEvents() + 1 ? d2 - 1.0 : branchColouring.getForwardTime(n);
            double d8 = d3 - d7;
            if (n2 == 0) {
                d6 = dArray[0] / d;
                d5 = dArray[1] / d;
            } else {
                d6 = dArray[3] / d;
                d5 = dArray[2] / d;
            }
            if (d7 < d2) {
                d8 = d3 - d2;
                d4 += -d6 * d8;
            } else {
                d4 += -d6 * d8 + Math.log(d5);
                n2 = branchColouring.getForwardColourBelow(n);
                ++n3;
            }
            d3 = d7;
            ++n;
        }
        return d4;
    }

    private void prettyPrint(String string, double[] dArray) {
        System.out.print(string + "= (");
        for (double d : dArray) {
            System.out.print(d + ", ");
        }
        System.out.println(")");
    }

    static void testMatrix(double[] dArray, double[] dArray2) {
        if (Math.abs(StructuredColourSampler.matrixEvolve(dArray, 0)[0] - dArray2[0]) > 1.0E-6) {
            throw new Error("1");
        }
        if (Math.abs(StructuredColourSampler.matrixEvolve(dArray, 0)[1] - dArray2[1]) > 1.0E-6) {
            throw new Error("2");
        }
        if (Math.abs(StructuredColourSampler.matrixEvolve(dArray, 1)[0] - dArray2[2]) > 1.0E-6) {
            throw new Error("3");
        }
        if (Math.abs(StructuredColourSampler.matrixEvolve(dArray, 1)[1] - dArray2[3]) > 1.0E-6) {
            throw new Error("4");
        }
        double[] dArray3 = new double[]{1.0, 0.0};
        double[] dArray4 = new double[]{0.0, 1.0};
        StructuredColourSampler.matrixPullBack(dArray, dArray3);
        StructuredColourSampler.matrixPullBack(dArray, dArray4);
        if (Math.abs(dArray3[0] - dArray2[0]) > 1.0E-6) {
            throw new Error("5");
        }
        if (Math.abs(dArray3[1] - dArray2[2]) > 1.0E-6) {
            throw new Error("7");
        }
        if (Math.abs(dArray4[0] - dArray2[1]) > 1.0E-6) {
            throw new Error("6");
        }
        if (Math.abs(dArray4[1] - dArray2[3]) > 1.0E-6) {
            throw new Error("8");
        }
    }

    public static void main(String[] stringArray) {
        double[] dArray = new double[]{5.0, 3.0, 2.0, 3.0};
        double[] dArray2 = new double[]{0.0811818, 0.145616, 0.097077, 0.178259};
        StructuredColourSampler.testMatrix(dArray, dArray2);
        System.out.println("First matrix OK");
        double[] dArray3 = new double[]{1.0, 1.0, 0.0, 1.0};
        double[] dArray4 = new double[]{0.367879, 0.367879, 0.0, 0.367879};
        StructuredColourSampler.testMatrix(dArray3, dArray4);
        System.out.println("Second matrix OK");
        double[] dArray5 = new double[]{1.0, 0.0, 1.0, 1.0};
        double[] dArray6 = new double[]{0.367879, 0.0, 0.367879, 0.367879};
        StructuredColourSampler.testMatrix(dArray5, dArray6);
        System.out.println("Third matrix OK");
    }
}

