/*
 * Decompiled with CFR 0.152.
 */
package de.peeeq.wurstscript.translation.imtranslation;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import de.peeeq.wurstscript.jassIm.Element;
import de.peeeq.wurstscript.jassIm.ImExprs;
import de.peeeq.wurstscript.jassIm.ImFunction;
import de.peeeq.wurstscript.jassIm.ImFunctionCall;
import de.peeeq.wurstscript.jassIm.ImProg;
import de.peeeq.wurstscript.jassIm.ImStatementExpr;
import de.peeeq.wurstscript.jassIm.ImStmt;
import de.peeeq.wurstscript.jassIm.ImStmts;
import de.peeeq.wurstscript.jassIm.ImType;
import de.peeeq.wurstscript.jassIm.ImTypeArgument;
import de.peeeq.wurstscript.jassIm.ImVar;
import de.peeeq.wurstscript.jassIm.ImVarAccess;
import de.peeeq.wurstscript.jassIm.ImVarargLoop;
import de.peeeq.wurstscript.jassIm.JassIm;
import de.peeeq.wurstscript.translation.imtranslation.FunctionFlagEnum;
import de.peeeq.wurstscript.translation.imtranslation.ImHelper;
import de.peeeq.wurstscript.translation.imtranslation.ReferenceRewritingCopy;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;

public class VarargEliminator {
    private final ImProg prog;
    private final Table<ImFunction, Integer, ImFunction> varargFuncs = HashBasedTable.create();

    public VarargEliminator(ImProg prog) {
        this.prog = prog;
    }

    public void run() {
        for (ImFunctionCall c : this.collectVarargCalls()) {
            if (!c.getFunc().hasFlag(FunctionFlagEnum.IS_VARARG)) continue;
            this.generateVarargFunc(c.getFunc(), c.getArguments().size());
        }
        this.prog.getFunctions().removeIf(f -> f.hasFlag(FunctionFlagEnum.IS_VARARG));
        for (ImFunctionCall call : this.collectVarargCalls()) {
            this.redirectCall(call, (ImFunction)this.varargFuncs.get((Object)call.getFunc(), (Object)call.getArguments().size()));
        }
    }

    @NotNull
    private Collection<ImFunctionCall> collectVarargCalls() {
        final ArrayList<ImFunctionCall> calls = new ArrayList<ImFunctionCall>();
        this.prog.accept(new Element.DefaultVisitor(){

            @Override
            public void visit(ImFunctionCall c) {
                super.visit(c);
                if (c.getFunc().hasFlag(FunctionFlagEnum.IS_VARARG)) {
                    calls.add(c);
                }
            }
        });
        return calls;
    }

    private void generateVarargFunc(ImFunction func, int numberOfParams) {
        if (this.varargFuncs.contains((Object)func, (Object)numberOfParams)) {
            return;
        }
        int argumentSize = 1 + numberOfParams - func.getParameters().size();
        ImFunction newFunc = ReferenceRewritingCopy.copy(func);
        newFunc.setName(func.getName() + "_" + argumentSize);
        ImVar varargParam = (ImVar)newFunc.getParameters().remove(newFunc.getParameters().size() - 1);
        ImType type = varargParam.getType();
        final ArrayList<ImVar> newParams = new ArrayList<ImVar>();
        for (int i = 0; i < argumentSize; ++i) {
            ImVar param = JassIm.ImVar(func.getTrace(), type, varargParam.getName() + "_" + i, false);
            newParams.add(param);
            newFunc.getParameters().add(param);
        }
        newFunc.getBody().accept(new Element.DefaultVisitor(){

            @Override
            public void visit(ImVarargLoop imLoop) {
                super.visit(imLoop);
                VarargEliminator.this.unrollVarargLoop(imLoop, newParams);
            }
        });
        List<ImVarAccess> varargParamUses = this.collectUsesOfVar(newFunc, varargParam);
        for (ImVarAccess va : varargParamUses) {
            ImExprs params = (ImExprs)va.getParent();
            ImFunctionCall call = (ImFunctionCall)params.getParent();
            params.remove(va);
            params.addAll((Collection)newParams.stream().map(JassIm::ImVarAccess).collect(Collectors.toList()));
            this.generateVarargFunc(call.getFunc(), call.getArguments().size());
        }
        newFunc.setFlags(newFunc.getFlags().stream().filter(flag -> flag != FunctionFlagEnum.IS_VARARG).collect(Collectors.toList()));
        this.prog.getFunctions().add(newFunc);
        this.varargFuncs.put((Object)func, (Object)numberOfParams, (Object)newFunc);
    }

    @NotNull
    private List<ImVarAccess> collectUsesOfVar(ImFunction newFunc, final ImVar varargParam) {
        final ArrayList<ImVarAccess> varargParamUses = new ArrayList<ImVarAccess>();
        newFunc.getBody().accept(new Element.DefaultVisitor(){

            @Override
            public void visit(ImVarAccess va) {
                super.visit(va);
                if (va.getVar() == varargParam) {
                    varargParamUses.add(va);
                }
            }
        });
        return varargParamUses;
    }

    private void redirectCall(ImFunctionCall call, ImFunction newFunc) {
        ImFunctionCall newCall = JassIm.ImFunctionCall(call.getTrace(), newFunc, JassIm.ImTypeArguments(new ImTypeArgument[0]), JassIm.ImExprs(call.getArguments().removeAll()), call.getTuplesEliminated(), call.getCallType());
        call.replaceBy(newCall);
    }

    private void unrollVarargLoop(final ImVarargLoop imLoop, final List<ImVar> newParams) {
        ImStatementExpr stmtExpr = ImHelper.statementExprVoid(JassIm.ImStmts(new ImStmt[0]));
        int i = 0;
        while (i < newParams.size()) {
            ImStmts bodyCopy = imLoop.getBody().copy();
            final int finalI = i++;
            bodyCopy.accept(new Element.DefaultVisitor(){

                @Override
                public void visit(ImVarAccess access) {
                    super.visit(access);
                    if (access.getVar() == imLoop.getLoopVar()) {
                        access.setVar((ImVar)newParams.get(finalI));
                    }
                }
            });
            stmtExpr.getStatements().addAll((Collection)bodyCopy.removeAll());
        }
        imLoop.replaceBy(stmtExpr);
    }
}

