package de.peeeq.wurstscript.translation.imtranslation;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import de.peeeq.datastructures.GraphInterpreter;
import de.peeeq.wurstio.TimeTaker;
import de.peeeq.wurstscript.WurstOperator;
import de.peeeq.wurstscript.ast.Element;
import de.peeeq.wurstscript.jassIm.Element;
import de.peeeq.wurstscript.jassIm.ImExpr;
import de.peeeq.wurstscript.jassIm.ImExprOpt;
import de.peeeq.wurstscript.jassIm.ImExprs;
import de.peeeq.wurstscript.jassIm.ImFuncRef;
import de.peeeq.wurstscript.jassIm.ImFunction;
import de.peeeq.wurstscript.jassIm.ImFunctionCall;
import de.peeeq.wurstscript.jassIm.ImProg;
import de.peeeq.wurstscript.jassIm.ImReturn;
import de.peeeq.wurstscript.jassIm.ImStmt;
import de.peeeq.wurstscript.jassIm.ImStmts;
import de.peeeq.wurstscript.jassIm.ImType;
import de.peeeq.wurstscript.jassIm.ImTypeArgument;
import de.peeeq.wurstscript.jassIm.ImTypeVar;
import de.peeeq.wurstscript.jassIm.ImVar;
import de.peeeq.wurstscript.jassIm.ImVarAccess;
import de.peeeq.wurstscript.jassIm.ImVoid;
import de.peeeq.wurstscript.jassIm.JassIm;
import de.peeeq.wurstscript.types.WurstTypeInt;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;

/* loaded from: input_file:de/peeeq/wurstscript/translation/imtranslation/CyclicFunctionRemover.class */
public class CyclicFunctionRemover {
    private final ImProg prog;
    private final TimeTaker timeTaker;
    private final ImTranslator tr;
    private final Map<String, ImVar> tempReturnVars = Maps.newLinkedHashMap();
    private final ImFuncGraph graph = new ImFuncGraph();

    /* loaded from: input_file:de/peeeq/wurstscript/translation/imtranslation/CyclicFunctionRemover$ImFuncGraph.class */
    class ImFuncGraph extends GraphInterpreter<ImFunction> {
        ImFuncGraph() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // de.peeeq.datastructures.GraphInterpreter
        public Collection<ImFunction> getIncidentNodes(ImFunction imFunction) {
            return CyclicFunctionRemover.this.tr.getCalledFunctions().get(imFunction);
        }
    }

    public CyclicFunctionRemover(ImTranslator imTranslator, ImProg imProg, TimeTaker timeTaker) {
        this.tr = imTranslator;
        this.prog = imProg;
        this.timeTaker = timeTaker;
    }

    public void work() {
        this.tr.calculateCallRelationsAndUsedVariables();
        AtomicReference atomicReference = new AtomicReference();
        this.timeTaker.measure("finding cycles", () -> {
            atomicReference.set(this.graph.findStronglyConnectedComponents(this.prog.getFunctions()));
        });
        this.timeTaker.measure("removing cycles", () -> {
            removeCycles(atomicReference);
        });
    }

    private void removeCycles(AtomicReference<Set<Set<ImFunction>>> atomicReference) {
        for (Set<ImFunction> set : atomicReference.get()) {
            if (set.size() > 1) {
                removeCycle(ImmutableList.copyOf(set), set);
            }
        }
    }

