363 lines
13 KiB
Rust
363 lines
13 KiB
Rust
use super::ids::ValueId;
|
|
use super::instr::InstrKind;
|
|
use super::program::Program;
|
|
use super::terminator::Terminator;
|
|
use std::collections::{HashMap, VecDeque};
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub enum HipOpKind {
|
|
Peek,
|
|
Borrow,
|
|
Mutate,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub struct HipOp {
|
|
pub kind: HipOpKind,
|
|
pub gate: ValueId,
|
|
}
|
|
|
|
pub fn validate_program(program: &Program) -> Result<(), String> {
|
|
for module in &program.modules {
|
|
for func in &module.functions {
|
|
validate_function(func)?;
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn validate_function(func: &super::function::Function) -> Result<(), String> {
|
|
let mut block_entry_stacks: HashMap<u32, Vec<HipOp>> = HashMap::new();
|
|
let mut worklist: VecDeque<u32> = VecDeque::new();
|
|
|
|
if func.blocks.is_empty() {
|
|
return Ok(());
|
|
}
|
|
|
|
// Assume the first block is the entry block (usually ID 0)
|
|
let entry_block_id = func.blocks[0].id;
|
|
block_entry_stacks.insert(entry_block_id, Vec::new());
|
|
worklist.push_back(entry_block_id);
|
|
|
|
let blocks_by_id: HashMap<u32, &super::block::Block> = func.blocks.iter().map(|b| (b.id, b)).collect();
|
|
let mut visited_with_stack: HashMap<u32, Vec<HipOp>> = HashMap::new();
|
|
|
|
while let Some(block_id) = worklist.pop_front() {
|
|
let block = blocks_by_id.get(&block_id).ok_or_else(|| format!("Invalid block ID: {}", block_id))?;
|
|
let mut current_stack = block_entry_stacks.get(&block_id).unwrap().clone();
|
|
|
|
// If we've already visited this block with the same stack, skip it to avoid infinite loops
|
|
if let Some(prev_stack) = visited_with_stack.get(&block_id) {
|
|
if prev_stack == ¤t_stack {
|
|
continue;
|
|
} else {
|
|
return Err(format!("Block {} reached with inconsistent HIP stacks: {:?} vs {:?}", block_id, prev_stack, current_stack));
|
|
}
|
|
}
|
|
visited_with_stack.insert(block_id, current_stack.clone());
|
|
|
|
for instr in &block.instrs {
|
|
match &instr.kind {
|
|
InstrKind::BeginPeek { gate } => {
|
|
current_stack.push(HipOp { kind: HipOpKind::Peek, gate: *gate });
|
|
}
|
|
InstrKind::BeginBorrow { gate } => {
|
|
current_stack.push(HipOp { kind: HipOpKind::Borrow, gate: *gate });
|
|
}
|
|
InstrKind::BeginMutate { gate } => {
|
|
current_stack.push(HipOp { kind: HipOpKind::Mutate, gate: *gate });
|
|
}
|
|
InstrKind::EndPeek => {
|
|
match current_stack.pop() {
|
|
Some(op) if op.kind == HipOpKind::Peek => {},
|
|
Some(op) => return Err(format!("EndPeek doesn't match current HIP op: {:?}", op)),
|
|
None => return Err("EndPeek without matching BeginPeek".to_string()),
|
|
}
|
|
}
|
|
InstrKind::EndBorrow => {
|
|
match current_stack.pop() {
|
|
Some(op) if op.kind == HipOpKind::Borrow => {},
|
|
Some(op) => return Err(format!("EndBorrow doesn't match current HIP op: {:?}", op)),
|
|
None => return Err("EndBorrow without matching BeginBorrow".to_string()),
|
|
}
|
|
}
|
|
InstrKind::EndMutate => {
|
|
match current_stack.pop() {
|
|
Some(op) if op.kind == HipOpKind::Mutate => {},
|
|
Some(op) => return Err(format!("EndMutate doesn't match current HIP op: {:?}", op)),
|
|
None => return Err("EndMutate without matching BeginMutate".to_string()),
|
|
}
|
|
}
|
|
InstrKind::GateLoadField { .. } | InstrKind::GateLoadIndex { .. } => {
|
|
if current_stack.is_empty() {
|
|
return Err("GateLoad outside of HIP operation".to_string());
|
|
}
|
|
}
|
|
InstrKind::GateStoreField { .. } | InstrKind::GateStoreIndex { .. } => {
|
|
match current_stack.last() {
|
|
Some(op) if op.kind == HipOpKind::Mutate => {},
|
|
_ => return Err("GateStore outside of BeginMutate".to_string()),
|
|
}
|
|
}
|
|
InstrKind::Call(id, _) => {
|
|
if id.0 == 0 {
|
|
return Err("Call to FunctionId(0)".to_string());
|
|
}
|
|
}
|
|
InstrKind::Alloc { ty, .. } => {
|
|
if ty.0 == 0 {
|
|
return Err("Alloc with TypeId(0)".to_string());
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
match &block.terminator {
|
|
Terminator::Return => {
|
|
if !current_stack.is_empty() {
|
|
return Err(format!("Function returns with non-empty HIP stack: {:?}", current_stack));
|
|
}
|
|
}
|
|
Terminator::Jump(target) => {
|
|
propagate_stack(&mut block_entry_stacks, &mut worklist, *target, ¤t_stack)?;
|
|
}
|
|
Terminator::JumpIfFalse { target, else_target } => {
|
|
propagate_stack(&mut block_entry_stacks, &mut worklist, *target, ¤t_stack)?;
|
|
propagate_stack(&mut block_entry_stacks, &mut worklist, *else_target, ¤t_stack)?;
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn propagate_stack(
|
|
entry_stacks: &mut HashMap<u32, Vec<HipOp>>,
|
|
worklist: &mut VecDeque<u32>,
|
|
target: u32,
|
|
stack: &Vec<HipOp>
|
|
) -> Result<(), String> {
|
|
if let Some(existing) = entry_stacks.get(&target) {
|
|
if existing != stack {
|
|
return Err(format!("Control flow merge at block {} with inconsistent HIP stacks: {:?} vs {:?}", target, existing, stack));
|
|
}
|
|
} else {
|
|
entry_stacks.insert(target, stack.clone());
|
|
worklist.push_back(target);
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::ir_core::*;
|
|
|
|
fn create_dummy_function(blocks: Vec<Block>) -> Function {
|
|
Function {
|
|
id: FunctionId(1),
|
|
name: "test".to_string(),
|
|
sig: {
|
|
let mut i = global_signature_interner().lock().unwrap();
|
|
i.intern(Signature { params: vec![], return_type: Type::Void })
|
|
},
|
|
param_slots: 0,
|
|
local_slots: 0,
|
|
return_slots: 0,
|
|
params: vec![],
|
|
return_type: Type::Void,
|
|
blocks,
|
|
local_types: HashMap::new(),
|
|
}
|
|
}
|
|
|
|
fn create_dummy_program(func: Function) -> Program {
|
|
Program {
|
|
const_pool: ConstPool::new(),
|
|
modules: vec![Module {
|
|
name: "test".to_string(),
|
|
functions: vec![func],
|
|
}],
|
|
field_offsets: HashMap::new(),
|
|
field_types: HashMap::new(),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_valid_hip_nesting() {
|
|
let block = Block {
|
|
id: 0,
|
|
instrs: vec![
|
|
Instr::from(InstrKind::BeginPeek { gate: ValueId(0) }),
|
|
Instr::from(InstrKind::GateLoadField { gate: ValueId(0), field: FieldId(0) }),
|
|
Instr::from(InstrKind::BeginMutate { gate: ValueId(1) }),
|
|
Instr::from(InstrKind::GateStoreField { gate: ValueId(1), field: FieldId(0), value: ValueId(2) }),
|
|
Instr::from(InstrKind::EndMutate),
|
|
Instr::from(InstrKind::EndPeek),
|
|
],
|
|
terminator: Terminator::Return,
|
|
};
|
|
let prog = create_dummy_program(create_dummy_function(vec![block]));
|
|
assert!(validate_program(&prog).is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn test_invalid_hip_unbalanced() {
|
|
let block = Block {
|
|
id: 0,
|
|
instrs: vec![
|
|
Instr::from(InstrKind::BeginPeek { gate: ValueId(0) }),
|
|
],
|
|
terminator: Terminator::Return,
|
|
};
|
|
let prog = create_dummy_program(create_dummy_function(vec![block]));
|
|
let res = validate_program(&prog);
|
|
assert!(res.is_err());
|
|
assert!(res.unwrap_err().contains("non-empty HIP stack"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_invalid_hip_wrong_end() {
|
|
let block = Block {
|
|
id: 0,
|
|
instrs: vec![
|
|
Instr::from(InstrKind::BeginPeek { gate: ValueId(0) }),
|
|
Instr::from(InstrKind::EndMutate),
|
|
],
|
|
terminator: Terminator::Return,
|
|
};
|
|
let prog = create_dummy_program(create_dummy_function(vec![block]));
|
|
let res = validate_program(&prog);
|
|
assert!(res.is_err());
|
|
assert!(res.unwrap_err().contains("EndMutate doesn't match"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_invalid_store_outside_mutate() {
|
|
let block = Block {
|
|
id: 0,
|
|
instrs: vec![
|
|
Instr::from(InstrKind::BeginBorrow { gate: ValueId(0) }),
|
|
Instr::from(InstrKind::GateStoreField { gate: ValueId(0), field: FieldId(0), value: ValueId(1) }),
|
|
Instr::from(InstrKind::EndBorrow),
|
|
],
|
|
terminator: Terminator::Return,
|
|
};
|
|
let prog = create_dummy_program(create_dummy_function(vec![block]));
|
|
let res = validate_program(&prog);
|
|
assert!(res.is_err());
|
|
assert!(res.unwrap_err().contains("GateStore outside of BeginMutate"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_valid_store_in_mutate() {
|
|
let block = Block {
|
|
id: 0,
|
|
instrs: vec![
|
|
Instr::from(InstrKind::BeginMutate { gate: ValueId(0) }),
|
|
Instr::from(InstrKind::GateStoreField { gate: ValueId(0), field: FieldId(0), value: ValueId(1) }),
|
|
Instr::from(InstrKind::EndMutate),
|
|
],
|
|
terminator: Terminator::Return,
|
|
};
|
|
let prog = create_dummy_program(create_dummy_function(vec![block]));
|
|
assert!(validate_program(&prog).is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn test_invalid_load_outside_hip() {
|
|
let block = Block {
|
|
id: 0,
|
|
instrs: vec![
|
|
Instr::from(InstrKind::GateLoadField { gate: ValueId(0), field: FieldId(0) }),
|
|
],
|
|
terminator: Terminator::Return,
|
|
};
|
|
let prog = create_dummy_program(create_dummy_function(vec![block]));
|
|
let res = validate_program(&prog);
|
|
assert!(res.is_err());
|
|
assert!(res.unwrap_err().contains("GateLoad outside of HIP operation"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_valid_hip_across_blocks() {
|
|
let block0 = Block {
|
|
id: 0,
|
|
instrs: vec![
|
|
Instr::from(InstrKind::BeginPeek { gate: ValueId(0) }),
|
|
],
|
|
terminator: Terminator::Jump(1),
|
|
};
|
|
let block1 = Block {
|
|
id: 1,
|
|
instrs: vec![
|
|
Instr::from(InstrKind::GateLoadField { gate: ValueId(0), field: FieldId(0) }),
|
|
Instr::from(InstrKind::EndPeek),
|
|
],
|
|
terminator: Terminator::Return,
|
|
};
|
|
let prog = create_dummy_program(create_dummy_function(vec![block0, block1]));
|
|
assert!(validate_program(&prog).is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn test_invalid_hip_across_blocks_inconsistent() {
|
|
let block0 = Block {
|
|
id: 0,
|
|
instrs: vec![
|
|
Instr::from(InstrKind::PushConst(ConstId(0))), // cond
|
|
],
|
|
terminator: Terminator::JumpIfFalse { target: 2, else_target: 1 },
|
|
};
|
|
let block1 = Block {
|
|
id: 1,
|
|
instrs: vec![
|
|
Instr::from(InstrKind::BeginPeek { gate: ValueId(0) }),
|
|
],
|
|
terminator: Terminator::Jump(3),
|
|
};
|
|
let block2 = Block {
|
|
id: 2,
|
|
instrs: vec![
|
|
// No BeginPeek here
|
|
],
|
|
terminator: Terminator::Jump(3),
|
|
};
|
|
let block3 = Block {
|
|
id: 3,
|
|
instrs: vec![
|
|
Instr::from(InstrKind::EndPeek), // ERROR: block 2 reaches here with empty stack
|
|
],
|
|
terminator: Terminator::Return,
|
|
};
|
|
let prog = create_dummy_program(create_dummy_function(vec![block0, block1, block2, block3]));
|
|
let res = validate_program(&prog);
|
|
assert!(res.is_err());
|
|
assert!(res.unwrap_err().contains("Control flow merge at block 3"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_silent_fallback_checks() {
|
|
let block_func0 = Block {
|
|
id: 0,
|
|
instrs: vec![
|
|
Instr::from(InstrKind::Call(FunctionId(0), 0)),
|
|
],
|
|
terminator: Terminator::Return,
|
|
};
|
|
let prog_func0 = create_dummy_program(create_dummy_function(vec![block_func0]));
|
|
assert!(validate_program(&prog_func0).is_err());
|
|
|
|
let block_ty0 = Block {
|
|
id: 0,
|
|
instrs: vec![
|
|
Instr::from(InstrKind::Alloc { ty: TypeId(0), slots: 1 }),
|
|
],
|
|
terminator: Terminator::Return,
|
|
};
|
|
let prog_ty0 = create_dummy_program(create_dummy_function(vec![block_ty0]));
|
|
assert!(validate_program(&prog_ty0).is_err());
|
|
}
|
|
}
|