package de.peeeq.wurstscript.translation.imtranslation;

import com.google.common.base.Preconditions;
import com.google.common.collect.LinkedListMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import de.peeeq.datastructures.TransitiveClosure;
import de.peeeq.wurstio.TimeTaker;
import de.peeeq.wurstscript.WurstOperator;
import de.peeeq.wurstscript.ast.FunctionDefinition;
import de.peeeq.wurstscript.ast.NameDef;
import de.peeeq.wurstscript.ast.WurstModel;
import de.peeeq.wurstscript.attributes.CompileError;
import de.peeeq.wurstscript.jassIm.Element;
import de.peeeq.wurstscript.jassIm.ImExpr;
import de.peeeq.wurstscript.jassIm.ImExprOpt;
import de.peeeq.wurstscript.jassIm.ImExprs;
import de.peeeq.wurstscript.jassIm.ImFuncRef;
import de.peeeq.wurstscript.jassIm.ImFuncRefOrCall;
import de.peeeq.wurstscript.jassIm.ImFunction;
import de.peeeq.wurstscript.jassIm.ImFunctionCall;
import de.peeeq.wurstscript.jassIm.ImGetStackTrace;
import de.peeeq.wurstscript.jassIm.ImIf;
import de.peeeq.wurstscript.jassIm.ImMethodCall;
import de.peeeq.wurstscript.jassIm.ImNoExpr;
import de.peeeq.wurstscript.jassIm.ImProg;
import de.peeeq.wurstscript.jassIm.ImReturn;
import de.peeeq.wurstscript.jassIm.ImStatementExpr;
import de.peeeq.wurstscript.jassIm.ImStmt;
import de.peeeq.wurstscript.jassIm.ImStmts;
import de.peeeq.wurstscript.jassIm.ImTypeArgument;
import de.peeeq.wurstscript.jassIm.ImTypeVar;
import de.peeeq.wurstscript.jassIm.ImVar;
import de.peeeq.wurstscript.jassIm.ImVarArrayAccess;
import de.peeeq.wurstscript.jassIm.ImVars;
import de.peeeq.wurstscript.jassIm.ImVoid;
import de.peeeq.wurstscript.jassIm.JassIm;
import de.peeeq.wurstscript.jassIm.JassImElementWithName;
import de.peeeq.wurstscript.parser.WPos;
import de.peeeq.wurstscript.types.TypesHelper;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:de/peeeq/wurstscript/translation/imtranslation/StackTraceInjector2.class */
public class StackTraceInjector2 {
    private static final int MAX_STACKTRACE_SIZE = 20;
    public static final String STACK_POS_PARAM = "__wurst_stackPos";
    private final ImProg prog;
    private final ImTranslator tr;
    private ImVar stackSize;
    private ImVar stack;
    private final ImGetStackTrace dummyGetStackTrace = JassIm.ImGetStackTrace();

    public StackTraceInjector2(ImProg imProg, ImTranslator imTranslator) {
        this.prog = imProg;
        this.tr = imTranslator;
    }

