/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.xml.Reportable;
import java.util.ArrayList;
import java.util.List;

public class CompactGradient
implements HessianWrtParameterProvider,
Reportable {
    private final GradientWrtParameterProvider source;
    private final Parameter sourceParameter;
    private final Likelihood likelihood;
    private final Parameter parameter;
    private final int[] map;
    private final int dimension;

    public CompactGradient(GradientWrtParameterProvider gradientWrtParameterProvider) {
        this.source = gradientWrtParameterProvider;
        this.sourceParameter = gradientWrtParameterProvider.getParameter();
        this.likelihood = gradientWrtParameterProvider.getLikelihood();
        ParameterMap parameterMap = this.constructParameter(this.sourceParameter);
        this.parameter = parameterMap.parameter;
        this.map = parameterMap.map;
        this.dimension = this.parameter.getDimension();
    }

    @Override
    public Likelihood getLikelihood() {
        return this.likelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override
    public int getDimension() {
        return this.dimension;
    }

    @Override
    public double[] getGradientLogDensity() {
        double[] dArray = this.source.getGradientLogDensity();
        return this.compact(dArray);
    }

    @Override
    public double[] getDiagonalHessianLogDensity() {
        if (!(this.source instanceof HessianWrtParameterProvider)) {
            throw new RuntimeException("Must use Hessian providers");
        }
        double[] dArray = ((HessianWrtParameterProvider)this.source).getDiagonalHessianLogDensity();
        return this.compact(dArray);
    }

    @Override
    public double[][] getHessianLogDensity() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public String getReport() {
        return "compactGradient." + this.sourceParameter.getParameterName() + "\n" + GradientWrtParameterProvider.getReportAndCheckForError(this, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, GradientWrtParameterProvider.TOLERANCE);
    }

    private double[] compact(double[] dArray) {
        double[] dArray2 = new double[this.dimension];
        for (int i = 0; i < this.map.length; ++i) {
            int n = this.map[i];
            dArray2[n] = dArray2[n] + dArray[i];
        }
        return dArray2;
    }

    private void map(Parameter parameter, int n, int n2, CompoundParameter compoundParameter, int[] nArray) {
        if (compoundParameter != null) {
            compoundParameter.addParameter(parameter);
        }
        for (int i = 0; i < parameter.getDimension(); ++i) {
            int n3;
            int n4 = n + i;
            nArray[n4] = n3 = n2 + i;
        }
    }

    private List<Parameter> unroll(CompoundParameter compoundParameter) {
        ArrayList<Parameter> arrayList = new ArrayList<Parameter>();
        for (int i = 0; i < compoundParameter.getParameterCount(); ++i) {
            Parameter parameter = compoundParameter.getParameter(i);
            if (parameter instanceof CompoundParameter) {
                arrayList.addAll(this.unroll((CompoundParameter)parameter));
                continue;
            }
            arrayList.add(parameter);
        }
        return arrayList;
    }

    private ParameterMap constructParameter(Parameter parameter) {
        ParameterMap parameterMap = new ParameterMap(this.sourceParameter.getDimension());
        if (parameter instanceof CompoundParameter) {
            List<Parameter> list = this.unroll((CompoundParameter)parameter);
            int n = 0;
            for (int i = 0; i < list.size(); ++i) {
                Parameter parameter2 = list.get(i);
                int n2 = 0;
                boolean bl = false;
                for (int j = 0; j < i && !bl; ++j) {
                    Parameter parameter3 = list.get(j);
                    if (parameter2 == parameter3) {
                        this.map(parameter2, n, n2, null, parameterMap.map);
                        bl = true;
                    }
                    n2 += parameter3.getDimension();
                }
                if (!bl) {
                    this.map(parameter2, n, n, parameterMap.parameter, parameterMap.map);
                }
                n += parameter2.getDimension();
            }
        } else {
            throw new IllegalArgumentException("Can only compact compound gradients");
        }
        return parameterMap;
    }

    private static class ParameterMap {
        CompoundParameter parameter = new CompoundParameter("compact");
        int[] map;

        ParameterMap(int n) {
            this.map = new int[n];
        }
    }
}

