/*
 * Decompiled with CFR 0.152.
 */
package jdplus.toolkit.base.core.bayes;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import jdplus.toolkit.base.api.data.DoubleSeq;
import jdplus.toolkit.base.api.data.DoubleSeqCursor;
import jdplus.toolkit.base.api.dstats.RandomNumberGenerator;
import jdplus.toolkit.base.api.math.matrices.Matrix;
import jdplus.toolkit.base.core.bayes.BayesRegularizedRegressionModel;
import jdplus.toolkit.base.core.data.DataBlock;
import jdplus.toolkit.base.core.data.DataBlockIterator;
import jdplus.toolkit.base.core.dstats.Exponential;
import jdplus.toolkit.base.core.dstats.Gamma;
import jdplus.toolkit.base.core.dstats.InverseGamma;
import jdplus.toolkit.base.core.dstats.InverseGaussian;
import jdplus.toolkit.base.core.dstats.Normal;
import jdplus.toolkit.base.core.dstats.SpecialFunctions;
import jdplus.toolkit.base.core.math.matrices.FastMatrix;
import jdplus.toolkit.base.core.math.matrices.LowerTriangularMatrix;
import jdplus.toolkit.base.core.math.matrices.SymmetricMatrix;
import jdplus.toolkit.base.core.random.MersenneTwister;
import jdplus.toolkit.base.core.stats.samples.Moments;
import lombok.Generated;
import org.jspecify.annotations.NonNull;
import org.jspecify.annotations.Nullable;

public class BayesRegularizedRegression {
    private final DataBlock y;
    private final FastMatrix X;
    private final BayesRegularizedRegressionModel.ModelType model;
    private final BayesRegularizedRegressionModel.Prior prior;
    private final int burnin;
    private final int nsamples;
    private final int tdf;
    private double[] xm;
    private double[] xstd;
    private int n;
    private int p;
    private double ydiff;
    private double b0;
    private DataBlock b;
    private DataBlock omega2;
    private double sigma2;
    private double muSigma2;
    private double tau2;
    private double xi;
    private DataBlock lambda2;
    private DataBlock nu;
    private DataBlock eta2;
    private DataBlock phi;
    private DataBlock kappa;
    private DataBlock z;
    private DataBlock e;
    private DataBlock waicProb;
    private DataBlock waicLProb;
    private DataBlock waicLProb2;
    private boolean mvnrue;
    private boolean precomputedXtX;
    private FastMatrix XtX;
    private DataBlock Xty;
    public static final List<Result> results = new ArrayList<Result>();
    private RandomNumberGenerator rng = MersenneTwister.fromSystemNanoTime();

    public BayesRegularizedRegression(DoubleSeq y, Matrix X, BayesRegularizedRegressionModel.ModelType model, int tdf, BayesRegularizedRegressionModel.Prior prior, int burnin, int nsamples) {
        this.y = DataBlock.of(y);
        this.X = FastMatrix.of(X);
        this.model = model;
        this.tdf = tdf;
        this.prior = prior;
        this.burnin = burnin;
        this.nsamples = nsamples;
        this.n = y.length();
        this.p = X.getColumnsCount();
        this.standardize();
        this.initialize();
        for (int k = 0; k < nsamples; ++k) {
            this.samplingIteration();
            if (k < burnin) continue;
            results.add(new Result(this.b.fn(DoubleSeq.of((double[])this.xstd), (x, q) -> x / q), this.b0, this.tau2));
        }
    }

    public List<Result> results() {
        return Collections.unmodifiableList(results);
    }

    private void standardize() {
        this.xm = new double[this.p];
        this.xstd = new double[this.p];
        int pos = 0;
        DataBlockIterator cols = this.X.columnsIterator();
        while (cols.hasNext()) {
            DataBlock col = cols.next();
            double mean = Moments.mean((DoubleSeq)col);
            double std = Math.sqrt(Moments.variance((DoubleSeq)col, mean, false) * (double)this.n);
            col.apply(a -> (a - mean) / std);
            this.xm[pos] = mean;
            this.xstd[pos++] = std;
        }
    }

