/*
 * Decompiled with CFR 0.152.
 */
package de.peeeq.wurstscript.translation.imtranslation;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import de.peeeq.datastructures.GraphInterpreter;
import de.peeeq.wurstio.TimeTaker;
import de.peeeq.wurstscript.WurstOperator;
import de.peeeq.wurstscript.jassIm.Element;
import de.peeeq.wurstscript.jassIm.ImExpr;
import de.peeeq.wurstscript.jassIm.ImExprOpt;
import de.peeeq.wurstscript.jassIm.ImExprs;
import de.peeeq.wurstscript.jassIm.ImFuncRef;
import de.peeeq.wurstscript.jassIm.ImFunction;
import de.peeeq.wurstscript.jassIm.ImFunctionCall;
import de.peeeq.wurstscript.jassIm.ImProg;
import de.peeeq.wurstscript.jassIm.ImReturn;
import de.peeeq.wurstscript.jassIm.ImStmt;
import de.peeeq.wurstscript.jassIm.ImStmts;
import de.peeeq.wurstscript.jassIm.ImType;
import de.peeeq.wurstscript.jassIm.ImTypeArgument;
import de.peeeq.wurstscript.jassIm.ImTypeVar;
import de.peeeq.wurstscript.jassIm.ImVar;
import de.peeeq.wurstscript.jassIm.ImVarAccess;
import de.peeeq.wurstscript.jassIm.ImVoid;
import de.peeeq.wurstscript.jassIm.JassIm;
import de.peeeq.wurstscript.translation.imtranslation.CallType;
import de.peeeq.wurstscript.translation.imtranslation.ImHelper;
import de.peeeq.wurstscript.translation.imtranslation.ImTranslator;
import de.peeeq.wurstscript.types.WurstTypeInt;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;

public class CyclicFunctionRemover {
    private final ImProg prog;
    private final TimeTaker timeTaker;
    private final ImTranslator tr;
    private final ImFuncGraph graph;
    private final Map<String, ImVar> tempReturnVars = Maps.newLinkedHashMap();

    public CyclicFunctionRemover(ImTranslator tr, ImProg prog, TimeTaker timeTaker) {
        this.tr = tr;
        this.prog = prog;
        this.timeTaker = timeTaker;
        this.graph = new ImFuncGraph();
    }

    public void work() {
        this.tr.calculateCallRelationsAndUsedVariables();
        AtomicReference components = new AtomicReference();
        this.timeTaker.measure("finding cycles", () -> components.set(this.graph.findStronglyConnectedComponents(this.prog.getFunctions())));
        this.timeTaker.measure("removing cycles", () -> this.removeCycles(components));
    }

    private void removeCycles(AtomicReference<Set<Set<ImFunction>>> components) {
        for (Set<ImFunction> component : components.get()) {
            if (component.size() <= 1) continue;
            this.removeCycle((List<ImFunction>)ImmutableList.copyOf(component), component);
        }
    }

