/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.converter.regression;

import com.google.common.collect.Iterables;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Objects;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.regression.CategoricalPredictor;
import org.dmg.pmml.regression.NumericPredictor;
import org.dmg.pmml.regression.PredictorTerm;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ConstantFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.InteractionFeature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PowerFeature;
import org.jpmml.converter.ProductFeature;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.ValueUtil;

public class RegressionModelUtil {
    private RegressionModelUtil() {
    }

    public static RegressionModel createRegression(List<? extends Feature> features, List<? extends Number> coefficients, Number intercept, RegressionModel.NormalizationMethod normalizationMethod, Schema schema) {
        return RegressionModelUtil.createRegression(null, features, coefficients, intercept, normalizationMethod, schema);
    }

    public static RegressionModel createRegression(MathContext mathContext, List<? extends Feature> features, List<? extends Number> coefficients, Number intercept, RegressionModel.NormalizationMethod normalizationMethod, Schema schema) {
        ContinuousLabel continuousLabel = (ContinuousLabel)schema.getLabel();
        if (normalizationMethod != null) {
            switch (normalizationMethod) {
                case NONE: 
                case SOFTMAX: 
                case LOGIT: 
                case PROBIT: 
                case CLOGLOG: 
                case EXP: 
                case LOGLOG: 
                case CAUCHIT: {
                    break;
                }
                default: {
                    throw new IllegalArgumentException();
                }
            }
        }
        RegressionModel regressionModel = new RegressionModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel), null).setNormalizationMethod(normalizationMethod).setMathContext(ModelUtil.simplifyMathContext(mathContext)).addRegressionTables(new RegressionTable[]{RegressionModelUtil.createRegressionTable(mathContext, features, coefficients, intercept)});
        return regressionModel;
    }

    public static RegressionModel createBinaryLogisticClassification(List<? extends Feature> features, List<? extends Number> coefficients, Number intercept, RegressionModel.NormalizationMethod normalizationMethod, boolean hasProbabilityDistribution, Schema schema) {
        return RegressionModelUtil.createBinaryLogisticClassification(null, features, coefficients, intercept, normalizationMethod, hasProbabilityDistribution, schema);
    }

    public static RegressionModel createBinaryLogisticClassification(MathContext mathContext, List<? extends Feature> features, List<? extends Number> coefficients, Number intercept, RegressionModel.NormalizationMethod normalizationMethod, boolean hasProbabilityDistribution, Schema schema) {
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        SchemaUtil.checkSize(2, categoricalLabel);
        if (normalizationMethod != null) {
            switch (normalizationMethod) {
                case NONE: 
                case LOGIT: 
                case PROBIT: 
                case CLOGLOG: 
                case LOGLOG: 
                case CAUCHIT: {
                    break;
                }
                default: {
                    throw new IllegalArgumentException();
                }
            }
        }
        RegressionTable activeRegressionTable = RegressionModelUtil.createRegressionTable(mathContext, features, coefficients, intercept).setTargetCategory(categoricalLabel.getValue(1));
        RegressionTable passiveRegressionTable = RegressionModelUtil.createRegressionTable(mathContext, Collections.emptyList(), Collections.emptyList(), null).setTargetCategory(categoricalLabel.getValue(0));
        RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), null).setNormalizationMethod(normalizationMethod).setMathContext(ModelUtil.simplifyMathContext(mathContext)).addRegressionTables(new RegressionTable[]{activeRegressionTable, passiveRegressionTable}).setOutput(hasProbabilityDistribution ? ModelUtil.createProbabilityOutput(mathContext, categoricalLabel) : null);
        return regressionModel;
    }

    public static RegressionTable createRegressionTable(List<? extends Feature> features, List<? extends Number> coefficients, Number intercept) {
        return RegressionModelUtil.createRegressionTable(null, features, coefficients, intercept);
    }

    public static RegressionTable createRegressionTable(MathContext mathContext, List<? extends Feature> features, List<? extends Number> coefficients, Number intercept) {
        if (features.size() != coefficients.size()) {
            throw new IllegalArgumentException();
        }
        RegressionTable regressionTable = new RegressionTable((Number)0.0);
        if (intercept != null && !ValueUtil.isZeroLike(intercept)) {
            regressionTable.setIntercept(intercept);
        }
        LinkedHashMap<PredictorKey, NumericPredictor> numericPredictors = new LinkedHashMap<PredictorKey, NumericPredictor>();
        LinkedHashMap<PredictorKey, CategoricalPredictor> categoricalPredictors = new LinkedHashMap<PredictorKey, CategoricalPredictor>();
        for (int i = 0; i < features.size(); ++i) {
            CategoricalPredictor categoricalPredictor;
            PredictorKey predictorKey;
            Feature feature = features.get(i);
            Number coefficient = coefficients.get(i);
            if (coefficient == null || ValueUtil.isZeroLike(coefficient)) continue;
            if (feature instanceof ProductFeature) {
                ProductFeature productFeature = (ProductFeature)feature;
                feature = productFeature.getFeature();
                coefficient = ValueUtil.multiply(mathContext, coefficient, productFeature.getFactor());
            }
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature)feature;
                predictorKey = new PredictorKey(binaryFeature.getName(), binaryFeature.getValue());
                categoricalPredictor = (CategoricalPredictor)categoricalPredictors.get(predictorKey);
                if (categoricalPredictor == null) {
                    categoricalPredictor = new CategoricalPredictor().setName(binaryFeature.getName()).setValue(binaryFeature.getValue()).setCoefficient(coefficient);
                    categoricalPredictors.put(predictorKey, categoricalPredictor);
                    regressionTable.addCategoricalPredictors(new CategoricalPredictor[]{categoricalPredictor});
                    continue;
                }
                categoricalPredictor.setCoefficient(ValueUtil.add(mathContext, categoricalPredictor.getCoefficient(), coefficient));
                continue;
            }
            if (feature instanceof BooleanFeature) {
                BooleanFeature booleanFeature = (BooleanFeature)feature;
                predictorKey = new PredictorKey(booleanFeature.getName(), BooleanFeature.VALUE_TRUE);
                categoricalPredictor = (CategoricalPredictor)categoricalPredictors.get(predictorKey);
                if (categoricalPredictor == null) {
                    categoricalPredictor = new CategoricalPredictor().setName(booleanFeature.getName()).setValue((Object)BooleanFeature.VALUE_TRUE).setCoefficient(coefficient);
                    categoricalPredictors.put(predictorKey, categoricalPredictor);
                    regressionTable.addCategoricalPredictors(new CategoricalPredictor[]{categoricalPredictor});
                    continue;
                }
                categoricalPredictor.setCoefficient(ValueUtil.add(mathContext, categoricalPredictor.getCoefficient(), coefficient));
                continue;
            }
            if (feature instanceof ConstantFeature) {
                ConstantFeature constantFeature = (ConstantFeature)feature;
                Number value = ValueUtil.add(mathContext, regressionTable.getIntercept(), ValueUtil.multiply(mathContext, coefficient, constantFeature.getValue()));
                regressionTable.setIntercept(value);
                continue;
            }
            if (feature instanceof InteractionFeature) {
                InteractionFeature interactionFeature = (InteractionFeature)feature;
                PredictorTerm predictorTerm = new PredictorTerm().setName(interactionFeature.getName()).setCoefficient(coefficient);
                List<? extends Feature> inputFeatures = interactionFeature.getInputFeatures();
                for (Feature feature2 : inputFeatures) {
                    if (feature2 instanceof ConstantFeature) {
                        ConstantFeature constantFeature = (ConstantFeature)feature2;
                        Number value = ValueUtil.multiply(mathContext, predictorTerm.getCoefficient(), constantFeature.getValue());
                        predictorTerm.setCoefficient(value);
                        continue;
                    }
                    ContinuousFeature continuousFeature = feature2.toContinuousFeature();
                    predictorTerm.addFieldRefs(new FieldRef[]{continuousFeature.ref()});
                }
                List fieldRefs = predictorTerm.getFieldRefs();
                if (fieldRefs.size() == 0) {
                    Number number = ValueUtil.add(mathContext, regressionTable.getIntercept(), predictorTerm.getCoefficient());
                    regressionTable.setIntercept(number);
                    continue;
                }
                if (fieldRefs.size() == 1) {
                    FieldRef fieldRef = (FieldRef)Iterables.getOnlyElement((Iterable)fieldRefs);
                    NumericPredictor numericPredictor = new NumericPredictor().setName(fieldRef.getField()).setCoefficient(predictorTerm.getCoefficient());
                    regressionTable.addNumericPredictors(new NumericPredictor[]{numericPredictor});
                    continue;
                }
                regressionTable.addPredictorTerms(new PredictorTerm[]{predictorTerm});
                continue;
            }
            if (feature instanceof PowerFeature) {
                PowerFeature powerFeature = (PowerFeature)feature;
                NumericPredictor numericPredictor = new NumericPredictor().setName(powerFeature.getName()).setExponent(Integer.valueOf(powerFeature.getPower())).setCoefficient(coefficient);
                regressionTable.addNumericPredictors(new NumericPredictor[]{numericPredictor});
                continue;
            }
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            predictorKey = new PredictorKey(continuousFeature.getName());
            NumericPredictor numericPredictor = (NumericPredictor)numericPredictors.get(predictorKey);
            if (numericPredictor == null) {
                numericPredictor = new NumericPredictor().setName(continuousFeature.getName()).setCoefficient(coefficient);
                numericPredictors.put(predictorKey, numericPredictor);
                regressionTable.addNumericPredictors(new NumericPredictor[]{numericPredictor});
                continue;
            }
            numericPredictor.setCoefficient(ValueUtil.add(mathContext, numericPredictor.getCoefficient(), coefficient));
        }
        return regressionTable;
    }

    private static class PredictorKey {
        private FieldName name = null;
        private Object value = null;

        private PredictorKey(FieldName name) {
            this(name, (Object)null);
        }

        private PredictorKey(FieldName name, Object value) {
            this.name = name;
            this.value = value;
        }

        public boolean equals(Object object) {
            if (object instanceof PredictorKey) {
                PredictorKey that = (PredictorKey)object;
                return Objects.equals(this.name, that.name) && Objects.equals(this.value, that.value);
            }
            return false;
        }

        public int hashCode() {
            int result = 0;
            result = 31 * result + Objects.hashCode(this.name);
            result = 31 * result + Objects.hashCode(this.value);
            return result;
        }
    }
}