    private void initialize() {
        this.ydiff = 0.0;
        this.b0 = 0.0;
        this.b = DataBlock.make(this.p);
        this.omega2 = DataBlock.make(this.n);
        this.omega2.set(1.0);
        this.sigma2 = 1.0;
        this.tau2 = 1.0;
        this.xi = 0.001;
        this.lambda2 = DataBlock.make(this.p);
        this.lambda2.set(1.0);
        this.nu = DataBlock.make(this.p);
        this.nu.set(1.0);
        this.eta2 = DataBlock.make(this.p);
        this.eta2.set(1.0);
        this.phi = DataBlock.make(this.p);
        this.phi.set(1.0);
        this.kappa = this.y.deepClone();
        this.kappa.sub(0.5);
        this.z = this.y.deepClone();
        this.waicProb = DataBlock.make(this.n);
        this.waicLProb = DataBlock.make(this.n);
        this.waicLProb2 = DataBlock.make(this.n);
        this.mvnrue = true;
        this.precomputedXtX = false;
        if (this.p >= 2 * this.n) {
            this.mvnrue = false;
        }
        this.XtX = null;
        this.Xty = null;
        if (this.model == BayesRegularizedRegressionModel.ModelType.GAUSSIAN && this.mvnrue) {
            this.XtX = SymmetricMatrix.XtX(this.X);
            this.Xty = DataBlock.make(this.p);
            this.Xty.product(this.y, this.X.columnsIterator());
            this.precomputedXtX = true;
        }
    }

    private void samplingIteration() {
        this.sampleBeta();
        this.sampleBeta0();
        this.sampleSigma2();
        this.sampleOmega2();
        this.sampleTau2();
        this.sampleLambda2();
        this.sampleDelta2();
    }

    private void sampleBeta() {
        DoubleSeq alpha = this.z.fastOp(q -> q - this.b0);
        double d = this.sigma2 * this.tau2;
        DoubleSeq Lambda = this.lambda2.fastOp(q -> q * d);
        double sigma = Math.sqrt(this.sigma2);
        if (this.mvnrue) {
            if (!this.precomputedXtX) {
                DoubleSeq omega = this.omega2.fn(q -> Math.sqrt(q) * sigma);
                FastMatrix X0 = this.X.deepClone();
                DoubleSeqCursor ocur = omega.cursor();
                DataBlockIterator cols = X0.columnsIterator();
                while (cols.hasNext()) {
                    cols.next().div(ocur.getAndNext());
                }
                this.rue_nongaussian(X0, alpha, Lambda, this.sigma2, omega);
            } else {
                this.rue_gaussian(this.X, alpha, Lambda, this.XtX, this.Xty, this.sigma2);
            }
        }
    }

    private void sampleBeta0() {
        DataBlockIterator cols = this.X.columnsIterator();
        DoubleSeqCursor.OnMutable bcur = this.b.cursor();
        this.e = this.y.deepClone();
        while (cols.hasNext()) {
            DataBlock coli = cols.next();
            double bi = bcur.getAndNext();
            this.e.addAY(-bi, coli);
        }
        double W = this.omega2.fastOp(q -> 1.0 / q).sum();
        double muB0 = this.e.fastOp((DoubleSeq)this.omega2, (q, w) -> q / w).sum() / W;
        double v = this.sigma2 / W;
        Normal N = new Normal(muB0, Math.sqrt(v));
        this.b0 = N.random(this.rng);
        this.e.sub(this.b0);
    }

    private void sampleSigma2() {
        double shape = (double)(this.n + this.p) / 2.0;
        double scale = (this.e.fastOp((DoubleSeq)this.omega2, (q, w) -> q * q / w).sum() + this.b.fastOp((DoubleSeq)this.lambda2, (q, l) -> q * q / l).sum() / this.tau2) / 2.0;
        this.sigma2 = InverseGamma.random(this.rng, shape, scale);
        this.muSigma2 = scale / (shape - 1.0);
    }