    private void removeCycle(List<ImFunction> list, Set<ImFunction> set) {
        List<ImVar> newArrayList = Lists.newArrayList();
        Map<ImVar, ImVar> newLinkedHashMap = Maps.newLinkedHashMap();
        calculateNewParameters(list, newArrayList, newLinkedHashMap);
        Element trace = list.get(0).getTrace();
        ImVar ImVar = JassIm.ImVar(trace, WurstTypeInt.instance().imTranslateType(this.tr), "funcChoice", false);
        ImFunction ImFunction = JassIm.ImFunction(trace, makeName(list), JassIm.ImTypeVars(new ImTypeVar[0]), JassIm.ImVars(new ImVar[0]), JassIm.ImVoid(), JassIm.ImVars(new ImVar[0]), JassIm.ImStmts(new ImStmt[0]), Lists.newArrayList());
        this.prog.getFunctions().add(ImFunction);
        ImFunction.getParameters().add(ImVar);
        ImFunction.getParameters().addAll(newArrayList);
        ImStmts body = ImFunction.getBody();
        for (int i = 0; i < list.size(); i++) {
            ImFunction imFunction = list.get(i);
            ImStmts ImStmts = JassIm.ImStmts(new ImStmt[0]);
            ImStmts.addAll(imFunction.getBody().removeAll());
            ImFunction.getLocals().addAll(imFunction.getLocals().removeAll());
            replaceVars(ImStmts, newLinkedHashMap);
            if (!(imFunction.getReturnType() instanceof ImVoid)) {
                replaceReturn(ImStmts, imFunction.getReturnType());
            }
            ImStmts ImStmts2 = JassIm.ImStmts(new ImStmt[0]);
            body.add(JassIm.ImIf(trace, JassIm.ImOperatorCall(WurstOperator.EQ, JassIm.ImExprs(JassIm.ImVarAccess(ImVar), JassIm.ImIntVal(i))), ImStmts, ImStmts2));
            body = ImStmts2;
        }
        Map<ImFunction, Integer> hashMap = new HashMap<>();
        for (int i2 = 0; i2 < list.size(); i2++) {
            hashMap.put(list.get(i2), Integer.valueOf(i2));
        }
        replaceCalls(set, hashMap, ImFunction, newLinkedHashMap, this.prog);
        Iterator it = Lists.newArrayList(this.tr.getCalledFunctions().keys()).iterator();
        while (it.hasNext()) {
            this.tr.getCalledFunctions().get((ImFunction) it.next()).removeAll(list);
        }
        this.prog.getFunctions().removeAll(list);
    }

    private void replaceVars(de.peeeq.wurstscript.jassIm.Element element, Map<ImVar, ImVar> map) {
        for (int i = 0; i < element.size(); i++) {
            replaceVars(element.get(i), map);
        }
        if (element instanceof ImVarAccess) {
            ImVarAccess imVarAccess = (ImVarAccess) element;
            ImVar imVar = map.get(imVarAccess.getVar());
            if (imVar != null) {
                imVarAccess.setVar(imVar);
            }
        }
    }

    private void replaceCalls(Set<ImFunction> set, Map<ImFunction, Integer> map, ImFunction imFunction, Map<ImVar, ImVar> map2, de.peeeq.wurstscript.jassIm.Element element) {
        final ArrayList arrayList = new ArrayList();
        element.accept(new Element.DefaultVisitor() { // from class: de.peeeq.wurstscript.translation.imtranslation.CyclicFunctionRemover.1
            @Override // de.peeeq.wurstscript.jassIm.Element.DefaultVisitor, de.peeeq.wurstscript.jassIm.Element.Visitor
            public void visit(ImFuncRef imFuncRef) {
                super.visit(imFuncRef);
                arrayList.add(imFuncRef);
            }

            @Override // de.peeeq.wurstscript.jassIm.Element.DefaultVisitor, de.peeeq.wurstscript.jassIm.Element.Visitor
            public void visit(ImFunctionCall imFunctionCall) {
                super.visit(imFunctionCall);
                arrayList.add(imFunctionCall);
            }
        });
        arrayList.parallelStream().forEach(element2 -> {
            if (element2 instanceof ImFuncRef) {
                replaceImFuncRef(set, map, imFunction, map2, (ImFuncRef) element2);
            } else if (element2 instanceof ImFunctionCall) {
                replaceImFunctionCall(set, map, imFunction, map2, (ImFunctionCall) element2);
            }
        });
    }

    private void replaceImFuncRef(Set<ImFunction> set, Map<ImFunction, Integer> map, ImFunction imFunction, Map<ImVar, ImVar> map2, ImFuncRef imFuncRef) {
        ImFunction func = imFuncRef.getFunc();
        if (set.contains(func)) {
            ImFunction ImFunction = JassIm.ImFunction(func.attrTrace(), func.getName() + "_proxy", JassIm.ImTypeVars(new ImTypeVar[0]), func.getParameters().copy(), func.getReturnType().copy(), JassIm.ImVars(new ImVar[0]), JassIm.ImStmts(new ImStmt[0]), Collections.emptyList());
            this.prog.getFunctions().add(ImFunction);
            ImExprs ImExprs = JassIm.ImExprs(new ImExpr[0]);
            Iterator it = ImFunction.getParameters().iterator();
            while (it.hasNext()) {
                ImExprs.add(JassIm.ImVarAccess((ImVar) it.next()));
            }
            ImFunctionCall ImFunctionCall = JassIm.ImFunctionCall(imFuncRef.attrTrace(), func, JassIm.ImTypeArguments(new ImTypeArgument[0]), ImExprs, true, CallType.NORMAL);
            if (func.getReturnType() instanceof ImVoid) {
                ImFunction.getBody().add(ImFunctionCall);
            } else {
                ImFunction.getBody().add(JassIm.ImReturn(ImFunction.getTrace(), ImFunctionCall));
            }
            replaceCalls(set, map, imFunction, map2, ImFunctionCall);
            imFuncRef.setFunc(ImFunction);
        }
    }

