/*
 * Decompiled with CFR 0.152.
 */
package de.peeeq.wurstscript.translation.imtranslation;

import com.google.common.base.Preconditions;
import com.google.common.collect.LinkedListMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import de.peeeq.datastructures.TransitiveClosure;
import de.peeeq.wurstio.TimeTaker;
import de.peeeq.wurstscript.WurstOperator;
import de.peeeq.wurstscript.ast.FunctionDefinition;
import de.peeeq.wurstscript.ast.NameDef;
import de.peeeq.wurstscript.ast.WurstModel;
import de.peeeq.wurstscript.attributes.CompileError;
import de.peeeq.wurstscript.jassIm.Element;
import de.peeeq.wurstscript.jassIm.ImExpr;
import de.peeeq.wurstscript.jassIm.ImExprOpt;
import de.peeeq.wurstscript.jassIm.ImExprs;
import de.peeeq.wurstscript.jassIm.ImFuncRef;
import de.peeeq.wurstscript.jassIm.ImFuncRefOrCall;
import de.peeeq.wurstscript.jassIm.ImFunction;
import de.peeeq.wurstscript.jassIm.ImFunctionCall;
import de.peeeq.wurstscript.jassIm.ImGetStackTrace;
import de.peeeq.wurstscript.jassIm.ImIf;
import de.peeeq.wurstscript.jassIm.ImMethodCall;
import de.peeeq.wurstscript.jassIm.ImNoExpr;
import de.peeeq.wurstscript.jassIm.ImProg;
import de.peeeq.wurstscript.jassIm.ImReturn;
import de.peeeq.wurstscript.jassIm.ImStatementExpr;
import de.peeeq.wurstscript.jassIm.ImStmt;
import de.peeeq.wurstscript.jassIm.ImStmts;
import de.peeeq.wurstscript.jassIm.ImTypeArgument;
import de.peeeq.wurstscript.jassIm.ImTypeVar;
import de.peeeq.wurstscript.jassIm.ImVar;
import de.peeeq.wurstscript.jassIm.ImVarArrayAccess;
import de.peeeq.wurstscript.jassIm.ImVars;
import de.peeeq.wurstscript.jassIm.ImVoid;
import de.peeeq.wurstscript.jassIm.JassIm;
import de.peeeq.wurstscript.jassIm.JassImElementWithName;
import de.peeeq.wurstscript.parser.WPos;
import de.peeeq.wurstscript.translation.imtranslation.CallType;
import de.peeeq.wurstscript.translation.imtranslation.FunctionFlagEnum;
import de.peeeq.wurstscript.translation.imtranslation.ImHelper;
import de.peeeq.wurstscript.translation.imtranslation.ImTranslator;
import de.peeeq.wurstscript.types.TypesHelper;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.eclipse.jdt.annotation.Nullable;

public class StackTraceInjector2 {
    private static final int MAX_STACKTRACE_SIZE = 20;
    public static final String STACK_POS_PARAM = "__wurst_stackPos";
    private final ImProg prog;
    private final ImTranslator tr;
    private ImVar stackSize;
    private ImVar stack;
    private final ImGetStackTrace dummyGetStackTrace = JassIm.ImGetStackTrace();

    public StackTraceInjector2(ImProg prog, ImTranslator tr) {
        this.prog = prog;
        this.tr = tr;
    }

