/*
 * Decompiled with CFR 0.152.
 */
package org.vikamine.kernel.data.discretization;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.Set;
import org.vikamine.kernel.data.Attribute;
import org.vikamine.kernel.data.DataRecord;
import org.vikamine.kernel.data.DataView;
import org.vikamine.kernel.data.IDataRecordSet;
import org.vikamine.kernel.data.NumericAttribute;
import org.vikamine.kernel.data.Value;
import org.vikamine.kernel.subgroup.SGDescription;
import org.vikamine.kernel.subgroup.target.BooleanTarget;
import org.vikamine.kernel.subgroup.target.NumericTarget;
import org.vikamine.kernel.subgroup.target.SGTarget;

public class DiscretizationUtils {
    public static List<DataRecord> getSortedDataRecords(Iterable<DataRecord> population, Attribute att, boolean includeMissings, boolean descending) {
        LinkedList<DataRecord> dataRecords = new LinkedList<DataRecord>();
        for (DataRecord dr : population) {
            if (!includeMissings && Value.isMissingValue(dr.getValue(att))) continue;
            dataRecords.add(dr);
        }
        Collections.sort(dataRecords, new AttributeComparator(att, descending));
        return dataRecords;
    }

    public static double[] getMinMaxValue(Iterable<DataRecord> drs, NumericAttribute na) {
        double maxValue = Double.MIN_VALUE;
        double minValue = Double.MAX_VALUE;
        if (drs.iterator().hasNext()) {
            int index = drs.iterator().next().getDataset().getIndex(na);
            for (DataRecord record : drs) {
                double value = record.getValue(index);
                if (Double.isNaN(value)) continue;
                maxValue = Math.max(value, maxValue);
                minValue = Math.min(value, minValue);
            }
        }
        return new double[]{minValue, maxValue};
    }

    public static double[] countsInPopulation(DataView dataView, NumericAttribute na, List<Double> cutpoints) {
        if (cutpoints == null || cutpoints.size() == 0) {
            return null;
        }
        double[] result = new double[cutpoints.size() + 1];
        Arrays.fill(result, 0.0);
        block0: for (DataRecord dr : dataView) {
            double value = dr.getValue(na);
            int i = 0;
            while (i < result.length) {
                double upperBoundOfInterval;
                double lowerBoundOfInterval = i == 0 ? Double.NEGATIVE_INFINITY : cutpoints.get(i - 1);
                double d = upperBoundOfInterval = i == cutpoints.size() ? Double.POSITIVE_INFINITY : cutpoints.get(i);
                if (value >= lowerBoundOfInterval && value < upperBoundOfInterval) {
                    int n = i;
                    result[n] = result[n] + 1.0;
                    continue block0;
                }
                ++i;
            }
        }
        return result;
    }

    public static double[] countsInSubgroup(DataView population, NumericAttribute na, List<Double> cutpoints, SGDescription sgDescription) {
        if (cutpoints == null || cutpoints.size() == 0) {
            return null;
        }
        double[] result = new double[cutpoints.size() + 1];
        Arrays.fill(result, 0.0);
        block0: for (DataRecord dr : population) {
            double value = dr.getValue(na);
            int i = 0;
            while (i < result.length) {
                double upperBoundOfInterval;
                double lowerBoundOfInterval = i == 0 ? Double.NEGATIVE_INFINITY : cutpoints.get(i - 1);
                double d = upperBoundOfInterval = i == cutpoints.size() ? Double.POSITIVE_INFINITY : cutpoints.get(i);
                if (value >= lowerBoundOfInterval && value < upperBoundOfInterval) {
                    if (!sgDescription.isMatching(dr)) continue block0;
                    int n = i;
                    result[n] = result[n] + 1.0;
                    continue block0;
                }
                ++i;
            }
        }
        return result;
    }

    public static double[] countTargets(DataView population, NumericAttribute na, List<Double> cutpoints, SGTarget target) {
        if (cutpoints == null || cutpoints.size() == 0) {
            return null;
        }
        double[] result = new double[cutpoints.size() + 1];
        Arrays.fill(result, 0.0);
        if (target == null) {
            return result;
        }
        block0: for (DataRecord dr : population) {
            double value = dr.getValue(na);
            int i = 0;
            while (i < result.length) {
                double upperBoundOfInterval;
                double lowerBoundOfInterval = i == 0 ? Double.NEGATIVE_INFINITY : cutpoints.get(i - 1);
                double d = upperBoundOfInterval = i == cutpoints.size() ? Double.POSITIVE_INFINITY : cutpoints.get(i);
                if (value >= lowerBoundOfInterval && value < upperBoundOfInterval) {
                    if (target.isBoolean()) {
                        if (!((BooleanTarget)target).isPositive(dr)) continue block0;
                        int n = i;
                        result[n] = result[n] + 1.0;
                        continue block0;
                    }
                    double valueNum = ((NumericTarget)target).getValue(dr);
                    if (Double.isNaN(valueNum)) continue block0;
                    int n = i;
                    result[n] = result[n] + valueNum;
                    continue block0;
                }
                ++i;
            }
        }
        return result;
    }

    public static void addCutpoint(List<Double> cutpointList, double newCutpoint) {
        if (cutpointList.contains(newCutpoint)) {
            return;
        }
        int i = 0;
        while (i < cutpointList.size()) {
            if (cutpointList.get(i) > newCutpoint) {
                cutpointList.add(i, newCutpoint);
                return;
            }
            ++i;
        }
        cutpointList.add(newCutpoint);
    }

    public static Set<DataRecord> resample(IDataRecordSet set, int size, int sampleSize) {
        HashSet<DataRecord> sample = new HashSet<DataRecord>();
        Random random = new Random();
        while (sample.size() < sampleSize) {
            sample.add(set.get(random.nextInt(size)));
        }
        return sample;
    }

    public static List<Double> Records2Double(List<DataRecord> records, Attribute attribute) {
        ArrayList<Double> doubles = new ArrayList<Double>();
        Iterator<DataRecord> it = records.iterator();
        while (it.hasNext()) {
            doubles.add(it.next().getValue(attribute));
        }
        return doubles;
    }

    private static class AttributeComparator
    implements Comparator<DataRecord>,
    Serializable {
        private static final long serialVersionUID = 1473923143827865392L;
        Attribute att;
        boolean descending;

        public AttributeComparator(Attribute attribute, boolean descending) {
            this.att = attribute;
            this.descending = descending;
        }

        @Override
        public int compare(DataRecord o1, DataRecord o2) {
            int value = Double.compare(o1.getValue(this.att), o2.getValue(this.att));
            return this.descending ? -value : value;
        }
    }
}

