/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.lops.compile.linearization;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.compile.linearization.IDagLinearizer;

public class LinearizerResourceAwareFast
extends IDagLinearizer {
    List<Lop> remaining;

    @Override
    public List<Lop> linearize(List<Lop> dag) {
        ArrayList<List<Lop>> sequences = new ArrayList<List<Lop>>();
        this.remaining = new ArrayList<Lop>(dag);
        List outputNodes = this.remaining.stream().filter(node -> node.getOutputs().isEmpty()).collect(Collectors.toList());
        for (Lop outputNode : outputNodes) {
            sequences.add(this.findSequence(outputNode));
        }
        while (!this.remaining.isEmpty()) {
            int maxLevel = this.remaining.stream().mapToInt(Lop::getLevel).max().getAsInt();
            Lop node2 = this.remaining.stream().filter(n -> n.getLevel() == maxLevel).findFirst().orElseThrow();
            sequences.add(this.findSequence(node2));
        }
        return this.scheduleSequences(sequences);
    }

    List<Lop> scheduleSequences(List<List<Lop>> sequences) {
        HashSet<ArrayList<Integer>> visited = new HashSet<ArrayList<Integer>>();
        ArrayList<Item> scheduledItems = new ArrayList<Item>();
        Set<Dependency> dependencies = this.getDependencies(sequences);
        List sequencesMaxIndex = sequences.stream().map(entry -> entry.size() - 1).collect(Collectors.toList());
        Item currentItem = new Item(new ArrayList<Integer>(), Collections.nCopies(sequences.size(), -1), new HashSet<Intermediate>(), 0.0);
        while (!currentItem.getCurrent().equals(sequencesMaxIndex)) {
            for (int i = 0; i < sequences.size(); ++i) {
                List<Lop> sequence = sequences.get(i);
                if (currentItem.getCurrent().get(i) + 1 >= sequence.size()) continue;
                ArrayList<Integer> newCurrent = new ArrayList<Integer>(currentItem.getCurrent());
                newCurrent.set(i, (Integer)newCurrent.get(i) + 1);
                if (visited.contains(newCurrent)) continue;
                Set filteredDependencies = dependencies.stream().filter(entry -> entry.getNodeIndex() == ((Integer)newCurrent.get(entry.getSequenceIndex())).intValue()).collect(Collectors.toSet());
                boolean dependencyIssue = filteredDependencies.parallelStream().anyMatch(dependency -> IntStream.range(0, newCurrent.size()).anyMatch(j -> j != dependency.getSequenceIndex() && (Integer)newCurrent.get(j) < dependency.getDependencies().get(j)));
                if (!dependencyIssue) {
                    HashSet<Intermediate> newIntermediates = new HashSet<Intermediate>(currentItem.getIntermediates());
                    Lop nextLop = sequence.get((Integer)newCurrent.get(i));
                    Iterator intermediateIter = newIntermediates.iterator();
                    while (intermediateIter.hasNext()) {
                        Intermediate entry2 = (Intermediate)intermediateIter.next();
                        entry2.remove(nextLop.getID());
                        if (!entry2.getLopIDs().isEmpty()) continue;
                        intermediateIter.remove();
                    }
                    newIntermediates.add(new Intermediate(nextLop.getOutputs().stream().map(Lop::getID).collect(Collectors.toList()), nextLop.getOutputMemoryEstimate()));
                    ArrayList<Integer> newSteps = new ArrayList<Integer>(currentItem.getSteps());
                    newSteps.add(i);
                    double mem = newIntermediates.stream().map(Intermediate::getMemoryUsage).reduce(0.0, Double::sum);
                    Item newItem = new Item(newSteps, newCurrent, newIntermediates, Math.max(mem, currentItem.getMaxMemoryUsage()));
                    int index = Collections.binarySearch(scheduledItems, newItem, Comparator.comparing(Item::getMaxMemoryUsage));
                    if (index < 0) {
                        index = -index - 1;
                    }
                    scheduledItems.add(index, newItem);
                }
                visited.add(newCurrent);
            }
            currentItem = (Item)scheduledItems.remove(0);
        }
        return this.walkPath(sequences, currentItem.getSteps());
    }

    List<Lop> walkPath(List<List<Lop>> sequences, List<Integer> path) {
        Iterator<Integer> iterator = path.iterator();
        ArrayList<Lop> sequence = new ArrayList<Lop>();
        while (iterator.hasNext()) {
            sequence.add(sequences.get(iterator.next()).remove(0));
        }
        return sequence;
    }

    List<Lop> findSequence(Lop startNode) {
        ArrayList<Lop> sequence = new ArrayList<Lop>();
        Lop currentNode = startNode;
        sequence.add(currentNode);
        this.remaining.remove(currentNode);
        while (currentNode.getInputs().size() == 1) {
            if (this.remaining.contains(currentNode.getInput(0))) {
                currentNode = currentNode.getInput(0);
                sequence.add(currentNode);
                this.remaining.remove(currentNode);
                continue;
            }
            Collections.reverse(sequence);
            return sequence;
        }
        Collections.reverse(sequence);
        ArrayList<Lop> children = currentNode.getInputs();
        if (children.isEmpty()) {
            return sequence;
        }
        ArrayList<List<Lop>> childSequences = new ArrayList<List<Lop>>();
        for (Lop child : children) {
            if (!this.remaining.contains(child)) continue;
            childSequences.add(this.findSequence(child));
        }
        List<Lop> finalSequence = this.scheduleSequences(childSequences);
        return Stream.concat(finalSequence.stream(), sequence.stream()).collect(Collectors.toList());
    }

    Set<Dependency> getDependencies(List<List<Lop>> sequences) {
        HashSet<Dependency> dependencies = new HashSet<Dependency>();
        List sequencesLopIDs = sequences.stream().map(sequence -> sequence.stream().map(Lop::getID).collect(Collectors.toList())).collect(Collectors.toList());
        int lastSequenceWithOutput = -1;
        for (int j = 0; j < sequences.size(); ++j) {
            List<Lop> sequence2 = sequences.get(j);
            int sequenceSize = sequence2.size();
            int sequenceIndex = j;
            sequence2.get(0).getInputs().forEach(input -> {
                long inputID = input.getID();
                List<Integer> dependencyIndices = sequencesLopIDs.stream().map(list -> list.contains(inputID) ? list.indexOf(inputID) : -1).collect(Collectors.toList());
                dependencies.add(new Dependency(sequenceIndex, 0, dependencyIndices));
            });
            for (int k = 0; k < sequenceSize; ++k) {
                int finalK = k;
                int finalJ = j;
                sequence2.get(k).getInputs().forEach(input -> {
                    long inputID = input.getID();
                    if (!((List)sequencesLopIDs.get(finalJ)).contains(inputID)) {
                        List<Integer> dependencyIndices = sequencesLopIDs.stream().map(list -> list.contains(inputID) ? list.indexOf(inputID) : -1).collect(Collectors.toList());
                        dependencies.add(new Dependency(finalJ, finalK, dependencyIndices));
                    }
                });
            }
            if (!sequence2.get(sequenceSize - 1).getOutputs().isEmpty()) continue;
            if (lastSequenceWithOutput != -1) {
                ArrayList<Integer> dependencyList = new ArrayList<Integer>(Collections.nCopies(sequences.size(), -1));
                dependencyList.set(lastSequenceWithOutput, sequences.get(lastSequenceWithOutput).size() - 1);
                dependencies.add(new Dependency(j, sequenceSize - 1, dependencyList));
            }
            lastSequenceWithOutput = j;
        }
        return dependencies;
    }

    static class Intermediate {
        List<Long> lopIDs;
        double memoryUsage;

        Intermediate(List<Long> lopIDs, double memoryUsage) {
            this.lopIDs = lopIDs;
            this.memoryUsage = memoryUsage;
        }

        void remove(long ID) {
            this.lopIDs.remove(ID);
        }

        public List<Long> getLopIDs() {
            return this.lopIDs;
        }

        public double getMemoryUsage() {
            return this.memoryUsage;
        }
    }

    static class Item {
        List<Integer> steps;
        List<Integer> current;
        Set<Intermediate> intermediates;
        double maxMemoryUsage;

        Item(List<Integer> steps, List<Integer> current, Set<Intermediate> intermediates, double maxMemoryUsage) {
            this.steps = steps;
            this.current = current;
            this.intermediates = intermediates;
            this.maxMemoryUsage = maxMemoryUsage;
        }

        public List<Integer> getSteps() {
            return this.steps;
        }

        public List<Integer> getCurrent() {
            return this.current;
        }

        public double getMaxMemoryUsage() {
            return this.maxMemoryUsage;
        }

        public Set<Intermediate> getIntermediates() {
            return this.intermediates;
        }
    }

    static class Dependency {
        int nodeIndex;
        int sequenceIndex;
        List<Integer> dependencies;

        Dependency(int sequenceIndex, int nodeIndex, List<Integer> dependencies) {
            this.sequenceIndex = sequenceIndex;
            this.nodeIndex = nodeIndex;
            this.dependencies = dependencies;
        }

        public int getSequenceIndex() {
            return this.sequenceIndex;
        }

        public int getNodeIndex() {
            return this.nodeIndex;
        }

        public List<Integer> getDependencies() {
            return this.dependencies;
        }
    }
}