    public void transform(TimeTaker timeTaker) {
        LinkedListMultimap stackTraceGets = LinkedListMultimap.create();
        LinkedListMultimap calls = LinkedListMultimap.create();
        LinkedListMultimap callRelation = LinkedListMultimap.create();
        ArrayList funcRefs = Lists.newArrayList();
        this.prog.accept(new Element.DefaultVisitor((Multimap)stackTraceGets, funcRefs, (Multimap)calls, (Multimap)callRelation){
            final /* synthetic */ Multimap val$stackTraceGets;
            final /* synthetic */ List val$funcRefs;
            final /* synthetic */ Multimap val$calls;
            final /* synthetic */ Multimap val$callRelation;
            {
                this.val$stackTraceGets = multimap;
                this.val$funcRefs = list;
                this.val$calls = multimap2;
                this.val$callRelation = multimap3;
            }

            @Override
            public void visit(ImGetStackTrace e) {
                super.visit(e);
                this.val$stackTraceGets.put((Object)e.getNearestFunc(), (Object)e);
            }

            @Override
            public void visit(ImVarArrayAccess va) {
                super.visit(va);
                if (va.getIndexes().size() > 1) {
                    this.val$stackTraceGets.put((Object)va.getNearestFunc(), (Object)StackTraceInjector2.this.dummyGetStackTrace);
                }
            }

            @Override
            public void visit(ImFunctionCall c) {
                super.visit(c);
                if (c.getCallType() == CallType.EXECUTE) {
                    this.val$funcRefs.add(c);
                } else {
                    this.val$calls.put((Object)c.getFunc(), (Object)c);
                    ImFunction caller = c.getNearestFunc();
                    this.val$callRelation.put((Object)c.getFunc(), (Object)caller);
                }
            }

            @Override
            public void visit(ImFuncRef imFuncRef) {
                super.visit(imFuncRef);
                this.val$funcRefs.add(imFuncRef);
            }
        });
        de.peeeq.wurstscript.ast.Element trace = this.prog.attrTrace();
        this.stackSize = JassIm.ImVar(trace, TypesHelper.imInt(), "wurst_stack_depth", false);
        this.prog.getGlobals().add(this.stackSize);
        this.stack = JassIm.ImVar(trace, TypesHelper.imStringArray(), "wurst_stack", false);
        this.prog.getGlobals().add(this.stack);
        this.prog.getGlobalInits().put(this.stackSize, Collections.singletonList(JassIm.ImSet(trace, JassIm.ImVarAccess(this.stackSize), JassIm.ImIntVal(0))));
        TransitiveClosure<ImFunction> callRelationTr = new TransitiveClosure<ImFunction>((Multimap<ImFunction, ImFunction>)callRelation);
        LinkedHashSet affectedFuncs = Sets.newLinkedHashSet((Iterable)stackTraceGets.keySet());
        if (this.tr.isLuaTarget()) {
            Stream.concat(this.prog.getFunctions().stream(), this.prog.getClasses().stream().flatMap(c -> c.getFunctions().stream())).filter(f -> !f.hasFlag(FunctionFlagEnum.IS_NATIVE) && !f.hasFlag(FunctionFlagEnum.IS_BJ) && !f.hasFlag(FunctionFlagEnum.IS_EXTERN)).collect(Collectors.toCollection(() -> affectedFuncs));
        } else {
            for (ImFunction stackTraceUse : stackTraceGets.keys()) {
                callRelationTr.get(stackTraceUse).forEach(affectedFuncs::add);
            }
        }
        this.passStacktraceParams((Multimap<ImFunction, ImFunctionCall>)calls, affectedFuncs);
        this.addStackTracePush((Multimap<ImFunction, ImFunctionCall>)calls, affectedFuncs);
        this.addStackTracePop(affectedFuncs);
        this.rewriteFuncRefs(funcRefs, affectedFuncs);
        this.rewriteErrorStatements((Multimap<ImFunction, ImGetStackTrace>)stackTraceGets);
        this.rewriteMethodCalls(affectedFuncs);
    }

    private void rewriteMethodCalls(final Set<ImFunction> affectedFuncs) {
        if (!this.tr.isLuaTarget()) {
            return;
        }
        this.prog.accept(new Element.DefaultVisitor(){

            @Override
            public void visit(ImMethodCall call) {
                super.visit(call);
                if (affectedFuncs.contains(call.getMethod().getImplementation())) {
                    String callPos = StackTraceInjector2.getCallPos(call.attrTrace().attrErrorPos());
                    call.getArguments().add(StackTraceInjector2.this.getStacktraceIndex(call), StackTraceInjector2.this.str("when calling " + StackTraceInjector2.this.name(call.getMethod()) + callPos));
                }
            }
        });
    }

