/*
 * Decompiled with CFR 0.152.
 */
package jdplus.toolkit.base.core.math.functions.ssq;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import jdplus.toolkit.base.api.data.DoubleSeq;
import jdplus.toolkit.base.core.data.DataBlock;
import jdplus.toolkit.base.core.math.functions.IFunction;
import jdplus.toolkit.base.core.math.functions.ssq.ISsqFunction;
import jdplus.toolkit.base.core.math.functions.ssq.ISsqFunctionDerivatives;
import jdplus.toolkit.base.core.math.functions.ssq.ISsqFunctionPoint;
import jdplus.toolkit.base.core.math.matrices.FastMatrix;

public class SsqNumericalDerivatives
implements ISsqFunctionDerivatives {
    private static final int NTHREADS = Runtime.getRuntime().availableProcessors();
    private DoubleSeq[] m_ep;
    private DoubleSeq[] m_em;
    private DoubleSeq[] m_de;
    private double[] m_epsp;
    private double[] m_epsm;
    private double[] m_grad;
    private FastMatrix m_h;
    private final ISsqFunction fn;
    private DoubleSeq m_pt;
    private DoubleSeq m_ecur;
    private final boolean m_sym;
    private final boolean m_mt;
    private static int g_nsteps = 2;

    public SsqNumericalDerivatives(ISsqFunctionPoint point) {
        this(point, false, false);
    }

    public SsqNumericalDerivatives(ISsqFunctionPoint point, boolean sym, boolean mt) {
        this.m_sym = sym;
        this.m_mt = mt;
        this.fn = point.getSsqFunction();
        this.m_ecur = point.getE();
        this.m_pt = point.getParameters();
    }

    public SsqNumericalDerivatives(ISsqFunctionPoint point, boolean sym) {
        this(point, sym, false);
    }

    private void calcgrad() {
        int n = this.m_pt.length();
        this.m_grad = new double[n];
        this.m_epsp = new double[n];
        this.m_ep = new DoubleSeq[n];
        if (this.m_sym) {
            this.m_epsm = new double[n];
            this.m_em = new DoubleSeq[n];
        }
        this.m_de = new DoubleSeq[n];
        if (!this.m_mt || n < 2) {
            for (int i = 0; i < n; ++i) {
                this.m_epsp[i] = this.fn.getDomain().epsilon(this.m_pt, i);
                this.checkepsilon(i);
                if (this.m_sym) {
                    this.checkmepsilon(i);
                }
                this.m_ep[i] = this.err(i, this.m_epsp[i]);
                if (!this.m_sym) continue;
                this.m_em[i] = this.err(i, this.m_epsm[i]);
            }
        } else {
            for (int i = 0; i < n; ++i) {
                this.m_epsp[i] = this.fn.getDomain().epsilon(this.m_pt, i);
                this.checkepsilon(i);
                if (!this.m_sym) continue;
                this.checkmepsilon(i);
            }
            List<Callable<Void>> tasks = this.createTasks(n, this.m_sym);
            ExecutorService executorService = Executors.newFixedThreadPool(NTHREADS);
            try {
                executorService.invokeAll(tasks);
                executorService.shutdown();
            }
            catch (InterruptedException ex) {
                Thread.currentThread().interrupt();
            }
        }
        int ne = this.m_ecur.length();
        for (int i = 0; i < n; ++i) {
            DoubleSeq ep = this.m_ep[i];
            DataBlock de = DataBlock.make(ne);
            if (this.m_sym) {
                DoubleSeq em = this.m_em[i];
                double eps = this.m_epsp[i] - this.m_epsm[i];
                de.set(ep, em, (x, y) -> (x - y) / eps);
            } else {
                double eps = this.m_epsp[i];
                de.set(ep, this.m_ecur, (x, y) -> (x - y) / eps);
            }
            this.m_grad[i] = 2.0 * this.m_ecur.dot((DoubleSeq)de);
            this.m_de[i] = de;
        }
    }

    private void calch() {
        int i;
        if (this.m_grad == null) {
            this.calcgrad();
        }
        int n = this.m_grad.length;
        this.m_h = FastMatrix.square(n);
        for (i = 0; i < n; ++i) {
            DoubleSeq de = this.m_de[i];
            this.m_h.set(i, i, 2.0 * de.ssq());
        }
        for (i = 0; i < n; ++i) {
            for (int j = 0; j < i; ++j) {
                DoubleSeq dei = this.m_de[i];
                DoubleSeq dej = this.m_de[j];
                double z = 2.0 * dei.dot(dej);
                this.m_h.set(i, j, z);
                this.m_h.set(j, i, z);
            }
        }
    }

    private void checkepsilon(int i) {
        double eps = this.m_epsp[i];
        if (eps == 0.0) {
            return;
        }
        DataBlock pcur = DataBlock.of(this.m_pt);
        double pi = pcur.get(i);
        pcur.add(i, eps);
        if (this.fn.getDomain().checkBoundaries((DoubleSeq)pcur)) {
            return;
        }
        int k = 0;
        do {
            pcur.set(i, pi + (eps /= 2.0));
        } while (++k <= g_nsteps && !this.fn.getDomain().checkBoundaries((DoubleSeq)pcur));
        if (k <= g_nsteps) {
            this.m_epsp[i] = eps;
            return;
        }
        eps = -this.m_epsp[i];
        pcur.set(i, pi + eps);
        if (this.fn.getDomain().checkBoundaries((DoubleSeq)pcur)) {
            this.m_epsp[i] = eps;
            return;
        }
        k = 0;
        do {
            pcur.set(i, pi + (eps /= 2.0));
        } while (++k <= g_nsteps && !this.fn.getDomain().checkBoundaries((DoubleSeq)pcur));
        if (k <= g_nsteps) {
            this.m_epsp[i] = eps;
            return;
        }
        this.m_epsp[i] = 0.0;
    }

    private void checkmepsilon(int i) {
        double eps = -this.m_epsp[i];
        DataBlock pcur = DataBlock.of(this.m_pt);
        double pi = pcur.get(i);
        pcur.set(i, pi + eps);
        if (this.fn.getDomain().checkBoundaries((DoubleSeq)pcur)) {
            this.m_epsm[i] = eps;
        }
    }

    @Override
    public IFunction getFunction() {
        return this.fn.asFunction();
    }

    @Override
    public DoubleSeq dEdX(int idx) {
        if (this.m_de == null) {
            this.calcgrad();
        }
        return this.m_de[idx];
    }

    private DoubleSeq err(int i, double dx) {
        try {
            if (dx == 0.0) {
                return this.m_ecur;
            }
            DataBlock pcur = DataBlock.of(this.m_pt);
            pcur.add(i, dx);
            ISsqFunctionPoint fn = this.fn.ssqEvaluate((DoubleSeq)pcur);
            return fn.getE();
        }
        catch (Exception err) {
            return this.m_ecur;
        }
    }

    @Override
    public DoubleSeq gradient() {
        if (this.m_grad == null) {
            this.calcgrad();
        }
        return DataBlock.of(this.m_grad);
    }

    @Override
    public void jacobian(FastMatrix m) {
        if (this.m_de == null) {
            this.calcgrad();
        }
        for (int i = 0; i < this.m_de.length; ++i) {
            m.column(i).copy(this.m_de[i]);
        }
    }

    @Override
    public void hessian(FastMatrix h) {
        if (this.m_h == null) {
            this.calch();
        }
        h.copy(this.m_h);
    }

    private List<Callable<Void>> createTasks(int n, boolean sym) {
        int i;
        ArrayList<Callable<Void>> result = new ArrayList<Callable<Void>>();
        for (i = 0; i < n; ++i) {
            result.add(new Err(this.m_ep, i, this.m_epsp[i]));
        }
        if (sym) {
            for (i = 0; i < n; ++i) {
                result.add(new Err(this.m_em, i, this.m_epsm[i]));
            }
        }
        return result;
    }

    private class Err
    implements Callable<Void> {
        DoubleSeq[] rslt;
        int pos;
        double eps;

        private Err(DoubleSeq[] rslt, int pos, double eps) {
            this.rslt = rslt;
            this.pos = pos;
            this.eps = eps;
        }

        @Override
        public Void call() throws Exception {
            try {
                DataBlock cur = DataBlock.of(SsqNumericalDerivatives.this.m_pt);
                cur.add(this.pos, this.eps);
                ISsqFunctionPoint fn = SsqNumericalDerivatives.this.fn.ssqEvaluate((DoubleSeq)cur);
                this.rslt[this.pos] = fn.getE();
            }
            catch (Exception err) {
                this.rslt[this.pos] = SsqNumericalDerivatives.this.m_ecur;
            }
            return null;
        }
    }
}

