/*
 * Decompiled with CFR 0.152.
 */
package ru.itmo.ctlab.virgo.sgmwcs.solver;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.ToDoubleFunction;
import ru.itmo.ctlab.virgo.sgmwcs.Signals;
import ru.itmo.ctlab.virgo.sgmwcs.graph.Edge;
import ru.itmo.ctlab.virgo.sgmwcs.graph.Graph;
import ru.itmo.ctlab.virgo.sgmwcs.graph.Node;
import ru.itmo.ctlab.virgo.sgmwcs.graph.Unit;
import ru.itmo.ctlab.virgo.sgmwcs.solver.Utils;

class Dijkstra {
    private final Graph graph;
    private final Signals signals;
    private Map<Node, Double> d;
    private Map<Unit, Set<Integer>> p;
    private Map<Node, Node> path;
    private Set<Edge> touched;
    private Set<Node> dests;
    private Set<Integer> currentSignals;

    private double currentWeight() {
        double w = 0.0;
        for (int sig : this.currentSignals) {
            w -= Math.min(0.0, this.signals.weight(sig));
        }
        return w;
    }

    private double weight(Node n) {
        return this.d.getOrDefault(n, (Double)Double.MAX_VALUE);
    }

    Dijkstra(Graph graph, Signals signals) {
        this.graph = graph;
        this.signals = signals;
        this.dests = new HashSet<Node>();
    }

    public void solve(Node u) {
        Node cur;
        this.d = new HashMap<Node, Double>();
        this.p = new HashMap<Unit, Set<Integer>>();
        PriorityQueue<Node> q = new PriorityQueue<Node>(Comparator.comparingDouble(this::weight));
        this.currentSignals = new HashSet<Integer>();
        q.add(u);
        this.d.put(u, 0.0);
        this.p.put(u, new HashSet<Integer>(this.signals.positiveUnitSets((Unit)u)));
        this.touched = new HashSet<Edge>();
        this.path = new HashMap<Node, Node>();
        ArrayList<Integer> addedE = new ArrayList<Integer>();
        ArrayList<Integer> addedN = new ArrayList<Integer>();
        HashSet<Node> visitedDests = new HashSet<Node>();
        this.path.put(u, u);
        while ((cur = q.poll()) != null) {
            if (visitedDests.contains(cur)) continue;
            if (this.dests.contains(cur) && visitedDests.add(cur) && visitedDests.containsAll(this.dests)) break;
            this.currentSignals = this.p.getOrDefault(cur, new HashSet());
            for (Node node : this.graph.neighborListOf(cur)) {
                double cw = this.currentWeight();
                List<Integer> negN = this.signals.unitSets((Unit)node);
                double sumN = 0.0;
                Iterator<Comparable<Integer>> iterator2 = negN.iterator();
                while (iterator2.hasNext()) {
                    int i = iterator2.next();
                    if (!this.currentSignals.add(i)) continue;
                    addedN.add(i);
                    if (!(this.signals.weight(i) < 0.0)) continue;
                    sumN -= this.signals.weight(i);
                }
                cw += sumN;
                for (Edge edge : this.graph.getAllEdges(node, cur)) {
                    List<Integer> negE2 = this.signals.unitSets((Unit)edge);
                    double sumE = 0.0;
                    for (int i : negE2) {
                        if (!this.currentSignals.add(i)) continue;
                        addedE.add(i);
                        if (!(this.signals.weight(i) < 0.0)) continue;
                        sumE -= this.signals.weight(i);
                    }
                    if ((cw += sumE) < this.weight(node)) {
                        q.remove(node);
                        this.d.put(node, cw);
                        this.p.put(node, new HashSet<Integer>(this.currentSignals));
                        q.add(node);
                        this.graph.getAllEdges(node, cur).forEach(this.touched::remove);
                        this.touched.add(edge);
                        this.path.putIfAbsent(node, cur);
                    }
                    addedE.forEach(this.currentSignals::remove);
                    addedE.clear();
                    cw -= sumE;
                }
                addedN.forEach(this.currentSignals::remove);
                addedN.clear();
            }
        }
    }