    public void transform(TimeTaker timeTaker) {
        final LinkedListMultimap create = LinkedListMultimap.create();
        final LinkedListMultimap create2 = LinkedListMultimap.create();
        final LinkedListMultimap create3 = LinkedListMultimap.create();
        final ArrayList newArrayList = Lists.newArrayList();
        this.prog.accept(new Element.DefaultVisitor() { // from class: de.peeeq.wurstscript.translation.imtranslation.StackTraceInjector2.1
            @Override // de.peeeq.wurstscript.jassIm.Element.DefaultVisitor, de.peeeq.wurstscript.jassIm.Element.Visitor
            public void visit(ImGetStackTrace imGetStackTrace) {
                super.visit(imGetStackTrace);
                create.put(imGetStackTrace.getNearestFunc(), imGetStackTrace);
            }

            @Override // de.peeeq.wurstscript.jassIm.Element.DefaultVisitor, de.peeeq.wurstscript.jassIm.Element.Visitor
            public void visit(ImVarArrayAccess imVarArrayAccess) {
                super.visit(imVarArrayAccess);
                if (imVarArrayAccess.getIndexes().size() > 1) {
                    create.put(imVarArrayAccess.getNearestFunc(), StackTraceInjector2.this.dummyGetStackTrace);
                }
            }

            @Override // de.peeeq.wurstscript.jassIm.Element.DefaultVisitor, de.peeeq.wurstscript.jassIm.Element.Visitor
            public void visit(ImFunctionCall imFunctionCall) {
                super.visit(imFunctionCall);
                if (imFunctionCall.getCallType() == CallType.EXECUTE) {
                    newArrayList.add(imFunctionCall);
                    return;
                }
                create2.put(imFunctionCall.getFunc(), imFunctionCall);
                create3.put(imFunctionCall.getFunc(), imFunctionCall.getNearestFunc());
            }

            @Override // de.peeeq.wurstscript.jassIm.Element.DefaultVisitor, de.peeeq.wurstscript.jassIm.Element.Visitor
            public void visit(ImFuncRef imFuncRef) {
                super.visit(imFuncRef);
                newArrayList.add(imFuncRef);
            }
        });
        de.peeeq.wurstscript.ast.Element attrTrace = this.prog.attrTrace();
        this.stackSize = JassIm.ImVar(attrTrace, TypesHelper.imInt(), "wurst_stack_depth", false);
        this.prog.getGlobals().add(this.stackSize);
        this.stack = JassIm.ImVar(attrTrace, TypesHelper.imStringArray(), "wurst_stack", false);
        this.prog.getGlobals().add(this.stack);
        this.prog.getGlobalInits().put(this.stackSize, Collections.singletonList(JassIm.ImSet(attrTrace, JassIm.ImVarAccess(this.stackSize), JassIm.ImIntVal(0))));
        TransitiveClosure transitiveClosure = new TransitiveClosure(create3);
        LinkedHashSet newLinkedHashSet = Sets.newLinkedHashSet(create.keySet());
        if (this.tr.isLuaTarget()) {
            Stream.concat(this.prog.getFunctions().stream(), this.prog.getClasses().stream().flatMap(imClass -> {
                return imClass.getFunctions().stream();
            })).filter(imFunction -> {
                return (imFunction.hasFlag(FunctionFlagEnum.IS_NATIVE) || imFunction.hasFlag(FunctionFlagEnum.IS_BJ) || imFunction.hasFlag(FunctionFlagEnum.IS_EXTERN)) ? false : true;
            }).collect(Collectors.toCollection(() -> {
                return newLinkedHashSet;
            }));
        } else {
            Iterator it = create.keys().iterator();
            while (it.hasNext()) {
                Stream stream = transitiveClosure.get((ImFunction) it.next());
                Objects.requireNonNull(newLinkedHashSet);
                stream.forEach((v1) -> {
                    r1.add(v1);
                });
            }
        }
        passStacktraceParams(create2, newLinkedHashSet);
        addStackTracePush(create2, newLinkedHashSet);
        addStackTracePop(newLinkedHashSet);
        rewriteFuncRefs(newArrayList, newLinkedHashSet);
        rewriteErrorStatements(create);
        rewriteMethodCalls(newLinkedHashSet);
    }

    private void rewriteMethodCalls(final Set<ImFunction> set) {
        if (this.tr.isLuaTarget()) {
            this.prog.accept(new Element.DefaultVisitor() { // from class: de.peeeq.wurstscript.translation.imtranslation.StackTraceInjector2.2
                @Override // de.peeeq.wurstscript.jassIm.Element.DefaultVisitor, de.peeeq.wurstscript.jassIm.Element.Visitor
                public void visit(ImMethodCall imMethodCall) {
                    super.visit(imMethodCall);
                    if (set.contains(imMethodCall.getMethod().getImplementation())) {
                        imMethodCall.getArguments().add(StackTraceInjector2.this.getStacktraceIndex(imMethodCall), StackTraceInjector2.this.str("when calling " + StackTraceInjector2.this.name(imMethodCall.getMethod()) + StackTraceInjector2.getCallPos(imMethodCall.attrTrace().attrErrorPos())));
                    }
                }
            });
        }
    }

    private void addStackTracePush(Multimap<ImFunction, ImFunctionCall> multimap, Set<ImFunction> set) {
        for (ImFunction imFunction : set) {
            if (!isMainOrConfig(imFunction)) {
                ImStmts body = imFunction.getBody();
                de.peeeq.wurstscript.ast.Element trace = imFunction.getTrace();
                body.add(0, increment(trace, this.stackSize));
                body.add(0, JassIm.ImSet(trace, JassIm.ImVarArrayAccess(trace, this.stack, JassIm.ImExprs(JassIm.ImVarAccess(this.stackSize))), getStackPosVar(imFunction)));
            }
        }
    }

