/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.discrete;

import dr.evolution.alignment.PatternList;
import dr.evolution.datatype.DataType;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.TaxonList;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.substmodel.SubstitutionModel;
import dr.evomodel.treelikelihood.AncestralStateBeagleTreeLikelihood;
import dr.inference.model.Statistic;
import dr.math.UnivariateFunction;
import dr.math.UnivariateMinimum;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;
import java.util.Set;

public class SequenceDistanceStatistic
extends Statistic.Abstract
implements Reportable {
    private AncestralStateBeagleTreeLikelihood asrLikelihood;
    private BranchRateModel branchRates;
    private PatternList patternList;
    private SubstitutionModel substitutionModel;
    private final Set<String> leafSet;
    private final Tree tree;
    private final DistanceType type;
    private final DataType dataType;

    public SequenceDistanceStatistic(AncestralStateBeagleTreeLikelihood ancestralStateBeagleTreeLikelihood, SubstitutionModel substitutionModel, BranchRateModel branchRateModel, PatternList patternList, TaxonList taxonList, DistanceType distanceType) throws TreeUtils.MissingTaxonException {
        this.asrLikelihood = ancestralStateBeagleTreeLikelihood;
        this.substitutionModel = substitutionModel;
        this.branchRates = branchRateModel;
        this.patternList = patternList;
        this.dataType = this.patternList.getDataType();
        this.type = distanceType;
        this.tree = this.asrLikelihood.getTreeModel();
        this.leafSet = taxonList != null ? TreeUtils.getLeavesForTaxa(this.tree, taxonList) : null;
    }

    @Override
    public int getDimension() {
        return this.patternList.getTaxonCount();
    }

    @Override
    public String getDimensionName(int n) {
        return this.type.getLabel() + "(" + this.patternList.getTaxonId(n) + ")";
    }

    @Override
    public String getStatisticName() {
        return "name";
    }

    @Override
    public double getStatisticValue(int n) {
        UnivariateMinimum univariateMinimum = this.optimizeBranchLength(n);
        return this.type.extractResultForType(univariateMinimum, this.branchRates.getBranchRate(this.tree, this.getNode()));
    }

    @Override
    public String getReport() {
        StringBuilder stringBuilder = new StringBuilder("sequenceDistanceStatistic Report\n\n");
        stringBuilder.append("dimension names: ");
        int n = this.getDimension();
        double[] dArray = new double[n];
        for (int i = 0; i < n; ++i) {
            stringBuilder.append(this.getDimensionName(i));
            if (i != n - 1) {
                stringBuilder.append(" ");
            }
            dArray[i] = this.getStatisticValue(i);
        }
        stringBuilder.append("\n\n");
        stringBuilder.append("values: ");
        stringBuilder.append(new Vector(dArray));
        stringBuilder.append("\n\n");
        return stringBuilder.toString();
    }

    private double computeLogLikelihood(double d, int[] nArray, int[] nArray2) {
        int n = this.dataType.getStateCount();
        double[] dArray = new double[n * n];
        this.substitutionModel.getTransitionProbabilities(d, dArray);
        double[] dArray2 = new double[n * n];
        for (int i = 0; i < n * n; ++i) {
            dArray2[i] = Math.log(dArray[i]);
        }
        double[] dArray3 = this.substitutionModel.getFrequencyModel().getFrequencies();
        double d2 = 0.0;
        for (int i = 0; i < nArray.length; ++i) {
            double d3 = 0.0;
            int n2 = nArray[i];
            int n3 = nArray2[i];
            if (n2 < n) {
                d2 += dArray2[n2 * n + n3];
                continue;
            }
            for (int j = 0; j < n; ++j) {
                d3 += dArray[j * n + n3] * dArray3[j];
            }
            d2 += Math.log(d3);
        }
        return d2;
    }

    private NodeRef getNode() {
        return this.leafSet != null ? TreeUtils.getCommonAncestorNode(this.tree, this.leafSet) : this.tree.getRoot();
    }

    private UnivariateMinimum optimizeBranchLength(int n) {
        NodeRef nodeRef = this.getNode();
        final int[] nArray = this.asrLikelihood.getStatesForNode(this.tree, nodeRef);
        final int[] nArray2 = new int[nArray.length];
        for (int i = 0; i < nArray.length; ++i) {
            nArray2[i] = this.patternList.getPatternState(n, i);
        }
        UnivariateFunction univariateFunction = new UnivariateFunction(){

            @Override
            public double evaluate(double d) {
                double d2 = SequenceDistanceStatistic.this.computeLogLikelihood(d, nArray2, nArray);
                return -d2;
            }

            @Override
            public double getLowerBound() {
                return 0.0;
            }

            @Override
            public double getUpperBound() {
                return 10.0;
            }
        };
        UnivariateMinimum univariateMinimum = new UnivariateMinimum();
        univariateMinimum.findMinimum(univariateFunction);
        return univariateMinimum;
    }

    public static enum DistanceType {
        MAXIMIZED_DISTANCE("distance", "distanceFrom"){

            @Override
            public double extractResultForType(UnivariateMinimum univariateMinimum, double d) {
                return univariateMinimum.minx / d;
            }
        }
        ,
        LOG_LIKELIHOOD("likelihood", "lnL"){

            @Override
            public double extractResultForType(UnivariateMinimum univariateMinimum, double d) {
                return -univariateMinimum.fminx;
            }
        };

        private String name;
        private String label;

        private DistanceType(String string2, String string3) {
            this.name = string2;
            this.label = string3;
        }

        public String getName() {
            return this.name;
        }

        public String getLabel() {
            return this.label;
        }

        public abstract double extractResultForType(UnivariateMinimum var1, double var2);
    }
}

