/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.stuff;

import dr.evomodel.stuff.HDPPolyaUrn;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Parameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.MathUtils;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.math.BigInteger;
import java.util.ArrayList;

public class HDPDrawsFromCommonBaseOperator
extends SimpleMCMCOperator
implements GibbsOperator {
    private static final String HDP_DRAWS_OPERATOR = "hdpDrawsFromCommonBaseOperator";
    private HDPPolyaUrn hdp;
    private CompoundParameter tableCounts;
    private CompoundParameter stickProportions;
    private Parameter groupAssignments;
    private Parameter categoriesParameter;
    public final ArrayList<ArrayList<BigInteger>> cachedStirlingNumbers;
    private double pathWeight = 1.0;
    public static final double ACCURACY_THRESHOLD = 1.0E-12;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{AttributeRule.newDoubleRule("weight"), new ElementRule(HDPPolyaUrn.class, false)};

        @Override
        public String getParserName() {
            return HDPDrawsFromCommonBaseOperator.HDP_DRAWS_OPERATOR;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            double d = xMLObject.getDoubleAttribute("weight");
            HDPPolyaUrn hDPPolyaUrn = (HDPPolyaUrn)xMLObject.getChild(HDPPolyaUrn.class);
            return new HDPDrawsFromCommonBaseOperator(hDPPolyaUrn, d);
        }

        @Override
        public String getParserDescription() {
            return "This element returns a Gibbs operator for sampling counts of draws from the common base distributionfor HDPs";
        }

        @Override
        public Class getReturnType() {
            return MCMCOperator.class;
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };

    public HDPDrawsFromCommonBaseOperator(HDPPolyaUrn hDPPolyaUrn, double d) {
        this.hdp = hDPPolyaUrn;
        this.tableCounts = hDPPolyaUrn.getTableCounts();
        this.stickProportions = hDPPolyaUrn.getStickProportions();
        this.categoriesParameter = hDPPolyaUrn.getCategoriesParameter();
        this.groupAssignments = hDPPolyaUrn.getGroupAssignments();
        this.cachedStirlingNumbers = new ArrayList();
        this.cachedStirlingNumbers.add(new ArrayList());
        this.cachedStirlingNumbers.get(0).add(0, BigInteger.valueOf(1L));
        this.cachedStirlingNumbers.get(0).add(1, BigInteger.valueOf(0L));
        this.cachedStirlingNumbers.add(new ArrayList());
        this.cachedStirlingNumbers.get(1).add(0, BigInteger.valueOf(0L));
        this.cachedStirlingNumbers.get(1).add(1, BigInteger.valueOf(1L));
        this.setWeight(d);
    }

    @Override
    public double doOperation() {
        int n;
        int n2;
        int[][] nArray = new int[this.hdp.maxGroupCount][this.hdp.maxCategoryCount];
        double d = this.hdp.getCommonMass().getParameterValue(0);
        for (int i = 0; i < this.categoriesParameter.getSize(); ++i) {
            int[] nArray2 = nArray[(int)this.groupAssignments.getParameterValue(i)];
            int n3 = (int)this.categoriesParameter.getParameterValue(i);
            nArray2[n3] = nArray2[n3] + 1;
        }
        int[] nArray3 = new int[this.hdp.maxCategoryCount];
        double[] dArray = new double[this.hdp.maxCategoryCount + 1];
        dArray[dArray.length - 1] = 1.0;
        double d2 = 0.0;
        for (n2 = 0; n2 < this.hdp.maxCategoryCount; ++n2) {
            for (n = 0; n < this.hdp.maxGroupCount; ++n) {
                if (nArray[n][n2] <= 0) continue;
                nArray3[n2] = 1;
                break;
            }
            if (nArray3[n2] == 1) {
                dArray[n2] = this.stickProportions.getParameter(n2).getParameterValue(0) * dArray[dArray.length - 1];
                dArray[dArray.length - 1] = (1.0 - this.stickProportions.getParameter(n2).getParameterValue(0)) * dArray[dArray.length - 1];
            }
            d2 += dArray[n2];
        }
        if (Math.abs((d2 += dArray[dArray.length - 1]) - 1.0) > 1.0E-12) {
            throw new RuntimeException("common base dist weights must sum to 1");
        }
        for (n2 = 0; n2 < this.hdp.maxGroupCount; ++n2) {
            for (n = 0; n < this.hdp.maxCategoryCount; ++n) {
                int n4;
                if (nArray[n2][n] <= 0) continue;
                if (dArray[n] <= 0.0 || dArray[n] > 1.0) {
                    throw new RuntimeException("commonBaseDistWeight has inappropriate value");
                }
                double[] dArray2 = new double[nArray[n2][n]];
                double[] dArray3 = new double[nArray[n2][n]];
                for (int i = 0; i < dArray2.length; ++i) {
                    dArray3[i] = (double)(i + 1) * Math.log(d * dArray[n]);
                    for (int j = 0; j < nArray[n2][n]; ++j) {
                        dArray3[i] = dArray3[i] - this.pathWeight * Math.log(d * dArray[n] + (double)j);
                    }
                    double d3 = HDPDrawsFromCommonBaseOperator.logOfBigInteger(this.getStirlingNumber(nArray[n2][n], i + 1));
                    dArray3[i] = dArray3[i] + this.pathWeight * d3;
                    if (!Double.isNaN(dArray3[i]) && dArray3[i] != Double.POSITIVE_INFINITY) continue;
                    System.out.println("logTableCountProbs[" + i + "]: " + dArray3[i]);
                    System.out.println("logsn: " + d3);
                    System.out.println("counts[g][c]: " + nArray[n2][n]);
                    System.out.println("m + 1: " + (i + 1));
                    System.out.println("getStirlingNumber(counts[g][c], m + 1): " + this.getStirlingNumber(nArray[n2][n], i + 1));
                }
                double d4 = 0.0;
                for (n4 = 0; n4 < dArray2.length; ++n4) {
                    dArray2[n4] = Math.exp(dArray3[n4]);
                }
                n4 = MathUtils.randomChoicePDF(dArray2) + 1;
                if (n4 == 0) {
                    throw new RuntimeException("New table count is 0");
                }
                this.tableCounts.getParameter(n2).setParameterValue(n, n4);
            }
        }
        return 0.0;
    }

    public BigInteger getStirlingNumber(int n, int n2) {
        int n3;
        if (this.cachedStirlingNumbers.size() < n + 1) {
            for (n3 = this.cachedStirlingNumbers.size(); n3 <= n; ++n3) {
                this.cachedStirlingNumbers.add(new ArrayList());
                this.cachedStirlingNumbers.get(n3).add(0, BigInteger.valueOf(0L));
                for (int i = 1; i <= n2; ++i) {
                    if (i > n3) {
                        this.cachedStirlingNumbers.get(n3).add(i, BigInteger.valueOf(0L));
                        continue;
                    }
                    this.cachedStirlingNumbers.get(n3).add(i, this.cachedStirlingNumbers.get(n3 - 1).get(i - 1).add(BigInteger.valueOf(n3 - 1).multiply(this.cachedStirlingNumbers.get(n3 - 1).get(i))));
                }
            }
        }
        if (this.cachedStirlingNumbers.get(n).size() < n2 + 1) {
            for (n3 = this.cachedStirlingNumbers.get(n).size(); n3 <= n2; ++n3) {
                if (n3 > n) {
                    this.cachedStirlingNumbers.get(n).add(n3, BigInteger.valueOf(0L));
                    continue;
                }
                this.cachedStirlingNumbers.get(n).add(n3, this.getStirlingNumber(n - 1, n3 - 1).add(BigInteger.valueOf(n - 1).multiply(this.getStirlingNumber(n - 1, n3))));
            }
        }
        return this.cachedStirlingNumbers.get(n).get(n2);
    }

    public static double logOfBigInteger(BigInteger bigInteger) {
        int n = bigInteger.bitLength() - 1022;
        if (n > 0) {
            bigInteger = bigInteger.shiftRight(n);
        }
        double d = Math.log(bigInteger.doubleValue());
        return n > 0 ? d + (double)n * Math.log(2.0) : d;
    }

    @Override
    public void setPathParameter(double d) {
        if (d < 0.0 || d > 1.0) {
            throw new IllegalArgumentException("Illegal path weight of " + d);
        }
        this.pathWeight = d;
    }

    @Override
    public String getOperatorName() {
        return HDP_DRAWS_OPERATOR;
    }
}

