/*
 * Decompiled with CFR 0.152.
 */
package eu.amidst.dynamic.learning.parametric;

import com.google.common.util.concurrent.AtomicDouble;
import eu.amidst.core.datastream.DataOnMemory;
import eu.amidst.core.datastream.DataStream;
import eu.amidst.core.exponentialfamily.SufficientStatistics;
import eu.amidst.core.learning.parametric.ParallelMLMissingData;
import eu.amidst.dynamic.datastream.DynamicDataInstance;
import eu.amidst.dynamic.exponentialfamily.EF_DynamicBayesianNetwork;
import eu.amidst.dynamic.learning.parametric.ParameterLearningAlgorithm;
import eu.amidst.dynamic.models.DynamicBayesianNetwork;
import eu.amidst.dynamic.models.DynamicDAG;
import java.util.stream.Stream;

public class ParallelMLMissingData
implements ParameterLearningAlgorithm {
    protected int windowsSize = 1000;
    protected boolean parallelMode = false;
    protected DataStream<DynamicDataInstance> dataStream;
    protected DynamicDAG dag;
    protected AtomicDouble dataInstanceCount;
    protected DynamicPartialSufficientSatistics sumSS;
    protected EF_DynamicBayesianNetwork efBayesianNetwork;
    protected boolean debug = true;
    protected boolean laplace = true;

    public void setLaplace(boolean laplace) {
        this.laplace = laplace;
    }

    public void setDebug(boolean debug) {
        this.debug = debug;
    }

    @Override
    public void setWindowsSize(int windowsSize) {
    }

    @Override
    public int getWindowsSize() {
        return this.windowsSize;
    }

    @Override
    public void initLearning() {
        this.efBayesianNetwork = new EF_DynamicBayesianNetwork(this.dag);
        if (this.laplace) {
            this.sumSS = DynamicPartialSufficientSatistics.createInitPartialSufficientStatistics(this.efBayesianNetwork);
            this.dataInstanceCount = new AtomicDouble(1.0);
        } else {
            this.sumSS = DynamicPartialSufficientSatistics.createZeroPartialSufficientStatistics(this.efBayesianNetwork);
            this.dataInstanceCount = new AtomicDouble(0.0);
        }
    }

    @Override
    public double updateModel(DataOnMemory<DynamicDataInstance> batch) {
        this.sumSS.sum(batch.stream().map(dataInstance -> ParallelMLMissingData.computeCountSufficientStatistics(this.efBayesianNetwork, dataInstance)).reduce(DynamicPartialSufficientSatistics::sumNonStateless).get());
        this.dataInstanceCount.addAndGet(batch.getNumberOfDataInstances());
        return Double.NaN;
    }

    @Override
    public double updateModel(DataStream<DynamicDataInstance> dataStream) {
        Stream<DataOnMemory<DynamicDataInstance>> stream = null;
        stream = this.parallelMode ? dataStream.parallelStreamOfBatches(this.windowsSize) : dataStream.streamOfBatches(this.windowsSize);
        this.dataInstanceCount = new AtomicDouble(0.0);
        this.sumSS = stream.peek(batch -> {
            this.dataInstanceCount.getAndAdd(batch.getNumberOfDataInstances());
            if (this.debug) {
                System.out.println("Parallel ML procesando " + (int)this.dataInstanceCount.get() + " instances");
            }
        }).map(batch -> batch.stream().map(dataInstance -> ParallelMLMissingData.computeCountSufficientStatistics(this.efBayesianNetwork, dataInstance)).reduce(DynamicPartialSufficientSatistics::sumNonStateless).get()).reduce(DynamicPartialSufficientSatistics::sumNonStateless).get();
        if (this.laplace) {
            DynamicPartialSufficientSatistics initSS = DynamicPartialSufficientSatistics.createInitPartialSufficientStatistics(this.efBayesianNetwork);
            this.sumSS.sum(initSS);
        }
        return Double.NaN;
    }

    @Override
    public void setDataStream(DataStream<DynamicDataInstance> data) {
        this.dataStream = data;
    }

    @Override
    public double getLogMarginalProbability() {
        throw new UnsupportedOperationException("Method not implemented yet");
    }

    @Override
    public void runLearning() {
        this.initLearning();
        Stream<DataOnMemory<DynamicDataInstance>> stream = null;
        stream = this.parallelMode ? this.dataStream.parallelStreamOfBatches(this.windowsSize) : this.dataStream.streamOfBatches(this.windowsSize);
        this.dataInstanceCount = new AtomicDouble(0.0);
        this.sumSS = stream.peek(batch -> {
            this.dataInstanceCount.getAndAdd(batch.getNumberOfDataInstances());
            if (this.debug) {
                System.out.println("Parallel ML procesando " + (int)this.dataInstanceCount.get() + " instances");
            }
        }).map(batch -> batch.stream().map(dataInstance -> ParallelMLMissingData.computeCountSufficientStatistics(this.efBayesianNetwork, dataInstance)).reduce(DynamicPartialSufficientSatistics::sumNonStateless).get()).reduce(DynamicPartialSufficientSatistics::sumNonStateless).get();
        if (this.laplace) {
            DynamicPartialSufficientSatistics initSS = DynamicPartialSufficientSatistics.createInitPartialSufficientStatistics(this.efBayesianNetwork);
            this.sumSS.sum(initSS);
        }
    }

    private static DynamicPartialSufficientSatistics computeCountSufficientStatistics(EF_DynamicBayesianNetwork bn, DynamicDataInstance dataInstance) {
        if (dataInstance.getTimeID() == 0L) {
            return DynamicPartialSufficientSatistics.createPartialSufficientStatisticsTime0(eu.amidst.core.learning.parametric.ParallelMLMissingData.computeCountSufficientStatistics(bn.getBayesianNetworkTime0(), dataInstance));
        }
        return DynamicPartialSufficientSatistics.createPartialSufficientStatisticsTimeT(eu.amidst.core.learning.parametric.ParallelMLMissingData.computeCountSufficientStatistics(bn.getBayesianNetworkTimeT(), dataInstance));
    }

    @Override
    public void setDynamicDAG(DynamicDAG dag_) {
        this.dag = dag_;
    }

    @Override
    public void setSeed(int seed) {
    }

    @Override
    public DynamicBayesianNetwork getLearntDBN() {
        DynamicPartialSufficientSatistics partialSufficientSatistics = DynamicPartialSufficientSatistics.createZeroPartialSufficientStatistics(this.efBayesianNetwork);
        partialSufficientSatistics.copy(this.sumSS);
        partialSufficientSatistics.normalize();
        SufficientStatistics finalSS = this.efBayesianNetwork.createZeroSufficientStatistics();
        finalSS.sum(partialSufficientSatistics.getCompoundVector());
        this.efBayesianNetwork.setMomentParameters(finalSS);
        return this.efBayesianNetwork.toDynamicBayesianNetwork(this.dag);
    }

    @Override
    public void setParallelMode(boolean parallelMode_) {
        this.parallelMode = parallelMode_;
    }

    @Override
    public void setOutput(boolean activateOutput) {
    }

    static class DynamicPartialSufficientSatistics {
        ParallelMLMissingData.PartialSufficientSatistics time0;
        ParallelMLMissingData.PartialSufficientSatistics timeT;

        public DynamicPartialSufficientSatistics(ParallelMLMissingData.PartialSufficientSatistics time0, ParallelMLMissingData.PartialSufficientSatistics timeT) {
            this.time0 = time0;
            this.timeT = timeT;
        }

        public static DynamicPartialSufficientSatistics createPartialSufficientStatisticsTime0(ParallelMLMissingData.PartialSufficientSatistics partialSufficientSatistics) {
            return new DynamicPartialSufficientSatistics(partialSufficientSatistics, null);
        }

        public static DynamicPartialSufficientSatistics createPartialSufficientStatisticsTimeT(ParallelMLMissingData.PartialSufficientSatistics partialSufficientSatistics) {
            return new DynamicPartialSufficientSatistics(null, partialSufficientSatistics);
        }

        public static DynamicPartialSufficientSatistics createInitPartialSufficientStatistics(EF_DynamicBayesianNetwork ef_bayesianNetwork) {
            return new DynamicPartialSufficientSatistics(ParallelMLMissingData.PartialSufficientSatistics.createInitPartialSufficientStatistics(ef_bayesianNetwork.getBayesianNetworkTime0()), ParallelMLMissingData.PartialSufficientSatistics.createInitPartialSufficientStatistics(ef_bayesianNetwork.getBayesianNetworkTimeT()));
        }

        public static DynamicPartialSufficientSatistics createZeroPartialSufficientStatistics(EF_DynamicBayesianNetwork ef_bayesianNetwork) {
            return new DynamicPartialSufficientSatistics(ParallelMLMissingData.PartialSufficientSatistics.createZeroPartialSufficientStatistics(ef_bayesianNetwork.getBayesianNetworkTime0()), ParallelMLMissingData.PartialSufficientSatistics.createZeroPartialSufficientStatistics(ef_bayesianNetwork.getBayesianNetworkTimeT()));
        }

        public ParallelMLMissingData.PartialSufficientSatistics getTime0() {
            return this.time0;
        }

        public ParallelMLMissingData.PartialSufficientSatistics getTimeT() {
            return this.timeT;
        }

        public void normalize() {
            this.time0.normalize();
            this.timeT.normalize();
        }

        public void copy(DynamicPartialSufficientSatistics a) {
            this.time0.copy(a.getTime0());
            this.timeT.copy(a.getTimeT());
        }

        public void sum(DynamicPartialSufficientSatistics a) {
            if (a.getTime0() != null) {
                this.time0.sum(a.getTime0());
            }
            if (a.getTimeT() != null) {
                this.timeT.sum(a.getTimeT());
            }
        }

        public static DynamicPartialSufficientSatistics sumNonStateless(DynamicPartialSufficientSatistics a, DynamicPartialSufficientSatistics b) {
            if (b.getTime0() == null) {
                b.time0 = a.getTime0();
            } else if (a.getTime0() != null) {
                b.getTime0().sum(a.getTime0());
            }
            if (b.getTimeT() == null) {
                b.timeT = a.getTimeT();
            } else if (a.getTimeT() != null) {
                b.getTimeT().sum(a.getTimeT());
            }
            return b;
        }

        public EF_DynamicBayesianNetwork.DynamiceBNCompoundVector getCompoundVector() {
            EF_DynamicBayesianNetwork.DynamiceBNCompoundVector vector = new EF_DynamicBayesianNetwork.DynamiceBNCompoundVector(this.getTime0().getCompoundVector(), this.getTimeT().getCompoundVector());
            vector.setIndicatorTime0(1.0);
            vector.setIndicatorTimeT(1.0);
            return vector;
        }
    }
}

