2026-03-24 13:40:32 +00:00

367 lines
13 KiB
Rust

use crate::ir_core::ids::SigId;
use crate::ir_core::types::Type;
use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
/// Canonical function signature: params + return type.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Signature {
pub params: Vec<Type>,
pub return_type: Type,
}
impl Signature {
/// Stable, deterministic descriptor. Example:
/// fn(int;string)->void
/// fn(array{int;4};optional{string})->result{void;error{E}}
pub fn descriptor(&self) -> String {
let mut s = String::new();
s.push_str("fn(");
for (i, p) in self.params.iter().enumerate() {
if i > 0 {
s.push(';');
}
encode_type(p, &mut s);
}
s.push(')');
s.push_str("->");
encode_type(&self.return_type, &mut s);
s
}
/// Parse a descriptor previously produced by `descriptor()`
pub fn from_descriptor(desc: &str) -> Result<Self, String> {
// Expect prefix: fn( ... )-> ...
if !desc.starts_with("fn(") {
return Err("Invalid descriptor: missing fn(".to_string());
}
let rest = &desc[3..];
let close = rest.find(')').ok_or_else(|| "Invalid descriptor: missing ')'".to_string())?;
let params_blob = &rest[..close];
let after = &rest[close + 1..];
let arrow = after.strip_prefix("->").ok_or_else(|| "Invalid descriptor: missing '->'".to_string())?;
let params = if params_blob.is_empty() {
Vec::new()
} else {
let mut v = Vec::new();
for part in split_top_level(params_blob, ';')? {
let (ty, consumed) = decode_type(part)?;
if consumed != part.len() {
return Err("Trailing garbage in parameter".to_string());
}
v.push(ty);
}
v
};
let (return_type, consumed) = decode_type(arrow)?;
if consumed != arrow.len() {
return Err("Trailing garbage after return type".to_string());
}
Ok(Signature { params, return_type })
}
}
/// Global signature interner. Thread-safe and process-wide for this compiler instance.
pub struct SignatureInterner {
map: HashMap<Signature, SigId>,
rev: Vec<Signature>,
}
impl SignatureInterner {
pub fn new() -> Self {
Self { map: HashMap::new(), rev: Vec::new() }
}
pub fn intern(&mut self, sig: Signature) -> SigId {
if let Some(id) = self.map.get(&sig) {
return *id;
}
let id = SigId(self.rev.len() as u32);
self.rev.push(sig.clone());
self.map.insert(sig, id);
id
}
pub fn resolve(&self, id: SigId) -> Option<&Signature> {
self.rev.get(id.0 as usize)
}
}
static GLOBAL_INTERNER: OnceLock<Mutex<SignatureInterner>> = OnceLock::new();
pub fn global_signature_interner() -> &'static Mutex<SignatureInterner> {
GLOBAL_INTERNER.get_or_init(|| Mutex::new(SignatureInterner::new()))
}
// ==============
// Encoding/Decoding helpers for `Type`
// Canonical grammar (EBNF-ish):
// Type := "void" | "int" | "bounded" | "float" | "bool" | "string"
// | "struct{" Name "}" | "service{" Name "}" | "contract{" Name "}"
// | "error{" Name "}" | "array{" Type ";" UInt "}"
// | "optional{" Type "}" | "result{" Type ";" Type "}"
// | "fn(" [Type { ";" Type }] ")" "->" Type
// Name := escaped UTF-8 without '}' and ';' (escape via '\')
fn encode_type(ty: &Type, out: &mut String) {
match ty {
Type::Void => out.push_str("void"),
Type::Int => out.push_str("int"),
Type::Bounded => out.push_str("bounded"),
Type::Float => out.push_str("float"),
Type::Bool => out.push_str("bool"),
Type::String => out.push_str("string"),
Type::Optional(inner) => {
out.push_str("optional{");
encode_type(inner, out);
out.push('}');
}
Type::Result(ok, err) => {
out.push_str("result{");
encode_type(ok, out);
out.push(';');
encode_type(err, out);
out.push('}');
}
Type::Struct(name) => {
out.push_str("struct{");
encode_name(name, out);
out.push('}');
}
Type::Service(name) => {
out.push_str("service{");
encode_name(name, out);
out.push('}');
}
Type::Contract(name) => {
out.push_str("contract{");
encode_name(name, out);
out.push('}');
}
Type::ErrorType(name) => {
out.push_str("error{");
encode_name(name, out);
out.push('}');
}
Type::Array(inner, n) => {
out.push_str("array{");
encode_type(inner, out);
out.push(';');
out.push_str(&n.to_string());
out.push('}');
}
Type::Function { params, return_type } => {
out.push_str("fn(");
for (i, p) in params.iter().enumerate() {
if i > 0 { out.push(';'); }
encode_type(p, out);
}
out.push(')');
out.push_str("->");
encode_type(return_type, out);
}
}
}
fn encode_name(name: &str, out: &mut String) {
for ch in name.chars() {
match ch {
'\\' => out.push_str("\\\\"),
'}' => out.push_str("\\}"),
';' => out.push_str("\\;"),
_ => out.push(ch),
}
}
}
fn decode_name(s: &str) -> Result<(String, usize), String> {
let mut out = String::new();
let mut chars = s.chars().peekable();
let mut consumed = 0;
while let Some(&c) = chars.peek() {
if c == '}' { break; }
consumed += 1;
chars.next();
if c == '\\' {
let next = chars.next().ok_or_else(|| "Invalid escape in name".to_string())?;
consumed += 1;
out.push(next);
} else {
out.push(c);
}
}
Ok((out, consumed))
}
fn split_top_level(input: &str, sep: char) -> Result<Vec<&str>, String> {
let mut parts = Vec::new();
let mut depth_brace = 0i32;
let mut depth_fn = 0i32; // counts '(' nesting for nested fn types
let mut start = 0usize;
for (i, ch) in input.char_indices() {
match ch {
'{' => depth_brace += 1,
'}' => depth_brace -= 1,
'(' => depth_fn += 1,
')' => depth_fn -= 1,
_ => {}
}
if ch == sep && depth_brace == 0 && depth_fn == 0 {
parts.push(&input[start..i]);
start = i + ch.len_utf8();
}
}
parts.push(&input[start..]);
if depth_brace != 0 || depth_fn != 0 {
return Err("Unbalanced delimiters".to_string());
}
Ok(parts)
}
fn decode_type(s: &str) -> Result<(Type, usize), String> {
// Order matters; check longer keywords first to avoid prefix issues
let keywords = [
"optional{",
"result{",
"struct{",
"service{",
"contract{",
"error{",
"array{",
"fn(",
"void",
"int",
"bounded",
"float",
"bool",
"string",
];
for kw in keywords {
if s.starts_with(kw) {
match kw {
"void" => return Ok((Type::Void, 4)),
"int" => return Ok((Type::Int, 3)),
"bounded" => return Ok((Type::Bounded, 7)),
"float" => return Ok((Type::Float, 5)),
"bool" => return Ok((Type::Bool, 4)),
"string" => return Ok((Type::String, 6)),
"optional{" => {
let (inner, used) = decode_type(&s[9..])?;
let rest = &s[9 + used..];
if !rest.starts_with('}') { return Err("Missing '}' for optional".to_string()); }
return Ok((Type::Optional(Box::new(inner)), 9 + used + 1));
}
"result{" => {
let (ok, used_ok) = decode_type(&s[7..])?;
let rest = &s[7 + used_ok..];
if !rest.starts_with(';') { return Err("Missing ';' in result".to_string()); }
let (err, used_err) = decode_type(&rest[1..])?;
let rest2 = &rest[1 + used_err..];
if !rest2.starts_with('}') { return Err("Missing '}' for result".to_string()); }
return Ok((Type::Result(Box::new(ok), Box::new(err)), 7 + used_ok + 1 + used_err + 1));
}
"struct{" => {
let (name, used) = decode_name(&s[7..])?;
let rest = &s[7 + used..];
if !rest.starts_with('}') { return Err("Missing '}' for struct".to_string()); }
return Ok((Type::Struct(name), 7 + used + 1));
}
"service{" => {
let (name, used) = decode_name(&s[8..])?;
let rest = &s[8 + used..];
if !rest.starts_with('}') { return Err("Missing '}' for service".to_string()); }
return Ok((Type::Service(name), 8 + used + 1));
}
"contract{" => {
let (name, used) = decode_name(&s[9..])?;
let rest = &s[9 + used..];
if !rest.starts_with('}') { return Err("Missing '}' for contract".to_string()); }
return Ok((Type::Contract(name), 9 + used + 1));
}
"error{" => {
let (name, used) = decode_name(&s[6..])?;
let rest = &s[6 + used..];
if !rest.starts_with('}') { return Err("Missing '}' for error".to_string()); }
return Ok((Type::ErrorType(name), 6 + used + 1));
}
"array{" => {
let (inner, used) = decode_type(&s[6..])?;
let rest = &s[6 + used..];
if !rest.starts_with(';') { return Err("Missing ';' in array".to_string()); }
// parse UInt
let rest_num = &rest[1..];
let mut n_str = String::new();
let mut consumed = 0usize;
for ch in rest_num.chars() {
if ch.is_ascii_digit() { n_str.push(ch); consumed += 1; } else { break; }
}
if n_str.is_empty() { return Err("Missing array size".to_string()); }
let n: u32 = n_str.parse().map_err(|_| "Invalid array size".to_string())?;
let rest2 = &rest_num[consumed..];
if !rest2.starts_with('}') { return Err("Missing '}' for array".to_string()); }
return Ok((Type::Array(Box::new(inner), n), 6 + used + 1 + consumed + 1));
}
"fn(" => {
// parse params until ')'
let after = &s[3..];
let close = after.find(')').ok_or_else(|| "Missing ')' in fn type".to_string())?;
let params_blob = &after[..close];
let mut params = Vec::new();
if !params_blob.is_empty() {
for part in split_top_level(params_blob, ';')? {
let (p, used) = decode_type(part)?;
if used != part.len() { return Err("Trailing data in fn param".to_string()); }
params.push(p);
}
}
let rest = &after[close + 1..];
if !rest.starts_with("->") { return Err("Missing '->' in fn type".to_string()); }
let (ret, used_ret) = decode_type(&rest[2..])?;
return Ok((Type::Function { params, return_type: Box::new(ret) }, 3 + close + 1 + 2 + used_ret));
}
_ => {}
}
}
}
Err("Unknown type in descriptor".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn descriptors_are_different_for_overloads() {
let sig_i = Signature {
params: vec![Type::Int],
return_type: Type::Void,
};
let sig_s = Signature {
params: vec![Type::String],
return_type: Type::Void,
};
assert_ne!(sig_i.descriptor(), sig_s.descriptor());
}
#[test]
fn descriptor_round_trip_stable() {
let sig = Signature {
params: vec![
Type::Array(Box::new(Type::Int), 4),
Type::Optional(Box::new(Type::String)),
Type::Result(Box::new(Type::Int), Box::new(Type::ErrorType("E42".into()))),
],
return_type: Type::Void,
};
let d1 = sig.descriptor();
let parsed = Signature::from_descriptor(&d1).expect("parse ok");
let d2 = parsed.descriptor();
assert_eq!(d1, d2);
assert_eq!(sig, parsed);
}
}