    private void addStackTracePush(Multimap<ImFunction, ImFunctionCall> calls, Set<ImFunction> affectedFuncs) {
        for (ImFunction f : affectedFuncs) {
            if (this.isMainOrConfig(f)) continue;
            ImStmts stmts = f.getBody();
            de.peeeq.wurstscript.ast.Element trace = f.getTrace();
            stmts.add(0, this.increment(trace, this.stackSize));
            stmts.add(0, JassIm.ImSet(trace, JassIm.ImVarArrayAccess(trace, this.stack, JassIm.ImExprs(JassIm.ImVarAccess(this.stackSize))), this.getStackPosVar(f)));
        }
    }

    private ImExpr getStackPosVar(ImFunction f) {
        return JassIm.ImVarAccess(f.getParameters().stream().filter(this::isStackTraceParam).findFirst().orElseGet(() -> {
            throw new CompileError(f, "Function " + f.getName() + " has no stacktrace parameter.");
        }));
    }

    private void addStackTracePop(Set<ImFunction> affectedFuncs) {
        for (ImFunction f : affectedFuncs) {
            if (this.isMainOrConfig(f)) continue;
            final LinkedHashSet returns = new LinkedHashSet();
            f.getBody().accept(new Element.DefaultVisitor(){

                @Override
                public void visit(ImReturn imReturn) {
                    super.visit(imReturn);
                    returns.add(imReturn);
                }
            });
            int count = 0;
            for (ImReturn ret : returns) {
                ImReturn newReturn;
                ImStmts stmts = JassIm.ImStmts(new ImStmt[0]);
                ImExprOpt returnedValue = ret.getReturnValue();
                returnedValue.setParent(null);
                de.peeeq.wurstscript.ast.Element trace = ret.getTrace();
                if (returnedValue instanceof ImNoExpr || !this.containsAffectedFunctioncall(returnedValue)) {
                    newReturn = JassIm.ImReturn(trace, returnedValue);
                } else {
                    Object tempReturnName = "stackTrace_tempReturn";
                    if (++count > 1) {
                        tempReturnName = (String)tempReturnName + "_" + count;
                    }
                    ImVar temp = JassIm.ImVar(trace, f.getReturnType(), (String)tempReturnName, false);
                    f.getLocals().add(temp);
                    stmts.add(JassIm.ImSet(trace, JassIm.ImVarAccess(temp), (ImExpr)returnedValue));
                    newReturn = JassIm.ImReturn(trace, JassIm.ImVarAccess(temp));
                }
                stmts.add(this.decrement(trace, this.stackSize));
                stmts.add(newReturn);
                ret.replaceBy(ImHelper.statementExprVoid(stmts));
            }
            if (this.returnsOnAllPaths(f.getBody())) continue;
            f.getBody().add(this.decrement(f.getTrace(), this.stackSize));
        }
    }

    private boolean returnsOnAllPaths(ImStmts body) {
        for (ImStmt v : body) {
            ImIf imIf;
            if (v instanceof ImReturn) {
                return true;
            }
            if (!(v instanceof ImIf ? this.returnsOnAllPaths((imIf = (ImIf)v).getThenBlock()) && this.returnsOnAllPaths(imIf.getElseBlock()) : v instanceof ImStatementExpr && this.returnsOnAllPaths(((ImStatementExpr)v).getStatements()))) continue;
            return true;
        }
        return false;
    }

    private boolean containsAffectedFunctioncall(ImExprOpt ret) {
        final boolean[] res = new boolean[]{false};
        ret.accept(new Element.DefaultVisitor(){

            @Override
            public void visit(ImFunctionCall imFunctionCall) {
                super.visit(imFunctionCall);
                res[0] = true;
            }

            @Override
            public void visit(ImGetStackTrace imGetStackTrace) {
                super.visit(imGetStackTrace);
                res[0] = true;
            }
        });
        return res[0];
    }

    private boolean isMainOrConfig(ImFunction f) {
        Preconditions.checkNotNull((Object)f);
        return f.getName().equals("main") || f.getName().equals("config");
    }