    private ImExpr getStackPosVar(ImFunction imFunction) {
        return JassIm.ImVarAccess((ImVar) imFunction.getParameters().stream().filter(this::isStackTraceParam).findFirst().orElseGet(() -> {
            throw new CompileError(imFunction, "Function " + imFunction.getName() + " has no stacktrace parameter.");
        }));
    }

    private void addStackTracePop(Set<ImFunction> set) {
        ImReturn ImReturn;
        String str;
        for (ImFunction imFunction : set) {
            if (!isMainOrConfig(imFunction)) {
                final LinkedHashSet<ImReturn> linkedHashSet = new LinkedHashSet();
                imFunction.getBody().accept(new Element.DefaultVisitor() { // from class: de.peeeq.wurstscript.translation.imtranslation.StackTraceInjector2.3
                    @Override // de.peeeq.wurstscript.jassIm.Element.DefaultVisitor, de.peeeq.wurstscript.jassIm.Element.Visitor
                    public void visit(ImReturn imReturn) {
                        super.visit(imReturn);
                        linkedHashSet.add(imReturn);
                    }
                });
                int i = 0;
                for (ImReturn imReturn : linkedHashSet) {
                    ImStmts ImStmts = JassIm.ImStmts(new ImStmt[0]);
                    ImExprOpt returnValue = imReturn.getReturnValue();
                    returnValue.setParent(null);
                    de.peeeq.wurstscript.ast.Element trace = imReturn.getTrace();
                    if ((returnValue instanceof ImNoExpr) || !containsAffectedFunctioncall(returnValue)) {
                        ImReturn = JassIm.ImReturn(trace, returnValue);
                    } else {
                        str = "stackTrace_tempReturn";
                        i++;
                        ImVar ImVar = JassIm.ImVar(trace, imFunction.getReturnType(), i > 1 ? str + "_" + i : "stackTrace_tempReturn", false);
                        imFunction.getLocals().add(ImVar);
                        ImStmts.add(JassIm.ImSet(trace, JassIm.ImVarAccess(ImVar), (ImExpr) returnValue));
                        ImReturn = JassIm.ImReturn(trace, JassIm.ImVarAccess(ImVar));
                    }
                    ImStmts.add(decrement(trace, this.stackSize));
                    ImStmts.add(ImReturn);
                    imReturn.replaceBy(ImHelper.statementExprVoid(ImStmts));
                }
                if (!returnsOnAllPaths(imFunction.getBody())) {
                    imFunction.getBody().add(decrement(imFunction.getTrace(), this.stackSize));
                }
            }
        }
    }

    private boolean returnsOnAllPaths(ImStmts imStmts) {
        Iterator it = imStmts.iterator();
        while (it.hasNext()) {
            ImStmt imStmt = (ImStmt) it.next();
            if (imStmt instanceof ImReturn) {
                return true;
            }
            if (imStmt instanceof ImIf) {
                ImIf imIf = (ImIf) imStmt;
                if (returnsOnAllPaths(imIf.getThenBlock()) && returnsOnAllPaths(imIf.getElseBlock())) {
                    return true;
                }
            } else if ((imStmt instanceof ImStatementExpr) && returnsOnAllPaths(((ImStatementExpr) imStmt).getStatements())) {
                return true;
            }
        }
        return false;
    }

    private boolean containsAffectedFunctioncall(ImExprOpt imExprOpt) {
        final boolean[] zArr = {false};
        imExprOpt.accept(new Element.DefaultVisitor() { // from class: de.peeeq.wurstscript.translation.imtranslation.StackTraceInjector2.4
            @Override // de.peeeq.wurstscript.jassIm.Element.DefaultVisitor, de.peeeq.wurstscript.jassIm.Element.Visitor
            public void visit(ImFunctionCall imFunctionCall) {
                super.visit(imFunctionCall);
                zArr[0] = true;
            }

            @Override // de.peeeq.wurstscript.jassIm.Element.DefaultVisitor, de.peeeq.wurstscript.jassIm.Element.Visitor
            public void visit(ImGetStackTrace imGetStackTrace) {
                super.visit(imGetStackTrace);
                zArr[0] = true;
            }
        });
        return zArr[0];
    }

    private boolean isMainOrConfig(ImFunction imFunction) {
        Preconditions.checkNotNull(imFunction);
        return imFunction.getName().equals("main") || imFunction.getName().equals("config");
    }

