diff --git a/crates/prometeu-compiler/src/analysis/symbols/mod.rs b/crates/prometeu-compiler/src/analysis/symbols/mod.rs index d77083ce..21d646c6 100644 --- a/crates/prometeu-compiler/src/analysis/symbols/mod.rs +++ b/crates/prometeu-compiler/src/analysis/symbols/mod.rs @@ -1,5 +1,6 @@ use crate::common::diagnostics::{Diagnostic, DiagnosticLevel}; use crate::common::spans::Span; +use crate::frontends::pbs::ast::{AstArena, NodeId}; use prometeu_analysis::NameId; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -52,11 +53,16 @@ pub struct DefIndex { symbols: HashMap, } -#[derive(Debug, Default, Clone)] +#[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct RefIndex { refs: Vec>, } +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub struct NodeToSymbol { + map: Vec>, +} + impl SymbolArena { pub fn new() -> Self { Self { symbols: Vec::new() } @@ -125,6 +131,32 @@ impl RefIndex { } } +impl NodeToSymbol { + pub fn new() -> Self { + Self { map: Vec::new() } + } + + pub fn bind_node(&mut self, node_id: NodeId, symbol_id: SymbolId) { + self.ensure(node_id); + self.map[node_id.0 as usize] = Some(symbol_id); + } + + pub fn get(&self, node_id: NodeId) -> Option { + self.map.get(node_id.0 as usize).and_then(|opt| *opt) + } + + pub fn ensure(&mut self, node_id: NodeId) { + let index = node_id.0 as usize; + if index >= self.map.len() { + self.map.resize(index + 1, None); + } + } + + pub fn resize_to_fit(&mut self, arena: &AstArena) { + self.map.resize(arena.nodes.len(), None); + } +} + #[cfg(test)] mod tests { use super::*; @@ -212,4 +244,26 @@ mod tests { assert_eq!(index.refs_of(SymbolId(5)), &[span_b1]); assert!(index.refs_of(SymbolId(9)).is_empty()); } + + #[test] + fn node_to_symbol_bind_and_get() { + let mut map = NodeToSymbol::new(); + let nid = NodeId(10); + let sid = SymbolId(5); + + map.bind_node(nid, sid); + assert_eq!(map.get(nid), Some(sid)); + assert_eq!(map.get(NodeId(0)), None); + } + + #[test] + fn node_to_symbol_expands_automatically() { + let mut map = NodeToSymbol::new(); + let nid_high = NodeId(100); + let sid = SymbolId(1); + + map.bind_node(nid_high, sid); + assert_eq!(map.get(nid_high), Some(sid)); + assert!(map.map.len() > 100); + } } \ No newline at end of file