/*
 * Decompiled with CFR 0.152.
 */
package de.peeeq.wurstscript.translation.imoptimizer;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import de.peeeq.wurstscript.jassIm.Element;
import de.peeeq.wurstscript.jassIm.ImConst;
import de.peeeq.wurstscript.jassIm.ImExpr;
import de.peeeq.wurstscript.jassIm.ImExprOpt;
import de.peeeq.wurstscript.jassIm.ImFunction;
import de.peeeq.wurstscript.jassIm.ImFunctionCall;
import de.peeeq.wurstscript.jassIm.ImProg;
import de.peeeq.wurstscript.jassIm.ImReturn;
import de.peeeq.wurstscript.jassIm.ImStatementExpr;
import de.peeeq.wurstscript.jassIm.ImStmt;
import de.peeeq.wurstscript.jassIm.ImStmts;
import de.peeeq.wurstscript.jassIm.ImVar;
import de.peeeq.wurstscript.jassIm.JassIm;
import de.peeeq.wurstscript.translation.imtranslation.CallType;
import de.peeeq.wurstscript.translation.imtranslation.FunctionFlag;
import de.peeeq.wurstscript.translation.imtranslation.FunctionFlagAnnotation;
import de.peeeq.wurstscript.translation.imtranslation.FunctionFlagEnum;
import de.peeeq.wurstscript.translation.imtranslation.ImHelper;
import de.peeeq.wurstscript.translation.imtranslation.ImTranslator;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class ImInliner {
    private static final String FORCEINLINE = "@inline";
    private static final String NOINLINE = "@noinline";
    private static final double THRESHOLD_MODIFIER_CONSTANT_ARG = 2.0;
    private static final Set<String> dontInline = Sets.newLinkedHashSet();
    private final ImTranslator translator;
    private final ImProg prog;
    private final Set<ImFunction> inlinableFunctions = Sets.newLinkedHashSet();
    private final Map<ImFunction, Integer> callCounts = Maps.newLinkedHashMap();
    private final Map<ImFunction, Integer> funcSizes = Maps.newLinkedHashMap();
    private final Set<ImFunction> done = Sets.newLinkedHashSet();
    private final double inlineTreshold = 50.0;

    public ImInliner(ImTranslator translator) {
        this.translator = translator;
        this.prog = translator.getImProg();
    }

    public void doInlining() {
        this.prog.flatten(this.translator);
        this.collectInlinableFunctions();
        this.rateInlinableFunctions();
        this.inlineFunctions();
    }

    private void inlineFunctions() {
        for (ImFunction f : ImHelper.calculateFunctionsOfProg(this.prog)) {
            this.inlineFunctions(f);
        }
    }

    private void inlineFunctions(ImFunction f) {
        if (this.done.contains(f)) {
            return;
        }
        this.done.add(f);
        for (ImFunction called : this.translator.getCalledFunctions().get((Object)f)) {
            this.inlineFunctions(called);
        }
        boolean[] changed = new boolean[]{false};
        this.inlineFunctions(f, f, 0, f.getBody(), changed, Collections.emptyMap());
    }

    private ImFunction inlineFunctions(ImFunction f, Element parent, int parentI, Element e, boolean[] changed, Map<ImFunction, Integer> alreadyInlined) {
        ImFunctionCall call;
        ImFunction called;
        if (e instanceof ImFunctionCall && f != (called = (call = (ImFunctionCall)e).getFunc()) && this.shouldInline(call, called) && alreadyInlined.getOrDefault(called, 0) < 5) {
            this.inlineCall(f, parent, parentI, call);
            changed[0] = true;
            int newSize = this.estimateSize(f);
            this.funcSizes.put(f, newSize);
            return called;
        }
        for (int i = 0; i < e.size(); ++i) {
            Element child;
            ImFunction inlined;
            Map<ImFunction, Integer> alreadyInlined2 = alreadyInlined;
            while ((inlined = this.inlineFunctions(f, e, i, child = e.get(i), changed, alreadyInlined2)) != null) {
                if (alreadyInlined2 == alreadyInlined) {
                    alreadyInlined2 = new HashMap<ImFunction, Integer>(alreadyInlined);
                }
                alreadyInlined2.put(inlined, 1 + alreadyInlined.getOrDefault(inlined, 0));
            }
        }
        return null;
    }

    private void inlineCall(ImFunction f, Element parent, int parentI, ImFunctionCall call) {
        ImStmt lastStmt;
        ImFunction called = call.getFunc();
        if (called == f) {
            throw new Error("cannot inline self.");
        }
        ArrayList stmts = Lists.newArrayList();
        List args = call.getArguments().removeAll();
        LinkedHashMap varSubtitutions = Maps.newLinkedHashMap();
        for (int pi = 0; pi < called.getParameters().size(); ++pi) {
            ImVar param = (ImVar)called.getParameters().get(pi);
            ImExpr arg = (ImExpr)args.get(pi);
            ImVar tempVar = JassIm.ImVar(arg.attrTrace(), param.getType(), param.getName(), false);
            f.getLocals().add(tempVar);
            varSubtitutions.put(param, tempVar);
            stmts.add(JassIm.ImSet(arg.attrTrace(), JassIm.ImVarAccess(tempVar), arg));
        }
        for (ImVar l : called.getLocals()) {
            ImVar newL = JassIm.ImVar(l.getTrace(), l.getType(), l.getName(), false);
            f.getLocals().add(newL);
            varSubtitutions.put(l, newL);
        }
        for (int i = 0; i < called.getBody().size(); ++i) {
            ImStmt s = ((ImStmt)called.getBody().get(i)).copy();
            ImHelper.replaceVar(s, varSubtitutions);
            s.accept(new Element.DefaultVisitor(){

                @Override
                public void visit(ImFunctionCall called) {
                    super.visit(called);
                    ImInliner.this.incCallCount(called.getFunc());
                }
            });
            stmts.add(s);
        }
        ImStatementExpr newExpr = null;
        if (stmts.size() > 0 && (lastStmt = (ImStmt)stmts.get(stmts.size() - 1)) instanceof ImReturn) {
            ImReturn ret = (ImReturn)lastStmt;
            stmts.remove(stmts.size() - 1);
            ImExprOpt valOpt = ret.getReturnValue();
            if (valOpt instanceof ImExpr) {
                ImExpr val = (ImExpr)valOpt.copy();
                ImHelper.replaceVar(val, varSubtitutions);
                newExpr = JassIm.ImStatementExpr(JassIm.ImStmts(stmts), val);
            }
        }
        if (newExpr == null) {
            newExpr = ImHelper.statementExprVoid(JassIm.ImStmts(stmts));
        }
        parent.set(parentI, newExpr);
    }

    private void rateInlinableFunctions() {
        for (Map.Entry entry : this.translator.getCalledFunctions().entries()) {
            this.incCallCount((ImFunction)entry.getKey());
        }
        for (ImFunction imFunction : this.inlinableFunctions) {
            int size = this.estimateSize(imFunction);
            this.funcSizes.put(imFunction, size);
        }
    }

    private double getRating(ImFunction f) {
        if (f.isNative() || !this.inlinableFunctions.contains(f) || dontInline.contains(f.getName())) {
            return Double.MAX_VALUE;
        }
        for (FunctionFlag flag : f.getFlags()) {
            if (!(flag instanceof FunctionFlagAnnotation)) continue;
            if (((FunctionFlagAnnotation)flag).getAnnotation().equals(FORCEINLINE)) {
                return 1.0;
            }
            if (!((FunctionFlagAnnotation)flag).getAnnotation().equals(NOINLINE)) continue;
            return Double.MAX_VALUE;
        }
        double size = this.getFuncSize(f);
        if (size < 20.0) {
            return 1.0;
        }
        double callCount = this.getCallCount(f);
        double rating = size * (callCount - 1.0);
        return rating;
    }

    private int getFuncSize(ImFunction f) {
        Integer size = this.funcSizes.get(f);
        if (size != null) {
            return size;
        }
        return Integer.MAX_VALUE;
    }

    private boolean shouldInline(ImFunctionCall call, ImFunction f) {
        if (f.isNative() || call.getCallType() == CallType.EXECUTE) {
            return false;
        }
        double threshold = 50.0;
        for (ImExpr arg : call.getArguments()) {
            if (!(arg instanceof ImConst)) continue;
            threshold *= 2.0;
            break;
        }
        return this.inlinableFunctions.contains(f) && this.getRating(f) < threshold && !this.isRecursive(f);
    }

    private boolean isRecursive(ImFunction f) {
        return this.containsCallTo(f, f.getBody());
    }

    private boolean containsCallTo(ImFunction f, Element e) {
        ImFunctionCall call;
        if (e instanceof ImFunctionCall && (call = (ImFunctionCall)e).getFunc() == f) {
            return true;
        }
        for (int i = 0; i < e.size(); ++i) {
            if (!this.containsCallTo(f, e.get(i))) continue;
            return true;
        }
        return false;
    }

    private int estimateSize(ImFunction f) {
        int[] r = new int[]{0};
        this.estimateSize(f.getBody(), r);
        return r[0];
    }

    private void estimateSize(Element e, int[] r) {
        for (int i = 0; i < e.size(); ++i) {
            r[0] = r[0] + 1;
            this.estimateSize(e.get(i), r);
        }
    }

    private void incCallCount(ImFunction f) {
        int count = this.getCallCount(f);
        this.callCounts.put(f, ++count);
    }

    private int getCallCount(ImFunction f) {
        Integer r = this.callCounts.get(f);
        if (r == null) {
            return 0;
        }
        return r;
    }

    private void collectInlinableFunctions() {
        for (ImFunction f : ImHelper.calculateFunctionsOfProg(this.prog)) {
            if (f.hasFlag(FunctionFlagEnum.IS_COMPILETIME_NATIVE) || f.hasFlag(FunctionFlagEnum.IS_NATIVE) || f == this.translator.getGlobalInitFunc() || f.hasFlag(FunctionFlagEnum.IS_VARARG) || !this.maxOneReturn(f)) continue;
            this.inlinableFunctions.add(f);
        }
    }

    private boolean maxOneReturn(ImFunction f) {
        return this.maxOneReturn(f.getBody());
    }

    private boolean maxOneReturn(ImStmts body) {
        if (body.size() == 0) {
            return true;
        }
        for (int i = 0; i < body.size() - 1; ++i) {
            if (!this.hasReturn((ImStmt)body.get(i))) continue;
            return false;
        }
        if (body.get(body.size() - 1) instanceof ImReturn) {
            return true;
        }
        return !this.hasReturn((ImStmt)body.get(body.size() - 1));
    }

    private boolean hasReturn(ImStmt s) {
        final boolean[] r = new boolean[]{false};
        s.accept(new Element.DefaultVisitor(){

            @Override
            public void visit(ImReturn rs) {
                super.visit(rs);
                r[0] = true;
            }
        });
        return r[0];
    }

    static {
        dontInline.add("SetPlayerAllianceStateAllyBJ");
        dontInline.add("InitBlizzard");
        dontInline.add("error");
    }
}

