diff --git a/crates/prometeu-compiler/src/frontends/pbs/lowering.rs b/crates/prometeu-compiler/src/frontends/pbs/lowering.rs index 3d7b1e80..4f32069f 100644 --- a/crates/prometeu-compiler/src/frontends/pbs/lowering.rs +++ b/crates/prometeu-compiler/src/frontends/pbs/lowering.rs @@ -328,6 +328,7 @@ impl<'a> Lowerer<'a> { NodeKind::Binary(n) => self.lower_binary(node, n), NodeKind::Unary(n) => self.lower_unary(node, n), NodeKind::IfExpr(n) => self.lower_if_expr(node, n), + NodeKind::WhenExpr(n) => self.lower_when_expr(node, n), NodeKind::Alloc(n) => self.lower_alloc(node, n), NodeKind::Mutate(n) => self.lower_mutate(node, n), NodeKind::Borrow(n) => self.lower_borrow(node, n), @@ -795,6 +796,9 @@ impl<'a> Lowerer<'a> { // Check for special built-in functions match callee_name.as_str() { "some" | "ok" | "err" => { + for arg in &n.args { + self.lower_node(*arg)?; + } return Ok(()); } _ => {} @@ -974,11 +978,25 @@ impl<'a> Lowerer<'a> { } fn lower_constructor_call(&mut self, ctor: NodeId, args: &[NodeId]) -> Result<(), ()> { + let ctor_id = ctor; let ctor = match self.arena.kind(ctor) { NodeKind::ConstructorDecl(ctor) => ctor, _ => return Err(()), }; + if args.len() != ctor.params.len() { + self.error( + "E_TYPE_MISMATCH", + format!( + "Expected {} arguments, found {}", + ctor.params.len(), + args.len() + ), + self.arena.span(ctor_id), + ); + return Err(()); + } + self.local_vars.push(HashMap::new()); let mut param_slots = Vec::new(); @@ -1093,6 +1111,52 @@ impl<'a> Lowerer<'a> { Ok(()) } + fn lower_when_expr(&mut self, node: NodeId, n: &WhenExprNodeArena) -> Result<(), ()> { + if n.arms.is_empty() { + return Ok(()); + } + + let merge_id = self.reserve_block_id(); + + for (idx, arm_id) in n.arms.iter().enumerate() { + let arm = match self.arena.kind(*arm_id) { + NodeKind::WhenArm(arm) => arm, + _ => { + self.error( + "E_LOWER_UNSUPPORTED", + "Expected when arm".to_string(), + self.arena.span(node), + ); + return Err(()); + } + }; + + let body_id = self.reserve_block_id(); + let next_cond_id = if idx + 1 < n.arms.len() { + self.reserve_block_id() + } else { + merge_id + }; + + self.lower_node(arm.cond)?; + self.terminate(Terminator::JumpIfFalse { + target: next_cond_id, + else_target: body_id, + }); + + self.start_block_with_id(body_id); + self.lower_node(arm.body)?; + self.terminate(Terminator::Jump(merge_id)); + + if idx + 1 < n.arms.len() { + self.start_block_with_id(next_cond_id); + } + } + + self.start_block_with_id(merge_id); + Ok(()) + } + fn lower_type_node(&mut self, node: NodeId) -> Type { match self.arena.kind(node) { NodeKind::TypeName(n) => match self.interner.resolve(n.name) { @@ -1102,7 +1166,25 @@ impl<'a> Lowerer<'a> { "bool" => Type::Bool, "string" => Type::String, "void" => Type::Void, - _ => Type::Struct(self.interner.resolve(n.name).to_string()), + _ => { + if let Some(sym) = self + .module_symbols + .type_symbols + .get(n.name) + .or_else(|| self.imported_symbols.type_symbols.get(n.name)) + { + let name = self.interner.resolve(n.name).to_string(); + match sym.kind { + SymbolKind::Struct => Type::Struct(name), + SymbolKind::Service => Type::Service(name), + SymbolKind::Contract => Type::Contract(name), + SymbolKind::ErrorType => Type::ErrorType(name), + _ => Type::Struct(name), + } + } else { + Type::Struct(self.interner.resolve(n.name).to_string()) + } + } }, NodeKind::TypeApp(ta) => { let base_name = self.interner.resolve(ta.base); @@ -1343,6 +1425,28 @@ mod tests { assert!(instrs.iter().any(|i| matches!(i.kind, ir_core::InstrKind::Or))); } + #[test] + fn test_unary_ops_lowering() { + let code = " + fn main() { + let a = -1; + let b = !true; + } + "; + let program = lower_program(code); + + let module = &program.modules[0]; + let main_func = module.functions.iter().find(|f| f.name == "main").unwrap(); + let instrs: Vec<_> = main_func + .blocks + .iter() + .flat_map(|b| b.instrs.iter()) + .collect(); + + assert!(instrs.iter().any(|i| matches!(i.kind, ir_core::InstrKind::Neg))); + assert!(instrs.iter().any(|i| matches!(i.kind, ir_core::InstrKind::Not))); + } + #[test] fn test_control_flow_lowering() { let code = " @@ -1373,6 +1477,211 @@ mod tests { assert!(max_func.blocks.len() >= 3); } + #[test] + fn test_if_expr_lowering() { + let code = " + fn main(a: int, b: int) { + if (a > b) { + let x = a; + } else { + let y = b; + } + } + "; + let program = lower_program(code); + + let module = &program.modules[0]; + let main_func = module.functions.iter().find(|f| f.name == "main").unwrap(); + let terminators: Vec<_> = main_func.blocks.iter().map(|b| &b.terminator).collect(); + + assert!(terminators.iter().any(|t| matches!(t, ir_core::Terminator::JumpIfFalse { .. }))); + assert!(terminators.iter().any(|t| matches!(t, ir_core::Terminator::Jump(_)))); + assert!(main_func.blocks.len() >= 3); + } + + #[test] + fn test_when_expr_lowering() { + let code = " + fn main(x: int) { + when { + x == 0 -> { return; }, + x == 1 -> { return; } + }; + } + "; + let program = lower_program(code); + + let module = &program.modules[0]; + let main_func = module.functions.iter().find(|f| f.name == "main").unwrap(); + let terminators: Vec<_> = main_func.blocks.iter().map(|b| &b.terminator).collect(); + + assert!(terminators.iter().any(|t| matches!(t, ir_core::Terminator::JumpIfFalse { .. }))); + assert!(main_func.blocks.len() >= 5); + } + + #[test] + fn test_lower_type_node() { + let code = " + service MyService { + fn ping(): void; + } + declare contract MyContract host {} + declare error MyError {} + declare struct Point { x: int } + + fn main( + s: MyService, + c: MyContract, + e: MyError, + p: Point, + o: optional, + r: result, + a: int[3] + ) {} + "; + let program = lower_program(code); + + let module = &program.modules[0]; + let main_func = module.functions.iter().find(|f| f.name == "main").unwrap(); + let params: Vec<_> = main_func.params.iter().map(|p| p.ty.clone()).collect(); + + assert_eq!(params[0], ir_core::Type::Service("MyService".to_string())); + assert_eq!(params[1], ir_core::Type::Contract("MyContract".to_string())); + assert_eq!(params[2], ir_core::Type::ErrorType("MyError".to_string())); + assert_eq!(params[3], ir_core::Type::Struct("Point".to_string())); + assert_eq!( + params[4], + ir_core::Type::Optional(Box::new(ir_core::Type::Int)) + ); + assert_eq!( + params[5], + ir_core::Type::Result( + Box::new(ir_core::Type::Int), + Box::new(ir_core::Type::String) + ) + ); + assert_eq!( + params[6], + ir_core::Type::Array(Box::new(ir_core::Type::Int), 3) + ); + } + + #[test] + fn test_call_lowering() { + let code = " + fn add(a: int, b: int): int { + return a + b; + } + fn main() { + let x = add(1, 2); + } + "; + let program = lower_program(code); + + let module = &program.modules[0]; + let main_func = module.functions.iter().find(|f| f.name == "main").unwrap(); + let instrs: Vec<_> = main_func + .blocks + .iter() + .flat_map(|b| b.instrs.iter()) + .collect(); + + assert!(instrs + .iter() + .any(|i| matches!(i.kind, ir_core::InstrKind::Call(_, 2)))); + } + + #[test] + fn test_host_call_lowering() { + let code = " + declare contract Gfx host {} + fn main() { + Gfx.clear(Color.WHITE); + } + "; + let program = lower_program(code); + + let module = &program.modules[0]; + let main_func = module.functions.iter().find(|f| f.name == "main").unwrap(); + let instrs: Vec<_> = main_func + .blocks + .iter() + .flat_map(|b| b.instrs.iter()) + .collect(); + + assert!(instrs + .iter() + .any(|i| matches!(i.kind, ir_core::InstrKind::HostCall(_, _)))); + assert!(instrs + .iter() + .any(|i| matches!(i.kind, ir_core::InstrKind::PushBounded(_)))); + } + + #[test] + fn test_member_access_lowering() { + let code = " + declare contract Input host {} + fn main() { + let p: Pad = Input.pad(); + let b: ButtonState = p.a; + let d: bool = p.a.down; + let c: Color = Color.WHITE; + } + "; + let program = lower_program(code); + + let module = &program.modules[0]; + let main_func = module.functions.iter().find(|f| f.name == "main").unwrap(); + let instrs: Vec<_> = main_func + .blocks + .iter() + .flat_map(|b| b.instrs.iter()) + .collect(); + + assert!(instrs + .iter() + .any(|i| matches!(i.kind, ir_core::InstrKind::GetLocal(16)))); + assert!(instrs + .iter() + .any(|i| matches!(i.kind, ir_core::InstrKind::GetLocal(18)))); + assert!(instrs + .iter() + .any(|i| matches!(i.kind, ir_core::InstrKind::PushBounded(_)))); + } + + #[test] + fn test_constructor_call_lowering() { + let code = " + declare struct Vec2(x: int, y: int) + [ + (x: int, y: int): (x, y) as default { } + (s: int): (s, s) as square { } + ] + fn main() { + let a = Vec2(1, 2); + let b = Vec2.square(3); + } + "; + let program = lower_program(code); + + let module = &program.modules[0]; + let main_func = module.functions.iter().find(|f| f.name == "main").unwrap(); + let instrs: Vec<_> = main_func + .blocks + .iter() + .flat_map(|b| b.instrs.iter()) + .collect(); + + let push_consts = instrs + .iter() + .filter(|i| matches!(i.kind, ir_core::InstrKind::PushConst(_))) + .count(); + + assert_eq!(push_consts, 3); + assert!(!instrs.iter().any(|i| matches!(i.kind, ir_core::InstrKind::Call(_, _)))); + assert!(!instrs.iter().any(|i| matches!(i.kind, ir_core::InstrKind::HostCall(_, _)))); + } + #[test] fn test_hip_lowering() { let code = "