/*
 * Decompiled with CFR 0.152.
 */
package eu.amidst.dynamic.inference;

import eu.amidst.core.datastream.DataStream;
import eu.amidst.core.distribution.ConditionalDistribution;
import eu.amidst.core.distribution.Distribution;
import eu.amidst.core.distribution.UnivariateDistribution;
import eu.amidst.core.inference.ImportanceSampling;
import eu.amidst.core.inference.InferenceAlgorithm;
import eu.amidst.core.inference.messagepassing.VMP;
import eu.amidst.core.models.BayesianNetwork;
import eu.amidst.core.utils.Serialization;
import eu.amidst.core.utils.Utils;
import eu.amidst.core.variables.Assignment;
import eu.amidst.core.variables.HashMapAssignment;
import eu.amidst.core.variables.Variable;
import eu.amidst.dynamic.datastream.DynamicDataInstance;
import eu.amidst.dynamic.inference.DynamicVMP;
import eu.amidst.dynamic.inference.InferenceAlgorithmForDBN;
import eu.amidst.dynamic.inference.InferenceEngineForDBN;
import eu.amidst.dynamic.learning.parametric.DynamicNaiveBayesClassifier;
import eu.amidst.dynamic.models.DynamicBayesianNetwork;
import eu.amidst.dynamic.utils.DataSetGenerator;
import eu.amidst.dynamic.variables.DynamicAssignment;
import eu.amidst.dynamic.variables.HashMapDynamicAssignment;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;

