/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.rexp;

import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.Apply;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
import org.dmg.pmml.general_regression.GeneralRegressionModel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FortranMatrixUtil;
import org.jpmml.converter.InteractionFeature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.general_regression.GeneralRegressionModelUtil;
import org.jpmml.rexp.DecorationUtil;
import org.jpmml.rexp.Formula;
import org.jpmml.rexp.FormulaUtil;
import org.jpmml.rexp.ModelConverter;
import org.jpmml.rexp.RDoubleVector;
import org.jpmml.rexp.RExp;
import org.jpmml.rexp.RExpEncoder;
import org.jpmml.rexp.RGenericVector;
import org.jpmml.rexp.RStringVector;
import org.jpmml.rexp.XLevelsFormulaContext;

public class EarthConverter
extends ModelConverter<RGenericVector> {
    public EarthConverter(RGenericVector earth) {
        super(earth);
    }

    @Override
    public void encodeSchema(RExpEncoder encoder) {
        RGenericVector earth = (RGenericVector)this.getObject();
        RDoubleVector dirs = earth.getDoubleElement("dirs");
        RDoubleVector cuts = earth.getDoubleElement("cuts");
        RDoubleVector selectedTerms = earth.getDoubleElement("selected.terms");
        RDoubleVector coefficients = earth.getDoubleElement("coefficients");
        RExp terms = (RExp)earth.getElement("terms");
        RGenericVector xlevels = DecorationUtil.getGenericElement(earth, "xlevels");
        RStringVector dirsRows = dirs.dimnames(0);
        RStringVector dirsColumns = dirs.dimnames(1);
        RStringVector cutsRows = cuts.dimnames(0);
        RStringVector cutsColumns = cuts.dimnames(1);
        if (!dirsRows.getValues().equals(cutsRows.getValues()) || !dirsColumns.getValues().equals(cutsColumns.getValues())) {
            throw new IllegalArgumentException();
        }
        int rows = dirsRows.size();
        int columns = dirsColumns.size();
        List<String> predictorNames = dirsColumns.getValues();
        XLevelsFormulaContext context = new XLevelsFormulaContext(xlevels);
        Formula formula = FormulaUtil.createFormula(terms, context, encoder);
        RStringVector yNames = coefficients.dimnames(1);
        FieldName name = FieldName.create((String)((String)yNames.asScalar()));
        DataField dataField = (DataField)encoder.getField(name);
        encoder.setLabel(dataField);
        for (int i = 1; i < selectedTerms.size(); ++i) {
            Feature feature;
            int termIndex = ValueUtil.asInt((Number)selectedTerms.getValue(i)) - 1;
            List dirsRow = FortranMatrixUtil.getRow(dirs.getValues(), (int)rows, (int)columns, (int)termIndex);
            List cutsRow = FortranMatrixUtil.getRow(cuts.getValues(), (int)rows, (int)columns, (int)termIndex);
            ArrayList<Feature> features = new ArrayList<Feature>();
            for (int j = 0; j < predictorNames.size(); ++j) {
                String predictorName = predictorNames.get(j);
                int dir = ValueUtil.asInt((Number)((Number)dirsRow.get(j)));
                double cut = (Double)cutsRow.get(j);
                if (dir == 0) continue;
                Feature feature2 = formula.resolveFeature(predictorName);
                switch (dir) {
                    case -1: 
                    case 1: {
                        ContinuousFeature continuousFeature = feature2.toContinuousFeature();
                        DerivedField derivedField = encoder.ensureDerivedField(FieldName.create((String)EarthConverter.formatHingeFunction(dir, (Feature)continuousFeature, cut)), OpType.CONTINUOUS, DataType.DOUBLE, () -> EarthConverter.createHingeFunction(dir, (Feature)continuousFeature, cut));
                        feature2 = new ContinuousFeature((PMMLEncoder)encoder, (Field)derivedField);
                        break;
                    }
                    case 2: {
                        break;
                    }
                    default: {
                        throw new IllegalArgumentException();
                    }
                }
                features.add(feature2);
            }
            if (features.size() == 1) {
                feature = (Feature)features.get(0);
            } else if (features.size() > 1) {
                feature = new InteractionFeature((PMMLEncoder)encoder, FieldName.create((String)dirsRows.getValue(i)), DataType.DOUBLE, features);
            } else {
                throw new IllegalArgumentException();
            }
            encoder.addFeature(feature);
        }
    }

    public GeneralRegressionModel encodeModel(Schema schema) {
        RGenericVector earth = (RGenericVector)this.getObject();
        RDoubleVector coefficients = earth.getDoubleElement("coefficients");
        Double intercept = coefficients.getValue(0);
        List features = schema.getFeatures();
        SchemaUtil.checkSize((int)(coefficients.size() - 1), (List)features);
        List<Double> featureCoefficients = coefficients.getValues().subList(1, features.size() + 1);
        GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)schema.getLabel()), null, null, null).setLinkFunction(GeneralRegressionModel.LinkFunction.IDENTITY);
        GeneralRegressionModelUtil.encodeRegressionTable((GeneralRegressionModel)generalRegressionModel, (List)features, featureCoefficients, (Number)intercept, null);
        return generalRegressionModel;
    }

    private static String formatHingeFunction(int dir, Feature feature, double cut) {
        switch (dir) {
            case -1: {
                return "h(" + cut + " - " + feature.getName().getValue() + ")";
            }
            case 1: {
                return "h(" + feature.getName().getValue() + " - " + cut + ")";
            }
        }
        throw new IllegalArgumentException();
    }

    private static Apply createHingeFunction(int dir, Feature feature, double cut) {
        Apply expression;
        switch (dir) {
            case -1: {
                expression = PMMLUtil.createApply((String)"-", (Expression[])new Expression[]{PMMLUtil.createConstant((Number)cut), feature.ref()});
                break;
            }
            case 1: {
                expression = PMMLUtil.createApply((String)"-", (Expression[])new Expression[]{feature.ref(), PMMLUtil.createConstant((Number)cut)});
                break;
            }
            default: {
                throw new IllegalArgumentException();
            }
        }
        return PMMLUtil.createApply((String)"max", (Expression[])new Expression[]{expression, PMMLUtil.createConstant((Number)0.0)});
    }
}

