package de.peeeq.wurstio; import com.google.common.base.Preconditions; import com.google.common.collect.LinkedListMultimap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import config.WurstProjectConfigData; import de.peeeq.wurstio.intermediateLang.interpreter.CompiletimeNatives; import de.peeeq.wurstio.intermediateLang.interpreter.ProgramStateIO; import de.peeeq.wurstio.jassinterpreter.InterpreterException; import de.peeeq.wurstio.jassinterpreter.ReflectionNativeProvider; import de.peeeq.wurstio.jassinterpreter.providers.HashtableProvider; import de.peeeq.wurstio.mpq.MpqEditor; import de.peeeq.wurstscript.WLogger; import de.peeeq.wurstscript.ast.Element; import de.peeeq.wurstscript.attributes.CompileError; import de.peeeq.wurstscript.attributes.ErrorHandler; import de.peeeq.wurstscript.gui.WurstGui; import de.peeeq.wurstscript.intermediatelang.*; import de.peeeq.wurstscript.intermediatelang.interpreter.ILInterpreter; import de.peeeq.wurstscript.intermediatelang.interpreter.ILStackFrame; import de.peeeq.wurstscript.intermediatelang.interpreter.LocalState; import de.peeeq.wurstscript.intermediatelang.interpreter.ProgramState; import de.peeeq.wurstscript.intermediatelang.optimizer.FunctionSplitter; import de.peeeq.wurstscript.jassIm.*; import de.peeeq.wurstscript.jassinterpreter.TestFailException; import de.peeeq.wurstscript.jassinterpreter.TestSuccessException; import de.peeeq.wurstscript.parser.WPos; import de.peeeq.wurstscript.translation.imtranslation.*; import de.peeeq.wurstscript.types.TypesHelper; import de.peeeq.wurstscript.utils.Pair; import de.peeeq.wurstscript.utils.Utils; import org.eclipse.lsp4j.jsonrpc.messages.Either; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; import java.io.File; import java.io.PrintStream; import java.util.*; import java.util.stream.Collectors; public class CompiletimeFunctionRunner { private final ImProg imProg; private final ILInterpreter interpreter; private final WurstGui gui; private final FunctionFlagToRun functionFlag; private final List successTests = Lists.newArrayList(); private final Map> failTests = Maps.newLinkedHashMap(); private final ProgramStateIO globalState; private final ImTranslator translator; private boolean injectObjects; private final Deque delayedActions = new ArrayDeque<>(); public ILInterpreter getInterpreter() { return interpreter; } public ProgramStateIO getGlobalState() { return globalState; } public enum FunctionFlagToRun { Tests { @Override public boolean matches(ImFunction f) { return f.hasFlag(FunctionFlagEnum.IS_TEST) || f.isCompiletime(); } }, CompiletimeFunctions { @Override public boolean matches(ImFunction f) { return f.isCompiletime(); } }; public abstract boolean matches(ImFunction f); } public CompiletimeFunctionRunner( ImTranslator tr, ImProg imProg, Optional mapFile, MpqEditor mpqEditor, WurstGui gui, FunctionFlagToRun flag, WurstProjectConfigData projectConfigData, boolean isProd, boolean cache) { Preconditions.checkNotNull(imProg); this.translator = tr; this.imProg = imProg; globalState = new ProgramStateIO(mapFile, mpqEditor, gui, imProg, true); this.interpreter = new ILInterpreter(imProg, gui, mapFile, globalState, cache); interpreter.addNativeProvider(new CompiletimeNatives(globalState, projectConfigData, isProd)); interpreter.addNativeProvider(new ReflectionNativeProvider(interpreter)); this.gui = gui; this.functionFlag = flag; } public void run() { try { List> toExecute = new ArrayList<>(); collectCompiletimeExpressions(toExecute); collectCompiletimeFunctions(toExecute); toExecute.sort(Comparator.comparing(this::getOrderIndex)); execute(toExecute); if (functionFlag == FunctionFlagToRun.CompiletimeFunctions) { interpreter.writebackGlobalState(isInjectObjects()); } runDelayedActions(); partitionCompiletimeStateInitFunction(); } catch (InterpreterException e) { Element origin = e.getTrace(); sendErrors(origin, e.getMessage(), e); if (isUnitTestMode()) { throw e; } } catch (Throwable e) { WLogger.severe(e); de.peeeq.wurstscript.jassIm.Element s = interpreter.getLastStatement(); Element origin = s == null ? null : s.attrTrace(); if (origin != null) { String msg = e.getMessage(); sendErrors(origin, msg, e); } else { throw new Error("could not get origin", e); } if (isUnitTestMode()) { throw e; } } } private void partitionCompiletimeStateInitFunction() { if (compiletimeStateInitFunction == null) { return; } FunctionSplitter.splitFunc(translator, compiletimeStateInitFunction); } private boolean isUnitTestMode() { return Optional.ofNullable(imProg) .map(ImProg::attrTrace) .map(Element::getErrorHandler) .map(ErrorHandler::isUnitTestMode) .orElse(false); } private void sendErrors(Element origin, String msg, Throwable ex) { gui.sendError(new CompileError(origin.attrSource(), msg, CompileError.ErrorType.ERROR, ex)); // stackframe messages ... for (ILStackFrame sf : Utils.iterateReverse(interpreter.getStackFrames().getStackFrames())) { gui.sendError(sf.makeCompileError()); } } /** * Run actions that must be run after all other code */ private void runDelayedActions() { while (!delayedActions.isEmpty()) { delayedActions.removeFirst().run(); } } private void execute(List> es) { for (Either e : es) { if (e.isLeft()) { ImCompiletimeExpr cte = e.getLeft(); executeCompiletimeExpr(cte); } else { ImFunction f = e.getRight(); executeCompiletimeFunction(f); } } interpreter.completeTimers(); } private int getOrderIndex(Either e) { if (e.isLeft()) { ImCompiletimeExpr cte = e.getLeft(); return cte.getExecutionOrderIndex(); } else { ImFunction f = e.getRight(); for (FunctionFlag flag : f.getFlags()) { if (flag instanceof FunctionFlagCompiletime) { FunctionFlagCompiletime cflag = (FunctionFlagCompiletime) flag; return cflag.getOrderIndex(); } } return 0; } } private void collectCompiletimeFunctions(List> toExecute) { for (ImFunction f : imProg.getFunctions()) { if (functionFlag.matches(f)) { toExecute.add(Either.forRight(f)); } } } private void collectCompiletimeExpressions(List> toExecute) { imProg.accept(new de.peeeq.wurstscript.jassIm.Element.DefaultVisitor() { @Override public void visit(ImCompiletimeExpr e) { super.visit(e); toExecute.add(Either.forLeft(e)); } }); } private void executeCompiletimeExpr(ImCompiletimeExpr cte) { try { ProgramState globalState = interpreter.getGlobalState(); globalState.setLastStatement(cte); globalState.resetStackframes(); globalState.pushStackframe(cte, cte.attrTrace().attrErrorPos()); LocalState localState = new LocalState(); ILconst value = cte.evaluate(globalState, localState); ImExpr newExpr = constantToExpr(cte.getTrace(), value); if(translator.isLuaTarget() && value.toString().equals("0")) { // convert 0 to null/nil, if the value is 0 and not a numeric type ImExpr expr = cte.getExpr(); if(expr instanceof ImNull) { newExpr = ImHelper.nullExpr(); } else { @Nullable ImType exprType = null; if(expr instanceof ImFunctionCall) { exprType = ((ImFunctionCall) expr).getFunc().getReturnType(); } else if(expr instanceof ImVarAccess) { exprType = ((ImVarAccess)expr).getVar().getType(); } else if(expr instanceof ImVarArrayAccess) { ImType type = ((ImVarArrayAccess)expr).getVar().getType(); if(type instanceof ImArrayLikeType) { exprType = ((ImArrayLikeType) type).getEntryType(); } } if(exprType != null && !TypesHelper.isIntType(exprType) && !TypesHelper.isRealType(exprType)) { newExpr = ImHelper.nullExpr(); } } // TODO is this complete? Are there more cases where 0 must be replaced? // A function can return null // null can be a literal // null can be a variable } cte.replaceBy(newExpr); } catch (InterpreterException e) { String msg = ILInterpreter.buildStacktrace(globalState, e); e.setStacktrace(msg); e.setTrace(cte.attrTrace()); throw e; } } private final GetAForB globalForObject = new GetAForB() { @Override public ImVar initFor(ILconstObject obj) { ImVar res = JassIm.ImVar(obj.getTrace(), obj.getType(), obj.getType() + "_compiletime", false); imProg.getGlobals().add(res); ImAlloc alloc = JassIm.ImAlloc(obj.getTrace(), obj.getType()); addCompiletimeStateInitAlloc(alloc.getTrace(), res, alloc); globalState.setVal(res, obj); Element trace = obj.getTrace(); delayedActions.add(() -> { for (Map.Entry, ILconst>> entry : obj.getAttributes().rowMap().entrySet()) { ImVar var = entry.getKey(); Map, ILconst> value1 = entry.getValue(); for (Map.Entry, ILconst> entry2 : value1.entrySet()) { List indexes = entry2.getKey(); ILconst attrValue = entry2.getValue(); ImExprs indexesT = indexes.stream() .map(i -> constantToExpr(trace, ILconstInt.create(i))) .collect(Collectors.toCollection(JassIm::ImExprs)); ImExpr value2 = constantToExpr(trace, attrValue); if(translator.isLuaTarget() && value2.toString().equals("0")) { ImType varType = var.getType(); if(varType instanceof ImArrayLikeType) { varType = ((ImArrayLikeType) varType).getEntryType(); } if (!TypesHelper.isIntType(varType) && !TypesHelper.isRealType(varType)) { value2 = ImHelper.nullExpr(); } } addCompiletimeStateInit(JassIm.ImSet(trace, JassIm.ImMemberAccess(trace, JassIm.ImVarAccess(res), JassIm.ImTypeArguments(), var, indexesT), value2)); } } }); return res; } }; private final GetAForB globalForHandle = new GetAForB() { @Override public ImVar initFor(IlConstHandle a) { Element trace = imProg.getTrace(); ImExpr init; Object obj = a.getObj(); if (obj instanceof LinkedListMultimap) { @SuppressWarnings("unchecked") LinkedListMultimap map = (LinkedListMultimap) obj; ImType type = TypesHelper.imHashTable(); ImVar res = JassIm.ImVar(trace, type, type + "_compiletime", false); imProg.getGlobals().add(res); globalState.setVal(res, a); init = constantToExprHashtable(trace, res, a, map); addCompiletimeStateInitAlloc(trace, res, init); return res; } else { throw new RuntimeException("Handle value " + obj + " (" + obj.getClass() + ") can not be persistet at compiletime"); } } }; private ImExpr constantToExpr(Element trace, ILconst value) { if (value instanceof ILconstBool) { return JassIm.ImBoolVal(((ILconstBool) value).getVal()); } else if (value instanceof ILconstInt) { return JassIm.ImIntVal(((ILconstInt) value).getVal()); } else if (value instanceof ILconstReal) { return JassIm.ImRealVal("" + ((ILconstReal) value).getVal()); } else if (value instanceof ILconstString) { return JassIm.ImStringVal(((ILconstString) value).getVal()); } else if (value instanceof ILconstTuple) { return JassIm.ImTupleExpr( JassIm.ImExprs( ((ILconstTuple) value).values().stream() .map(e -> constantToExpr(trace, e)) .collect(Collectors.toList()) ) ); } else if (value instanceof IlConstHandle) { IlConstHandle h = (IlConstHandle) value; ImVar hVar = globalForHandle.getFor(h); return JassIm.ImVarAccess(hVar); } else if (value instanceof ILconstObject) { ILconstObject obj = globalState.toObject(value); ImVar v = globalForObject.getFor(obj); return JassIm.ImVarAccess(v); } throw new InterpreterException(trace, "Compiletime expression returned unsupported value " + value); } private ImFunction compiletimeStateInitFunction = null; private ImFunction getCompiletimeStateInitFunction() { ImFunction res = this.compiletimeStateInitFunction; if (res == null) { Element trace = imProg.getTrace(); res = JassIm.ImFunction(trace, "initCompiletimeState", JassIm.ImTypeVars(), JassIm.ImVars(), JassIm.ImVoid(), JassIm.ImVars(), JassIm.ImStmts(), Collections.emptyList()); imProg.getFunctions().add(res); compiletimeStateInitFunction = res; ImFunction mainFunc = translator.getMainFunc(); ImFunction globalInitFunc = translator.getGlobalInitFunc(); Preconditions.checkNotNull(mainFunc); ListIterator iterator = mainFunc.getBody().listIterator(); ImFunctionCall call = JassIm.ImFunctionCall(trace, res, JassIm.ImTypeArguments(), JassIm.ImExprs(), true, CallType.NORMAL); while (iterator.hasNext()) { ImStmt stmt = iterator.next(); if (stmt instanceof ImFunctionCall) { ImFunctionCall fc = (ImFunctionCall) stmt; if (fc.getFunc() == globalInitFunc) { // call initCompiletimeState right after globalInitFunc iterator.add(call); return res; } } } iterator.add(call); } return res; } // insert at the beginning private void addCompiletimeStateInitAlloc(Element trace, ImVar v, ImExpr init) { ImSet imSet = JassIm.ImSet(trace, JassIm.ImVarAccess(v), init.copy()); imProg.getGlobalInits().put(v, Collections.singletonList(imSet)); getCompiletimeStateInitFunction().getBody().add(0, imSet); } // insert at the end private void addCompiletimeStateInit(ImStmt stmt) { getCompiletimeStateInitFunction().getBody().add(stmt); } /** * Stores a hashtable value in a compiletime expression * by generating the respective native calls */ private ImExpr constantToExprHashtable(Element trace, ImVar htVar, IlConstHandle handle, LinkedListMultimap map) { WPos errorPos = trace.attrErrorPos(); // we have to collect all values after all compiletime functions have run, so use delayedActions delayedActions.add(() -> { for (Map.Entry entry : map.entries()) { HashtableProvider.KeyPair key = entry.getKey(); Object v = entry.getValue(); if (v instanceof ILconstInt) { ILconstInt iv = (ILconstInt) v; ImFunction SaveInteger = findNative("SaveInteger", errorPos); addCompiletimeStateInit(JassIm.ImFunctionCall(trace, SaveInteger, JassIm.ImTypeArguments(), JassIm.ImExprs( JassIm.ImVarAccess(htVar), JassIm.ImIntVal(key.getParentkey()), JassIm.ImIntVal(key.getChildkey()), JassIm.ImIntVal(iv.getVal()) ), false, CallType.NORMAL)); } else if (v instanceof ILconstReal) { ILconstReal iv = (ILconstReal) v; ImFunction SaveReal = findNative("SaveReal", errorPos); addCompiletimeStateInit(JassIm.ImFunctionCall(trace, SaveReal, JassIm.ImTypeArguments(), JassIm.ImExprs( JassIm.ImVarAccess(htVar), JassIm.ImIntVal(key.getParentkey()), JassIm.ImIntVal(key.getChildkey()), JassIm.ImRealVal("" + iv.getVal()) ), false, CallType.NORMAL)); } else if (v instanceof ILconstString) { ILconstString iv = (ILconstString) v; ImFunction SaveStr = findNative("SaveStr", errorPos); addCompiletimeStateInit(JassIm.ImFunctionCall(trace, SaveStr, JassIm.ImTypeArguments(), JassIm.ImExprs( JassIm.ImVarAccess(htVar), JassIm.ImIntVal(key.getParentkey()), JassIm.ImIntVal(key.getChildkey()), JassIm.ImStringVal(iv.getVal()) ), false, CallType.NORMAL)); } else if (v instanceof ILconstBool) { ILconstBool iv = (ILconstBool) v; ImFunction SaveBoolean = findNative("SaveBoolean", errorPos); addCompiletimeStateInit(JassIm.ImFunctionCall(trace, SaveBoolean, JassIm.ImTypeArguments(), JassIm.ImExprs( JassIm.ImVarAccess(htVar), JassIm.ImIntVal(key.getParentkey()), JassIm.ImIntVal(key.getChildkey()), JassIm.ImBoolVal(iv.getVal()) ), false, CallType.NORMAL)); } else if (v instanceof ILconstNull) { // treat null like no entry } else { throw new CompileError(errorPos, "Unsupported value stored in HashMap: " + v + " // " + v.getClass().getSimpleName()); } } }); // we already return the expr and fill out stmts in delayedActions (see above) ImFunction initHashtable = findNative("InitHashtable", errorPos); return JassIm.ImFunctionCall(trace, initHashtable, JassIm.ImTypeArguments(), JassIm.ImExprs(), false, CallType.NORMAL); } @NotNull private ImFunction findNative(String funcName, WPos trace) { return imProg.getFunctions() .stream() .filter(ImFunction::isNative) .filter(func -> func.getName().equals(funcName)) .findFirst() .orElseGet(() -> { throw new CompileError(trace, "Could not find native 'InitHashtable'"); }); } private void executeCompiletimeFunction(ImFunction f) { if (functionFlag.matches(f)) { try { if (!f.getBody().isEmpty()) { interpreter.getGlobalState().setLastStatement(f.getBody().get(0)); } WLogger.info("running " + functionFlag + " function " + f.getName()); interpreter.runVoidFunc(f, null); successTests.add(f); } catch (TestSuccessException e) { successTests.add(f); } catch (TestFailException e) { failTests.put(f, Pair.create(interpreter.getLastStatement(), e.toString())); } catch (Throwable e) { failTests.put(f, Pair.create(interpreter.getLastStatement(), e.toString())); throw e; } } } public List getSuccessTests() { return successTests; } public Map> getFailTests() { return failTests; } public boolean isInjectObjects() { return injectObjects; } public void setInjectObjects(boolean injectObjects) { this.injectObjects = injectObjects; } public void setOutputStream(PrintStream printStream) { interpreter.getGlobalState().setOutStream(printStream); } }