diff --git a/prometeu-compiler/frontends/prometeu-frontend-pbs/src/main/java/p/studio/compiler/pbs/semantics/PbsFlowBodyAnalyzer.java b/prometeu-compiler/frontends/prometeu-frontend-pbs/src/main/java/p/studio/compiler/pbs/semantics/PbsFlowBodyAnalyzer.java index e59253ae..633b2324 100644 --- a/prometeu-compiler/frontends/prometeu-frontend-pbs/src/main/java/p/studio/compiler/pbs/semantics/PbsFlowBodyAnalyzer.java +++ b/prometeu-compiler/frontends/prometeu-frontend-pbs/src/main/java/p/studio/compiler/pbs/semantics/PbsFlowBodyAnalyzer.java @@ -146,18 +146,14 @@ final class PbsFlowBodyAnalyzer { } if (statement instanceof PbsAst.ReturnStatement returnStatement) { if (returnStatement.value() != null) { - expressionAnalyzer.analyzeExpression( + analyzeReturnStatement( returnStatement.value(), scope, returnType, - returnType, resultErrorName, receiverType, model, - diagnostics, - ExprUse.VALUE, - true, - this::analyzeBlock); + diagnostics); } return; } @@ -292,6 +288,128 @@ final class PbsFlowBodyAnalyzer { } } + private void analyzeReturnStatement( + final PbsAst.Expression value, + final Scope scope, + final TypeView returnType, + final String resultErrorName, + final TypeView receiverType, + final Model model, + final DiagnosticSink diagnostics) { + final var root = unwrapGroup(value); + if (root instanceof PbsAst.OkExpr okExpr) { + analyzeReturnOk(okExpr, scope, returnType, resultErrorName, receiverType, model, diagnostics); + return; + } + if (root instanceof PbsAst.ErrExpr errExpr) { + analyzeReturnErr(errExpr, returnType, resultErrorName, model, diagnostics); + return; + } + expressionAnalyzer.analyzeExpression( + value, + scope, + returnType, + returnType, + resultErrorName, + receiverType, + model, + diagnostics, + ExprUse.VALUE, + true, + this::analyzeBlock); + } + + private void analyzeReturnOk( + final PbsAst.OkExpr okExpr, + final Scope scope, + final TypeView returnType, + final String resultErrorName, + final TypeView receiverType, + final Model model, + final DiagnosticSink diagnostics) { + if (returnType.kind() != PbsFlowSemanticSupport.Kind.RESULT || resultErrorName == null) { + p.studio.compiler.source.diagnostics.Diagnostics.error(diagnostics, + PbsSemanticsErrors.E_SEM_RESULT_FLOW_INVALID_POSITION.name(), + "'ok(...)' is only allowed when returning from a result callable", + okExpr.span()); + expressionAnalyzer.analyzeExpression( + okExpr.value(), + scope, + null, + returnType, + resultErrorName, + receiverType, + model, + diagnostics, + ExprUse.VALUE, + true, + this::analyzeBlock); + return; + } + + final var payloadType = returnType.inner(); + final var actualType = expressionAnalyzer.analyzeExpression( + okExpr.value(), + scope, + payloadType, + returnType, + resultErrorName, + receiverType, + model, + diagnostics, + ExprUse.VALUE, + true, + this::analyzeBlock).type(); + if (!typeOps.compatible(actualType, payloadType)) { + p.studio.compiler.source.diagnostics.Diagnostics.error(diagnostics, + PbsSemanticsErrors.E_SEM_RESULT_OK_PAYLOAD_MISMATCH.name(), + "Payload in 'ok(...)' is incompatible with result payload type", + okExpr.value().span()); + } + } + + private void analyzeReturnErr( + final PbsAst.ErrExpr errExpr, + final TypeView returnType, + final String resultErrorName, + final Model model, + final DiagnosticSink diagnostics) { + if (returnType.kind() != PbsFlowSemanticSupport.Kind.RESULT || resultErrorName == null) { + p.studio.compiler.source.diagnostics.Diagnostics.error(diagnostics, + PbsSemanticsErrors.E_SEM_RESULT_FLOW_INVALID_POSITION.name(), + "'err(...)' is only allowed when returning from a result callable", + errExpr.span()); + return; + } + + if (!matchesTargetError(errExpr.errorPath(), resultErrorName, model)) { + p.studio.compiler.source.diagnostics.Diagnostics.error(diagnostics, + PbsSemanticsErrors.E_SEM_RESULT_ERROR_LABEL_INVALID.name(), + "Error label in 'err(...)' does not match enclosing result error type", + errExpr.errorPath().span()); + } + } + + private PbsAst.Expression unwrapGroup(final PbsAst.Expression expression) { + if (expression instanceof PbsAst.GroupExpr groupExpr) { + return unwrapGroup(groupExpr.expression()); + } + return expression; + } + + private boolean matchesTargetError( + final PbsAst.ErrorPath path, + final String resultErrorName, + final Model model) { + if (path == null || resultErrorName == null || path.segments().size() != 2) { + return false; + } + final var errorName = path.segments().getFirst(); + final var caseName = path.segments().get(1); + final var targetCases = model.errors.get(errorName); + return resultErrorName.equals(errorName) && targetCases != null && targetCases.contains(caseName); + } + private void analyzeAssignmentStatement( final PbsAst.AssignStatement assignStatement, final Scope scope, diff --git a/prometeu-compiler/frontends/prometeu-frontend-pbs/src/main/java/p/studio/compiler/pbs/semantics/PbsFlowExpressionAnalyzer.java b/prometeu-compiler/frontends/prometeu-frontend-pbs/src/main/java/p/studio/compiler/pbs/semantics/PbsFlowExpressionAnalyzer.java index f53bcb82..2d1d4c0f 100644 --- a/prometeu-compiler/frontends/prometeu-frontend-pbs/src/main/java/p/studio/compiler/pbs/semantics/PbsFlowExpressionAnalyzer.java +++ b/prometeu-compiler/frontends/prometeu-frontend-pbs/src/main/java/p/studio/compiler/pbs/semantics/PbsFlowExpressionAnalyzer.java @@ -509,6 +509,10 @@ final class PbsFlowExpressionAnalyzer { } if (expression instanceof PbsAst.OkExpr okExpr) { + p.studio.compiler.source.diagnostics.Diagnostics.error(diagnostics, + PbsSemanticsErrors.E_SEM_RESULT_FLOW_INVALID_POSITION.name(), + "'ok(...)' is only allowed in result return flow and handle arm terminals", + okExpr.span()); analyzeExpressionInternal( okExpr.value(), scope, @@ -523,6 +527,10 @@ final class PbsFlowExpressionAnalyzer { return ExprResult.type(TypeView.unknown()); } if (expression instanceof PbsAst.ErrExpr errExpr) { + p.studio.compiler.source.diagnostics.Diagnostics.error(diagnostics, + PbsSemanticsErrors.E_SEM_RESULT_FLOW_INVALID_POSITION.name(), + "'err(...)' is only allowed in result return flow and handle arm terminals", + errExpr.span()); return ExprResult.type(TypeView.unknown()); } @@ -1062,6 +1070,7 @@ final class PbsFlowExpressionAnalyzer { final var sourceErrorName = sourceType.errorType() == null ? null : sourceType.errorType().name(); final var sourceCases = sourceErrorName == null ? Set.of() : model.errors.getOrDefault(sourceErrorName, Set.of()); + final var sourcePayloadType = sourceType.inner() == null ? TypeView.unknown() : sourceType.inner(); final var matchedCases = new HashSet(); var hasWildcard = false; @@ -1096,7 +1105,15 @@ final class PbsFlowExpressionAnalyzer { } continue; } - blockAnalyzer.analyze(arm.block(), scope, returnType, resultErrorName, receiverType, model, diagnostics, true); + analyzeHandleBlockArm( + arm, + scope, + sourcePayloadType, + returnType, + resultErrorName, + receiverType, + model, + diagnostics); } if (!hasWildcard && !sourceCases.isEmpty() && !matchedCases.containsAll(sourceCases)) { @@ -1109,6 +1126,93 @@ final class PbsFlowExpressionAnalyzer { return sourceType.inner() == null ? TypeView.unknown() : sourceType.inner(); } + private void analyzeHandleBlockArm( + final PbsAst.HandleArm arm, + final Scope scope, + final TypeView sourcePayloadType, + final TypeView returnType, + final String resultErrorName, + final TypeView receiverType, + final Model model, + final DiagnosticSink diagnostics) { + if (arm.block() == null) { + p.studio.compiler.source.diagnostics.Diagnostics.error(diagnostics, + PbsSemanticsErrors.E_SEM_HANDLE_ARM_TERMINAL_INVALID.name(), + "Handle block arm must terminate with 'ok(...)' or 'err(E.case)'", + arm.span()); + return; + } + + final var block = arm.block(); + final var terminal = unwrapGroup(block.tailExpression()); + if (terminal instanceof PbsAst.OkExpr okExpr) { + final var payloadBlock = new PbsAst.Block(block.statements(), okExpr.value(), block.span()); + final var actualPayloadType = blockAnalyzer.analyze( + payloadBlock, + scope, + returnType, + resultErrorName, + receiverType, + model, + diagnostics, + true); + if (!typeOps.compatible(actualPayloadType, sourcePayloadType)) { + p.studio.compiler.source.diagnostics.Diagnostics.error(diagnostics, + PbsSemanticsErrors.E_SEM_RESULT_OK_PAYLOAD_MISMATCH.name(), + "Handle arm 'ok(...)' payload is incompatible with source result payload type", + okExpr.value().span()); + } + return; + } + + if (terminal instanceof PbsAst.ErrExpr errExpr) { + final var statementsOnly = new PbsAst.Block(block.statements(), null, block.span()); + blockAnalyzer.analyze( + statementsOnly, + scope, + returnType, + resultErrorName, + receiverType, + model, + diagnostics, + false); + if (resultErrorName != null && !matchesTargetError(errExpr.errorPath(), resultErrorName, model)) { + p.studio.compiler.source.diagnostics.Diagnostics.error(diagnostics, + PbsSemanticsErrors.E_SEM_RESULT_ERROR_LABEL_INVALID.name(), + "Error label in handle arm 'err(...)' does not match enclosing result error type", + errExpr.errorPath().span()); + } + return; + } + + if (block.tailExpression() == null) { + final var statementsOnly = new PbsAst.Block(block.statements(), null, block.span()); + blockAnalyzer.analyze( + statementsOnly, + scope, + returnType, + resultErrorName, + receiverType, + model, + diagnostics, + false); + } else { + blockAnalyzer.analyze( + block, + scope, + returnType, + resultErrorName, + receiverType, + model, + diagnostics, + true); + } + p.studio.compiler.source.diagnostics.Diagnostics.error(diagnostics, + PbsSemanticsErrors.E_SEM_HANDLE_ARM_TERMINAL_INVALID.name(), + "Handle block arm must terminate with 'ok(...)' or 'err(E.case)'", + block.span()); + } + private boolean matchesTargetError( final PbsAst.ErrorPath path, final String resultErrorName, @@ -1122,4 +1226,14 @@ final class PbsFlowExpressionAnalyzer { return resultErrorName.equals(errorName) && targetCases != null && targetCases.contains(caseName); } + private PbsAst.Expression unwrapGroup(final PbsAst.Expression expression) { + if (expression == null) { + return null; + } + if (expression instanceof PbsAst.GroupExpr groupExpr) { + return unwrapGroup(groupExpr.expression()); + } + return expression; + } + } diff --git a/prometeu-compiler/frontends/prometeu-frontend-pbs/src/main/java/p/studio/compiler/pbs/semantics/PbsSemanticsErrors.java b/prometeu-compiler/frontends/prometeu-frontend-pbs/src/main/java/p/studio/compiler/pbs/semantics/PbsSemanticsErrors.java index 067d81fc..d6c3648c 100644 --- a/prometeu-compiler/frontends/prometeu-frontend-pbs/src/main/java/p/studio/compiler/pbs/semantics/PbsSemanticsErrors.java +++ b/prometeu-compiler/frontends/prometeu-frontend-pbs/src/main/java/p/studio/compiler/pbs/semantics/PbsSemanticsErrors.java @@ -41,8 +41,12 @@ public enum PbsSemanticsErrors { E_SEM_NONE_WITHOUT_EXPECTED_OPTIONAL, E_SEM_ELSE_NON_OPTIONAL_LEFT, E_SEM_ELSE_FALLBACK_TYPE_MISMATCH, + E_SEM_RESULT_FLOW_INVALID_POSITION, + E_SEM_RESULT_OK_PAYLOAD_MISMATCH, + E_SEM_RESULT_ERROR_LABEL_INVALID, E_SEM_RESULT_PROPAGATE_NON_RESULT, E_SEM_RESULT_PROPAGATE_ERROR_MISMATCH, E_SEM_HANDLE_NON_RESULT, E_SEM_HANDLE_ERROR_MISMATCH, + E_SEM_HANDLE_ARM_TERMINAL_INVALID, } diff --git a/prometeu-compiler/frontends/prometeu-frontend-pbs/src/test/java/p/studio/compiler/pbs/semantics/PbsSemanticsResultFlowRulesTest.java b/prometeu-compiler/frontends/prometeu-frontend-pbs/src/test/java/p/studio/compiler/pbs/semantics/PbsSemanticsResultFlowRulesTest.java new file mode 100644 index 00000000..7cb66df8 --- /dev/null +++ b/prometeu-compiler/frontends/prometeu-frontend-pbs/src/test/java/p/studio/compiler/pbs/semantics/PbsSemanticsResultFlowRulesTest.java @@ -0,0 +1,133 @@ +package p.studio.compiler.pbs.semantics; + +import org.junit.jupiter.api.Test; +import p.studio.compiler.pbs.PbsFrontendCompiler; +import p.studio.compiler.source.diagnostics.DiagnosticSink; +import p.studio.compiler.source.identifiers.FileId; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class PbsSemanticsResultFlowRulesTest { + + @Test + void shouldAcceptValidResultReturnAndHandleArmTerminals() { + final var source = """ + declare error InErr { A; B; } + declare error OutErr { Retry; Abort; } + + fn inner(v: int) -> result int { + if v > 0 { return ok(v); } + return err(InErr.A); + } + + fn wrap(v: int) -> result int { + let recovered: int = handle inner(v) { + InErr.A -> { ok(1) }, + InErr.B -> { err(OutErr.Retry) } + }; + return ok(recovered); + } + """; + final var diagnostics = DiagnosticSink.empty(); + + new PbsFrontendCompiler().compileFile(new FileId(0), source, diagnostics); + + assertFalse(diagnostics.stream().anyMatch(d -> + d.getCode().equals(PbsSemanticsErrors.E_SEM_RESULT_FLOW_INVALID_POSITION.name()))); + assertFalse(diagnostics.stream().anyMatch(d -> + d.getCode().equals(PbsSemanticsErrors.E_SEM_RESULT_OK_PAYLOAD_MISMATCH.name()))); + assertFalse(diagnostics.stream().anyMatch(d -> + d.getCode().equals(PbsSemanticsErrors.E_SEM_RESULT_ERROR_LABEL_INVALID.name()))); + assertFalse(diagnostics.stream().anyMatch(d -> + d.getCode().equals(PbsSemanticsErrors.E_SEM_HANDLE_ARM_TERMINAL_INVALID.name()))); + } + + @Test + void shouldRejectOkAndErrOutsideAllowedResultFlowPositions() { + final var source = """ + declare error Err { Fail; } + + fn bad(v: int) -> int { + ok(v); + err(Err.Fail); + return v; + } + """; + final var diagnostics = DiagnosticSink.empty(); + + new PbsFrontendCompiler().compileFile(new FileId(0), source, diagnostics); + + assertTrue(diagnostics.stream().anyMatch(d -> + d.getCode().equals(PbsSemanticsErrors.E_SEM_RESULT_FLOW_INVALID_POSITION.name()))); + } + + @Test + void shouldRejectMismatchedResultPayloadAndErrorLabels() { + final var source = """ + declare error ErrA { Fail; } + declare error ErrB { Stop; } + + fn badPayload() -> result int { + return ok("oops"); + } + + fn badLabel() -> result int { + return err(ErrB.Stop); + } + """; + final var diagnostics = DiagnosticSink.empty(); + + new PbsFrontendCompiler().compileFile(new FileId(0), source, diagnostics); + + assertTrue(diagnostics.stream().anyMatch(d -> + d.getCode().equals(PbsSemanticsErrors.E_SEM_RESULT_OK_PAYLOAD_MISMATCH.name()))); + assertTrue(diagnostics.stream().anyMatch(d -> + d.getCode().equals(PbsSemanticsErrors.E_SEM_RESULT_ERROR_LABEL_INVALID.name()))); + } + + @Test + void shouldRejectInvalidHandleBlockTerminalAndInvalidHandleTerminalForms() { + final var source = """ + declare error InErr { Fail; } + declare error OutErr { Retry; Abort; } + declare error OtherErr { Oops; } + + fn source() -> result int { return err(InErr.Fail); } + + fn invalidTerminal() -> result int { + let value: int = handle source() { + InErr.Fail -> { 1 } + }; + return ok(value); + } + + fn badHandleOkPayload() -> result int { + let value: int = handle source() { + InErr.Fail -> { + let payload: str = "x"; + ok(payload) + } + }; + return ok(value); + } + + fn badHandleErrLabel() -> result int { + let value: int = handle source() { + InErr.Fail -> { err(OtherErr.Oops) } + }; + return ok(value); + } + """; + final var diagnostics = DiagnosticSink.empty(); + + new PbsFrontendCompiler().compileFile(new FileId(0), source, diagnostics); + + assertTrue(diagnostics.stream().anyMatch(d -> + d.getCode().equals(PbsSemanticsErrors.E_SEM_HANDLE_ARM_TERMINAL_INVALID.name()))); + assertTrue(diagnostics.stream().anyMatch(d -> + d.getCode().equals(PbsSemanticsErrors.E_SEM_RESULT_OK_PAYLOAD_MISMATCH.name()))); + assertTrue(diagnostics.stream().anyMatch(d -> + d.getCode().equals(PbsSemanticsErrors.E_SEM_RESULT_ERROR_LABEL_INVALID.name()))); + } +}