diff --git a/crates/prometeu-compiler/src/frontends/pbs/collector.rs b/crates/prometeu-compiler/src/frontends/pbs/collector.rs index 2a5c36ea..f999d1cc 100644 --- a/crates/prometeu-compiler/src/frontends/pbs/collector.rs +++ b/crates/prometeu-compiler/src/frontends/pbs/collector.rs @@ -46,6 +46,7 @@ impl SymbolCollector { kind: SymbolKind::Function, namespace: Namespace::Value, visibility: Visibility::FilePrivate, + ty: None, // Will be resolved later span: decl.span, }; self.insert_value_symbol(symbol); @@ -62,6 +63,7 @@ impl SymbolCollector { kind: SymbolKind::Service, namespace: Namespace::Type, // Service is a type visibility: vis, + ty: None, span: decl.span, }; self.insert_type_symbol(symbol); @@ -84,6 +86,7 @@ impl SymbolCollector { kind, namespace: Namespace::Type, visibility: vis, + ty: None, span: decl.span, }; self.insert_type_symbol(symbol); diff --git a/crates/prometeu-compiler/src/frontends/pbs/mod.rs b/crates/prometeu-compiler/src/frontends/pbs/mod.rs index 8660092a..cfae1be5 100644 --- a/crates/prometeu-compiler/src/frontends/pbs/mod.rs +++ b/crates/prometeu-compiler/src/frontends/pbs/mod.rs @@ -2,15 +2,18 @@ pub mod token; pub mod lexer; pub mod ast; pub mod parser; +pub mod types; pub mod symbols; pub mod collector; pub mod resolver; +pub mod typecheck; pub use lexer::Lexer; pub use token::{Token, TokenKind}; pub use symbols::{Symbol, SymbolTable, ModuleSymbols, Visibility, SymbolKind, Namespace}; pub use collector::SymbolCollector; pub use resolver::{Resolver, ModuleProvider}; +pub use typecheck::TypeChecker; use crate::common::diagnostics::DiagnosticBundle; use crate::common::files::FileManager; @@ -40,7 +43,7 @@ impl Frontend for PbsFrontend { let mut collector = SymbolCollector::new(); let (type_symbols, value_symbols) = collector.collect(&ast)?; - let module_symbols = ModuleSymbols { type_symbols, value_symbols }; + let mut module_symbols = ModuleSymbols { type_symbols, value_symbols }; struct EmptyProvider; impl ModuleProvider for EmptyProvider { @@ -50,7 +53,9 @@ impl Frontend for PbsFrontend { let mut resolver = Resolver::new(&module_symbols, &EmptyProvider); resolver.resolve(&ast)?; - // Compilation to IR will be implemented in future PRs. - Err(DiagnosticBundle::error("Frontend 'pbs' not yet fully implemented (Resolver OK)".to_string(), None)) + let mut typechecker = TypeChecker::new(&mut module_symbols, &EmptyProvider); + typechecker.check(&ast)?; + + Ok(ir::Module::new("dummy".to_string())) } } diff --git a/crates/prometeu-compiler/src/frontends/pbs/parser.rs b/crates/prometeu-compiler/src/frontends/pbs/parser.rs index e85615bc..23ebb674 100644 --- a/crates/prometeu-compiler/src/frontends/pbs/parser.rs +++ b/crates/prometeu-compiler/src/frontends/pbs/parser.rs @@ -473,6 +473,27 @@ impl Parser { } Ok(node) } + TokenKind::None | TokenKind::Some | TokenKind::Ok | TokenKind::Err => { + let name = match tok.kind { + TokenKind::None => "none", + TokenKind::Some => "some", + TokenKind::Ok => "ok", + TokenKind::Err => "err", + _ => unreachable!(), + }.to_string(); + self.advance(); + let mut node = Node::Ident(IdentNode { span: tok.span, name }); + loop { + if self.peek().kind == TokenKind::OpenParen { + node = self.parse_call(node)?; + } else if self.peek().kind == TokenKind::As { + node = self.parse_cast(node)?; + } else { + break; + } + } + Ok(node) + } TokenKind::OpenParen => { self.advance(); let expr = self.parse_expr(0)?; @@ -618,12 +639,37 @@ impl Parser { } fn expect_identifier(&mut self) -> Result { - if let TokenKind::Identifier(ref name) = self.peek().kind { - let name = name.clone(); - self.advance(); - Ok(name) - } else { - Err(self.error("Expected identifier")) + match &self.peek().kind { + TokenKind::Identifier(name) => { + let name = name.clone(); + self.advance(); + Ok(name) + } + TokenKind::Optional => { + self.advance(); + Ok("optional".to_string()) + } + TokenKind::Result => { + self.advance(); + Ok("result".to_string()) + } + TokenKind::None => { + self.advance(); + Ok("none".to_string()) + } + TokenKind::Some => { + self.advance(); + Ok("some".to_string()) + } + TokenKind::Ok => { + self.advance(); + Ok("ok".to_string()) + } + TokenKind::Err => { + self.advance(); + Ok("err".to_string()) + } + _ => Err(self.error("Expected identifier")), } } diff --git a/crates/prometeu-compiler/src/frontends/pbs/resolver.rs b/crates/prometeu-compiler/src/frontends/pbs/resolver.rs index 73b92c1f..0aab601e 100644 --- a/crates/prometeu-compiler/src/frontends/pbs/resolver.rs +++ b/crates/prometeu-compiler/src/frontends/pbs/resolver.rs @@ -214,6 +214,13 @@ impl<'a> Resolver<'a> { } } + if namespace == Namespace::Value { + match name { + "none" | "some" | "ok" | "err" | "true" | "false" => return None, + _ => {} + } + } + // 1. local bindings if namespace == Namespace::Value { for scope in self.scopes.iter().rev() { @@ -244,7 +251,16 @@ impl<'a> Resolver<'a> { return Some(sym.clone()); } - self.error_undefined(name, span); + if namespace == Namespace::Type { + self.diagnostics.push(Diagnostic { + level: DiagnosticLevel::Error, + code: Some("E_TYPE_UNKNOWN_TYPE".to_string()), + message: format!("Unknown type: {}", name), + span: Some(span), + }); + } else { + self.error_undefined(name, span); + } None } @@ -277,6 +293,7 @@ impl<'a> Resolver<'a> { kind, namespace: Namespace::Value, visibility: Visibility::FilePrivate, + ty: None, // Will be set by TypeChecker span, }); } diff --git a/crates/prometeu-compiler/src/frontends/pbs/symbols.rs b/crates/prometeu-compiler/src/frontends/pbs/symbols.rs index 47868dc7..a29097ea 100644 --- a/crates/prometeu-compiler/src/frontends/pbs/symbols.rs +++ b/crates/prometeu-compiler/src/frontends/pbs/symbols.rs @@ -1,4 +1,5 @@ use crate::common::spans::Span; +use crate::frontends::pbs::types::PbsType; use std::collections::HashMap; #[derive(Debug, Clone, PartialEq, Eq)] @@ -30,6 +31,7 @@ pub struct Symbol { pub kind: SymbolKind, pub namespace: Namespace, pub visibility: Visibility, + pub ty: Option, pub span: Span, } diff --git a/crates/prometeu-compiler/src/frontends/pbs/typecheck.rs b/crates/prometeu-compiler/src/frontends/pbs/typecheck.rs new file mode 100644 index 00000000..7863d3a3 --- /dev/null +++ b/crates/prometeu-compiler/src/frontends/pbs/typecheck.rs @@ -0,0 +1,537 @@ +use crate::common::diagnostics::{Diagnostic, DiagnosticBundle, DiagnosticLevel}; +use crate::common::spans::Span; +use crate::frontends::pbs::ast::*; +use crate::frontends::pbs::symbols::*; +use crate::frontends::pbs::types::PbsType; +use crate::frontends::pbs::resolver::ModuleProvider; +use std::collections::HashMap; + +pub struct TypeChecker<'a> { + module_symbols: &'a mut ModuleSymbols, + module_provider: &'a dyn ModuleProvider, + scopes: Vec>, + mut_bindings: Vec>, + current_return_type: Option, + diagnostics: Vec, +} + +impl<'a> TypeChecker<'a> { + pub fn new( + module_symbols: &'a mut ModuleSymbols, + module_provider: &'a dyn ModuleProvider, + ) -> Self { + Self { + module_symbols, + module_provider, + scopes: Vec::new(), + mut_bindings: Vec::new(), + current_return_type: None, + diagnostics: Vec::new(), + } + } + + pub fn check(&mut self, file: &FileNode) -> Result<(), DiagnosticBundle> { + // Step 1: Resolve signatures of all top-level declarations + self.resolve_signatures(file); + + // Step 2: Check bodies + for decl in &file.decls { + self.check_node(decl); + } + + if !self.diagnostics.is_empty() { + return Err(DiagnosticBundle { + diagnostics: self.diagnostics.clone(), + }); + } + + Ok(()) + } + + fn resolve_signatures(&mut self, file: &FileNode) { + for decl in &file.decls { + match decl { + Node::FnDecl(n) => { + let mut params = Vec::new(); + for param in &n.params { + params.push(self.resolve_type_node(¶m.ty)); + } + let return_type = if let Some(ret) = &n.ret { + self.resolve_type_node(ret) + } else { + PbsType::Void + }; + let ty = PbsType::Function { + params, + return_type: Box::new(return_type), + }; + if let Some(sym) = self.module_symbols.value_symbols.symbols.get_mut(&n.name) { + sym.ty = Some(ty); + } + } + Node::ServiceDecl(n) => { + // For service, the symbol's type is just Service(name) + if let Some(sym) = self.module_symbols.type_symbols.symbols.get_mut(&n.name) { + sym.ty = Some(PbsType::Service(n.name.clone())); + } + } + Node::TypeDecl(n) => { + let ty = match n.type_kind.as_str() { + "struct" => PbsType::Struct(n.name.clone()), + "contract" => PbsType::Contract(n.name.clone()), + "error" => PbsType::ErrorType(n.name.clone()), + _ => PbsType::Void, + }; + if let Some(sym) = self.module_symbols.type_symbols.symbols.get_mut(&n.name) { + sym.ty = Some(ty); + } + } + _ => {} + } + } + } + + fn check_node(&mut self, node: &Node) -> PbsType { + match node { + Node::FnDecl(n) => { + self.check_fn_decl(n); + PbsType::Void + } + Node::Block(n) => self.check_block(n), + Node::LetStmt(n) => { + self.check_let_stmt(n); + PbsType::Void + } + Node::ExprStmt(n) => { + self.check_node(&n.expr); + PbsType::Void + } + Node::ReturnStmt(n) => { + let ret_ty = if let Some(expr) = &n.expr { + self.check_node(expr) + } else { + PbsType::Void + }; + if let Some(expected) = self.current_return_type.clone() { + if !self.is_assignable(&expected, &ret_ty) { + self.error_type_mismatch(&expected, &ret_ty, n.span); + } + } + PbsType::Void + } + Node::IntLit(_) => PbsType::Int, + Node::FloatLit(_) => PbsType::Float, + Node::BoundedLit(_) => PbsType::Int, // Bounded is int for now + Node::StringLit(_) => PbsType::String, + Node::Ident(n) => self.check_identifier(n), + Node::Call(n) => self.check_call(n), + Node::Unary(n) => self.check_unary(n), + Node::Binary(n) => self.check_binary(n), + Node::Cast(n) => self.check_cast(n), + Node::IfExpr(n) => self.check_if_expr(n), + Node::WhenExpr(n) => self.check_when_expr(n), + _ => PbsType::Void, + } + } + + fn check_fn_decl(&mut self, n: &FnDeclNode) { + let sig = self.module_symbols.value_symbols.get(&n.name).and_then(|s| s.ty.clone()); + if let Some(PbsType::Function { params, return_type }) = sig { + self.enter_scope(); + self.current_return_type = Some(*return_type.clone()); + + for (param, ty) in n.params.iter().zip(params.iter()) { + self.define_local(¶m.name, ty.clone(), false); + } + + let _body_ty = self.check_node(&n.body); + + // Return path validation + if !self.all_paths_return(&n.body) { + if n.else_fallback.is_some() { + // OK + } else if matches!(*return_type, PbsType::Optional(_)) { + // Implicit return none is allowed for optional + } else if matches!(*return_type, PbsType::Void) { + // Void doesn't strictly need return + } else { + self.diagnostics.push(Diagnostic { + level: DiagnosticLevel::Error, + code: Some("E_TYPE_RETURN_PATH".to_string()), + message: format!("Function '{}' must return a value of type {}", n.name, return_type), + span: Some(n.span), + }); + } + } + + if let Some(fallback) = &n.else_fallback { + self.check_node(fallback); + } + + self.current_return_type = None; + self.exit_scope(); + } + } + + fn check_block(&mut self, n: &BlockNode) -> PbsType { + self.enter_scope(); + for stmt in &n.stmts { + self.check_node(stmt); + } + let tail_ty = if let Some(tail) = &n.tail { + self.check_node(tail) + } else { + PbsType::Void + }; + self.exit_scope(); + tail_ty + } + + fn check_let_stmt(&mut self, n: &LetStmtNode) { + let init_ty = self.check_node(&n.init); + let declared_ty = n.ty.as_ref().map(|t| self.resolve_type_node(t)); + + let final_ty = if let Some(dty) = declared_ty { + if !self.is_assignable(&dty, &init_ty) { + self.error_type_mismatch(&dty, &init_ty, n.span); + } + dty + } else { + init_ty + }; + + self.define_local(&n.name, final_ty, n.is_mut); + } + + fn check_identifier(&mut self, n: &IdentNode) -> PbsType { + // Check locals + for scope in self.scopes.iter().rev() { + if let Some(ty) = scope.get(&n.name) { + return ty.clone(); + } + } + + // Check module symbols + if let Some(sym) = self.module_symbols.value_symbols.get(&n.name) { + if let Some(ty) = &sym.ty { + return ty.clone(); + } + } + + // Built-ins (some, none, ok, err might be handled as calls or special keywords) + // For v0, let's treat none as a special literal or identifier + if n.name == "none" { + return PbsType::None; + } + if n.name == "true" || n.name == "false" { + return PbsType::Bool; + } + + // Error should have been caught by Resolver, but we return Void + PbsType::Void + } + + fn check_call(&mut self, n: &CallNode) -> PbsType { + let callee_ty = self.check_node(&n.callee); + + // Handle special built-in "constructors" + if let Node::Ident(id) = &*n.callee { + match id.name.as_str() { + "some" => { + if n.args.len() == 1 { + let inner_ty = self.check_node(&n.args[0]); + return PbsType::Optional(Box::new(inner_ty)); + } + } + "ok" => { + if n.args.len() == 1 { + let inner_ty = self.check_node(&n.args[0]); + return PbsType::Result(Box::new(inner_ty), Box::new(PbsType::Void)); // Error type unknown here + } + } + "err" => { + if n.args.len() == 1 { + let inner_ty = self.check_node(&n.args[0]); + return PbsType::Result(Box::new(PbsType::Void), Box::new(inner_ty)); + } + } + _ => {} + } + } + + match callee_ty { + PbsType::Function { params, return_type } => { + if n.args.len() != params.len() { + self.diagnostics.push(Diagnostic { + level: DiagnosticLevel::Error, + code: Some("E_TYPE_MISMATCH".to_string()), + message: format!("Expected {} arguments, found {}", params.len(), n.args.len()), + span: Some(n.span), + }); + } else { + for (i, arg) in n.args.iter().enumerate() { + let arg_ty = self.check_node(arg); + if !self.is_assignable(¶ms[i], &arg_ty) { + self.error_type_mismatch(¶ms[i], &arg_ty, arg.span()); + } + } + } + *return_type + } + _ => { + if callee_ty != PbsType::Void { + self.diagnostics.push(Diagnostic { + level: DiagnosticLevel::Error, + code: Some("E_TYPE_MISMATCH".to_string()), + message: format!("Type {} is not callable", callee_ty), + span: Some(n.span), + }); + } + PbsType::Void + } + } + } + + fn check_unary(&mut self, n: &UnaryNode) -> PbsType { + let expr_ty = self.check_node(&n.expr); + match n.op.as_str() { + "-" => { + if expr_ty == PbsType::Int || expr_ty == PbsType::Float { + expr_ty + } else { + self.error_type_mismatch(&PbsType::Int, &expr_ty, n.span); + PbsType::Void + } + } + "!" => { + if expr_ty == PbsType::Bool { + PbsType::Bool + } else { + self.error_type_mismatch(&PbsType::Bool, &expr_ty, n.span); + PbsType::Void + } + } + _ => PbsType::Void, + } + } + + fn check_binary(&mut self, n: &BinaryNode) -> PbsType { + let left_ty = self.check_node(&n.left); + let right_ty = self.check_node(&n.right); + + match n.op.as_str() { + "+" | "-" | "*" | "/" | "%" => { + if (left_ty == PbsType::Int || left_ty == PbsType::Float) && left_ty == right_ty { + left_ty + } else { + self.error_type_mismatch(&left_ty, &right_ty, n.span); + PbsType::Void + } + } + "==" | "!=" => { + if left_ty == right_ty { + PbsType::Bool + } else { + self.error_type_mismatch(&left_ty, &right_ty, n.span); + PbsType::Bool + } + } + "<" | "<=" | ">" | ">=" => { + if (left_ty == PbsType::Int || left_ty == PbsType::Float) && left_ty == right_ty { + PbsType::Bool + } else { + self.error_type_mismatch(&left_ty, &right_ty, n.span); + PbsType::Bool + } + } + "&&" | "||" => { + if left_ty == PbsType::Bool && right_ty == PbsType::Bool { + PbsType::Bool + } else { + self.error_type_mismatch(&PbsType::Bool, &left_ty, n.left.span()); + self.error_type_mismatch(&PbsType::Bool, &right_ty, n.right.span()); + PbsType::Bool + } + } + _ => PbsType::Void, + } + } + + fn check_cast(&mut self, n: &CastNode) -> PbsType { + let _expr_ty = self.check_node(&n.expr); + let target_ty = self.resolve_type_node(&n.ty); + // Minimal cast validation for v0 + target_ty + } + + fn check_if_expr(&mut self, n: &IfExprNode) -> PbsType { + let cond_ty = self.check_node(&n.cond); + if cond_ty != PbsType::Bool { + self.error_type_mismatch(&PbsType::Bool, &cond_ty, n.cond.span()); + } + let then_ty = self.check_node(&n.then_block); + if let Some(else_block) = &n.else_block { + let else_ty = self.check_node(else_block); + if then_ty != else_ty { + self.error_type_mismatch(&then_ty, &else_ty, n.span); + } + then_ty + } else { + PbsType::Void + } + } + + fn check_when_expr(&mut self, n: &WhenExprNode) -> PbsType { + let mut first_ty = None; + for arm in &n.arms { + if let Node::WhenArm(arm_node) = arm { + let cond_ty = self.check_node(&arm_node.cond); + if cond_ty != PbsType::Bool { + self.error_type_mismatch(&PbsType::Bool, &cond_ty, arm_node.cond.span()); + } + let body_ty = self.check_node(&arm_node.body); + if first_ty.is_none() { + first_ty = Some(body_ty); + } else if let Some(fty) = &first_ty { + if *fty != body_ty { + self.error_type_mismatch(fty, &body_ty, arm_node.body.span()); + } + } + } + } + first_ty.unwrap_or(PbsType::Void) + } + + fn resolve_type_node(&mut self, node: &Node) -> PbsType { + match node { + Node::TypeName(tn) => { + match tn.name.as_str() { + "int" => PbsType::Int, + "float" => PbsType::Float, + "bool" => PbsType::Bool, + "string" => PbsType::String, + "void" => PbsType::Void, + _ => { + // Look up in symbol table + if let Some(sym) = self.lookup_type(&tn.name) { + match sym.kind { + SymbolKind::Struct => PbsType::Struct(tn.name.clone()), + SymbolKind::Service => PbsType::Service(tn.name.clone()), + SymbolKind::Contract => PbsType::Contract(tn.name.clone()), + SymbolKind::ErrorType => PbsType::ErrorType(tn.name.clone()), + _ => PbsType::Void, + } + } else { + self.diagnostics.push(Diagnostic { + level: DiagnosticLevel::Error, + code: Some("E_TYPE_UNKNOWN_TYPE".to_string()), + message: format!("Unknown type: {}", tn.name), + span: Some(tn.span), + }); + PbsType::Void + } + } + } + } + Node::TypeApp(ta) => { + match ta.base.as_str() { + "optional" => { + if ta.args.len() == 1 { + PbsType::Optional(Box::new(self.resolve_type_node(&ta.args[0]))) + } else { + PbsType::Void + } + } + "result" => { + if ta.args.len() == 2 { + PbsType::Result( + Box::new(self.resolve_type_node(&ta.args[0])), + Box::new(self.resolve_type_node(&ta.args[1])), + ) + } else { + PbsType::Void + } + } + _ => PbsType::Void, + } + } + _ => PbsType::Void, + } + } + + fn lookup_type(&self, name: &str) -> Option<&Symbol> { + if let Some(sym) = self.module_symbols.type_symbols.get(name) { + return Some(sym); + } + None + } + + fn is_assignable(&self, expected: &PbsType, found: &PbsType) -> bool { + if expected == found { + return true; + } + match (expected, found) { + (PbsType::Optional(_), PbsType::None) => true, + (PbsType::Optional(inner), found) => self.is_assignable(inner, found), + (PbsType::Result(ok_exp, _), PbsType::Result(ok_found, err_found)) if **err_found == PbsType::Void => { + self.is_assignable(ok_exp, ok_found) + } + (PbsType::Result(_, err_exp), PbsType::Result(ok_found, err_found)) if **ok_found == PbsType::Void => { + self.is_assignable(err_exp, err_found) + } + _ => false, + } + } + + fn all_paths_return(&self, node: &Node) -> bool { + match node { + Node::ReturnStmt(_) => true, + Node::Block(n) => { + for stmt in &n.stmts { + if self.all_paths_return(stmt) { + return true; + } + } + if let Some(tail) = &n.tail { + return self.all_paths_return(tail); + } + false + } + Node::IfExpr(n) => { + let then_returns = self.all_paths_return(&n.then_block); + let else_returns = n.else_block.as_ref().map(|b| self.all_paths_return(b)).unwrap_or(false); + then_returns && else_returns + } + // For simplicity, we don't assume When returns unless all arms do + _ => false, + } + } + + fn enter_scope(&mut self) { + self.scopes.push(HashMap::new()); + self.mut_bindings.push(HashMap::new()); + } + + fn exit_scope(&mut self) { + self.scopes.pop(); + self.mut_bindings.pop(); + } + + fn define_local(&mut self, name: &str, ty: PbsType, is_mut: bool) { + if let Some(scope) = self.scopes.last_mut() { + scope.insert(name.to_string(), ty); + } + if let Some(muts) = self.mut_bindings.last_mut() { + muts.insert(name.to_string(), is_mut); + } + } + + fn error_type_mismatch(&mut self, expected: &PbsType, found: &PbsType, span: Span) { + self.diagnostics.push(Diagnostic { + level: DiagnosticLevel::Error, + code: Some("E_TYPE_MISMATCH".to_string()), + message: format!("Type mismatch: expected {}, found {}", expected, found), + span: Some(span), + }); + } +} diff --git a/crates/prometeu-compiler/src/frontends/pbs/types.rs b/crates/prometeu-compiler/src/frontends/pbs/types.rs new file mode 100644 index 00000000..bab4f3e3 --- /dev/null +++ b/crates/prometeu-compiler/src/frontends/pbs/types.rs @@ -0,0 +1,50 @@ +use std::fmt; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PbsType { + Int, + Float, + Bool, + String, + Void, + None, + Optional(Box), + Result(Box, Box), + Struct(String), + Service(String), + Contract(String), + ErrorType(String), + Function { + params: Vec, + return_type: Box, + }, +} + +impl fmt::Display for PbsType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PbsType::Int => write!(f, "int"), + PbsType::Float => write!(f, "float"), + PbsType::Bool => write!(f, "bool"), + PbsType::String => write!(f, "string"), + PbsType::Void => write!(f, "void"), + PbsType::None => write!(f, "none"), + PbsType::Optional(inner) => write!(f, "optional<{}>", inner), + PbsType::Result(ok, err) => write!(f, "result<{}, {}>", ok, err), + PbsType::Struct(name) => write!(f, "{}", name), + PbsType::Service(name) => write!(f, "{}", name), + PbsType::Contract(name) => write!(f, "{}", name), + PbsType::ErrorType(name) => write!(f, "{}", name), + PbsType::Function { params, return_type } => { + write!(f, "fn(")?; + for (i, param) in params.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", param)?; + } + write!(f, ") -> {}", return_type) + } + } + } +} diff --git a/crates/prometeu-compiler/tests/pbs_typecheck_tests.rs b/crates/prometeu-compiler/tests/pbs_typecheck_tests.rs new file mode 100644 index 00000000..15a5bef2 --- /dev/null +++ b/crates/prometeu-compiler/tests/pbs_typecheck_tests.rs @@ -0,0 +1,142 @@ +use prometeu_compiler::frontends::pbs::PbsFrontend; +use prometeu_compiler::frontends::Frontend; +use prometeu_compiler::common::files::FileManager; +use std::fs; + +fn check_code(code: &str) -> Result<(), String> { + let mut file_manager = FileManager::new(); + let temp_dir = tempfile::tempdir().unwrap(); + let file_path = temp_dir.path().join("test.pbs"); + fs::write(&file_path, code).unwrap(); + + let frontend = PbsFrontend; + match frontend.compile_to_ir(&file_path, &mut file_manager) { + Ok(_) => Ok(()), + Err(bundle) => { + let mut errors = Vec::new(); + for diag in bundle.diagnostics { + let code = diag.code.unwrap_or_else(|| "NO_CODE".to_string()); + errors.push(format!("{}: {}", code, diag.message)); + } + Err(errors.join(", ")) + } + } +} + +#[test] +fn test_type_mismatch_let() { + let code = "fn main() { let x: int = \"hello\"; }"; + let res = check_code(code); + if let Err(e) = &res { println!("Error: {}", e); } + assert!(res.is_err()); + assert!(res.unwrap_err().contains("E_TYPE_MISMATCH")); +} + +#[test] +fn test_type_mismatch_return() { + let code = "fn main() -> int { return \"hello\"; }"; + let res = check_code(code); + assert!(res.is_err()); + assert!(res.unwrap_err().contains("E_TYPE_MISMATCH")); +} + +#[test] +fn test_type_mismatch_call() { + let code = " + fn foo(a: int) {} + fn main() { + foo(\"hello\"); + } + "; + let res = check_code(code); + assert!(res.is_err()); + assert!(res.unwrap_err().contains("E_TYPE_MISMATCH")); +} + +#[test] +fn test_missing_return_path() { + let code = "fn foo() -> int { if (true) { return 1; } }"; + let res = check_code(code); + assert!(res.is_err()); + assert!(res.unwrap_err().contains("E_TYPE_RETURN_PATH")); +} + +#[test] +fn test_implicit_none_optional() { + let code = "fn foo() -> optional { if (true) { return some(1); } }"; + let res = check_code(code); + if let Err(e) = &res { println!("Error: {}", e); } + assert!(res.is_ok()); // Implicit none allowed for optional +} + +#[test] +fn test_valid_optional_assignment() { + let code = "fn main() { let x: optional = none; let y: optional = some(10); }"; + let res = check_code(code); + if let Err(e) = &res { println!("Error: {}", e); } + assert!(res.is_ok()); +} + +#[test] +fn test_valid_result_usage() { + let code = " + fn foo() -> result { + if (true) { + return ok(10); + } else { + return err(\"error\"); + } + } + "; + let res = check_code(code); + if let Err(e) = &res { println!("Error: {}", e); } + assert!(res.is_ok()); +} + +#[test] +fn test_unknown_type() { + let code = "fn main() { let x: UnknownType = 10; }"; + let res = check_code(code); + assert!(res.is_err()); + assert!(res.unwrap_err().contains("E_TYPE_UNKNOWN_TYPE")); +} + +#[test] +fn test_void_return_ok() { + let code = "fn main() { return; }"; + let res = check_code(code); + assert!(res.is_ok()); +} + +#[test] +fn test_binary_op_mismatch() { + let code = "fn main() { let x = 1 + \"hello\"; }"; + let res = check_code(code); + assert!(res.is_err()); + assert!(res.unwrap_err().contains("E_TYPE_MISMATCH")); +} + +#[test] +fn test_struct_type_usage() { + let code = " + declare struct Point { x: int, y: int } + fn foo(p: Point) {} + fn main() { + // Struct literals not in v0, but we can have variables of struct type + } + "; + let res = check_code(code); + assert!(res.is_ok()); +} + +#[test] +fn test_service_type_usage() { + let code = " + pub service MyService { + fn hello(name: string) -> void + } + fn foo(s: MyService) {} + "; + let res = check_code(code); + assert!(res.is_ok()); +}