/*
 * Decompiled with CFR 0.152.
 */
package eu.amidst.core.learning.parametric;

import com.google.common.util.concurrent.AtomicDouble;
import eu.amidst.core.datastream.DataInstance;
import eu.amidst.core.datastream.DataOnMemory;
import eu.amidst.core.datastream.DataStream;
import eu.amidst.core.exponentialfamily.EF_BayesianNetwork;
import eu.amidst.core.exponentialfamily.SufficientStatistics;
import eu.amidst.core.learning.parametric.ParameterLearningAlgorithm;
import eu.amidst.core.models.BayesianNetwork;
import eu.amidst.core.models.DAG;
import eu.amidst.core.utils.CompoundVector;
import eu.amidst.core.utils.Utils;
import eu.amidst.core.utils.Vector;
import eu.amidst.core.variables.Variable;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class ParallelMLMissingData
implements ParameterLearningAlgorithm {
    protected int windowsSize = 1000;
    protected boolean parallelMode = true;
    protected DataStream<DataInstance> dataStream;
    protected DAG dag;
    protected AtomicDouble dataInstanceCount;
    protected PartialSufficientSatistics sumSS;
    protected EF_BayesianNetwork efBayesianNetwork;
    protected boolean debug = false;
    protected boolean laplace = true;

    public void setLaplace(boolean laplace) {
        this.laplace = laplace;
    }

    public void setDebug(boolean debug) {
        this.debug = debug;
    }

    @Override
    public void setWindowsSize(int windowsSize) {
    }

    @Override
    public int getWindowsSize() {
        return this.windowsSize;
    }

    @Override
    public void initLearning() {
        this.efBayesianNetwork = new EF_BayesianNetwork(this.dag);
        if (this.laplace) {
            this.sumSS = PartialSufficientSatistics.createInitPartialSufficientStatistics(this.efBayesianNetwork);
            this.dataInstanceCount = new AtomicDouble(1.0);
        } else {
            this.sumSS = PartialSufficientSatistics.createZeroPartialSufficientStatistics(this.efBayesianNetwork);
            this.dataInstanceCount = new AtomicDouble(0.0);
        }
    }

    @Override
    public double updateModel(DataOnMemory<DataInstance> batch) {
        this.sumSS.sum(batch.stream().map(dataInstance -> ParallelMLMissingData.computeCountSufficientStatistics(this.efBayesianNetwork, dataInstance)).reduce(PartialSufficientSatistics::sumNonStateless).get());
        this.dataInstanceCount.addAndGet(batch.getNumberOfDataInstances());
        return Double.NaN;
    }

    @Override
    public double updateModel(DataStream<DataInstance> dataStream) {
        Stream<DataOnMemory<DataInstance>> stream = null;
        stream = this.parallelMode ? dataStream.parallelStreamOfBatches(this.windowsSize) : dataStream.streamOfBatches(this.windowsSize);
        this.dataInstanceCount = new AtomicDouble(0.0);
        this.sumSS = stream.peek(batch -> {
            this.dataInstanceCount.getAndAdd(batch.getNumberOfDataInstances());
            if (this.debug) {
                System.out.println("Parallel ML procesando " + (int)this.dataInstanceCount.get() + " instances");
            }
        }).map(batch -> batch.stream().map(dataInstance -> ParallelMLMissingData.computeCountSufficientStatistics(this.efBayesianNetwork, dataInstance)).reduce(PartialSufficientSatistics::sumNonStateless).get()).reduce(PartialSufficientSatistics::sumNonStateless).get();
        if (this.laplace) {
            PartialSufficientSatistics initSS = PartialSufficientSatistics.createInitPartialSufficientStatistics(this.efBayesianNetwork);
            this.sumSS.sum(initSS);
        }
        return Double.NaN;
    }

    @Override
    public void setDataStream(DataStream<DataInstance> data) {
        this.dataStream = data;
    }

    @Override
    public double getLogMarginalProbability() {
        throw new UnsupportedOperationException("Method not implemented yet");
    }

    @Override
    public void runLearning() {
        this.initLearning();
        Stream<DataOnMemory<DataInstance>> stream = null;
        stream = this.parallelMode ? this.dataStream.parallelStreamOfBatches(this.windowsSize) : this.dataStream.streamOfBatches(this.windowsSize);
        this.dataInstanceCount = new AtomicDouble(0.0);
        this.sumSS = stream.peek(batch -> {
            this.dataInstanceCount.getAndAdd(batch.getNumberOfDataInstances());
            if (this.debug) {
                System.out.println("Parallel ML procesando " + (int)this.dataInstanceCount.get() + " instances");
            }
        }).map(batch -> batch.stream().map(dataInstance -> ParallelMLMissingData.computeCountSufficientStatistics(this.efBayesianNetwork, dataInstance)).reduce(PartialSufficientSatistics::sumNonStateless).get()).reduce(PartialSufficientSatistics::sumNonStateless).get();
        if (this.laplace) {
            PartialSufficientSatistics initSS = PartialSufficientSatistics.createInitPartialSufficientStatistics(this.efBayesianNetwork);
            this.sumSS.sum(initSS);
        }
    }

    public static PartialSufficientSatistics computeCountSufficientStatistics(EF_BayesianNetwork bn, DataInstance dataInstance) {
        List<CountVector> list = bn.getDistributionList().stream().map(dist -> {
            if (Utils.isMissingValue(dataInstance.getValue(dist.getVariable()))) {
                return new CountVector();
            }
            for (Variable var : dist.getConditioningVariables()) {
                if (!Utils.isMissingValue(dataInstance.getValue(var))) continue;
                return new CountVector();
            }
            return new CountVector(dist.getSufficientStatistics(dataInstance));
        }).collect(Collectors.toList());
        return new PartialSufficientSatistics(list);
    }

    @Override
    public void setDAG(DAG dag_) {
        this.dag = dag_;
    }

    @Override
    public void setSeed(int seed) {
    }

    @Override
    public BayesianNetwork getLearntBayesianNetwork() {
        PartialSufficientSatistics partialSufficientSatistics = PartialSufficientSatistics.createZeroPartialSufficientStatistics(this.efBayesianNetwork);
        partialSufficientSatistics.copy(this.sumSS);
        partialSufficientSatistics.normalize();
        SufficientStatistics finalSS = this.efBayesianNetwork.createZeroSufficientStatistics();
        finalSS.sum(partialSufficientSatistics.getCompoundVector());
        this.efBayesianNetwork.setMomentParameters(finalSS);
        return this.efBayesianNetwork.toBayesianNetwork(this.dag);
    }

    @Override
    public void setParallelMode(boolean parallelMode_) {
        this.parallelMode = parallelMode_;
    }

    @Override
    public void setOutput(boolean activateOutput) {
    }

    static class CountVector {
        SufficientStatistics sufficientStatistics;
        int count;

        public CountVector() {
            this.count = 0;
            this.sufficientStatistics = null;
        }

        public CountVector(SufficientStatistics sufficientStatistics) {
            this.sufficientStatistics = sufficientStatistics;
            this.count = 1;
        }

        public void normalize() {
            this.sufficientStatistics.divideBy(this.count);
        }

        public void copy(CountVector a) {
            this.count = a.count;
            if (a.sufficientStatistics == null) {
                this.sufficientStatistics = null;
            } else if (this.sufficientStatistics == null) {
                this.sufficientStatistics = a.sufficientStatistics;
            } else {
                this.sufficientStatistics.copy(a.sufficientStatistics);
            }
        }

        public void sum(CountVector a) {
            if (a.sufficientStatistics == null) {
                return;
            }
            this.count += a.count;
            if (this.sufficientStatistics == null) {
                this.sufficientStatistics = a.sufficientStatistics;
            } else {
                this.sufficientStatistics.sum(a.sufficientStatistics);
            }
        }
    }

    public static class PartialSufficientSatistics {
        List<CountVector> list;

        public PartialSufficientSatistics(List<CountVector> list) {
            this.list = list;
        }

        public static PartialSufficientSatistics createInitPartialSufficientStatistics(EF_BayesianNetwork ef_bayesianNetwork) {
            return new PartialSufficientSatistics(ef_bayesianNetwork.getDistributionList().stream().map(w -> new CountVector(w.createInitSufficientStatistics())).collect(Collectors.toList()));
        }

        public static PartialSufficientSatistics createZeroPartialSufficientStatistics(EF_BayesianNetwork ef_bayesianNetwork) {
            return new PartialSufficientSatistics(ef_bayesianNetwork.getDistributionList().stream().map(w -> new CountVector(w.createZeroSufficientStatistics())).collect(Collectors.toList()));
        }

        public void normalize() {
            this.list.stream().forEach(a -> a.normalize());
        }

        public void copy(PartialSufficientSatistics a) {
            for (int i = 0; i < this.list.size(); ++i) {
                this.list.get(i).copy(a.list.get(i));
            }
        }

        public void sum(PartialSufficientSatistics a) {
            for (int i = 0; i < this.list.size(); ++i) {
                this.list.get(i).sum(a.list.get(i));
            }
        }

        public static PartialSufficientSatistics sumNonStateless(PartialSufficientSatistics a, PartialSufficientSatistics b) {
            for (int i = 0; i < b.list.size(); ++i) {
                b.list.get(i).sum(a.list.get(i));
            }
            return b;
        }

        public CompoundVector getCompoundVector() {
            List<Vector> ssList = this.list.stream().map(a -> a.sufficientStatistics).collect(Collectors.toList());
            return new CompoundVector(ssList);
        }
    }
}

