diff --git a/prometeu-compiler/frontends/prometeu-frontend-pbs/src/main/java/p/studio/compiler/pbs/PbsFrontendCompiler.java b/prometeu-compiler/frontends/prometeu-frontend-pbs/src/main/java/p/studio/compiler/pbs/PbsFrontendCompiler.java index 78687576..2a190660 100644 --- a/prometeu-compiler/frontends/prometeu-frontend-pbs/src/main/java/p/studio/compiler/pbs/PbsFrontendCompiler.java +++ b/prometeu-compiler/frontends/prometeu-frontend-pbs/src/main/java/p/studio/compiler/pbs/PbsFrontendCompiler.java @@ -3,6 +3,7 @@ package p.studio.compiler.pbs; import p.studio.compiler.models.IRFunction; import p.studio.compiler.models.SourceKind; import p.studio.compiler.models.IRBackendFile; +import p.studio.compiler.models.IRBackendExecutableFunction; import p.studio.compiler.messages.HostAdmissionContext; import p.studio.compiler.models.IRReservedMetadata; import p.studio.compiler.pbs.ast.PbsAst; @@ -16,6 +17,9 @@ import p.studio.compiler.source.identifiers.FileId; import p.studio.utilities.structures.ReadOnlyList; import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; public final class PbsFrontendCompiler { private final PbsDeclarationSemanticsValidator declarationSemanticsValidator = new PbsDeclarationSemanticsValidator(); @@ -61,7 +65,7 @@ public final class PbsFrontendCompiler { final FileId fileId, final PbsAst.File ast, final DiagnosticSink diagnostics) { - return compileParsedFile(fileId, ast, diagnostics, SourceKind.PROJECT); + return compileParsedFile(fileId, ast, diagnostics, SourceKind.PROJECT, "", HostAdmissionContext.permissiveDefault()); } public IRBackendFile compileParsedFile( @@ -69,7 +73,7 @@ public final class PbsFrontendCompiler { final PbsAst.File ast, final DiagnosticSink diagnostics, final SourceKind sourceKind) { - return compileParsedFile(fileId, ast, diagnostics, sourceKind, HostAdmissionContext.permissiveDefault()); + return compileParsedFile(fileId, ast, diagnostics, sourceKind, "", HostAdmissionContext.permissiveDefault()); } public IRBackendFile compileParsedFile( @@ -78,6 +82,16 @@ public final class PbsFrontendCompiler { final DiagnosticSink diagnostics, final SourceKind sourceKind, final HostAdmissionContext hostAdmissionContext) { + return compileParsedFile(fileId, ast, diagnostics, sourceKind, "", hostAdmissionContext); + } + + public IRBackendFile compileParsedFile( + final FileId fileId, + final PbsAst.File ast, + final DiagnosticSink diagnostics, + final SourceKind sourceKind, + final String moduleKey, + final HostAdmissionContext hostAdmissionContext) { final var semanticsErrorBaseline = diagnostics.errorCount(); declarationSemanticsValidator.validate(ast, sourceKind, diagnostics); flowSemanticsValidator.validate(ast, diagnostics); @@ -103,7 +117,10 @@ public final class PbsFrontendCompiler { final ReadOnlyList functions = sourceKind == SourceKind.SDK_INTERFACE ? ReadOnlyList.empty() : lowerFunctions(fileId, ast); - return new IRBackendFile(fileId, functions, reservedMetadata); + final ReadOnlyList executableFunctions = sourceKind == SourceKind.SDK_INTERFACE + ? ReadOnlyList.empty() + : lowerExecutableFunctions(fileId, ast, moduleKey, reservedMetadata); + return new IRBackendFile(fileId, functions, executableFunctions, reservedMetadata); } private ReadOnlyList lowerFunctions(final FileId fileId, final PbsAst.File ast) { @@ -118,4 +135,258 @@ public final class PbsFrontendCompiler { } return ReadOnlyList.wrap(functions); } + + private ReadOnlyList lowerExecutableFunctions( + final FileId fileId, + final PbsAst.File ast, + final String moduleKey, + final IRReservedMetadata reservedMetadata) { + final var hostByMethodName = new HashMap(); + for (final var hostBinding : reservedMetadata.hostMethodBindings()) { + hostByMethodName.put(hostBinding.sourceMethodName(), hostBinding); + } + final var intrinsicByMethodName = new HashMap(); + for (final var builtinType : reservedMetadata.builtinTypeSurfaces()) { + for (final var intrinsicSurface : builtinType.intrinsics()) { + intrinsicByMethodName.put(intrinsicSurface.sourceMethodName(), intrinsicSurface); + } + } + + final var executableFunctions = new ArrayList(ast.functions().size()); + for (final var fn : ast.functions()) { + final var instructions = new ArrayList(); + final var callsites = new ArrayList(); + collectCallsFromBlock(fn.body(), callsites); + for (final var callExpr : callsites) { + final var calleeName = extractSimpleCalleeName(callExpr.callee()); + final var host = hostByMethodName.get(calleeName); + if (host != null) { + instructions.add(new IRBackendExecutableFunction.Instruction( + IRBackendExecutableFunction.InstructionKind.CALL_HOST, + "", + "", + new IRBackendExecutableFunction.HostCallMetadata( + host.abiModule(), + host.abiMethod(), + host.abiVersion(), + callExpr.arguments().size(), + 0), + null, + callExpr.span())); + continue; + } + + final var intrinsic = intrinsicByMethodName.get(calleeName); + if (intrinsic != null) { + instructions.add(new IRBackendExecutableFunction.Instruction( + IRBackendExecutableFunction.InstructionKind.CALL_INTRINSIC, + "", + "", + null, + new IRBackendExecutableFunction.IntrinsicCallMetadata( + intrinsic.canonicalName(), + intrinsic.canonicalVersion(), + intrinsicIdFor(intrinsic.canonicalName(), intrinsic.canonicalVersion())), + callExpr.span())); + continue; + } + + instructions.add(new IRBackendExecutableFunction.Instruction( + IRBackendExecutableFunction.InstructionKind.CALL_FUNC, + moduleKey == null ? "" : moduleKey, + calleeName, + null, + null, + callExpr.span())); + } + instructions.add(new IRBackendExecutableFunction.Instruction( + IRBackendExecutableFunction.InstructionKind.RET, + "", + "", + null, + null, + fn.span())); + + final var returnSlots = switch (fn.returnKind()) { + case INFERRED_UNIT, EXPLICIT_UNIT -> 0; + case PLAIN, RESULT -> 1; + }; + final var start = safeToInt(fn.span().getStart()); + final var end = safeToInt(fn.span().getEnd()); + executableFunctions.add(new IRBackendExecutableFunction( + fileId, + moduleKey == null ? "" : moduleKey, + fn.name(), + start, + end, + fn.parameters().size(), + 0, + returnSlots, + Math.max(4, fn.parameters().size() + 2), + ReadOnlyList.wrap(instructions), + fn.span())); + } + return ReadOnlyList.wrap(executableFunctions); + } + + private int intrinsicIdFor( + final String canonicalName, + final long canonicalVersion) { + return (canonicalName + "#" + canonicalVersion).hashCode(); + } + + private int safeToInt(final long value) { + if (value > Integer.MAX_VALUE) { + return Integer.MAX_VALUE; + } + if (value < Integer.MIN_VALUE) { + return Integer.MIN_VALUE; + } + return (int) value; + } + + private void collectCallsFromBlock( + final PbsAst.Block block, + final List output) { + for (final var statement : block.statements()) { + collectCallsFromStatement(statement, output); + } + if (block.tailExpression() != null) { + collectCallsFromExpression(block.tailExpression(), output); + } + } + + private void collectCallsFromStatement( + final PbsAst.Statement statement, + final List output) { + switch (statement) { + case PbsAst.LetStatement letStatement -> collectCallsFromExpression(letStatement.initializer(), output); + case PbsAst.AssignStatement assignStatement -> collectCallsFromExpression(assignStatement.value(), output); + case PbsAst.ReturnStatement returnStatement -> { + if (returnStatement.value() != null) { + collectCallsFromExpression(returnStatement.value(), output); + } + } + case PbsAst.IfStatement ifStatement -> { + collectCallsFromExpression(ifStatement.condition(), output); + collectCallsFromBlock(ifStatement.thenBlock(), output); + if (ifStatement.elseIf() != null) { + collectCallsFromStatement(ifStatement.elseIf(), output); + } + if (ifStatement.elseBlock() != null) { + collectCallsFromBlock(ifStatement.elseBlock(), output); + } + } + case PbsAst.ForStatement forStatement -> { + collectCallsFromExpression(forStatement.fromExpression(), output); + collectCallsFromExpression(forStatement.untilExpression(), output); + if (forStatement.stepExpression() != null) { + collectCallsFromExpression(forStatement.stepExpression(), output); + } + collectCallsFromBlock(forStatement.body(), output); + } + case PbsAst.WhileStatement whileStatement -> { + collectCallsFromExpression(whileStatement.condition(), output); + collectCallsFromBlock(whileStatement.body(), output); + } + case PbsAst.ExpressionStatement expressionStatement -> + collectCallsFromExpression(expressionStatement.expression(), output); + case PbsAst.BreakStatement ignored -> { + } + case PbsAst.ContinueStatement ignored -> { + } + } + } + + private void collectCallsFromExpression( + final PbsAst.Expression expression, + final List output) { + switch (expression) { + case PbsAst.CallExpr callExpr -> { + output.add(callExpr); + collectCallsFromExpression(callExpr.callee(), output); + for (final var arg : callExpr.arguments()) { + collectCallsFromExpression(arg, output); + } + } + case PbsAst.ApplyExpr applyExpr -> { + collectCallsFromExpression(applyExpr.callee(), output); + collectCallsFromExpression(applyExpr.argument(), output); + } + case PbsAst.BinaryExpr binaryExpr -> { + collectCallsFromExpression(binaryExpr.left(), output); + collectCallsFromExpression(binaryExpr.right(), output); + } + case PbsAst.UnaryExpr unaryExpr -> collectCallsFromExpression(unaryExpr.expression(), output); + case PbsAst.ElseExpr elseExpr -> { + collectCallsFromExpression(elseExpr.optionalExpression(), output); + collectCallsFromExpression(elseExpr.fallbackExpression(), output); + } + case PbsAst.IfExpr ifExpr -> { + collectCallsFromExpression(ifExpr.condition(), output); + collectCallsFromBlock(ifExpr.thenBlock(), output); + collectCallsFromExpression(ifExpr.elseExpression(), output); + } + case PbsAst.SwitchExpr switchExpr -> { + collectCallsFromExpression(switchExpr.selector(), output); + for (final var arm : switchExpr.arms()) { + collectCallsFromBlock(arm.block(), output); + } + } + case PbsAst.HandleExpr handleExpr -> { + collectCallsFromExpression(handleExpr.value(), output); + for (final var arm : handleExpr.arms()) { + collectCallsFromBlock(arm.block(), output); + } + } + case PbsAst.AsExpr asExpr -> collectCallsFromExpression(asExpr.expression(), output); + case PbsAst.MemberExpr memberExpr -> collectCallsFromExpression(memberExpr.receiver(), output); + case PbsAst.PropagateExpr propagateExpr -> collectCallsFromExpression(propagateExpr.expression(), output); + case PbsAst.GroupExpr groupExpr -> collectCallsFromExpression(groupExpr.expression(), output); + case PbsAst.NewExpr newExpr -> { + for (final var arg : newExpr.arguments()) { + collectCallsFromExpression(arg, output); + } + } + case PbsAst.BindExpr bindExpr -> collectCallsFromExpression(bindExpr.contextExpression(), output); + case PbsAst.SomeExpr someExpr -> collectCallsFromExpression(someExpr.value(), output); + case PbsAst.OkExpr okExpr -> collectCallsFromExpression(okExpr.value(), output); + case PbsAst.TupleExpr tupleExpr -> { + for (final var item : tupleExpr.items()) { + collectCallsFromExpression(item.expression(), output); + } + } + case PbsAst.BlockExpr blockExpr -> collectCallsFromBlock(blockExpr.block(), output); + case PbsAst.IdentifierExpr ignored -> { + } + case PbsAst.IntLiteralExpr ignored -> { + } + case PbsAst.FloatLiteralExpr ignored -> { + } + case PbsAst.BoundedLiteralExpr ignored -> { + } + case PbsAst.StringLiteralExpr ignored -> { + } + case PbsAst.BoolLiteralExpr ignored -> { + } + case PbsAst.ThisExpr ignored -> { + } + case PbsAst.NoneExpr ignored -> { + } + case PbsAst.ErrExpr ignored -> { + } + case PbsAst.UnitExpr ignored -> { + } + } + } + + private String extractSimpleCalleeName(final PbsAst.Expression callee) { + return switch (callee) { + case PbsAst.IdentifierExpr identifierExpr -> identifierExpr.name(); + case PbsAst.MemberExpr memberExpr -> memberExpr.memberName(); + case PbsAst.BindExpr bindExpr -> bindExpr.functionName(); + case PbsAst.GroupExpr groupExpr -> extractSimpleCalleeName(groupExpr.expression()); + default -> ""; + }; + } } diff --git a/prometeu-compiler/frontends/prometeu-frontend-pbs/src/main/java/p/studio/compiler/services/PBSFrontendPhaseService.java b/prometeu-compiler/frontends/prometeu-frontend-pbs/src/main/java/p/studio/compiler/services/PBSFrontendPhaseService.java index ccda745e..c0aee957 100644 --- a/prometeu-compiler/frontends/prometeu-frontend-pbs/src/main/java/p/studio/compiler/services/PBSFrontendPhaseService.java +++ b/prometeu-compiler/frontends/prometeu-frontend-pbs/src/main/java/p/studio/compiler/services/PBSFrontendPhaseService.java @@ -148,6 +148,7 @@ public class PBSFrontendPhaseService implements FrontendPhaseService { parsedSource.ast(), diagnostics, parsedSource.sourceKind(), + parsedSource.moduleKey(), ctx.hostAdmissionContext()); if (diagnostics.errorCount() > compileErrorBaseline) { failedModuleKeys.add(parsedSource.moduleKey()); diff --git a/prometeu-compiler/frontends/prometeu-frontend-pbs/src/test/java/p/studio/compiler/pbs/PbsFrontendCompilerTest.java b/prometeu-compiler/frontends/prometeu-frontend-pbs/src/test/java/p/studio/compiler/pbs/PbsFrontendCompilerTest.java index f30a4a6c..60fc11c4 100644 --- a/prometeu-compiler/frontends/prometeu-frontend-pbs/src/test/java/p/studio/compiler/pbs/PbsFrontendCompilerTest.java +++ b/prometeu-compiler/frontends/prometeu-frontend-pbs/src/test/java/p/studio/compiler/pbs/PbsFrontendCompilerTest.java @@ -72,6 +72,32 @@ class PbsFrontendCompilerTest { assertEquals(2, fileBackend.functions().size()); } + @Test + void shouldLowerExecutableFunctionsWithCallsiteCategories() { + final var source = """ + fn b(x: int) -> int { + return x; + } + + fn a() -> int { + return b(1); + } + """; + + final var diagnostics = DiagnosticSink.empty(); + final var compiler = new PbsFrontendCompiler(); + final var fileBackend = compiler.compileFile(new FileId(42), source, diagnostics); + + assertTrue(diagnostics.isEmpty(), "Valid program should not report diagnostics"); + assertEquals(2, fileBackend.executableFunctions().size()); + final var executableA = fileBackend.executableFunctions().stream() + .filter(fn -> fn.callableName().equals("a")) + .findFirst() + .orElseThrow(); + assertTrue(executableA.instructions().stream().anyMatch(i -> + i.kind() == p.studio.compiler.models.IRBackendExecutableFunction.InstructionKind.CALL_FUNC)); + } + @Test void shouldNotLowerWhenSyntaxErrorsExist() { final var source = """