    private void sampleOmega2() {
        switch (this.model) {
            case LAPLACE: {
                this.omega2 = DataBlock.of(this.e.fastOp(q -> Math.sqrt(2.0 * this.sigma2 / (q * q))));
                this.omega2.apply(q -> 1.0 / InverseGaussian.random(this.rng, q, 0.2));
                break;
            }
            case T: {
                double shape = (double)(this.tdf + 1) * 0.5;
                this.omega2 = DataBlock.of(this.e.fastOp(q -> (q * q / this.sigma2 + (double)this.tdf) / 2.0));
                this.omega2.apply(q -> Gamma.random(this.rng, shape, 1.0 / q));
            }
        }
    }

    private void sampleTau2() {
        switch (this.prior) {
            case HORSESHOE: 
            case HORSESHOEPLUS: 
            case RIDGE: {
                double shape = (double)(this.p + 1) * 0.5;
                double scale = 1.0 / this.xi + this.b.fastOp((DoubleSeq)this.lambda2, (q, l) -> q * q / l).sum() / (2.0 * this.sigma2);
                this.tau2 = InverseGamma.random(this.rng, shape, scale);
                scale = 1.0 + 1.0 / this.tau2;
                this.xi = InverseGamma.random(this.rng, 1.0, scale);
                break;
            }
            case LASSO: {
                double shape = (double)(this.p + 1) * 0.5;
                double scale = 1.0 + this.b.fastOp((DoubleSeq)this.lambda2, (q, l) -> q * q / l).sum() / (2.0 * this.sigma2);
                this.tau2 = InverseGamma.random(this.rng, shape, scale);
            }
        }
    }

    private void sampleLambda2() {
        switch (this.prior) {
            case HORSESHOE: {
                double d = 2.0 * this.tau2 * this.sigma2;
                DoubleSeq scale = this.b.fastOp((DoubleSeq)this.nu, (q, c) -> 1.0 / c + q * q / d);
                this.lambda2.set(scale, s -> s / Exponential.random(this.rng, 1.0));
                scale = this.lambda2.fastOp(q -> 1.0 + 1.0 / q);
                this.nu.set(scale, s -> s / Exponential.random(this.rng, 1.0));
                break;
            }
            case HORSESHOEPLUS: {
                double d = 2.0 * this.tau2 * this.sigma2;
                DoubleSeq be = this.b.fastOp((DoubleSeq)this.eta2, (q, t) -> q * q / (d * t));
                DoubleSeq scale = this.nu.fastOp(be, (q, c) -> 1.0 / q + c);
                this.lambda2.set(scale, s -> s / Exponential.random(this.rng, 1.0));
                scale = this.lambda2.fastOp(q -> 1.0 + 1.0 / q);
                this.nu.set(scale, s -> s / Exponential.random(this.rng, 1.0));
                DoubleSeq bl = this.b.fastOp((DoubleSeq)this.lambda2, (q, l) -> q * q / (d * l));
                scale = this.phi.fastOp(bl, (q, l) -> 1.0 / q + l);
                this.eta2.set(scale, s -> s / Exponential.random(this.rng, 1.0));
                scale = this.eta2.fastOp(q -> 1.0 + 1.0 / q);
                this.phi.set(scale, s -> s / Exponential.random(this.rng, 1.0));
                this.lambda2.op((DoubleSeq)this.eta2, (l, t) -> l * t);
                break;
            }
            case LASSO: {
                double d = 2.0 * this.sigma2 * this.tau2;
                DoubleSeq mu = this.b.fastOp(q -> Math.sqrt(d / (q * q)));
                this.lambda2.set(mu, q -> 1.0 / InverseGaussian.random(this.rng, q, 2.0));
            }
        }
    }

    private void sampleDelta2() {
    }

    private void errLinearRegression() {
        DataBlockIterator cols = this.X.columnsIterator();
        DoubleSeqCursor.OnMutable bcur = this.b.cursor();
        this.e = this.y.deepClone();
        this.e.sub(this.b0);
        while (cols.hasNext()) {
            DataBlock coli = cols.next();
            double bi = bcur.getAndNext();
            this.e.addAY(-bi, coli);
        }
    }

    private double nllLinearRegression() {
        this.errLinearRegression();
        switch (this.model) {
            case GAUSSIAN: {
                return this.glr();
            }
            case LAPLACE: {
                return this.llr();
            }
            case T: {
                return this.tlr();
            }
        }
        return Double.NaN;
    }