    private void removeCycle(List<ImFunction> funcs, Set<ImFunction> funcSet) {
        ArrayList newParameters = Lists.newArrayList();
        LinkedHashMap oldToNewVar = Maps.newLinkedHashMap();
        this.calculateNewParameters(funcs, newParameters, oldToNewVar);
        de.peeeq.wurstscript.ast.Element trace = funcs.get(0).getTrace();
        ImVar choiceVar = JassIm.ImVar(trace, WurstTypeInt.instance().imTranslateType(this.tr), "funcChoice", false);
        ArrayList flags = Lists.newArrayList();
        ImFunction newFunc = JassIm.ImFunction(trace, this.makeName(funcs), JassIm.ImTypeVars(new ImTypeVar[0]), JassIm.ImVars(new ImVar[0]), JassIm.ImVoid(), JassIm.ImVars(new ImVar[0]), JassIm.ImStmts(new ImStmt[0]), flags);
        this.prog.getFunctions().add(newFunc);
        newFunc.getParameters().add(choiceVar);
        newFunc.getParameters().addAll((Collection)newParameters);
        ImStmts stmts = newFunc.getBody();
        for (int i = 0; i < funcs.size(); ++i) {
            ImFunction f = funcs.get(i);
            ImStmts thenBlock = JassIm.ImStmts(new ImStmt[0]);
            thenBlock.addAll((Collection)f.getBody().removeAll());
            newFunc.getLocals().addAll((Collection)f.getLocals().removeAll());
            this.replaceVars(thenBlock, oldToNewVar);
            if (!(f.getReturnType() instanceof ImVoid)) {
                this.replaceReturn(thenBlock, f.getReturnType());
            }
            ImStmts elseBlock = JassIm.ImStmts(new ImStmt[0]);
            stmts.add(JassIm.ImIf(trace, JassIm.ImOperatorCall(WurstOperator.EQ, JassIm.ImExprs(JassIm.ImVarAccess(choiceVar), JassIm.ImIntVal(i))), thenBlock, elseBlock));
            stmts = elseBlock;
        }
        HashMap<ImFunction, Integer> funcToIndex = new HashMap<ImFunction, Integer>();
        for (int i = 0; i < funcs.size(); ++i) {
            funcToIndex.put(funcs.get(i), i);
        }
        this.replaceCalls(funcSet, funcToIndex, newFunc, oldToNewVar, this.prog);
        for (ImFunction e : Lists.newArrayList((Iterable)this.tr.getCalledFunctions().keys())) {
            Collection called = this.tr.getCalledFunctions().get((Object)e);
            called.removeAll(funcs);
        }
        this.prog.getFunctions().removeAll(funcs);
    }

    private void replaceVars(Element e, Map<ImVar, ImVar> oldToNewVar) {
        ImVarAccess va;
        ImVar newVar;
        for (int i = 0; i < e.size(); ++i) {
            this.replaceVars(e.get(i), oldToNewVar);
        }
        if (e instanceof ImVarAccess && (newVar = oldToNewVar.get((va = (ImVarAccess)e).getVar())) != null) {
            va.setVar(newVar);
        }
    }

    private void replaceCalls(Set<ImFunction> funcSet, Map<ImFunction, Integer> funcToIndex, ImFunction newFunc, Map<ImVar, ImVar> oldToNewVar, Element e) {
        final ArrayList relevant = new ArrayList();
        e.accept(new Element.DefaultVisitor(){

            @Override
            public void visit(ImFuncRef imFuncRef) {
                super.visit(imFuncRef);
                relevant.add(imFuncRef);
            }

            @Override
            public void visit(ImFunctionCall imFunctionCall) {
                super.visit(imFunctionCall);
                relevant.add(imFunctionCall);
            }
        });
        relevant.parallelStream().forEach(relevantElem -> {
            if (relevantElem instanceof ImFuncRef) {
                this.replaceImFuncRef(funcSet, funcToIndex, newFunc, oldToNewVar, (ImFuncRef)relevantElem);
            } else if (relevantElem instanceof ImFunctionCall) {
                this.replaceImFunctionCall(funcSet, funcToIndex, newFunc, oldToNewVar, (ImFunctionCall)relevantElem);
            }
        });
    }

    private void replaceImFuncRef(Set<ImFunction> funcSet, Map<ImFunction, Integer> funcToIndex, ImFunction newFunc, Map<ImVar, ImVar> oldToNewVar, ImFuncRef e) {
        ImFuncRef fr = e;
        ImFunction f = fr.getFunc();
        if (funcSet.contains(f)) {
            ImFunction proxyFunc = JassIm.ImFunction(f.attrTrace(), f.getName() + "_proxy", JassIm.ImTypeVars(new ImTypeVar[0]), f.getParameters().copy(), f.getReturnType().copy(), JassIm.ImVars(new ImVar[0]), JassIm.ImStmts(new ImStmt[0]), Collections.emptyList());
            this.prog.getFunctions().add(proxyFunc);
            ImExprs arguments = JassIm.ImExprs(new ImExpr[0]);
            for (ImVar p : proxyFunc.getParameters()) {
                arguments.add(JassIm.ImVarAccess(p));
            }
            ImFunctionCall call = JassIm.ImFunctionCall(fr.attrTrace(), f, JassIm.ImTypeArguments(new ImTypeArgument[0]), arguments, true, CallType.NORMAL);
            if (f.getReturnType() instanceof ImVoid) {
                proxyFunc.getBody().add(call);
            } else {
                proxyFunc.getBody().add(JassIm.ImReturn(proxyFunc.getTrace(), call));
            }
            this.replaceCalls(funcSet, funcToIndex, newFunc, oldToNewVar, call);
            fr.setFunc(proxyFunc);
        }
    }

