diff --git a/crates/prometeu-compiler/src/analysis/types.rs b/crates/prometeu-compiler/src/analysis/types.rs index 73a1eca2..1b7ef4f4 100644 --- a/crates/prometeu-compiler/src/analysis/types.rs +++ b/crates/prometeu-compiler/src/analysis/types.rs @@ -1,4 +1,5 @@ use crate::analysis::symbols::SymbolId; +use crate::frontends::pbs::ast::NodeId; use prometeu_analysis::interner::NameId; use serde::{Deserialize, Serialize}; @@ -37,6 +38,42 @@ impl TypeArena { } } +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct TypeFacts { + pub node_type: Vec>, + pub symbol_type: Vec>, +} + +impl TypeFacts { + pub fn new() -> Self { + Self::default() + } + + pub fn set_node_type(&mut self, node_id: NodeId, type_id: TypeId) { + let idx = node_id.0 as usize; + if idx >= self.node_type.len() { + self.node_type.resize_with(idx + 1, || None); + } + self.node_type[idx] = Some(type_id); + } + + pub fn get_node_type(&self, node_id: NodeId) -> Option { + self.node_type.get(node_id.0 as usize).and_then(|t| *t) + } + + pub fn set_symbol_type(&mut self, symbol_id: SymbolId, type_id: TypeId) { + let idx = symbol_id.0 as usize; + if idx >= self.symbol_type.len() { + self.symbol_type.resize_with(idx + 1, || None); + } + self.symbol_type[idx] = Some(type_id); + } + + pub fn get_symbol_type(&self, symbol_id: SymbolId) -> Option { + self.symbol_type.get(symbol_id.0 as usize).and_then(|t| *t) + } +} + #[cfg(test)] mod tests { use super::*; @@ -71,4 +108,28 @@ mod tests { assert!(matches!(arena.kind(t2), TypeKind::Optional { .. })); assert!(matches!(arena.kind(t3), TypeKind::Array { .. })); } + + #[test] + fn type_facts_auto_grows_for_node_ids() { + let mut facts = TypeFacts::new(); + let nid = NodeId(10); + let tid = TypeId(1); + + assert_eq!(facts.get_node_type(nid), None); + facts.set_node_type(nid, tid); + assert_eq!(facts.get_node_type(nid), Some(tid)); + assert!(facts.node_type.len() > 10); + } + + #[test] + fn type_facts_auto_grows_for_symbol_ids() { + let mut facts = TypeFacts::new(); + let sid = SymbolId(20); + let tid = TypeId(2); + + assert_eq!(facts.get_symbol_type(sid), None); + facts.set_symbol_type(sid, tid); + assert_eq!(facts.get_symbol_type(sid), Some(tid)); + assert!(facts.symbol_type.len() > 20); + } }