    private double glr() {
        return (double)(this.n / 2) * Math.log(Math.PI * 2 * this.sigma2) + 1.0 / (2.0 * this.sigma2) * this.e.ssq();
    }

    private double llr() {
        double scale = Math.sqrt(this.sigma2 / 2.0);
        return (double)this.n * Math.log(2.0 * scale) + this.e.norm1() / scale;
    }

    private double tlr() {
        DoubleSeq te = this.e.fastOp(z -> Math.log(1.0 + z * z / ((double)this.tdf * this.sigma2)));
        double c = te.sum();
        return (double)this.n * (-SpecialFunctions.logGamma(((double)this.tdf + 1.0) / 2.0) + SpecialFunctions.logGamma((double)this.tdf / 2.0) + Math.log(Math.PI * (double)this.tdf * this.sigma2) / 2.0) + (double)((this.tdf + 1) / 2) * c;
    }

    private void rue_nongaussian(FastMatrix X0, DoubleSeq alpha, DoubleSeq Lambda, double sigma2, DoubleSeq omega) {
        FastMatrix S = SymmetricMatrix.XtX(X0);
        S.diagonal().add(Lambda.fastOp(q -> 1.0 / q));
        DataBlock y = DataBlock.of(alpha.fastOp(omega, (q, r) -> q / r));
        SymmetricMatrix.solve(S, y, false);
        DataBlock w = DataBlock.of(this.p, i -> Normal.random(this.rng, 0.0, 1.0));
        LowerTriangularMatrix.solvexL(S, w);
        this.b.copy(y);
        this.b.add(w);
    }

    private void rue_gaussian(FastMatrix X, DoubleSeq alpha, DoubleSeq Lambda, FastMatrix XtX, DataBlock Xty, double sigma2) {
        FastMatrix S = XtX;
        if (XtX == null) {
            S = SymmetricMatrix.XtX(X);
        }
        S = S.dividedBy(sigma2);
        S.diagonal().add(Lambda.fastOp(q -> 1.0 / q));
        DataBlock y = Xty.deepClone();
        y.div(sigma2);
        SymmetricMatrix.solve(S, y, false);
        DataBlock w = DataBlock.of(this.p, i -> Normal.random(this.rng, 0.0, 1.0));
        LowerTriangularMatrix.solvexL(S, w);
        this.b.copy(y);
        this.b.add(w);
    }

    public static final class Result {
        private final DoubleSeq b;
        private final double b0;
        private final double tau2;

        @Generated
        public Result(DoubleSeq b, double b0, double tau2) {
            this.b = b;
            this.b0 = b0;
            this.tau2 = tau2;
        }

        @Generated
        public DoubleSeq getB() {
            return this.b;
        }

        @Generated
        public double getB0() {
            return this.b0;
        }

        @Generated
        public double getTau2() {
            return this.tau2;
        }

        @Generated
        public boolean equals(@Nullable Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Result)) {
                return false;
            }
            Result other = (Result)o;
            if (Double.compare(this.getB0(), other.getB0()) != 0) {
                return false;
            }
            if (Double.compare(this.getTau2(), other.getTau2()) != 0) {
                return false;
            }
            DoubleSeq this$b = this.getB();
            DoubleSeq other$b = other.getB();
            return !(this$b == null ? other$b != null : !this$b.equals(other$b));
        }

        @Generated
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $b0 = Double.doubleToLongBits(this.getB0());
            result = result * 59 + (int)($b0 >>> 32 ^ $b0);
            long $tau2 = Double.doubleToLongBits(this.getTau2());
            result = result * 59 + (int)($tau2 >>> 32 ^ $tau2);
            DoubleSeq $b = this.getB();
            result = result * 59 + ($b == null ? 43 : $b.hashCode());
            return result;
        }

        @Generated
        public @NonNull String toString() {
            return "BayesRegularizedRegression.Result(b=" + String.valueOf(this.getB()) + ", b0=" + this.getB0() + ", tau2=" + this.getTau2() + ")";
        }
    }
}