    boolean solveNP(Node u) {
        List<Node> nbors = this.graph.neighborListOf(u);
        if (nbors.size() != 2) {
            return false;
        }
        Node v_1 = nbors.get(0);
        Node v_2 = nbors.get(1);
        this.dests.add(v_2);
        this.solve(v_1);
        HashSet<Integer> neg = new HashSet<Integer>(this.signals.negativeUnitSets((Unit)u));
        neg.addAll(this.signals.negativeUnitSets(this.graph.edgesOf(u)));
        if (this.p.get(v_2).containsAll(neg)) {
            return false;
        }
        HashSet<Integer> pos = new HashSet<Integer>(this.signals.positiveUnitSets((Unit)u));
        pos.addAll(this.signals.positiveUnitSets(this.graph.edgesOf(u)));
        pos.removeAll(this.signals.positiveUnitSets(v_1, v_2));
        return this.p.get(v_2).containsAll(pos) || -(this.signals.sum(this.graph.edgesOf(u)) + this.signals.weight(u)) > this.d.get(v_2);
    }

    Set<Edge> solveNE(Node u, List<Node> neighbors) {
        this.dests = new HashSet<Node>(neighbors);
        this.solve(u);
        HashSet<Edge> res = new HashSet<Edge>();
        neighbors.forEach(n -> {
            List<Edge> edges = this.graph.getAllEdges((Node)n, u);
            this.p.get(n).removeAll(this.signals.unitSets((Unit)u, (Unit)n));
            for (Edge e : edges) {
                if (this.p.get(n).containsAll(this.signals.negativeUnitSets((Unit)e))) continue;
                res.add(e);
            }
        });
        return res;
    }

    Map<Node, Double> distances() {
        return this.d;
    }

    Set<Unit> getPath(Node n) {
        HashSet<Unit> result = new HashSet<Unit>();
        while (!this.path.get(n).equals(n)) {
            result.add(n);
            for (Edge e : this.graph.getAllEdges(this.path.get(n), n)) {
                if (!this.touched.contains(e)) continue;
                result.add(e);
                break;
            }
            n = this.path.get(n);
        }
        result.add(n);
        return result;
    }

    public Set<Unit> greedyHeuristic(Node rt, List<Unit> absorbed) {
        Node r = new Node(rt);
        Graph graph = new Graph();
        Signals signals = new Signals();
        Utils.copy(this.graph, this.signals, graph, signals);
        ArrayList<Node> nodes = new ArrayList<Node>(graph.vertexSet());
        this.solve(r);
        ToDoubleFunction<Node> w = n -> {
            Set<Unit> path = this.getPath((Node)n);
            r.getAbsorbed().forEach(path::remove);
            Set<Integer> sig = signals.unitSets(path);
            sig.addAll(signals.unitSets((Unit)n, r));
            return signals.weightSum(sig);
        };
        nodes.remove(r);
        Node[] sorted2 = (Node[])nodes.stream().sorted(Comparator.comparingDouble(w)).toArray(Node[]::new);
        if (sorted2.length == 0 || w.applyAsDouble(sorted2[sorted2.length - 1]) <= signals.weight(r)) {
            return new HashSet<Unit>(absorbed);
        }
        for (int i = sorted2.length - 1; i >= 0; --i) {
            Node v = sorted2[i];
            if (r.getAbsorbed().contains(v) || v == r || w.applyAsDouble(v) <= signals.weight(r)) continue;
            Set<Unit> pt = this.getPath(v);
            pt.remove(r);
            r.getAbsorbed().forEach(pt::remove);
            Consumer<Unit> absorb = u -> {
                r.absorb((Unit)u, false);
                signals.join((Unit)u, r);
                if (u instanceof Node) {
                    Node n = (Node)u;
                    this.path.put(n, n);
                    for (Edge e : graph.edgesOf(n)) {
                        Node m = graph.getOppositeVertex(n, e);
                        if (pt.contains(m)) continue;
                        graph.removeEdge(e);
                        graph.addEdge(r, m, e);
                    }
                }
                if (graph.containsUnit((Unit)u)) {
                    graph.removeUnit((Unit)u);
                }
            };
            pt.forEach(absorb);
        }
        return new Dijkstra(graph, signals).greedyHeuristic(r, r.getAbsorbed());
    }
}