    private void replaceImFunctionCall(Set<ImFunction> set, Map<ImFunction, Integer> map, ImFunction imFunction, Map<ImVar, ImVar> map2, ImFunctionCall imFunctionCall) {
        ImFunction func = imFunctionCall.getFunc();
        if (set.contains(func)) {
            ImExprs ImExprs = JassIm.ImExprs(new ImExpr[0]);
            ImExprs.add(JassIm.ImIntVal(map.get(func).intValue()));
            List<ImExpr> removeAll = imFunctionCall.getArguments().removeAll();
            int i = 0;
            for (int i2 = 1; i2 < imFunction.getParameters().size(); i2++) {
                ImVar imVar = (ImVar) imFunction.getParameters().get(i2);
                if (i >= removeAll.size() || map2.get(func.getParameters().get(i)) != imVar) {
                    ImExprs.add(this.tr.getDefaultValueForJassType(imVar.getType()));
                } else {
                    ImExprs.add(removeAll.get(i));
                    i++;
                }
            }
            ImFunctionCall ImFunctionCall = JassIm.ImFunctionCall(imFunctionCall.getTrace(), imFunction, JassIm.ImTypeArguments(new ImTypeArgument[0]), ImExprs, true, CallType.NORMAL);
            imFunctionCall.replaceBy(func.getReturnType() instanceof ImVoid ? ImFunctionCall : JassIm.ImStatementExpr(JassIm.ImStmts(ImFunctionCall), JassIm.ImVarAccess(getTempReturnVar(func.getReturnType()))));
        }
    }

    private void replaceReturn(de.peeeq.wurstscript.jassIm.Element element, ImType imType) {
        for (int i = 0; i < element.size(); i++) {
            replaceReturn(element.get(i), imType);
        }
        if (element instanceof ImReturn) {
            ImReturn imReturn = (ImReturn) element;
            ImExprOpt returnValue = imReturn.getReturnValue();
            returnValue.setParent(null);
            imReturn.replaceBy(ImHelper.statementExprVoid(JassIm.ImStmts(JassIm.ImSet(imReturn.getTrace(), JassIm.ImVarAccess(getTempReturnVar(imType)), (ImExpr) returnValue), JassIm.ImReturn(imReturn.getTrace(), JassIm.ImNoExpr()))));
        }
    }

    private ImVar getTempReturnVar(ImType imType) {
        String translateType = imType.translateType();
        ImVar imVar = this.tempReturnVars.get(translateType);
        if (imVar == null) {
            imVar = JassIm.ImVar(imType.attrTrace(), imType, "tempReturn_" + translateType, false);
            this.prog.getGlobals().add(imVar);
            this.tempReturnVars.put(translateType, imVar);
        }
        return imVar;
    }

    private String makeName(List<ImFunction> list) {
        return "cyc_" + list.get(0).getName();
    }

    private void calculateNewParameters(List<ImFunction> list, List<ImVar> list2, Map<ImVar, ImVar> map) {
        Iterator<ImFunction> it = list.iterator();
        while (it.hasNext()) {
            int i = 0;
            Iterator it2 = it.next().getParameters().iterator();
            while (it2.hasNext()) {
                ImVar imVar = (ImVar) it2.next();
                int i2 = i;
                while (true) {
                    if (i2 >= list2.size()) {
                        ImVar ImVar = JassIm.ImVar(imVar.getTrace(), imVar.getType().copy(), imVar.getName(), false);
                        map.put(imVar, ImVar);
                        list2.add(ImVar);
                        i = list2.size() + 1;
                        break;
                    }
                    if (list2.get(i2).getType().translateType().equals(imVar.getType().translateType())) {
                        map.put(imVar, list2.get(i2));
                        i = i2 + 1;
                        break;
                    }
                    i2++;
                }
            }
        }
    }
}