    private void replaceImFunctionCall(Set<ImFunction> funcSet, Map<ImFunction, Integer> funcToIndex, ImFunction newFunc, Map<ImVar, ImVar> oldToNewVar, ImFunctionCall e) {
        ImFunctionCall fc = e;
        ImFunction oldFunc = fc.getFunc();
        if (funcSet.contains(oldFunc)) {
            ImExprs arguments = JassIm.ImExprs(new ImExpr[0]);
            arguments.add(JassIm.ImIntVal(funcToIndex.get(oldFunc)));
            List oldArgs = fc.getArguments().removeAll();
            int pos = 0;
            for (int i = 1; i < newFunc.getParameters().size(); ++i) {
                ImVar p = (ImVar)newFunc.getParameters().get(i);
                if (pos < oldArgs.size() && oldToNewVar.get(oldFunc.getParameters().get(pos)) == p) {
                    arguments.add((ImExpr)oldArgs.get(pos));
                    ++pos;
                    continue;
                }
                arguments.add(this.tr.getDefaultValueForJassType(p.getType()));
            }
            ImFunctionCall newCall = JassIm.ImFunctionCall(fc.getTrace(), newFunc, JassIm.ImTypeArguments(new ImTypeArgument[0]), arguments, true, CallType.NORMAL);
            ImExpr ret = oldFunc.getReturnType() instanceof ImVoid ? newCall : JassIm.ImStatementExpr(JassIm.ImStmts(newCall), JassIm.ImVarAccess(this.getTempReturnVar(oldFunc.getReturnType())));
            fc.replaceBy(ret);
        }
    }

    private void replaceReturn(Element e, ImType returnType) {
        for (int i = 0; i < e.size(); ++i) {
            this.replaceReturn(e.get(i), returnType);
        }
        if (e instanceof ImReturn) {
            ImReturn r = (ImReturn)e;
            ImExprOpt returnValue = r.getReturnValue();
            returnValue.setParent(null);
            ImStmts stmts = JassIm.ImStmts(JassIm.ImSet(r.getTrace(), JassIm.ImVarAccess(this.getTempReturnVar(returnType)), (ImExpr)returnValue), JassIm.ImReturn(r.getTrace(), JassIm.ImNoExpr()));
            r.replaceBy(ImHelper.statementExprVoid(stmts));
        }
    }

    private ImVar getTempReturnVar(ImType t) {
        String typeName = t.translateType();
        ImVar r = this.tempReturnVars.get(typeName);
        if (r == null) {
            r = JassIm.ImVar(t.attrTrace(), t, "tempReturn_" + typeName, false);
            this.prog.getGlobals().add(r);
            this.tempReturnVars.put(typeName, r);
        }
        return r;
    }

    private String makeName(List<ImFunction> funcs) {
        return "cyc_" + funcs.get(0).getName();
    }

    private void calculateNewParameters(List<ImFunction> funcs, List<ImVar> newParameters, Map<ImVar, ImVar> oldToNewVar) {
        for (ImFunction f : funcs) {
            int pos = 0;
            block1: for (ImVar v : f.getParameters()) {
                for (int i = pos; i < newParameters.size(); ++i) {
                    if (!newParameters.get(i).getType().translateType().equals(v.getType().translateType())) continue;
                    oldToNewVar.put(v, newParameters.get(i));
                    pos = i + 1;
                    continue block1;
                }
                ImVar newVar = JassIm.ImVar(v.getTrace(), v.getType().copy(), v.getName(), false);
                oldToNewVar.put(v, newVar);
                newParameters.add(newVar);
                pos = newParameters.size() + 1;
            }
        }
    }

    class ImFuncGraph
    extends GraphInterpreter<ImFunction> {
        ImFuncGraph() {
        }

        @Override
        protected Collection<ImFunction> getIncidentNodes(ImFunction f) {
            return CyclicFunctionRemover.this.tr.getCalledFunctions().get((Object)f);
        }
    }
}