    private void passStacktraceParams(Multimap<ImFunction, ImFunctionCall> calls, Set<ImFunction> affectedFuncs) {
        for (ImFunction f : affectedFuncs) {
            if (this.isMainOrConfig(f)) continue;
            Collection callsForF = calls.get((Object)f);
            f.getParameters().add(this.getStacktraceIndex(f), JassIm.ImVar(f.getTrace(), TypesHelper.imString(), STACK_POS_PARAM, false));
            for (ImFunctionCall call : callsForF) {
                String callPos = StackTraceInjector2.getCallPos(call.attrTrace().attrErrorPos());
                call.getArguments().add(this.getStacktraceIndex(call), this.str("when calling " + this.name(f) + callPos));
            }
        }
    }

    private int getStacktraceIndex(ImFunction f) {
        if (f.hasFlag(FunctionFlagEnum.IS_VARARG)) {
            return f.getParameters().size() - 1;
        }
        return f.getParameters().size();
    }

    private int getStacktraceIndex(ImFunctionCall c) {
        ImFunction f = c.getFunc();
        int res = f.getParameters().size() - 1;
        if (f.hasFlag(FunctionFlagEnum.IS_VARARG)) {
            --res;
        }
        if (res < 0 || res > c.getArguments().size() + 1) {
            throw new CompileError(c, "Call " + c + " invalid index " + res + " for parameters " + f.getParameters() + " and isVararg = " + f.hasFlag(FunctionFlagEnum.IS_VARARG));
        }
        return res;
    }

    private int getStacktraceIndex(ImMethodCall c) {
        ImFunction f = c.getMethod().getImplementation();
        int res = f.getParameters().size() - 2;
        if (f.hasFlag(FunctionFlagEnum.IS_VARARG)) {
            --res;
        }
        if (res < 0 || res > c.getArguments().size() + 1) {
            throw new CompileError(c, "Call " + c + " invalid index " + res + " for parameters " + f.getParameters() + " and isVararg = " + f.hasFlag(FunctionFlagEnum.IS_VARARG));
        }
        return res;
    }

    private String name(JassImElementWithName f) {
        @Nullable NameDef nameDef = f.attrTrace().tryGetNameDef();
        if (nameDef instanceof FunctionDefinition) {
            return nameDef.getName();
        }
        return f.getName();
    }

    public static String getCallPos(WPos source) {
        Object callPos = source.getFile().startsWith("<") ? "" : " in " + source.printShort();
        return callPos;
    }

    private ImStmt increment(de.peeeq.wurstscript.ast.Element trace, ImVar v) {
        return JassIm.ImSet(trace, JassIm.ImVarAccess(v), JassIm.ImOperatorCall(WurstOperator.PLUS, JassIm.ImExprs(JassIm.ImVarAccess(v), JassIm.ImIntVal(1))));
    }

    private ImStmt decrement(de.peeeq.wurstscript.ast.Element trace, ImVar v) {
        return JassIm.ImSet(trace, JassIm.ImVarAccess(v), JassIm.ImOperatorCall(WurstOperator.MINUS, JassIm.ImExprs(JassIm.ImVarAccess(v), JassIm.ImIntVal(1))));
    }

    private void rewriteFuncRefs(List<ImFuncRefOrCall> funcRefs, Set<ImFunction> affectedFuncs) {
        for (ImFuncRefOrCall fr : funcRefs) {
            ImFunction f = fr.getFunc();
            if (!affectedFuncs.contains(f)) continue;
            ImVars params = f.getParameters().copy();
            params.removeIf(this::isStackTraceParam);
            ImFunction bridgeFunc = JassIm.ImFunction(f.getTrace(), "bridge_" + f.getName(), JassIm.ImTypeVars(new ImTypeVar[0]), params, f.getReturnType().copy(), JassIm.ImVars(new ImVar[0]), JassIm.ImStmts(new ImStmt[0]), f.getFlags());
            this.prog.getFunctions().add(bridgeFunc);
            de.peeeq.wurstscript.ast.Element frTrace = fr.attrTrace();
            Object msg = fr instanceof ImFuncRef ? "via function reference" : "via ExecuteFunc";
            if (frTrace instanceof WurstModel) {
                ImFunction frEnclosingFunc = this.getEnclosingFunc(fr);
                if (frEnclosingFunc != null) {
                    msg = (String)msg + " in function " + frEnclosingFunc.getName();
                }
            } else {
                msg = (String)msg + " " + frTrace.attrSource().printShort();
            }
            ImExpr str = this.str((String)msg);
            ImExprs args = JassIm.ImExprs(str);
            ImStmts body = bridgeFunc.getBody();
            de.peeeq.wurstscript.ast.Element trace = frTrace;
            body.add(JassIm.ImSet(trace, JassIm.ImVarAccess(this.stackSize), JassIm.ImIntVal(0)));
            ImFunctionCall call = JassIm.ImFunctionCall(frTrace, f, JassIm.ImTypeArguments(new ImTypeArgument[0]), args, true, CallType.NORMAL);
            ImStmt stmt = bridgeFunc.getReturnType() instanceof ImVoid ? call : JassIm.ImReturn(frTrace, call);
            body.add(stmt);
            fr.setFunc(bridgeFunc);
        }
    }

