/*
 * Decompiled with CFR 0.152.
 */
package dr.oldevomodel.MSSD;

import dr.evolution.alignment.AscertainedSitePatterns;
import dr.evolution.alignment.PatternList;
import dr.evolution.datatype.MutationDeathType;
import dr.evolution.tree.NodeRef;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.DefaultBranchRateModel;
import dr.evomodel.tree.TreeChangedEvent;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.treelikelihood.LikelihoodPartialsProvider;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.GammaFunction;
import dr.oldevomodel.sitemodel.SiteRateModel;
import dr.oldevomodel.treelikelihood.LikelihoodCore;
import dr.oldevomodel.treelikelihood.ScaleFactorsHelper;

public abstract class AbstractObservationProcess
extends AbstractModel {
    protected boolean[] nodePatternInclusion;
    protected boolean[] storedNodePatternInclusion;
    protected double[] cumLike;
    protected double[] nodePartials;
    protected double[] nodeLikelihoods;
    protected int nodeCount;
    protected int patternCount;
    protected int stateCount;
    protected TreeModel treeModel;
    protected PatternList patterns;
    protected double[] patternWeights;
    protected Parameter mu;
    protected Parameter lam;
    protected boolean weightKnown;
    protected double logTreeWeight;
    protected double storedLogTreeWeight;
    private double gammaNorm;
    private double totalPatterns;
    protected MutationDeathType dataType;
    protected int deathState;
    protected SiteRateModel siteModel;
    private double logN;
    protected boolean nodePatternInclusionKnown = false;
    BranchRateModel branchRateModel;
    private boolean integrateGainRate = false;
    private double storedAverageRate;
    private double averageRate;
    private boolean averageRateKnown = false;

    public AbstractObservationProcess(String string, TreeModel treeModel, PatternList patternList, SiteRateModel siteRateModel, BranchRateModel branchRateModel, Parameter parameter, Parameter parameter2) {
        super(string);
        this.treeModel = treeModel;
        this.patterns = patternList;
        this.mu = parameter;
        this.lam = parameter2;
        this.siteModel = siteRateModel;
        this.branchRateModel = branchRateModel != null ? branchRateModel : new DefaultBranchRateModel();
        this.addModel(treeModel);
        this.addModel(siteRateModel);
        this.addModel(this.branchRateModel);
        this.addVariable(parameter);
        this.addVariable(parameter2);
        this.nodeCount = treeModel.getNodeCount();
        this.stateCount = patternList.getDataType().getStateCount();
        this.patterns = patternList;
        this.patternCount = patternList.getPatternCount();
        this.patternWeights = patternList.getPatternWeights();
        this.totalPatterns = 0.0;
        for (int i = 0; i < this.patternCount; ++i) {
            this.totalPatterns += this.patternWeights[i];
        }
        this.logN = Math.log(this.totalPatterns);
        this.gammaNorm = -GammaFunction.lnGamma(this.totalPatterns + 1.0);
        this.dataType = (MutationDeathType)patternList.getDataType();
        this.deathState = this.dataType.DEATHSTATE;
        this.setNodePatternInclusion();
        this.cumLike = new double[this.patternCount];
        this.nodeLikelihoods = new double[this.patternCount];
        this.weightKnown = false;
    }

    private double calculateSiteLogLikelihood(int n, double[] dArray, double[] dArray2) {
        int n2 = n * this.stateCount;
        double d = 0.0;
        for (int i = 0; i < this.stateCount; ++i) {
            d += dArray2[i] * dArray[n2 + i];
        }
        return Math.log(d);
    }

    private void calculateNodePatternLikelihood(int n, double[] dArray, LikelihoodCore likelihoodCore, double d, double[] dArray2) {
        likelihoodCore.getPartials(n, this.nodePartials);
        double d2 = Math.log(this.getNodeSurvivalProbability(n, d));
        for (int i = 0; i < this.patternCount; ++i) {
            if (!this.nodePatternInclusion[n * this.patternCount + i]) continue;
            int n2 = i;
            dArray2[n2] = dArray2[n2] + Math.exp(this.calculateSiteLogLikelihood(i, this.nodePartials, dArray) + d2);
        }
    }

    private double accumulateCorrectedLikelihoods(double[] dArray, double d, double[] dArray2) {
        double d2 = 0.0;
        for (int i = 0; i < this.patternCount; ++i) {
            d2 += Math.log(dArray[i] / d) * this.patternWeights[i];
        }
        return d2;
    }

    public final double nodePatternLikelihood(double[] dArray, LikelihoodPartialsProvider likelihoodPartialsProvider, ScaleFactorsHelper scaleFactorsHelper) {
        int n;
        double d = this.gammaNorm;
        double d2 = this.lam.getParameterValue(0);
        if (!this.nodePatternInclusionKnown) {
            this.setNodePatternInclusion();
        }
        if (this.nodePartials == null) {
            this.nodePartials = new double[this.patternCount * this.stateCount];
        }
        double d3 = this.getAverageRate();
        for (n = 0; n < this.patternCount; ++n) {
            this.cumLike[n] = 0.0;
        }
        for (int i = 0; i < this.nodeCount; ++i) {
            likelihoodPartialsProvider.getPartials(i, this.nodePartials);
            scaleFactorsHelper.rescalePartials(i, this.nodePartials);
            double d4 = Math.log(this.getNodeSurvivalProbability(i, d3));
            for (n = 0; n < this.patternCount; ++n) {
                if (!this.nodePatternInclusion[i * this.patternCount + n]) continue;
                int n2 = n;
                this.cumLike[n2] = this.cumLike[n2] + Math.exp(this.calculateSiteLogLikelihood(n, this.nodePartials, dArray) + d4);
            }
        }
        double d5 = this.getAscertainmentCorrection(this.cumLike);
        for (n = 0; n < this.patternCount; ++n) {
            d += Math.log(this.cumLike[n] / d5) * this.patternWeights[n];
        }
        double d6 = this.mu.getParameterValue(0);
        double d7 = this.getLogTreeWeight();
        d = this.integrateGainRate ? (d -= this.gammaNorm + this.logN + Math.log(-d7 * d6 / d2) * this.totalPatterns) : (d += d7 + Math.log(d2 / d6) * this.totalPatterns);
        return d;
    }

    protected double getAscertainmentCorrection(double[] dArray) {
        double d = 0.0;
        double d2 = 0.0;
        double d3 = 1.0;
        if (this.patterns instanceof AscertainedSitePatterns) {
            int n;
            int n2;
            int[] nArray = ((AscertainedSitePatterns)this.patterns).getIncludePatternIndices();
            int[] nArray2 = ((AscertainedSitePatterns)this.patterns).getExcludePatternIndices();
            for (n2 = 0; n2 < ((AscertainedSitePatterns)this.patterns).getIncludePatternCount(); ++n2) {
                n = nArray[n2];
                d2 += dArray[n];
            }
            for (n2 = 0; n2 < ((AscertainedSitePatterns)this.patterns).getExcludePatternCount(); ++n2) {
                n = nArray2[n2];
                d += dArray[n];
            }
            d3 = d2 == 0.0 ? (d3 -= d) : (d == 0.0 ? d2 : d2 - d);
        }
        return d3;
    }

    public final double getLogTreeWeight() {
        if (!this.weightKnown) {
            this.logTreeWeight = this.calculateLogTreeWeight();
            this.weightKnown = true;
        }
        return this.logTreeWeight;
    }

    public abstract double calculateLogTreeWeight();

    abstract void setNodePatternInclusion();

    public final double getAverageRate() {
        if (!this.averageRateKnown) {
            double d = 0.0;
            double[] dArray = this.siteModel.getCategoryProportions();
            for (int i = 0; i < this.siteModel.getCategoryCount(); ++i) {
                d += dArray[i] * this.siteModel.getRateForCategory(i);
            }
            this.averageRate = d;
            this.averageRateKnown = true;
        }
        return this.averageRate;
    }

    public double getNodeSurvivalProbability(int n, double d) {
        NodeRef nodeRef = this.treeModel.getNode(n);
        NodeRef nodeRef2 = this.treeModel.getParent(nodeRef);
        if (nodeRef2 == null) {
            return 1.0;
        }
        double d2 = this.mu.getParameterValue(0) * d;
        double d3 = this.branchRateModel.getBranchRate(this.treeModel, nodeRef);
        double d4 = d3 * this.treeModel.getBranchLength(nodeRef);
        return 1.0 - Math.exp(-d2 * d4);
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        if (model == this.siteModel) {
            this.averageRateKnown = false;
        }
        if (model == this.treeModel || model == this.siteModel || model == this.branchRateModel) {
            this.weightKnown = false;
        }
        if (model == this.treeModel && object instanceof TreeChangedEvent && ((TreeChangedEvent)object).isTreeChanged()) {
            this.nodePatternInclusionKnown = false;
        }
    }

    @Override
    protected final void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        if (variable == this.mu || variable == this.lam) {
            this.weightKnown = false;
        } else {
            System.err.println("AbstractObservationProcess: Got unexpected parameter changed event. (Parameter = " + variable + ")");
        }
    }

    @Override
    protected void storeState() {
        this.storedLogTreeWeight = this.logTreeWeight;
        System.arraycopy(this.nodePatternInclusion, 0, this.storedNodePatternInclusion, 0, this.storedNodePatternInclusion.length);
    }

    @Override
    protected void restoreState() {
        this.averageRateKnown = false;
        this.logTreeWeight = this.storedLogTreeWeight;
        boolean[] blArray = this.storedNodePatternInclusion;
        this.storedNodePatternInclusion = this.nodePatternInclusion;
        this.nodePatternInclusion = blArray;
    }

    @Override
    protected void acceptState() {
    }

    public void setIntegrateGainRate(boolean bl) {
        this.integrateGainRate = bl;
    }
}