public class FactoredFrontierForDBN
implements InferenceAlgorithmForDBN {
    private InferenceAlgorithm infAlgTime0;
    private InferenceAlgorithm infAlgTimeT;
    private BayesianNetwork bnTime0;
    private BayesianNetwork bnTimeT;
    private DynamicBayesianNetwork model;
    private DynamicAssignment assignment = new HashMapDynamicAssignment(0);
    private long timeID;
    private long sequenceID;

    public FactoredFrontierForDBN(InferenceAlgorithm inferenceAlgorithm) {
        this.infAlgTime0 = inferenceAlgorithm;
        this.infAlgTimeT = Serialization.deepCopy(inferenceAlgorithm);
        this.timeID = -1L;
        this.setSeed(0);
    }

    public void setSeed(int seed) {
        this.infAlgTime0.setSeed(seed);
        this.infAlgTimeT.setSeed(seed);
    }

    private List<Variable> getTargetVarsTimeT() {
        return this.model.getDynamicVariables().getListOfDynamicVariables().stream().filter(var -> !var.isInterfaceVariable()).filter(var -> Utils.isMissingValue(this.assignment.getValue((Variable)var))).filter(var -> {
            boolean notContainInterfaceVar = true;
            for (Variable variable : this.model.getDynamicDAG().getParentSetTimeT((Variable)var)) {
                notContainInterfaceVar = notContainInterfaceVar && !variable.isInterfaceVariable();
            }
            return !notContainInterfaceVar;
        }).collect(Collectors.toList());
    }

    private List<Variable> getTargetVarsTime0() {
        return this.model.getDynamicVariables().getListOfDynamicVariables().stream().filter(var -> Utils.isMissingValue(this.assignment.getValue((Variable)var))).collect(Collectors.toList());
    }

    @Override
    public void runInference() {
        if (this.timeID == -1L && this.assignment.getTimeID() > 0L) {
            this.infAlgTime0.setModel(this.bnTime0);
            this.infAlgTime0.setEvidence(null);
            this.infAlgTime0.runInference();
            this.timeID = 0L;
            this.getTargetVarsTime0().stream().forEach(var -> this.moveNodeQDist(this.infAlgTime0, this.bnTimeT, (Variable)var));
        }
        if (this.assignment.getTimeID() == 0L) {
            this.infAlgTime0.setModel(this.bnTime0);
            this.infAlgTime0.setEvidence(this.updateDynamicAssignmentTime0(this.assignment));
            this.infAlgTime0.runInference();
            this.timeID = 0L;
            this.getTargetVarsTimeT().stream().forEach(var -> this.moveNodeQDist(this.infAlgTime0, this.bnTimeT, (Variable)var));
        } else {
            if (this.assignment.getTimeID() - this.timeID > 1L) {
                this.moveWindow((int)(this.assignment.getTimeID() - this.timeID - 1L));
            }
            this.timeID = this.assignment.getTimeID();
            this.infAlgTimeT.setModel(this.bnTimeT);
            this.infAlgTimeT.setEvidence(this.updateDynamicAssignmentTimeT(this.assignment));
            this.infAlgTimeT.runInference();
            this.getTargetVarsTimeT().stream().forEach(var -> this.moveNodeQDist(this.infAlgTimeT, this.bnTimeT, (Variable)var));
        }
    }

    private void moveNodeQDist(InferenceAlgorithm infAlg, BayesianNetwork bnTo, Variable var) {
        Variable temporalClone = this.model.getDynamicVariables().getInterfaceVariable(var);
        UnivariateDistribution posteriorDist = ((UnivariateDistribution)infAlg.getPosterior(var)).deepCopy(temporalClone);
        bnTo.setConditionalDistribution(temporalClone, posteriorDist);
    }

    private void moveWindow(int nsteps) {
        HashMapDynamicAssignment newassignment = null;
        if (this.assignment != null) {
            newassignment = new HashMapDynamicAssignment(this.model.getNumberOfDynamicVars());
            for (Variable var2 : this.model.getDynamicVariables()) {
                newassignment.setValue(this.model.getDynamicVariables().getInterfaceVariable(var2), this.assignment.getValue(var2));
                newassignment.setValue(var2, Utils.missingValue());
            }
        }
        for (int i = 0; i < nsteps; ++i) {
            this.infAlgTimeT.setModel(this.bnTimeT);
            this.infAlgTimeT.setEvidence(this.updateDynamicAssignmentTimeT(newassignment));
            this.infAlgTimeT.runInference();
            this.getTargetVarsTimeT().stream().forEach(var -> this.moveNodeQDist(this.infAlgTimeT, this.bnTimeT, (Variable)var));
        }
    }

    @Override
    public void setModel(DynamicBayesianNetwork model_) {
        this.model = model_;
        this.bnTime0 = this.model.toBayesianNetworkTime0();
        this.bnTimeT = this.model.toBayesianNetworkTimeT();
    }

    @Override
    public DynamicBayesianNetwork getOriginalModel() {
        return this.model;
    }

    @Override
    public void addDynamicEvidence(DynamicAssignment assignment_) {
        if (this.sequenceID != -1L && this.sequenceID != assignment_.getSequenceID()) {
            throw new IllegalArgumentException("The sequence ID does not match. If you want to change the sequence, invoke reset method");
        }
        if (this.timeID >= assignment_.getTimeID()) {
            throw new IllegalArgumentException("The provided assignment is not posterior to the previous provided assignment.");
        }
        this.assignment = assignment_;
    }

    @Override
    public void reset() {
        this.timeID = -1L;
        this.sequenceID = -1L;
        this.resetInfAlgorithms();
    }

    private void resetInfAlgorithms() {
        this.infAlgTime0.setModel(this.model.toBayesianNetworkTime0());
        this.infAlgTimeT.setModel(this.model.toBayesianNetworkTimeT());
    }

    @Override
    public <E extends UnivariateDistribution> E getFilteredPosterior(Variable var) {
        if (this.getTimeIDOfPosterior() == 0L) {
            return this.infAlgTime0.getPosterior(var);
        }
        return this.infAlgTimeT.getPosterior(var);
    }

    @Override
    public <E extends UnivariateDistribution> E getPredictivePosterior(Variable var, int nTimesAhead) {
        if (this.timeID == -1L) {
            this.infAlgTime0.setModel(this.bnTime0);
            this.infAlgTime0.setEvidence(null);
            this.infAlgTime0.runInference();
            this.getTargetVarsTimeT().stream().forEach(v -> this.moveNodeQDist(this.infAlgTime0, this.bnTimeT, (Variable)v));
            this.moveWindow(nTimesAhead - 1);
            E resultQ = this.getFilteredPosterior(var);
            this.resetInfAlgorithms();
            return resultQ;
        }
        if (this.timeID == 0L) {
            this.moveWindow(nTimesAhead);
            E resultQ = this.getFilteredPosterior(var);
            this.getTargetVarsTime0().stream().forEach(v -> this.moveNodeQDist(this.infAlgTime0, this.bnTimeT, (Variable)v));
            return resultQ;
        }
        HashMap map = new HashMap();
        this.getTargetVarsTimeT().stream().forEach(v -> map.put(v, ((UnivariateDistribution)this.infAlgTimeT.getPosterior((Variable)v)).deepCopy(v.getInterfaceVariable())));
        this.moveWindow(nTimesAhead);
        E resultQ = this.getFilteredPosterior(var);
        map.entrySet().forEach(e -> this.bnTimeT.setConditionalDistribution(((Variable)e.getKey()).getInterfaceVariable(), (ConditionalDistribution)e.getValue()));
        return resultQ;
    }

    @Override
    public long getTimeIDOfLastEvidence() {
        return this.assignment.getTimeID();
    }

    @Override
    public long getTimeIDOfPosterior() {
        return this.timeID;
    }

    private Assignment updateDynamicAssignmentTime0(DynamicAssignment dynamicAssignment) {
        HashMapAssignment assignment = new HashMapAssignment();
        this.model.getDynamicVariables().getListOfDynamicVariables().stream().forEach(var -> {
            double value = dynamicAssignment.getValue((Variable)var);
            assignment.setValue((Variable)var, value);
        });
        return assignment;
    }

    private Assignment updateDynamicAssignmentTimeT(DynamicAssignment dynamicAssignment) {
        HashMapAssignment assignment = new HashMapAssignment();
        this.model.getDynamicVariables().getListOfDynamicVariables().stream().forEach(var -> {
            double value = dynamicAssignment.getValue((Variable)var);
            assignment.setValue((Variable)var, value);
        });
        this.model.getDynamicVariables().getListOfDynamicVariables().stream().filter(var -> {
            boolean notContainInterfaceVar = true;
            for (Variable variable : this.model.getDynamicDAG().getParentSetTimeT((Variable)var)) {
                notContainInterfaceVar = notContainInterfaceVar && !variable.isInterfaceVariable();
            }
            return !notContainInterfaceVar;
        }).forEach(var -> {
            Variable var_interface = var.getInterfaceVariable();
            double value_interface = dynamicAssignment.getValue(var_interface);
            assignment.setValue(var_interface, value_interface);
        });
        return assignment;
    }

    public static void main(String[] arguments) throws IOException, ClassNotFoundException {
        DataStream<DynamicDataInstance> data = DataSetGenerator.generate(15, 10000, 10, 0);
        DynamicNaiveBayesClassifier model = new DynamicNaiveBayesClassifier();
        model.setClassVarID(0);
        model.setParallelMode(true);
        model.learn(data);
        DynamicBayesianNetwork bn = model.getDynamicBNModel();
        bn.randomInitialization(new Random(0L));
        System.out.println(bn.toString());
        DataStream<DynamicDataInstance> dataPredict = DataSetGenerator.generate(50, 1000, 10, 0);
        Variable targetVar = bn.getDynamicVariables().getVariableByName("DiscreteVar0");
        Distribution dist = null;
        Distribution distAhead = null;
        InferenceEngineForDBN.setInferenceAlgorithmForDBN(new DynamicVMP());
        InferenceEngineForDBN.setModel(bn);
        System.out.println("---------------- Dynamic VMP----------------");
        for (DynamicDataInstance dynamicDataInstance : dataPredict) {
            if (dynamicDataInstance.getTimeID() == 0L && dist != null) {
                System.out.println("\nNew sequence #" + dynamicDataInstance.getSequenceID());
                InferenceEngineForDBN.reset();
            }
            dynamicDataInstance.setValue(targetVar, Utils.missingValue());
            InferenceEngineForDBN.addDynamicEvidence(dynamicDataInstance);
            InferenceEngineForDBN.runInference();
            dist = (Distribution)InferenceEngineForDBN.getFilteredPosterior(targetVar);
            System.out.println("[" + dynamicDataInstance.getSequenceID() + "," + dynamicDataInstance.getTimeID() + "]" + dist.toString());
            distAhead = (Distribution)InferenceEngineForDBN.getPredictivePosterior(targetVar, 1);
            System.out.println("PP: " + distAhead.toString());
        }
        System.out.println("---------------- FF - VMP--------------");
        FactoredFrontierForDBN FFalgorithm = new FactoredFrontierForDBN(new VMP());
        InferenceEngineForDBN.setInferenceAlgorithmForDBN(FFalgorithm);
        InferenceEngineForDBN.setModel(bn);
        dist = null;
        dataPredict = DataSetGenerator.generate(50, 1000, 10, 0);
        for (DynamicDataInstance instance : dataPredict) {
            if (instance.getTimeID() == 0L && dist != null) {
                System.out.println("\nNew sequence #" + instance.getSequenceID());
                InferenceEngineForDBN.reset();
            }
            instance.setValue(targetVar, Utils.missingValue());
            InferenceEngineForDBN.addDynamicEvidence(instance);
            InferenceEngineForDBN.runInference();
            dist = InferenceEngineForDBN.getFilteredPosterior(targetVar);
            System.out.println("[" + instance.getSequenceID() + "," + instance.getTimeID() + "]" + dist.toString());
            distAhead = InferenceEngineForDBN.getPredictivePosterior(targetVar, 1);
            System.out.println("PP: " + distAhead.toString());
        }
        System.out.println("---------------- FF - Importance Sampling--------------");
        ImportanceSampling importanceSampling = new ImportanceSampling();
        importanceSampling.setKeepDataOnMemory(true);
        FFalgorithm = new FactoredFrontierForDBN(importanceSampling);
        InferenceEngineForDBN.setInferenceAlgorithmForDBN(FFalgorithm);
        InferenceEngineForDBN.setModel(bn);
        dist = null;
        dataPredict = DataSetGenerator.generate(50, 1000, 10, 0);
        for (DynamicDataInstance instance : dataPredict) {
            if (instance.getTimeID() == 0L && dist != null) {
                System.out.println("\nNew sequence #" + instance.getSequenceID());
                InferenceEngineForDBN.reset();
            }
            instance.setValue(targetVar, Utils.missingValue());
            InferenceEngineForDBN.addDynamicEvidence(instance);
            InferenceEngineForDBN.runInference();
            dist = InferenceEngineForDBN.getFilteredPosterior(targetVar);
            System.out.println("[" + instance.getSequenceID() + "," + instance.getTimeID() + "]" + dist.toString());
            distAhead = InferenceEngineForDBN.getPredictivePosterior(targetVar, 1);
            System.out.println("PP: " + distAhead.toString());
        }
    }
}

