/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.rules.multilabel.functions;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.MultiLabelInstance;
import com.yahoo.labs.samoa.instances.MultiLabelPrediction;
import com.yahoo.labs.samoa.instances.Prediction;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Random;
import moa.classifiers.AbstractMultiLabelLearner;
import moa.classifiers.MultiTargetRegressor;
import moa.classifiers.rules.core.Utils;
import moa.classifiers.rules.multilabel.functions.AMRulesFunction;
import moa.core.Measurement;

public class StackedPredictor
extends AbstractMultiLabelLearner
implements MultiTargetRegressor,
AMRulesFunction {
    private static final long serialVersionUID = 1L;
    private final double SD_THRESHOLD = 1.0E-7;
    public FlagOption constantLearningRatioDecayOption = new FlagOption("learningRatio_Decay_set_constant", 'd', "Learning Ratio Decay in Perceptron set to be constant. (The next parameter).");
    public FloatOption learningRatioOption = new FloatOption("learningRatio", 'l', "Learning Ratio to use for training the 1st layer.", 0.025);
    public FloatOption learningRatio2ndLayerOption = new FloatOption("learningRatio2ndLayer", 'n', "Learning Ratio to use in the second layer.", 0.001);
    public FloatOption learningRateDecayOption = new FloatOption("learningRateDecay", 'm', " Learning Rate decay to use for training the 1st layer.", 0.001);
    public FlagOption skipStackingOption = new FlagOption("skipStackingOption", 's', "Predicts the outputs of the first layer (no dependence among output is computed)");
    public IntOption randomSeedOption = new IntOption("randomSeed", 'r', "Seed for random behaviour of the classifier.", 1);
    public FlagOption printWeightsOption = new FlagOption("printWeights", 'p', "Outputs the 2nd layer weights as measurements.");
    private boolean hasStarted;
    private double count;
    private double[] inAttrSum;
    private double[] inAttrSquaredSum;
    private double[] outAttrSum;
    private double[] outAttrSquaredSum;
    private double[][] layer1Weights;
    private double[][] layer2Weights;
    double currentLearningRate;
    LinkedList<Integer> numericIndices;

    @Override
    public boolean isRandomizable() {
        return true;
    }

    @Override
    public void resetWithMemory() {
        this.currentLearningRate = this.learningRatioOption.getValue();
    }

    @Override
    public void trainOnInstanceImpl(MultiLabelInstance instance) {
        int numInputs;
        int numOutputs = instance.numOutputAttributes();
        if (!this.hasStarted) {
            this.hasStarted = true;
            this.numericIndices = new LinkedList();
            for (int i = 0; i < instance.numInputAttributes(); ++i) {
                if (!instance.inputAttribute(i).isNumeric()) continue;
                this.numericIndices.add(i);
            }
            numInputs = this.numericIndices.size();
            this.inAttrSum = new double[numInputs];
            this.inAttrSquaredSum = new double[numInputs];
            this.outAttrSum = new double[numOutputs];
            this.outAttrSquaredSum = new double[numOutputs];
            this.layer1Weights = new double[numInputs + 1][numOutputs];
            this.layer2Weights = new double[numOutputs + 1][numOutputs];
            for (int j = 0; j < numOutputs; ++j) {
                for (int i = 0; i < numInputs + 1; ++i) {
                    this.layer1Weights[i][j] = 2.0 * this.classifierRandom.nextDouble() - 1.0;
                }
                this.layer2Weights[j][j] = 1.0;
            }
        }
        numInputs = this.numericIndices.size();
        double w = instance.weight();
        this.count += w;
        Iterator it = this.numericIndices.iterator();
        int ct = 0;
        while (it.hasNext()) {
            double value = instance.valueInputAttribute((Integer)it.next());
            int n = ct;
            this.inAttrSum[n] = this.inAttrSum[n] + value * w;
            int n2 = ct++;
            this.inAttrSquaredSum[n2] = this.inAttrSquaredSum[n2] + value * value * w;
        }
        int i = 0;
        while (i < numOutputs) {
            double value = instance.valueOutputAttribute(i);
            int n = i;
            this.outAttrSum[n] = this.outAttrSum[n] + value * w;
            int n3 = i++;
            this.outAttrSquaredSum[n3] = this.outAttrSquaredSum[n3] + value * value * w;
        }
        double[] normInputs = this.getNormalizedInput(instance);
        double[] firstLayerOutput = this.predict1stLayer(normInputs);
        double[] secondLayerOutput = null;
        if (!this.skipStackingOption.isSet()) {
            secondLayerOutput = this.predict2ndLayer(firstLayerOutput);
        }
        if (!this.constantLearningRatioDecayOption.isSet()) {
            this.currentLearningRate = this.learningRatioOption.getValue() / (1.0 + this.count * this.learningRateDecayOption.getValue());
        }
        double[] normOutputs = this.getNormalizedOutput(instance);
        for (int j = 0; j < numOutputs; ++j) {
            int i2;
            double delta = normOutputs[j] - firstLayerOutput[j];
            double sumLayer = 0.0;
            for (i2 = 0; i2 < numInputs; ++i2) {
                double[] dArray = this.layer1Weights[i2];
                int n = j;
                dArray[n] = dArray[n] + this.currentLearningRate * delta * normInputs[i2] * instance.weight();
                sumLayer += Math.abs(this.layer1Weights[i2][j]);
            }
            double[] dArray = this.layer1Weights[numInputs];
            int n = j;
            dArray[n] = dArray[n] + this.currentLearningRate * delta * instance.weight();
            if (!((sumLayer += Math.abs(this.layer1Weights[numInputs][j])) > (double)numInputs)) continue;
            for (i2 = 0; i2 < numInputs + 1; ++i2) {
                double[] dArray2 = this.layer1Weights[i2];
                int n4 = j;
                dArray2[n4] = dArray2[n4] / sumLayer;
            }
        }
        if (!this.skipStackingOption.isSet()) {
            double learningRate2ndLayer = this.learningRatio2ndLayerOption.getValue();
            for (int j = 0; j < numOutputs; ++j) {
                int i3;
                double delta = normOutputs[j] - secondLayerOutput[j];
                double sumLayer = 0.0;
                for (i3 = 0; i3 < numOutputs; ++i3) {
                    double[] dArray = this.layer2Weights[i3];
                    int n = j;
                    dArray[n] = dArray[n] + learningRate2ndLayer * delta * firstLayerOutput[i3] * instance.weight();
                    sumLayer += Math.abs(this.layer2Weights[i3][j]);
                }
                double[] dArray = this.layer2Weights[numOutputs];
                int n = j;
                dArray[n] = dArray[n] + learningRate2ndLayer * delta * instance.weight();
                if (!((sumLayer += Math.abs(this.layer2Weights[numOutputs][j])) > (double)numOutputs)) continue;
                for (i3 = 0; i3 < numOutputs + 1; ++i3) {
                    double[] dArray3 = this.layer2Weights[i3];
                    int n5 = j;
                    dArray3[n5] = dArray3[n5] / sumLayer;
                }
            }
        }
    }

    @Override
    public Prediction getPredictionForInstance(MultiLabelInstance inst) {
        MultiLabelPrediction pred = null;
        if (this.hasStarted) {
            int numOutputs = this.outAttrSum.length;
            pred = new MultiLabelPrediction(numOutputs);
            double[] normInputs = this.getNormalizedInput(inst);
            double[] firstLayerOutput = this.predict1stLayer(normInputs);
            double[] denormalizedOutput = null;
            if (!this.skipStackingOption.isSet()) {
                double[] secondLayerOutput = this.predict2ndLayer(firstLayerOutput);
                denormalizedOutput = this.getDenormalizedOutput(secondLayerOutput);
            } else {
                denormalizedOutput = this.getDenormalizedOutput(firstLayerOutput);
            }
            for (int i = 0; i < numOutputs; ++i) {
                pred.setVotes(i, new double[]{denormalizedOutput[i]});
            }
        }
        return pred;
    }

    @Override
    public void resetLearningImpl() {
        this.hasStarted = false;
        this.count = 0.0;
        this.inAttrSum = null;
        this.inAttrSquaredSum = null;
        this.outAttrSum = null;
        this.outAttrSquaredSum = null;
        this.layer1Weights = null;
        this.layer2Weights = null;
        this.numericIndices = null;
        this.currentLearningRate = this.learningRatioOption.getValue();
        this.classifierRandom = new Random();
        this.classifierRandom.setSeed(this.randomSeedOption.getValue());
    }

    protected double[] getNormalizedInput(MultiLabelInstance instance) {
        int numInputs = this.numericIndices.size();
        double[] normalizedInput = new double[numInputs];
        Iterator it = this.numericIndices.iterator();
        int i = 0;
        while (it.hasNext()) {
            double mean = this.inAttrSum[i] / this.count;
            double std = Utils.computeSD(this.inAttrSquaredSum[i], this.inAttrSum[i], this.count);
            normalizedInput[i] = instance.valueInputAttribute((Integer)it.next()) - mean;
            if (std > 1.0E-7) {
                int n = i;
                normalizedInput[n] = normalizedInput[n] / std;
            }
            ++i;
        }
        return normalizedInput;
    }

    protected double[] getNormalizedOutput(MultiLabelInstance instance) {
        int numOutputs = instance.numOutputAttributes();
        double[] normalizedOutput = new double[numOutputs];
        for (int i = 0; i < numOutputs; ++i) {
            double mean = this.outAttrSum[i] / this.count;
            double std = Utils.computeSD(this.outAttrSquaredSum[i], this.outAttrSum[i], this.count);
            normalizedOutput[i] = instance.valueOutputAttribute(i) - mean;
            if (!(std > 1.0E-7)) continue;
            int n = i;
            normalizedOutput[n] = normalizedOutput[n] / std;
        }
        return normalizedOutput;
    }

    protected double[] getDenormalizedOutput(double[] normOutputs) {
        int numOutputs = normOutputs.length;
        double[] denormalizedOutput = new double[numOutputs];
        for (int i = 0; i < numOutputs; ++i) {
            double mean = this.outAttrSum[i] / this.count;
            double std = Utils.computeSD(this.outAttrSquaredSum[i], this.outAttrSum[i], this.count);
            denormalizedOutput[i] = std > 1.0E-7 ? normOutputs[i] * std + mean : normOutputs[i] + mean;
        }
        return denormalizedOutput;
    }

    private double[] predict1stLayer(double[] normInputs) {
        int numInputs = this.numericIndices.size();
        int numOutputs = this.outAttrSum.length;
        double[] firstLayerOutput = new double[numOutputs];
        for (int j = 0; j < numOutputs; ++j) {
            for (int i = 0; i < numInputs; ++i) {
                int n = j;
                firstLayerOutput[n] = firstLayerOutput[n] + normInputs[i] * this.layer1Weights[i][j];
            }
            int n = j;
            firstLayerOutput[n] = firstLayerOutput[n] + this.layer1Weights[numInputs][j];
        }
        return firstLayerOutput;
    }

    private double[] predict2ndLayer(double[] firstLayerOutput) {
        int numOutputs = firstLayerOutput.length;
        double[] secondLayerOutput = new double[numOutputs];
        for (int j = 0; j < numOutputs; ++j) {
            for (int i = 0; i < numOutputs; ++i) {
                int n = j;
                secondLayerOutput[n] = secondLayerOutput[n] + firstLayerOutput[i] * this.layer2Weights[i][j];
            }
            int n = j;
            secondLayerOutput[n] = secondLayerOutput[n] + this.layer2Weights[numOutputs][j];
        }
        return secondLayerOutput;
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        Measurement[] measurements = null;
        if (this.printWeightsOption.isSet()) {
            int numWeights = this.layer2Weights.length;
            measurements = new Measurement[numWeights * (numWeights - 1)];
            int ct = 0;
            for (int j = 0; j < numWeights - 1; ++j) {
                for (int i = 0; i < numWeights - 1; ++i) {
                    measurements[ct++] = new Measurement("W Out" + (i + 1) + ": Out" + (j + 1), this.layer2Weights[i][j]);
                }
                measurements[ct++] = new Measurement("W Bias: Out" + (j + 1), this.layer2Weights[numWeights - 1][j]);
            }
        }
        return measurements;
    }

    @Override
    public void getModelDescription(StringBuilder out, int indent) {
    }

    @Override
    public void selectOutputsToLearn(int[] outputAtributtes) {
        int numOutputs = outputAtributtes.length;
        double[] newOutAttrSum = new double[numOutputs];
        double[] newOutAttrSquaredSum = new double[numOutputs];
        int numInputsPlus1 = this.layer1Weights.length;
        double[][] newLayer1Weights = new double[numInputsPlus1][numOutputs];
        double[][] newLayer2Weights = new double[numInputsPlus1][numOutputs];
        int oldNumOutputs = this.layer2Weights.length - 1;
        for (int j = 0; j < numOutputs; ++j) {
            int i;
            int out = outputAtributtes[j];
            newOutAttrSum[j] = this.outAttrSum[out];
            newOutAttrSquaredSum[j] = this.outAttrSquaredSum[out];
            for (i = 0; i < numInputsPlus1; ++i) {
                newLayer1Weights[i][j] = this.layer1Weights[i][out];
            }
            for (i = 0; i < numOutputs; ++i) {
                int out2 = outputAtributtes[i];
                newLayer2Weights[i][j] = this.layer2Weights[out2][out];
            }
            newLayer2Weights[numOutputs][j] = this.layer2Weights[oldNumOutputs][out];
        }
        this.outAttrSum = newOutAttrSum;
        this.outAttrSquaredSum = newOutAttrSquaredSum;
        this.layer1Weights = newLayer1Weights;
        this.layer2Weights = newLayer2Weights;
    }
}

