This commit is contained in:
bQUARKz 2026-02-04 19:46:03 +00:00
parent 9a62b7b643
commit 00ad4730c8
Signed by: bquarkz
SSH Key Fingerprint: SHA256:Z7dgqoglWwoK6j6u4QC87OveEq74WOhFN+gitsxtkf8

View File

@ -1,4 +1,5 @@
use crate::analysis::symbols::SymbolId;
use crate::frontends::pbs::ast::NodeId;
use prometeu_analysis::interner::NameId;
use serde::{Deserialize, Serialize};
@ -37,6 +38,42 @@ impl TypeArena {
}
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct TypeFacts {
pub node_type: Vec<Option<TypeId>>,
pub symbol_type: Vec<Option<TypeId>>,
}
impl TypeFacts {
pub fn new() -> Self {
Self::default()
}
pub fn set_node_type(&mut self, node_id: NodeId, type_id: TypeId) {
let idx = node_id.0 as usize;
if idx >= self.node_type.len() {
self.node_type.resize_with(idx + 1, || None);
}
self.node_type[idx] = Some(type_id);
}
pub fn get_node_type(&self, node_id: NodeId) -> Option<TypeId> {
self.node_type.get(node_id.0 as usize).and_then(|t| *t)
}
pub fn set_symbol_type(&mut self, symbol_id: SymbolId, type_id: TypeId) {
let idx = symbol_id.0 as usize;
if idx >= self.symbol_type.len() {
self.symbol_type.resize_with(idx + 1, || None);
}
self.symbol_type[idx] = Some(type_id);
}
pub fn get_symbol_type(&self, symbol_id: SymbolId) -> Option<TypeId> {
self.symbol_type.get(symbol_id.0 as usize).and_then(|t| *t)
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -71,4 +108,28 @@ mod tests {
assert!(matches!(arena.kind(t2), TypeKind::Optional { .. }));
assert!(matches!(arena.kind(t3), TypeKind::Array { .. }));
}
#[test]
fn type_facts_auto_grows_for_node_ids() {
let mut facts = TypeFacts::new();
let nid = NodeId(10);
let tid = TypeId(1);
assert_eq!(facts.get_node_type(nid), None);
facts.set_node_type(nid, tid);
assert_eq!(facts.get_node_type(nid), Some(tid));
assert!(facts.node_type.len() > 10);
}
#[test]
fn type_facts_auto_grows_for_symbol_ids() {
let mut facts = TypeFacts::new();
let sid = SymbolId(20);
let tid = TypeId(2);
assert_eq!(facts.get_symbol_type(sid), None);
facts.set_symbol_type(sid, tid);
assert_eq!(facts.get_symbol_type(sid), Some(tid));
assert!(facts.symbol_type.len() > 20);
}
}