    private void passStacktraceParams(Multimap<ImFunction, ImFunctionCall> multimap, Set<ImFunction> set) {
        for (ImFunction imFunction : set) {
            if (!isMainOrConfig(imFunction)) {
                Collection<ImFunctionCall> collection = multimap.get(imFunction);
                imFunction.getParameters().add(getStacktraceIndex(imFunction), JassIm.ImVar(imFunction.getTrace(), TypesHelper.imString(), STACK_POS_PARAM, false));
                for (ImFunctionCall imFunctionCall : collection) {
                    imFunctionCall.getArguments().add(getStacktraceIndex(imFunctionCall), str("when calling " + name(imFunction) + getCallPos(imFunctionCall.attrTrace().attrErrorPos())));
                }
            }
        }
    }

    private int getStacktraceIndex(ImFunction imFunction) {
        return imFunction.hasFlag(FunctionFlagEnum.IS_VARARG) ? imFunction.getParameters().size() - 1 : imFunction.getParameters().size();
    }

    private int getStacktraceIndex(ImFunctionCall imFunctionCall) {
        ImFunction func = imFunctionCall.getFunc();
        int size = func.getParameters().size() - 1;
        if (func.hasFlag(FunctionFlagEnum.IS_VARARG)) {
            size--;
        }
        if (size < 0 || size > imFunctionCall.getArguments().size() + 1) {
            throw new CompileError(imFunctionCall, "Call " + imFunctionCall + " invalid index " + size + " for parameters " + func.getParameters() + " and isVararg = " + func.hasFlag(FunctionFlagEnum.IS_VARARG));
        }
        return size;
    }

    private int getStacktraceIndex(ImMethodCall imMethodCall) {
        ImFunction implementation = imMethodCall.getMethod().getImplementation();
        int size = implementation.getParameters().size() - 2;
        if (implementation.hasFlag(FunctionFlagEnum.IS_VARARG)) {
            size--;
        }
        if (size < 0 || size > imMethodCall.getArguments().size() + 1) {
            throw new CompileError(imMethodCall, "Call " + imMethodCall + " invalid index " + size + " for parameters " + implementation.getParameters() + " and isVararg = " + implementation.hasFlag(FunctionFlagEnum.IS_VARARG));
        }
        return size;
    }

    private String name(JassImElementWithName jassImElementWithName) {
        NameDef tryGetNameDef = jassImElementWithName.attrTrace().tryGetNameDef();
        return tryGetNameDef instanceof FunctionDefinition ? tryGetNameDef.getName() : jassImElementWithName.getName();
    }

    public static String getCallPos(WPos wPos) {
        return wPos.getFile().startsWith("<") ? "" : " in " + wPos.printShort();
    }

    private ImStmt increment(de.peeeq.wurstscript.ast.Element element, ImVar imVar) {
        return JassIm.ImSet(element, JassIm.ImVarAccess(imVar), JassIm.ImOperatorCall(WurstOperator.PLUS, JassIm.ImExprs(JassIm.ImVarAccess(imVar), JassIm.ImIntVal(1))));
    }

    private ImStmt decrement(de.peeeq.wurstscript.ast.Element element, ImVar imVar) {
        return JassIm.ImSet(element, JassIm.ImVarAccess(imVar), JassIm.ImOperatorCall(WurstOperator.MINUS, JassIm.ImExprs(JassIm.ImVarAccess(imVar), JassIm.ImIntVal(1))));
    }

    private void rewriteFuncRefs(List<ImFuncRefOrCall> list, Set<ImFunction> set) {
        for (ImFuncRefOrCall imFuncRefOrCall : list) {
            ImFunction func = imFuncRefOrCall.getFunc();
            if (set.contains(func)) {
                ImVars copy = func.getParameters().copy();
                copy.removeIf(this::isStackTraceParam);
                ImFunction ImFunction = JassIm.ImFunction(func.getTrace(), "bridge_" + func.getName(), JassIm.ImTypeVars(new ImTypeVar[0]), copy, func.getReturnType().copy(), JassIm.ImVars(new ImVar[0]), JassIm.ImStmts(new ImStmt[0]), func.getFlags());
                this.prog.getFunctions().add(ImFunction);
                de.peeeq.wurstscript.ast.Element attrTrace = imFuncRefOrCall.attrTrace();
                String str = imFuncRefOrCall instanceof ImFuncRef ? "via function reference" : "via ExecuteFunc";
                if (attrTrace instanceof WurstModel) {
                    ImFunction enclosingFunc = getEnclosingFunc(imFuncRefOrCall);
                    if (enclosingFunc != null) {
                        str = str + " in function " + enclosingFunc.getName();
                    }
                } else {
                    str = str + " " + attrTrace.attrSource().printShort();
                }
                ImExprs ImExprs = JassIm.ImExprs(str(str));
                ImStmts body = ImFunction.getBody();
                body.add(JassIm.ImSet(attrTrace, JassIm.ImVarAccess(this.stackSize), JassIm.ImIntVal(0)));
                ImFunctionCall ImFunctionCall = JassIm.ImFunctionCall(attrTrace, func, JassIm.ImTypeArguments(new ImTypeArgument[0]), ImExprs, true, CallType.NORMAL);
                body.add(ImFunction.getReturnType() instanceof ImVoid ? ImFunctionCall : JassIm.ImReturn(attrTrace, ImFunctionCall));
                imFuncRefOrCall.setFunc(ImFunction);
            }
        }
    }