    private ImFunction getEnclosingFunc(Element e) {
        while (e != null) {
            if (e instanceof ImFunction) {
                return (ImFunction)e;
            }
            e = e.getParent();
        }
        return null;
    }

    private boolean isStackTraceParam(ImVar p) {
        return p.getName().equals(STACK_POS_PARAM);
    }

    private void rewriteErrorStatements(Multimap<ImFunction, ImGetStackTrace> stackTraceGets) {
        for (Map.Entry e : stackTraceGets.entries()) {
            ImFunction f = (ImFunction)e.getKey();
            ImGetStackTrace s = (ImGetStackTrace)e.getValue();
            if (s == this.dummyGetStackTrace) continue;
            de.peeeq.wurstscript.ast.Element trace = s.attrTrace();
            ImVar traceStr = JassIm.ImVar(trace, TypesHelper.imString(), "stacktraceStr", false);
            f.getLocals().add(traceStr);
            ImVar traceI = JassIm.ImVar(trace, TypesHelper.imInt(), "stacktraceIndex", false);
            f.getLocals().add(traceI);
            ImVar traceLimit = JassIm.ImVar(trace, TypesHelper.imInt(), "stacktraceLimit", false);
            f.getLocals().add(traceLimit);
            ImStmts stmts = JassIm.ImStmts(new ImStmt[0]);
            stmts.add(JassIm.ImSet(trace, JassIm.ImVarAccess(traceStr), JassIm.ImStringVal("")));
            stmts.add(JassIm.ImSet(trace, JassIm.ImVarAccess(traceI), JassIm.ImVarAccess(this.stackSize)));
            stmts.add(JassIm.ImSet(trace, JassIm.ImVarAccess(traceLimit), JassIm.ImIntVal(0)));
            ImStmts loopBody = JassIm.ImStmts(new ImStmt[0]);
            stmts.add(JassIm.ImLoop(trace, loopBody));
            loopBody.add(this.decrement(trace, traceI));
            loopBody.add(this.increment(trace, traceLimit));
            loopBody.add(JassIm.ImExitwhen(trace, JassIm.ImOperatorCall(WurstOperator.GREATER, JassIm.ImExprs(JassIm.ImVarAccess(traceLimit), JassIm.ImIntVal(20)))));
            loopBody.add(JassIm.ImExitwhen(trace, JassIm.ImOperatorCall(WurstOperator.LESS, JassIm.ImExprs(JassIm.ImVarAccess(traceI), JassIm.ImIntVal(0)))));
            loopBody.add(JassIm.ImSet(trace, JassIm.ImVarAccess(traceStr), JassIm.ImOperatorCall(WurstOperator.PLUS, JassIm.ImExprs(JassIm.ImVarAccess(traceStr), JassIm.ImOperatorCall(WurstOperator.PLUS, JassIm.ImExprs(JassIm.ImStringVal("\n   "), JassIm.ImVarArrayAccess(trace, this.stack, JassIm.ImExprs(JassIm.ImVarAccess(traceI)))))))));
            s.replaceBy(JassIm.ImStatementExpr(stmts, JassIm.ImVarAccess(traceStr)));
        }
    }

    private ImExpr str(String s) {
        return JassIm.ImStringVal(s);
    }
}