    private ImFunction getEnclosingFunc(Element element) {
        while (element != null) {
            if (element instanceof ImFunction) {
                return (ImFunction) element;
            }
            element = element.getParent();
        }
        return null;
    }

    private boolean isStackTraceParam(ImVar imVar) {
        return imVar.getName().equals(STACK_POS_PARAM);
    }

    private void rewriteErrorStatements(Multimap<ImFunction, ImGetStackTrace> multimap) {
        for (Map.Entry entry : multimap.entries()) {
            ImFunction imFunction = (ImFunction) entry.getKey();
            ImGetStackTrace imGetStackTrace = (ImGetStackTrace) entry.getValue();
            if (imGetStackTrace != this.dummyGetStackTrace) {
                de.peeeq.wurstscript.ast.Element attrTrace = imGetStackTrace.attrTrace();
                ImVar ImVar = JassIm.ImVar(attrTrace, TypesHelper.imString(), "stacktraceStr", false);
                imFunction.getLocals().add(ImVar);
                ImVar ImVar2 = JassIm.ImVar(attrTrace, TypesHelper.imInt(), "stacktraceIndex", false);
                imFunction.getLocals().add(ImVar2);
                ImVar ImVar3 = JassIm.ImVar(attrTrace, TypesHelper.imInt(), "stacktraceLimit", false);
                imFunction.getLocals().add(ImVar3);
                ImStmts ImStmts = JassIm.ImStmts(new ImStmt[0]);
                ImStmts.add(JassIm.ImSet(attrTrace, JassIm.ImVarAccess(ImVar), JassIm.ImStringVal("")));
                ImStmts.add(JassIm.ImSet(attrTrace, JassIm.ImVarAccess(ImVar2), JassIm.ImVarAccess(this.stackSize)));
                ImStmts.add(JassIm.ImSet(attrTrace, JassIm.ImVarAccess(ImVar3), JassIm.ImIntVal(0)));
                ImStmts ImStmts2 = JassIm.ImStmts(new ImStmt[0]);
                ImStmts.add(JassIm.ImLoop(attrTrace, ImStmts2));
                ImStmts2.add(decrement(attrTrace, ImVar2));
                ImStmts2.add(increment(attrTrace, ImVar3));
                ImStmts2.add(JassIm.ImExitwhen(attrTrace, JassIm.ImOperatorCall(WurstOperator.GREATER, JassIm.ImExprs(JassIm.ImVarAccess(ImVar3), JassIm.ImIntVal(20)))));
                ImStmts2.add(JassIm.ImExitwhen(attrTrace, JassIm.ImOperatorCall(WurstOperator.LESS, JassIm.ImExprs(JassIm.ImVarAccess(ImVar2), JassIm.ImIntVal(0)))));
                ImStmts2.add(JassIm.ImSet(attrTrace, JassIm.ImVarAccess(ImVar), JassIm.ImOperatorCall(WurstOperator.PLUS, JassIm.ImExprs(JassIm.ImVarAccess(ImVar), JassIm.ImOperatorCall(WurstOperator.PLUS, JassIm.ImExprs(JassIm.ImStringVal("\n   "), JassIm.ImVarArrayAccess(attrTrace, this.stack, JassIm.ImExprs(JassIm.ImVarAccess(ImVar2)))))))));
                imGetStackTrace.replaceBy(JassIm.ImStatementExpr(ImStmts, JassIm.ImVarAccess(ImVar)));
            }
        }
    }

    private ImExpr str(String str) {
        return JassIm.ImStringVal(str);
    }
}
