From 6961597db8b153b8cea2792477cfb83e6ea96b50 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 17 Aug 2025 13:22:40 +0900 Subject: [PATCH 01/75] Move rty::FormulaWithAtoms to chc::Body --- src/analyze/basic_block.rs | 2 +- src/chc.rs | 176 ++++++++++++++++++++++++++++++++----- src/chc/clause_builder.rs | 32 ++----- src/chc/format_context.rs | 5 +- src/chc/smtlib2.rs | 36 ++++++-- src/chc/unbox.rs | 16 ++-- src/refine/env.rs | 28 +++--- src/rty.rs | 158 +++------------------------------ src/rty/clause_builder.rs | 12 ++- 9 files changed, 227 insertions(+), 238 deletions(-) diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 07c4813..ee9023b 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -767,7 +767,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { #[derive(Debug, Clone)] pub struct UnbindAtoms { existentials: IndexVec, - formula: rty::FormulaWithAtoms>, + formula: chc::Body>, target_equations: Vec<(rty::RefinedTypeVar, chc::Term>)>, } diff --git a/src/chc.rs b/src/chc.rs index 8da6587..f3eb058 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -1193,16 +1193,16 @@ impl Formula { } } - pub fn atoms(&self) -> impl Iterator> { - self.atoms_impl() + pub fn iter_atoms(&self) -> impl Iterator> { + self.iter_atoms_impl() } - fn atoms_impl(&self) -> Box> + '_> { + fn iter_atoms_impl(&self) -> Box> + '_> { match self { Formula::Atom(atom) => Box::new(std::iter::once(atom)), - Formula::Not(fo) => Box::new(fo.atoms()), - Formula::And(fs) => Box::new(fs.iter().flat_map(Formula::atoms)), - Formula::Or(fs) => Box::new(fs.iter().flat_map(Formula::atoms)), + Formula::Not(fo) => Box::new(fo.iter_atoms()), + Formula::And(fs) => Box::new(fs.iter().flat_map(Formula::iter_atoms)), + Formula::Or(fs) => Box::new(fs.iter().flat_map(Formula::iter_atoms)), } } @@ -1248,13 +1248,154 @@ impl Formula { } } +#[derive(Debug, Clone)] +pub struct Body { + pub atoms: Vec>, + // `formula` doesn't contain PredVar + pub formula: Formula, +} + +impl Default for Body { + fn default() -> Self { + Body { + atoms: Default::default(), + formula: Default::default(), + } + } +} + +impl From> for Body { + fn from(atom: Atom) -> Self { + Body { + atoms: vec![atom], + formula: Formula::top(), + } + } +} + +impl From>> for Body { + fn from(atoms: Vec>) -> Self { + Body { + atoms, + formula: Formula::top(), + } + } +} + +impl From> for Body { + fn from(formula: Formula) -> Self { + Body { + atoms: vec![], + formula, + } + } +} + +impl Body { + pub fn new(atoms: Vec>, formula: Formula) -> Self { + Body { atoms, formula } + } + + pub fn top() -> Self { + Body { + atoms: vec![], + formula: Formula::top(), + } + } + + pub fn bottom() -> Self { + Body { + atoms: vec![], + formula: Formula::bottom(), + } + } + + pub fn is_top(&self) -> bool { + self.formula.is_top() && self.atoms.iter().all(|a| a.is_top()) + } + + pub fn is_bottom(&self) -> bool { + self.formula.is_bottom() || self.atoms.iter().any(|a| a.is_bottom()) + } + + pub fn push_conj(&mut self, other: impl Into>) { + let Body { atoms, formula } = other.into(); + self.atoms.extend(atoms); + self.formula.push_conj(formula); + } + + pub fn map_var(self, mut f: F) -> Body + where + F: FnMut(V) -> W, + { + Body { + atoms: self.atoms.into_iter().map(|a| a.map_var(&mut f)).collect(), + formula: self.formula.map_var(f), + } + } + + pub fn subst_var(self, mut f: F) -> Body + where + F: FnMut(V) -> Term, + { + Body { + atoms: self + .atoms + .into_iter() + .map(|a| a.subst_var(&mut f)) + .collect(), + formula: self.formula.subst_var(f), + } + } + + pub fn simplify(&mut self) { + self.formula.simplify(); + self.atoms.retain(|a| !a.is_top()); + if self.is_bottom() { + self.atoms = vec![Atom::bottom()]; + self.formula = Formula::top(); + } + } + + pub fn iter_atoms(&self) -> impl Iterator> { + self.formula.iter_atoms().chain(&self.atoms) + } +} + +impl<'a, 'b, D, V> Pretty<'a, D, termcolor::ColorSpec> for &'b Body +where + V: Var, + D: pretty::DocAllocator<'a, termcolor::ColorSpec>, + D::Doc: Clone, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, termcolor::ColorSpec> { + let atoms = allocator.intersperse( + &self.atoms, + allocator + .text("∧") + .enclose(allocator.line(), allocator.space()), + ); + let formula = self.formula.pretty(allocator); + if self.atoms.is_empty() { + formula + } else if self.formula.is_top() { + atoms.group() + } else { + atoms + .append(allocator.line()) + .append(allocator.text("∧")) + .append(allocator.line()) + .append(formula) + .group() + } + } +} + #[derive(Debug, Clone)] pub struct Clause { pub vars: IndexVec, pub head: Atom, - pub body_atoms: Vec>, - // body_formula doesn't contain PredVar - pub body_formula: Formula, + pub body: Body, pub debug_info: DebugInfo, } @@ -1272,7 +1413,7 @@ where allocator.text(",").append(allocator.line()), ) .group(); - let body = self.body().pretty(allocator); + let body = self.body.pretty(allocator); let imp = self .head .pretty(allocator) @@ -1291,21 +1432,8 @@ where } impl Clause { - pub fn body(&self) -> Formula { - Formula::And( - self.body_atoms - .clone() - .into_iter() - .map(Formula::from) - .collect(), - ) - .and(self.body_formula.clone()) - } - pub fn is_nop(&self) -> bool { - self.head.is_top() - || self.body_atoms.iter().any(Atom::is_bottom) - || self.body_formula.is_bottom() + self.head.is_top() || self.body.is_bottom() } fn term_sort(&self, term: &Term) -> Sort { diff --git a/src/chc/clause_builder.rs b/src/chc/clause_builder.rs index 387e69e..0aec55b 100644 --- a/src/chc/clause_builder.rs +++ b/src/chc/clause_builder.rs @@ -6,7 +6,7 @@ use std::rc::Rc; use rustc_index::IndexVec; -use super::{Atom, Clause, DebugInfo, Formula, Sort, TermVarIdx}; +use super::{Atom, Body, Clause, DebugInfo, Sort, TermVarIdx}; pub trait Var: Eq + Ord + Hash + Copy + Debug + 'static {} impl Var for T {} @@ -58,8 +58,7 @@ impl Hash for dyn Key { pub struct ClauseBuilder { vars: IndexVec, mapped_var_indices: HashMap, TermVarIdx>, - body_atoms: Vec>, - body_formula: Formula, + body: Body, } impl ClauseBuilder { @@ -94,36 +93,19 @@ impl ClauseBuilder { .unwrap_or_else(|| panic!("unbound var {:?}", v)) } - pub fn add_body(&mut self, atom: Atom) -> &mut Self { - self.body_atoms.push(atom); - self - } - - pub fn add_body_formula(&mut self, formula: Formula) -> &mut Self { - self.body_formula.push_conj(formula); + pub fn add_body(&mut self, body: impl Into>) -> &mut Self { + self.body.push_conj(body); self } pub fn head(&self, head: Atom) -> Clause { let vars = self.vars.clone(); - let mut body_atoms: Vec<_> = self - .body_atoms - .clone() - .into_iter() - .filter(|a| !a.is_top()) - .collect(); - if body_atoms.is_empty() { - body_atoms = vec![Atom::top()]; - } else if body_atoms.iter().any(Atom::is_bottom) { - body_atoms = vec![Atom::bottom()]; - } - let mut body_formula = self.body_formula.clone(); - body_formula.simplify(); + let mut body = self.body.clone(); + body.simplify(); Clause { vars, head, - body_atoms, - body_formula, + body, debug_info: DebugInfo::from_current_span(), } } diff --git a/src/chc/format_context.rs b/src/chc/format_context.rs index f526f17..97c3a04 100644 --- a/src/chc/format_context.rs +++ b/src/chc/format_context.rs @@ -197,10 +197,7 @@ fn collect_sorts(system: &chc::System) -> BTreeSet { for clause in &system.clauses { sorts.extend(clause.vars.clone()); atom_sorts(clause, &clause.head, &mut sorts); - for a in &clause.body_atoms { - atom_sorts(clause, a, &mut sorts); - } - for a in clause.body_formula.atoms() { + for a in clause.body.iter_atoms() { atom_sorts(clause, a, &mut sorts); } } diff --git a/src/chc/smtlib2.rs b/src/chc/smtlib2.rs index 1346f95..6e9ebfa 100644 --- a/src/chc/smtlib2.rs +++ b/src/chc/smtlib2.rs @@ -278,6 +278,32 @@ impl<'ctx, 'a> Formula<'ctx, 'a> { } } +#[derive(Debug, Clone)] +pub struct Body<'ctx, 'a> { + ctx: &'ctx FormatContext, + clause: &'a chc::Clause, + inner: &'a chc::Body, +} + +impl<'ctx, 'a> std::fmt::Display for Body<'ctx, 'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let atoms = List::open( + self.inner + .atoms + .iter() + .map(|a| Atom::new(self.ctx, self.clause, a)), + ); + let formula = Formula::new(self.ctx, self.clause, &self.inner.formula); + write!(f, "(and {atoms} {formula})") + } +} + +impl<'ctx, 'a> Body<'ctx, 'a> { + pub fn new(ctx: &'ctx FormatContext, clause: &'a chc::Clause, inner: &'a chc::Body) -> Self { + Self { ctx, clause, inner } + } +} + #[derive(Debug, Clone)] pub struct Clause<'ctx, 'a> { ctx: &'ctx FormatContext, @@ -289,13 +315,7 @@ impl<'ctx, 'a> std::fmt::Display for Clause<'ctx, 'a> { if !self.inner.debug_info.is_empty() { writeln!(f, "{}", self.inner.debug_info.display("; "))?; } - let body_atoms = List::open( - self.inner - .body_atoms - .iter() - .map(|a| Atom::new(self.ctx, self.inner, a)), - ); - let body_formula = Formula::new(self.ctx, self.inner, &self.inner.body_formula); + let body = Body::new(self.ctx, self.inner, &self.inner.body); let head = Atom::new(self.ctx, self.inner, &self.inner.head); if !self.inner.vars.is_empty() { let vars = List::closed( @@ -306,7 +326,7 @@ impl<'ctx, 'a> std::fmt::Display for Clause<'ctx, 'a> { ); write!(f, "(forall {vars} ")?; } - write!(f, "(=> (and {body_atoms} {body_formula}) {head})")?; + write!(f, "(=> {body} {head})")?; if !self.inner.vars.is_empty() { write!(f, ")")?; } diff --git a/src/chc/unbox.rs b/src/chc/unbox.rs index d5b8131..441c95b 100644 --- a/src/chc/unbox.rs +++ b/src/chc/unbox.rs @@ -53,23 +53,27 @@ fn unbox_formula(formula: Formula) -> Formula { } } +fn unbox_body(body: Body) -> Body { + let Body { atoms, formula } = body; + let atoms = atoms.into_iter().map(unbox_atom).collect(); + let formula = unbox_formula(formula); + Body { atoms, formula } +} + fn unbox_clause(clause: Clause) -> Clause { let Clause { vars, head, - body_atoms, - body_formula, + body, debug_info, } = clause; let vars = vars.into_iter().map(unbox_sort).collect(); let head = unbox_atom(head); - let body_atoms = body_atoms.into_iter().map(unbox_atom).collect(); - let body_formula = unbox_formula(body_formula); + let body = unbox_body(body); Clause { vars, head, - body_atoms, - body_formula, + body, debug_info, } } diff --git a/src/refine/env.rs b/src/refine/env.rs index 8311dfe..dd5d27d 100644 --- a/src/refine/env.rs +++ b/src/refine/env.rs @@ -225,7 +225,7 @@ pub struct PlaceType { pub ty: rty::Type, pub existentials: IndexVec, pub term: chc::Term, - pub formula: rty::FormulaWithAtoms, + pub formula: chc::Body, } impl From for rty::RefinedType { @@ -564,7 +564,7 @@ impl PlaceType { tys: Vec>, terms: Vec>, existentials: IndexVec, - formula: rty::FormulaWithAtoms, + formula: chc::Body, } let State { tys, @@ -612,7 +612,7 @@ impl PlaceType { struct State { existentials: IndexVec, terms: Vec>, - formula: rty::FormulaWithAtoms, + formula: chc::Body, } let State { mut existentials, @@ -658,14 +658,13 @@ impl PlaceType { #[derive(Debug, Clone, Default)] pub struct UnboundAssumption { pub existentials: IndexVec, - pub formula: rty::FormulaWithAtoms, + pub formula: chc::Body, } impl From> for UnboundAssumption { fn from(atom: chc::Atom) -> Self { let existentials = IndexVec::new(); - let formula = - rty::FormulaWithAtoms::new(vec![atom.map_var(Into::into)], Default::default()); + let formula = chc::Body::new(vec![atom.map_var(Into::into)], Default::default()); UnboundAssumption { existentials, formula, @@ -732,7 +731,7 @@ where impl UnboundAssumption { pub fn new( existentials: IndexVec, - formula: rty::FormulaWithAtoms, + formula: chc::Body, ) -> Self { UnboundAssumption { existentials, @@ -776,11 +775,11 @@ impl rty::ClauseScope for Env { if !rty.ty.to_sort().is_singleton() { instantiator.value_var(builder.mapped_var(var)); } - let rty::FormulaWithAtoms { formula, atoms } = instantiator.instantiate(); + let chc::Body { formula, atoms } = instantiator.instantiate(); for atom in atoms { builder.add_body(atom); } - builder.add_body_formula(formula); + builder.add_body(formula); } for assumption in &self.unbound_assumptions { let mut evs = HashMap::new(); @@ -788,15 +787,14 @@ impl rty::ClauseScope for Env { let tv = builder.add_var(sort.clone()); evs.insert(ev, tv); } - let rty::FormulaWithAtoms { formula, atoms } = - assumption.formula.clone().map_var(|v| match v { - PlaceTypeVar::Var(v) => builder.mapped_var(v), - PlaceTypeVar::Existential(ev) => evs[&ev], - }); + let chc::Body { formula, atoms } = assumption.formula.clone().map_var(|v| match v { + PlaceTypeVar::Var(v) => builder.mapped_var(v), + PlaceTypeVar::Existential(ev) => evs[&ev], + }); for atom in atoms { builder.add_body(atom); } - builder.add_body_formula(formula); + builder.add_body(formula); } builder } diff --git a/src/rty.rs b/src/rty.rs index 2641da7..a68b862 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -904,152 +904,14 @@ impl RefinedTypeVar { } } -#[derive(Debug, Clone)] -pub struct FormulaWithAtoms { - pub atoms: Vec>, - pub formula: chc::Formula, -} - -impl Default for FormulaWithAtoms { - fn default() -> Self { - FormulaWithAtoms { - atoms: Default::default(), - formula: Default::default(), - } - } -} - -impl From> for FormulaWithAtoms { - fn from(atom: chc::Atom) -> Self { - FormulaWithAtoms { - atoms: vec![atom], - formula: chc::Formula::top(), - } - } -} - -impl From>> for FormulaWithAtoms { - fn from(atoms: Vec>) -> Self { - FormulaWithAtoms { - atoms, - formula: chc::Formula::top(), - } - } -} - -impl From> for FormulaWithAtoms { - fn from(formula: chc::Formula) -> Self { - FormulaWithAtoms { - atoms: vec![], - formula, - } - } -} - -impl FormulaWithAtoms { - pub fn new(atoms: Vec>, formula: chc::Formula) -> Self { - FormulaWithAtoms { atoms, formula } - } - - pub fn top() -> Self { - FormulaWithAtoms { - atoms: vec![], - formula: chc::Formula::top(), - } - } - - pub fn bottom() -> Self { - FormulaWithAtoms { - atoms: vec![], - formula: chc::Formula::bottom(), - } - } - - pub fn is_top(&self) -> bool { - self.formula.is_top() && self.atoms.iter().all(|a| a.is_top()) - } - - pub fn is_bottom(&self) -> bool { - self.formula.is_bottom() || self.atoms.iter().any(|a| a.is_bottom()) - } - - pub fn push_conj(&mut self, other: impl Into>) { - let FormulaWithAtoms { atoms, formula } = other.into(); - self.atoms.extend(atoms); - self.formula.push_conj(formula); - } - - pub fn map_var(self, mut f: F) -> FormulaWithAtoms - where - F: FnMut(V) -> W, - { - FormulaWithAtoms { - atoms: self.atoms.into_iter().map(|a| a.map_var(&mut f)).collect(), - formula: self.formula.map_var(f), - } - } - - pub fn subst_var(self, mut f: F) -> FormulaWithAtoms - where - F: FnMut(V) -> chc::Term, - { - FormulaWithAtoms { - atoms: self - .atoms - .into_iter() - .map(|a| a.subst_var(&mut f)) - .collect(), - formula: self.formula.subst_var(f), - } - } - - pub fn simplify(&mut self) { - self.formula.simplify(); - self.atoms.retain(|a| !a.is_top()); - if self.is_bottom() { - self.atoms = vec![chc::Atom::bottom()]; - self.formula = chc::Formula::top(); - } - } -} - -impl<'a, 'b, D, V> Pretty<'a, D, termcolor::ColorSpec> for &'b FormulaWithAtoms -where - V: chc::Var, - D: pretty::DocAllocator<'a, termcolor::ColorSpec>, - D::Doc: Clone, -{ - fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, termcolor::ColorSpec> { - let atoms = allocator.intersperse( - &self.atoms, - allocator - .text("∧") - .enclose(allocator.line(), allocator.space()), - ); - let formula = self.formula.pretty(allocator); - if self.atoms.is_empty() { - formula - } else if self.formula.is_top() { - atoms.group() - } else { - atoms - .append(allocator.line()) - .append(allocator.text("∧")) - .append(allocator.line()) - .append(formula) - .group() - } - } -} - #[derive(Debug, Clone)] pub struct Refinement { pub existentials: IndexVec, - pub formula: FormulaWithAtoms>, + pub formula: chc::Body>, } -impl From>> for Refinement { - fn from(formula: FormulaWithAtoms>) -> Self { +impl From>> for Refinement { + fn from(formula: chc::Body>) -> Self { Refinement { existentials: IndexVec::new(), formula, @@ -1061,7 +923,7 @@ impl From>> for Refinement { fn from(atom: chc::Atom>) -> Self { Refinement { existentials: IndexVec::new(), - formula: FormulaWithAtoms::new(vec![atom], chc::Formula::top()), + formula: chc::Body::new(vec![atom], chc::Formula::top()), } } } @@ -1070,7 +932,7 @@ impl From>> for Refinement { fn from(formula: chc::Formula>) -> Self { Refinement { existentials: IndexVec::new(), - formula: FormulaWithAtoms::new(vec![], formula), + formula: chc::Body::new(vec![], formula), } } } @@ -1107,7 +969,7 @@ where impl Refinement { pub fn with_formula( existentials: IndexVec, - formula: FormulaWithAtoms>, + formula: chc::Body>, ) -> Self { Refinement { existentials, @@ -1121,7 +983,7 @@ impl Refinement { ) -> Self { let mut refinement = Refinement { existentials, - formula: FormulaWithAtoms::new(atoms, chc::Formula::top()), + formula: chc::Body::new(atoms, chc::Formula::top()), }; refinement.formula.simplify(); refinement @@ -1144,11 +1006,11 @@ impl Refinement { } pub fn top() -> Self { - Refinement::with_formula(IndexVec::new(), FormulaWithAtoms::top()) + Refinement::with_formula(IndexVec::new(), chc::Body::top()) } pub fn bottom() -> Self { - Refinement::with_formula(IndexVec::new(), FormulaWithAtoms::bottom()) + Refinement::with_formula(IndexVec::new(), chc::Body::bottom()) } pub fn extend(&mut self, other: Refinement) { @@ -1231,7 +1093,7 @@ impl Instantiator { self } - pub fn instantiate(self) -> FormulaWithAtoms + pub fn instantiate(self) -> chc::Body where T: Clone, { diff --git a/src/rty/clause_builder.rs b/src/rty/clause_builder.rs index 813a6d9..f45e948 100644 --- a/src/rty/clause_builder.rs +++ b/src/rty/clause_builder.rs @@ -1,6 +1,6 @@ use crate::chc; -use super::{FormulaWithAtoms, Refinement, Type}; +use super::{Refinement, Type}; pub trait ClauseBuilderExt { fn with_value_var<'a, T>(&'a mut self, ty: &Type) -> RefinementClauseBuilder<'a>; @@ -55,11 +55,11 @@ impl<'a> RefinementClauseBuilder<'a> { if let Some(value_var) = self.value_var { instantiator.value_var(value_var); } - let FormulaWithAtoms { atoms, formula } = instantiator.instantiate(); + let chc::Body { atoms, formula } = instantiator.instantiate(); for atom in atoms { self.builder.add_body(atom); } - self.builder.add_body_formula(formula); + self.builder.add_body(formula); self } @@ -76,7 +76,7 @@ impl<'a> RefinementClauseBuilder<'a> { if let Some(value_var) = self.value_var { instantiator.value_var(value_var); } - let FormulaWithAtoms { atoms, formula } = instantiator.instantiate(); + let chc::Body { atoms, formula } = instantiator.instantiate(); let mut cs = atoms .into_iter() .map(|a| self.builder.head(a)) @@ -84,9 +84,7 @@ impl<'a> RefinementClauseBuilder<'a> { if !formula.is_top() { cs.push({ let mut builder = self.builder.clone(); - builder - .add_body_formula(formula.not()) - .head(chc::Atom::bottom()) + builder.add_body(formula.not()).head(chc::Atom::bottom()) }); } cs From 0709cec5101961b00e131393c8149aa1b722935f Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 17 Aug 2025 15:15:49 +0900 Subject: [PATCH 02/75] rty::Refinement = rty::Formula --- src/analyze/basic_block.rs | 32 ++++---- src/refine/env.rs | 23 +++--- src/refine/template.rs | 7 +- src/rty.rs | 147 +++++++++++++++++++++++-------------- src/rty/params.rs | 2 +- 5 files changed, 117 insertions(+), 94 deletions(-) diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index ee9023b..5d71d99 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -16,7 +16,9 @@ use crate::refine::{ self, BasicBlockType, Env, PlaceType, PlaceTypeVar, TempVarIdx, TemplateTypeGenerator, UnboundAssumption, UnrefinedTypeGenerator, Var, }; -use crate::rty::{self, ClauseBuilderExt as _, ClauseScope as _, Subtyping as _}; +use crate::rty::{ + self, ClauseBuilderExt as _, ClauseScope as _, ShiftExistential as _, Subtyping as _, +}; mod drop_point; mod visitor; @@ -767,7 +769,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { #[derive(Debug, Clone)] pub struct UnbindAtoms { existentials: IndexVec, - formula: chc::Body>, + body: chc::Body>, target_equations: Vec<(rty::RefinedTypeVar, chc::Term>)>, } @@ -775,7 +777,7 @@ impl Default for UnbindAtoms { fn default() -> Self { UnbindAtoms { existentials: Default::default(), - formula: Default::default(), + body: Default::default(), target_equations: Default::default(), } } @@ -783,7 +785,7 @@ impl Default for UnbindAtoms { impl UnbindAtoms { pub fn push(&mut self, target: rty::RefinedTypeVar, var_ty: PlaceType) { - self.formula.push_conj( + self.body.push_conj( var_ty .formula .map_var(|v| v.shift_existential(self.existentials.len()).into()), @@ -802,13 +804,10 @@ impl UnbindAtoms { ty: src_ty, refinement, } = ty; - let rty::Refinement { - existentials, - formula, - } = refinement; + let rty::Refinement { existentials, body } = refinement; - self.formula - .push_conj(formula.map_var(|v| v.shift_existential(self.existentials.len()))); + self.body + .push_conj(body.map_var(|v| v.shift_existential(self.existentials.len()))); self.existentials.extend(existentials); let mut substs = HashMap::new(); @@ -817,12 +816,12 @@ impl UnbindAtoms { substs.insert(v, ev); } - let mut formula = self.formula.map_var(|v| match v { + let mut body = self.body.map_var(|v| match v { rty::RefinedTypeVar::Value => rty::RefinedTypeVar::Value, rty::RefinedTypeVar::Free(v) => rty::RefinedTypeVar::Existential(substs[&v]), rty::RefinedTypeVar::Existential(ev) => rty::RefinedTypeVar::Existential(ev), }); - formula.push_conj( + body.push_conj( self.target_equations .into_iter() .map(|(t, term)| { @@ -839,7 +838,7 @@ impl UnbindAtoms { .collect::>(), ); - let refinement = rty::Refinement::with_formula(self.existentials, formula); + let refinement = rty::Refinement::new(self.existentials, body); // TODO: polymorphic datatypes: template needed? rty::RefinedType::new(src_ty.assert_closed().vacuous(), refinement) } @@ -892,11 +891,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } for (idx, param) in expected_params.iter_enumerated() { - let rty::Refinement { - existentials, - formula, - } = param.refinement.clone(); - assumption.formula.push_conj(formula.subst_var(|v| match v { + let rty::Refinement { existentials, body } = param.refinement.clone(); + assumption.formula.push_conj(body.subst_var(|v| match v { rty::RefinedTypeVar::Value => param_terms[&idx].clone(), rty::RefinedTypeVar::Free(v) => param_terms[&v].clone(), rty::RefinedTypeVar::Existential(ev) => chc::Term::var(PlaceTypeVar::Existential( diff --git a/src/refine/env.rs b/src/refine/env.rs index dd5d27d..35f2999 100644 --- a/src/refine/env.rs +++ b/src/refine/env.rs @@ -236,11 +236,11 @@ impl From for rty::RefinedType { term, formula, } = ty; - let mut formula = formula.map_var(Into::into); - formula.push_conj( + let mut body = formula.map_var(Into::into); + body.push_conj( chc::Term::var(rty::RefinedTypeVar::Value).equal_to(term.map_var(Into::into)), ); - let refinement = rty::Refinement::with_formula(existentials, formula); + let refinement = rty::Refinement::new(existentials, body); rty::RefinedType::new(ty, refinement) } } @@ -328,7 +328,7 @@ impl PlaceType { let rty::RefinedType { ty, refinement } = *inner_ty.elem; let rty::Refinement { existentials: inner_existentials, - formula: inner_formula, + body: inner_formula, } = refinement; let value_var_ex = existentials.push(ty.to_sort()); let term = chc::Term::var(value_var_ex.into()); @@ -360,7 +360,7 @@ impl PlaceType { let rty::RefinedType { ty, refinement } = inner_ty.elems[idx].clone(); let rty::Refinement { existentials: inner_existentials, - formula: inner_formula, + body: inner_formula, } = refinement; let value_var_ex = existentials.push(ty.to_sort()); let term = chc::Term::var(value_var_ex.into()); @@ -408,7 +408,7 @@ impl PlaceType { field_terms.push(chc::Term::var(field_ex_var.into())); field_tys.push(ty); - formula.push_conj(refinement.formula.map_var(|v| match v { + formula.push_conj(refinement.body.map_var(|v| match v { rty::RefinedTypeVar::Value => PlaceTypeVar::Existential(field_ex_var), rty::RefinedTypeVar::Existential(ev) => { PlaceTypeVar::Existential(ev + existentials.len()) @@ -925,7 +925,7 @@ impl Env { .collect(), ); let mut existentials = tuple_ty.existentials; - let mut formula = refinement.formula.subst_var(|v| match v { + let mut formula = refinement.body.subst_var(|v| match v { rty::RefinedTypeVar::Value => tuple_ty.term.clone(), rty::RefinedTypeVar::Free(v) => chc::Term::var(PlaceTypeVar::Var(v)), rty::RefinedTypeVar::Existential(ev) => { @@ -984,7 +984,7 @@ impl Env { let value_var_ev = existentials.push(rty::Type::Enum(ty.clone()).to_sort()); let mut assumption = UnboundAssumption { existentials, - formula: refinement.formula.map_var(|v| match v { + formula: refinement.body.map_var(|v| match v { rty::RefinedTypeVar::Value => PlaceTypeVar::Existential(value_var_ev), rty::RefinedTypeVar::Free(v) => PlaceTypeVar::Var(v), rty::RefinedTypeVar::Existential(ev) => PlaceTypeVar::Existential(ev), @@ -1388,15 +1388,12 @@ impl Env { ty: field_ty, refinement, } = field_rty; - let rty::Refinement { - formula, - existentials, - } = refinement; + let rty::Refinement { body, existentials } = refinement; PlaceType { ty: field_ty, existentials, term: chc::Term::var(ev.into()), - formula: formula.map_var(|v| match v { + formula: body.map_var(|v| match v { rty::RefinedTypeVar::Value => PlaceTypeVar::Existential(ev), rty::RefinedTypeVar::Free(v) => PlaceTypeVar::Var(v), // TODO: (but otherwise we can't distinguish field existentials from them...) diff --git a/src/refine/template.rs b/src/refine/template.rs index 285a216..b6ae7b8 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -147,13 +147,10 @@ where &mut self, refinement: rty::Refinement, ) -> &mut Self { - let rty::Refinement { - existentials, - formula, - } = refinement; + let rty::Refinement { existentials, body } = refinement; let refinement = rty::Refinement { existentials, - formula: formula.map_var(|v| match v { + body: body.map_var(|v| match v { rty::RefinedTypeVar::Free(idx) if idx.index() == self.param_tys.len() - 1 => { rty::RefinedTypeVar::Value } diff --git a/src/rty.rs b/src/rty.rs index a68b862..57e32cd 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -842,6 +842,10 @@ where } } +pub trait ShiftExistential { + fn shift_existential(self, offset: usize) -> Self; +} + #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum RefinedTypeVar { Value, @@ -895,8 +899,8 @@ where } } -impl RefinedTypeVar { - pub fn shift_existential(self, offset: usize) -> Self { +impl ShiftExistential for RefinedTypeVar { + fn shift_existential(self, offset: usize) -> Self { match self { RefinedTypeVar::Existential(v) => RefinedTypeVar::Existential(v + offset), v => v, @@ -905,41 +909,79 @@ impl RefinedTypeVar { } #[derive(Debug, Clone)] -pub struct Refinement { +pub struct Formula { pub existentials: IndexVec, - pub formula: chc::Body>, + pub body: chc::Body, } -impl From>> for Refinement { - fn from(formula: chc::Body>) -> Self { - Refinement { +impl Default for Formula { + fn default() -> Self { + Formula { + existentials: Default::default(), + body: Default::default(), + } + } +} + +impl From> for Formula { + fn from(body: chc::Body) -> Self { + Formula { existentials: IndexVec::new(), - formula, + body, } } } -impl From>> for Refinement { - fn from(atom: chc::Atom>) -> Self { - Refinement { +impl From> for Formula { + fn from(atom: chc::Atom) -> Self { + Formula { existentials: IndexVec::new(), - formula: chc::Body::new(vec![atom], chc::Formula::top()), + body: chc::Body::new(vec![atom], chc::Formula::top()), } } } -impl From>> for Refinement { - fn from(formula: chc::Formula>) -> Self { - Refinement { +impl From> for Formula { + fn from(formula: chc::Formula) -> Self { + Formula { existentials: IndexVec::new(), - formula: chc::Body::new(vec![], formula), + body: chc::Body::new(vec![], formula), } } } -impl<'a, 'b, D, FV> Pretty<'a, D, termcolor::ColorSpec> for &'b Refinement +impl Extend> for Formula where - FV: chc::Var, + V: ShiftExistential, +{ + fn extend(&mut self, iter: T) + where + T: IntoIterator>, + { + for formula in iter { + self.push_conj(formula); + } + self.body.simplify(); + } +} + +impl FromIterator> for Formula +where + V: ShiftExistential, +{ + fn from_iter(iter: T) -> Self + where + T: IntoIterator>, + { + let mut result = Formula::default(); + result.extend(iter); + result + } +} + +impl<'a, 'b, D, V> Pretty<'a, D, termcolor::ColorSpec> for &'b Formula +where + V: chc::Var, D: pretty::DocAllocator<'a, termcolor::ColorSpec>, D::Doc: Clone, { @@ -952,41 +994,26 @@ where allocator.text(",").append(allocator.line()), ) .group(); - let formula = self.formula.pretty(allocator); + let body = self.body.pretty(allocator); if self.existentials.is_empty() { - formula + body } else { allocator .text("∃") .append(existentials.nest(2)) .append(allocator.text(".")) - .append(allocator.line().append(formula).nest(2)) + .append(allocator.line().append(body).nest(2)) .group() } } } -impl Refinement { - pub fn with_formula( - existentials: IndexVec, - formula: chc::Body>, - ) -> Self { - Refinement { - existentials, - formula, - } - } - +impl Formula { pub fn new( existentials: IndexVec, - atoms: Vec>>, + body: chc::Body, ) -> Self { - let mut refinement = Refinement { - existentials, - formula: chc::Body::new(atoms, chc::Formula::top()), - }; - refinement.formula.simplify(); - refinement + Formula { existentials, body } } pub fn has_existentials(&self) -> bool { @@ -998,39 +1025,45 @@ impl Refinement { } pub fn is_top(&self) -> bool { - self.formula.is_top() + self.body.is_top() } pub fn is_bottom(&self) -> bool { - self.formula.is_bottom() + self.body.is_bottom() } pub fn top() -> Self { - Refinement::with_formula(IndexVec::new(), chc::Body::top()) + Formula::new(IndexVec::new(), chc::Body::top()) } pub fn bottom() -> Self { - Refinement::with_formula(IndexVec::new(), chc::Body::bottom()) + Formula::new(IndexVec::new(), chc::Body::bottom()) } +} - pub fn extend(&mut self, other: Refinement) { - let Refinement { - existentials, - formula, - } = other; - self.formula - .push_conj(formula.map_var(|v| v.shift_existential(self.existentials.len()))); +impl Formula +where + V: ShiftExistential, +{ + pub fn push_conj(&mut self, other: Self) { + let Formula { existentials, body } = other; + self.body + .push_conj(body.map_var(|v| v.shift_existential(self.existentials.len()))); self.existentials.extend(existentials); - self.formula.simplify(); + self.body.simplify(); } +} +pub type Refinement = Formula>; + +impl Refinement { pub fn subst_var(self, mut f: F) -> Refinement where F: FnMut(FV) -> chc::Term, { Refinement { existentials: self.existentials, - formula: self.formula.subst_var(|v| match v { + body: self.body.subst_var(|v| match v { RefinedTypeVar::Value => chc::Term::var(RefinedTypeVar::Value), RefinedTypeVar::Existential(v) => chc::Term::var(RefinedTypeVar::Existential(v)), RefinedTypeVar::Free(v) => f(v).map_var(RefinedTypeVar::Free), @@ -1044,7 +1077,7 @@ impl Refinement { { Refinement { existentials: self.existentials, - formula: self.formula.subst_var(|v| match v { + body: self.body.subst_var(|v| match v { RefinedTypeVar::Value => f(), RefinedTypeVar::Existential(v) => chc::Term::var(RefinedTypeVar::Existential(v)), RefinedTypeVar::Free(v) => chc::Term::var(RefinedTypeVar::Free(v)), @@ -1058,7 +1091,7 @@ impl Refinement { { Refinement { existentials: self.existentials, - formula: self.formula.map_var(|v| match v { + body: self.body.map_var(|v| match v { RefinedTypeVar::Value => RefinedTypeVar::Value, RefinedTypeVar::Existential(v) => RefinedTypeVar::Existential(v), RefinedTypeVar::Free(v) => RefinedTypeVar::Free(f(v)), @@ -1102,7 +1135,7 @@ impl Instantiator { existentials, refinement, } = self; - refinement.formula.map_var(move |v| match v { + refinement.body.map_var(move |v| match v { RefinedTypeVar::Value => value_var.clone().unwrap(), RefinedTypeVar::Existential(v) => existentials[&v].clone(), RefinedTypeVar::Free(v) => v, @@ -1209,7 +1242,7 @@ impl RefinedType { } pub fn extend_refinement(&mut self, refinement: Refinement) { - self.refinement.extend(refinement); + self.refinement.push_conj(refinement); } pub fn strip_refinement(self) -> Type { @@ -1232,7 +1265,7 @@ impl RefinedType { ty: replacement_ty, refinement, } = rty.clone(); - self.refinement.extend(refinement); + self.refinement.push_conj(refinement); self.ty = replacement_ty; } } diff --git a/src/rty/params.rs b/src/rty/params.rs index d2c0573..26bd289 100644 --- a/src/rty/params.rs +++ b/src/rty/params.rs @@ -84,7 +84,7 @@ impl TypeParamSubst { for (idx, mut t1) in other.subst { t1.subst_ty_params(self); if let Some(t2) = self.subst.remove(&idx) { - t1.refinement.extend(t2.refinement); + t1.refinement.push_conj(t2.refinement); } self.subst.insert(idx, t1); } From 383ccc7a5a65c099520910a701a30eae88af265f Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 17 Aug 2025 15:37:05 +0900 Subject: [PATCH 03/75] Assumption = rty::Formula --- src/analyze/basic_block.rs | 18 ++-- src/refine.rs | 2 +- src/refine/env.rs | 169 ++++++++----------------------------- src/rty.rs | 5 +- 4 files changed, 45 insertions(+), 149 deletions(-) diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 5d71d99..d5cefba 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -13,8 +13,8 @@ use crate::analyze; use crate::chc; use crate::pretty::PrettyDisplayExt as _; use crate::refine::{ - self, BasicBlockType, Env, PlaceType, PlaceTypeVar, TempVarIdx, TemplateTypeGenerator, - UnboundAssumption, UnrefinedTypeGenerator, Var, + self, Assumption, BasicBlockType, Env, PlaceType, PlaceTypeVar, TempVarIdx, + TemplateTypeGenerator, UnrefinedTypeGenerator, Var, }; use crate::rty::{ self, ClauseBuilderExt as _, ClauseScope as _, ShiftExistential as _, Subtyping as _, @@ -341,11 +341,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.ctx.extend_clauses(clauses); } - fn with_assumptions( - &mut self, - assumptions: Vec>, - callback: F, - ) -> T + fn with_assumptions(&mut self, assumptions: Vec>, callback: F) -> T where F: FnOnce(&mut Self) -> T, { @@ -356,7 +352,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { result } - fn with_assumption(&mut self, assumption: impl Into, callback: F) -> T + fn with_assumption(&mut self, assumption: impl Into, callback: F) -> T where F: FnOnce(&mut Self) -> T, { @@ -851,7 +847,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { expected_params: &IndexVec>, ) { let mut param_terms = HashMap::>::new(); - let mut assumption = UnboundAssumption::default(); + let mut assumption = Assumption::default(); let bb_ty = self.basic_block_ty(self.basic_block).clone(); let params = &bb_ty.as_ref().params; @@ -877,7 +873,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } let local_ty = self.env.local_type(local); - assumption.formula.push_conj( + assumption.body.push_conj( local_ty .formula .map_var(|v| v.shift_existential(assumption.existentials.len())), @@ -892,7 +888,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { for (idx, param) in expected_params.iter_enumerated() { let rty::Refinement { existentials, body } = param.refinement.clone(); - assumption.formula.push_conj(body.subst_var(|v| match v { + assumption.body.push_conj(body.subst_var(|v| match v { rty::RefinedTypeVar::Value => param_terms[&idx].clone(), rty::RefinedTypeVar::Free(v) => param_terms[&v].clone(), rty::RefinedTypeVar::Existential(ev) => chc::Term::var(PlaceTypeVar::Existential( diff --git a/src/refine.rs b/src/refine.rs index 73de912..9c058cf 100644 --- a/src/refine.rs +++ b/src/refine.rs @@ -5,7 +5,7 @@ mod basic_block; pub use basic_block::BasicBlockType; mod env; -pub use env::{Env, PlaceType, PlaceTypeVar, TempVarIdx, UnboundAssumption, Var}; +pub use env::{Assumption, Env, PlaceType, PlaceTypeVar, TempVarIdx, Var}; use crate::chc::DatatypeSymbol; use rustc_middle::ty as mir_ty; diff --git a/src/refine/env.rs b/src/refine/env.rs index 35f2999..26e1ad3 100644 --- a/src/refine/env.rs +++ b/src/refine/env.rs @@ -9,7 +9,7 @@ use rustc_target::abi::{FieldIdx, VariantIdx}; use crate::chc; use crate::pretty::PrettyDisplayExt as _; use crate::refine; -use crate::rty; +use crate::rty::{self, ShiftExistential as _}; rustc_index::newtype_index! { #[orderable] @@ -183,6 +183,15 @@ impl From for rty::RefinedTypeVar { } } +impl rty::ShiftExistential for PlaceTypeVar { + fn shift_existential(self, amount: usize) -> Self { + match self { + PlaceTypeVar::Var(v) => PlaceTypeVar::Var(v), + PlaceTypeVar::Existential(ev) => PlaceTypeVar::Existential(ev + amount), + } + } +} + impl std::fmt::Debug for PlaceTypeVar { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { @@ -211,13 +220,6 @@ impl PlaceTypeVar { _ => None, } } - - pub fn shift_existential(self, amount: usize) -> Self { - match self { - PlaceTypeVar::Var(v) => PlaceTypeVar::Var(v), - PlaceTypeVar::Existential(ev) => PlaceTypeVar::Existential(ev + amount), - } - } } #[derive(Debug, Clone)] @@ -300,7 +302,7 @@ impl PlaceType { } } - pub fn into_assumption(self, term_to_atom: F) -> UnboundAssumption + pub fn into_assumption(self, term_to_atom: F) -> Assumption where F: FnOnce(chc::Term) -> chc::Atom, { @@ -311,10 +313,7 @@ impl PlaceType { mut formula, } = self; formula.push_conj(term_to_atom(term)); - UnboundAssumption { - existentials, - formula, - } + Assumption::new(existentials, formula) } pub fn deref(self) -> PlaceType { @@ -477,7 +476,7 @@ impl PlaceType { } } - pub fn merge_into_assumption(self, other: PlaceType, f: F) -> UnboundAssumption + pub fn merge_into_assumption(self, other: PlaceType, f: F) -> Assumption where F: FnOnce(chc::Term, chc::Term) -> chc::Atom, { @@ -500,10 +499,7 @@ impl PlaceType { formula.push_conj(formula2.map_var(|v| v.shift_existential(existentials.len()))); formula.push_conj(atom); existentials.extend(evs2); - UnboundAssumption { - existentials, - formula, - } + Assumption::new(existentials, formula) } pub fn merge(self, other: PlaceType, f: F) -> PlaceType @@ -655,101 +651,14 @@ impl PlaceType { } } -#[derive(Debug, Clone, Default)] -pub struct UnboundAssumption { - pub existentials: IndexVec, - pub formula: chc::Body, -} - -impl From> for UnboundAssumption { - fn from(atom: chc::Atom) -> Self { - let existentials = IndexVec::new(); - let formula = chc::Body::new(vec![atom.map_var(Into::into)], Default::default()); - UnboundAssumption { - existentials, - formula, - } - } -} - -impl Extend for UnboundAssumption { - fn extend(&mut self, iter: T) - where - T: IntoIterator, - { - for assumption in iter { - self.formula.push_conj( - assumption - .formula - .map_var(|v| v.shift_existential(self.existentials.len())), - ); - self.existentials.extend(assumption.existentials); - } - self.formula.simplify(); - } -} - -impl FromIterator for UnboundAssumption { - fn from_iter(iter: T) -> Self - where - T: IntoIterator, - { - let mut result = UnboundAssumption::default(); - result.extend(iter); - result - } -} - -impl<'a, 'b, D> Pretty<'a, D, termcolor::ColorSpec> for &'b UnboundAssumption -where - D: pretty::DocAllocator<'a, termcolor::ColorSpec>, - D::Doc: Clone, -{ - fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, termcolor::ColorSpec> { - let existentials = allocator - .intersperse( - self.existentials - .iter_enumerated() - .map(|(v, s)| v.pretty(allocator).append(allocator.text(":")).append(s)), - allocator.text(",").append(allocator.line()), - ) - .group(); - let formula = self.formula.pretty(allocator); - if self.existentials.is_empty() { - formula - } else { - allocator - .text("∃") - .append(existentials.nest(2)) - .append(allocator.text(".")) - .append(allocator.line().append(formula).nest(2)) - .group() - } - } -} - -impl UnboundAssumption { - pub fn new( - existentials: IndexVec, - formula: chc::Body, - ) -> Self { - UnboundAssumption { - existentials, - formula, - } - } - - pub fn is_top(&self) -> bool { - self.formula.is_top() - } -} +pub type Assumption = rty::Formula; #[derive(Debug, Clone)] pub struct Env { locals: BTreeMap>, flow_locals: BTreeMap, temp_vars: IndexVec, - unbound_assumptions: Vec, + assumptions: Vec, enum_defs: HashMap, @@ -781,13 +690,13 @@ impl rty::ClauseScope for Env { } builder.add_body(formula); } - for assumption in &self.unbound_assumptions { + for assumption in &self.assumptions { let mut evs = HashMap::new(); for (ev, sort) in assumption.existentials.iter_enumerated() { let tv = builder.add_var(sort.clone()); evs.insert(ev, tv); } - let chc::Body { formula, atoms } = assumption.formula.clone().map_var(|v| match v { + let chc::Body { formula, atoms } = assumption.body.clone().map_var(|v| match v { PlaceTypeVar::Var(v) => builder.mapped_var(v), PlaceTypeVar::Existential(ev) => evs[&ev], }); @@ -816,7 +725,7 @@ impl Env { locals: Default::default(), flow_locals: Default::default(), temp_vars: IndexVec::new(), - unbound_assumptions: Vec::new(), + assumptions: Vec::new(), enum_defs, enum_expansion_depth_limit: std::env::var("THRUST_ENUM_EXPANSION_DEPTH_LIMIT") .ok() @@ -934,10 +843,7 @@ impl Env { }); formula.push_conj(tuple_ty.formula); existentials.extend(refinement.existentials); - UnboundAssumption { - existentials, - formula, - } + Assumption::new(existentials, formula) }; self.assume(assumption); let binding = FlowBinding::Tuple(xs.clone()); @@ -982,20 +888,20 @@ impl Env { let mut existentials = refinement.existentials; let value_var_ev = existentials.push(rty::Type::Enum(ty.clone()).to_sort()); - let mut assumption = UnboundAssumption { + let mut assumption = Assumption::new( existentials, - formula: refinement.body.map_var(|v| match v { + refinement.body.map_var(|v| match v { rty::RefinedTypeVar::Value => PlaceTypeVar::Existential(value_var_ev), rty::RefinedTypeVar::Free(v) => PlaceTypeVar::Var(v), rty::RefinedTypeVar::Existential(ev) => PlaceTypeVar::Existential(ev), }), - }; + ); let mut pred_args: Vec<_> = variants .iter() .flat_map(|v| &v.fields) .map(|&x| { let ty = self.var_type(x.into()); - assumption.formula.push_conj( + assumption.body.push_conj( ty.formula .map_var(|v| v.shift_existential(assumption.existentials.len())), ); @@ -1008,7 +914,7 @@ impl Env { .collect(); pred_args.push(chc::Term::var(value_var_ev.into())); assumption - .formula + .body .push_conj(chc::Atom::new(matcher_pred.into(), pred_args)); let discr_var = self .temp_vars @@ -1016,7 +922,7 @@ impl Env { rty::Type::int(), ))); assumption - .formula + .body .push_conj( chc::Term::var(discr_var.into()).equal_to(chc::Term::datatype_discr( def.name.clone(), @@ -1079,14 +985,14 @@ impl Env { tracing::debug!(local = ?local, rty = %rty_disp.display(), place_type = %self.local_type(local).display(), "immut_bind"); } - pub fn assume(&mut self, assumption: impl Into) { + pub fn assume(&mut self, assumption: impl Into) { let assumption = assumption.into(); tracing::debug!(assumption = %assumption.display(), "assume"); - self.unbound_assumptions.push(assumption); + self.assumptions.push(assumption); } - pub fn extend_assumptions(&mut self, assumptions: Vec>) { - self.unbound_assumptions + pub fn extend_assumptions(&mut self, assumptions: Vec>) { + self.assumptions .extend(assumptions.into_iter().map(Into::into)); } @@ -1343,7 +1249,7 @@ impl Env { self.path_type(&place.into()) } - fn dropping_assumption(&mut self, path: &Path) -> UnboundAssumption { + fn dropping_assumption(&mut self, path: &Path) -> Assumption { let ty = self.path_type(path); if ty.ty.is_mut() { ty.into_assumption(|term| { @@ -1404,24 +1310,21 @@ impl Env { } }; - let UnboundAssumption { + let Assumption { existentials: assumption_existentials, - formula: assumption_formula, + body: assumption_body, } = self.dropping_assumption(&Path::PlaceTy(field_pty)); // dropping assumption should not generate any existential assert!(assumption_existentials.is_empty()); - formula.push_conj(assumption_formula); + formula.push_conj(assumption_body); } pred_args.push(term); formula.push_conj(chc::Atom::new(matcher_pred.into(), pred_args)); - UnboundAssumption { - existentials, - formula, - } + Assumption::new(existentials, formula) } else { - UnboundAssumption::default() + Assumption::default() } } diff --git a/src/rty.rs b/src/rty.rs index 57e32cd..c8e6e9e 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -1009,10 +1009,7 @@ where } impl Formula { - pub fn new( - existentials: IndexVec, - body: chc::Body, - ) -> Self { + pub fn new(existentials: IndexVec, body: chc::Body) -> Self { Formula { existentials, body } } From 821cab2f275d09cd63220b1314c0c30f39bd109e Mon Sep 17 00:00:00 2001 From: coord_e Date: Tue, 26 Aug 2025 22:54:22 +0900 Subject: [PATCH 04/75] Introduce PlaceTypeBuilder to simplify PlaceType construction --- src/analyze/basic_block.rs | 150 +++++++------- src/refine.rs | 2 +- src/refine/env.rs | 401 ++++++++++++------------------------- 3 files changed, 206 insertions(+), 347 deletions(-) diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index d5cefba..4d01eff 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -13,7 +13,7 @@ use crate::analyze; use crate::chc; use crate::pretty::PrettyDisplayExt as _; use crate::refine::{ - self, Assumption, BasicBlockType, Env, PlaceType, PlaceTypeVar, TempVarIdx, + self, Assumption, BasicBlockType, Env, PlaceType, PlaceTypeBuilder, PlaceTypeVar, TempVarIdx, TemplateTypeGenerator, UnrefinedTypeGenerator, Var, }; use crate::rty::{ @@ -138,12 +138,15 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { Rvalue::Use(operand) => self.operand_type(operand), Rvalue::UnaryOp(op, operand) => { let operand_ty = self.operand_type(operand); - match (&operand_ty.ty, op) { + + let mut builder = PlaceTypeBuilder::default(); + let (operand_ty, operand_term) = builder.subsume(operand_ty); + match (&operand_ty, op) { (rty::Type::Bool, mir::UnOp::Not) => { - operand_ty.replace(|_, term| (rty::Type::Bool, term.not())) + builder.build(rty::Type::Bool, operand_term.not()) } (rty::Type::Int, mir::UnOp::Neg) => { - operand_ty.replace(|_, term| (rty::Type::Int, term.neg())) + builder.build(rty::Type::Int, operand_term.neg()) } _ => unimplemented!("ty={}, op={:?}", operand_ty.display(), op), } @@ -152,51 +155,47 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let (lhs, rhs) = *operands; let lhs_ty = self.operand_type(lhs); let rhs_ty = self.operand_type(rhs); - match (&lhs_ty.ty, op) { - (rty::Type::Int, mir::BinOp::Add) => lhs_ty - .merge(rhs_ty, |(lhs_ty, lhs_term), (_, rhs_term)| { - (lhs_ty, lhs_term.add(rhs_term)) - }), - (rty::Type::Int, mir::BinOp::Sub) => lhs_ty - .merge(rhs_ty, |(lhs_ty, lhs_term), (_, rhs_term)| { - (lhs_ty, lhs_term.sub(rhs_term)) - }), - (rty::Type::Int, mir::BinOp::Mul) => lhs_ty - .merge(rhs_ty, |(lhs_ty, lhs_term), (_, rhs_term)| { - (lhs_ty, lhs_term.mul(rhs_term)) - }), - (rty::Type::Int | rty::Type::Bool, mir::BinOp::Ge) => lhs_ty - .merge(rhs_ty, |(_, lhs_term), (_, rhs_term)| { - (rty::Type::Bool, lhs_term.ge(rhs_term)) - }), - (rty::Type::Int | rty::Type::Bool, mir::BinOp::Gt) => lhs_ty - .merge(rhs_ty, |(_, lhs_term), (_, rhs_term)| { - (rty::Type::Bool, lhs_term.gt(rhs_term)) - }), - (rty::Type::Int | rty::Type::Bool, mir::BinOp::Le) => lhs_ty - .merge(rhs_ty, |(_, lhs_term), (_, rhs_term)| { - (rty::Type::Bool, lhs_term.le(rhs_term)) - }), - (rty::Type::Int | rty::Type::Bool, mir::BinOp::Lt) => lhs_ty - .merge(rhs_ty, |(_, lhs_term), (_, rhs_term)| { - (rty::Type::Bool, lhs_term.lt(rhs_term)) - }), - (rty::Type::Int | rty::Type::Bool, mir::BinOp::Eq) => lhs_ty - .merge(rhs_ty, |(_, lhs_term), (_, rhs_term)| { - (rty::Type::Bool, lhs_term.eq(rhs_term)) - }), - (rty::Type::Int | rty::Type::Bool, mir::BinOp::Ne) => lhs_ty - .merge(rhs_ty, |(_, lhs_term), (_, rhs_term)| { - (rty::Type::Bool, lhs_term.ne(rhs_term)) - }), + + let mut builder = PlaceTypeBuilder::default(); + let (lhs_ty, lhs_term) = builder.subsume(lhs_ty); + let (_rhs_ty, rhs_term) = builder.subsume(rhs_ty); + match (&lhs_ty, op) { + (rty::Type::Int, mir::BinOp::Add) => { + builder.build(lhs_ty, rhs_term.add(lhs_term)) + } + (rty::Type::Int, mir::BinOp::Sub) => { + builder.build(lhs_ty, lhs_term.sub(rhs_term)) + } + (rty::Type::Int, mir::BinOp::Mul) => { + builder.build(lhs_ty, lhs_term.mul(rhs_term)) + } + (rty::Type::Int | rty::Type::Bool, mir::BinOp::Ge) => { + builder.build(rty::Type::Bool, lhs_term.ge(rhs_term)) + } + (rty::Type::Int | rty::Type::Bool, mir::BinOp::Gt) => { + builder.build(rty::Type::Bool, lhs_term.gt(rhs_term)) + } + (rty::Type::Int | rty::Type::Bool, mir::BinOp::Le) => { + builder.build(rty::Type::Bool, lhs_term.le(rhs_term)) + } + (rty::Type::Int | rty::Type::Bool, mir::BinOp::Lt) => { + builder.build(rty::Type::Bool, lhs_term.lt(rhs_term)) + } + (rty::Type::Int | rty::Type::Bool, mir::BinOp::Eq) => { + builder.build(rty::Type::Bool, lhs_term.eq(rhs_term)) + } + (rty::Type::Int | rty::Type::Bool, mir::BinOp::Ne) => { + builder.build(rty::Type::Bool, lhs_term.ne(rhs_term)) + } _ => unimplemented!("ty={}, op={:?}", lhs_ty.display(), op), } } Rvalue::Ref(_, mir::BorrowKind::Shared, place) => { let ty = self.env.place_type(self.elaborate_place(&place)); - ty.replace(|ty, term| { - (rty::PointerType::immut_to(ty).into(), chc::Term::box_(term)) - }) + + let mut builder = PlaceTypeBuilder::default(); + let (ty, term) = builder.subsume(ty); + builder.build(rty::PointerType::immut_to(ty).into(), chc::Term::box_(term)) } Rvalue::Aggregate(kind, fields) => { // elaboration: all fields are boxed @@ -240,12 +239,16 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let sort_args: Vec<_> = params.iter().map(|rty| rty.ty.to_sort()).collect(); let ty = rty::EnumType::new(ty_sym.clone(), params).into(); - PlaceType::aggregate( - field_tys, - |_| ty, - |fields_term| { - chc::Term::datatype_ctor(ty_sym, sort_args, v_sym, fields_term) - }, + + let mut builder = PlaceTypeBuilder::default(); + let mut field_terms = Vec::new(); + for field_ty in field_tys { + let (_, field_term) = builder.subsume(field_ty); + field_terms.push(field_term); + } + builder.build( + ty, + chc::Term::datatype_ctor(ty_sym, sort_args, v_sym, field_terms), ) } _ => PlaceType::tuple(field_tys), @@ -276,7 +279,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .expect("discriminant of non-enum") .symbol .clone(); - ty.replace(|_ty, term| (rty::Type::Int, chc::Term::datatype_discr(sym, term))) + + let mut builder = PlaceTypeBuilder::default(); + let (_, term) = builder.subsume(ty); + builder.build(rty::Type::Int, chc::Term::datatype_discr(sym, term)) } _ => unimplemented!( "rvalue={:?} ({:?})", @@ -383,20 +389,23 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { _ => unimplemented!(), }; - self.with_assumption( - discr_ty - .clone() - .into_assumption(|term| term.equal_to(target_term.clone())), - |ecx| { - callback(ecx, bb); - ecx.type_goto(bb, expected_ret); - }, - ); - negations.push( - discr_ty - .clone() - .into_assumption(|term| term.not_equal_to(target_term)), - ); + let pos_assumption = { + let mut builder = PlaceTypeBuilder::default(); + let (_, discr_term) = builder.subsume(discr_ty.clone()); + builder.push_formula(discr_term.equal_to(target_term.clone())); + builder.build_assumption() + }; + self.with_assumption(pos_assumption, |ecx| { + callback(ecx, bb); + ecx.type_goto(bb, expected_ret); + }); + let neg_assumption = { + let mut builder = PlaceTypeBuilder::default(); + let (_, discr_term) = builder.subsume(discr_ty.clone()); + builder.push_formula(discr_term.not_equal_to(target_term.clone())); + builder.build_assumption() + }; + negations.push(neg_assumption); } self.with_assumptions(negations, |ecx| { callback(ecx, targets.otherwise()); @@ -508,11 +517,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let local_ty = self.env.local_type(local); let rvalue_ty = self.rvalue_type(rvalue); if !rvalue_ty.ty.to_sort().is_singleton() { - self.env.assume( - local_ty.merge_into_assumption(rvalue_ty, |local_term, rvalue_term| { - local_term.mut_final().equal_to(rvalue_term) - }), - ); + let mut builder = PlaceTypeBuilder::default(); + let (_, local_term) = builder.subsume(local_ty); + let (_, rvalue_term) = builder.subsume(rvalue_ty); + builder.push_formula(local_term.mut_final().equal_to(rvalue_term)); + let assumption = builder.build_assumption(); + self.env.assume(assumption); } } diff --git a/src/refine.rs b/src/refine.rs index 9c058cf..60f29f4 100644 --- a/src/refine.rs +++ b/src/refine.rs @@ -5,7 +5,7 @@ mod basic_block; pub use basic_block::BasicBlockType; mod env; -pub use env::{Assumption, Env, PlaceType, PlaceTypeVar, TempVarIdx, Var}; +pub use env::{Assumption, Env, PlaceType, PlaceTypeBuilder, PlaceTypeVar, TempVarIdx, Var}; use crate::chc::DatatypeSymbol; use rustc_middle::ty as mir_ty; diff --git a/src/refine/env.rs b/src/refine/env.rs index 26e1ad3..e647765 100644 --- a/src/refine/env.rs +++ b/src/refine/env.rs @@ -222,6 +222,75 @@ impl PlaceTypeVar { } } +/// A builder for `PlaceType` and `Assumption`. +/// +/// We often combine multiple [`PlaceType`]s and [`rty::RefinedType`]s into a single one in order to +/// construct another [`PlaceType`], while retaining formulas from each. [`PlaceTypeBuilder`] helps +/// this by properly managing existential variable indices. +#[derive(Debug, Clone, Default)] +pub struct PlaceTypeBuilder { + existentials: IndexVec, + formula: chc::Body, +} + +impl PlaceTypeBuilder { + pub fn subsume(&mut self, pty: PlaceType) -> (rty::Type, chc::Term) { + let PlaceType { + ty, + existentials, + term, + formula, + } = pty; + self.formula + .push_conj(formula.map_var(|v| v.shift_existential(self.existentials.len()))); + let term = term.map_var(|v| v.shift_existential(self.existentials.len())); + self.existentials.extend(existentials); + (ty, term) + } + + pub fn subsume_rty( + &mut self, + rty: rty::RefinedType, + ) -> (rty::Type, rty::ExistentialVarIdx) { + let rty::RefinedType { ty, refinement } = rty; + let rty::Refinement { existentials, body } = refinement; + let value_var_ex = self.existentials.push(ty.to_sort()); + self.formula.push_conj(body.map_var(|v| match v { + rty::RefinedTypeVar::Value => PlaceTypeVar::Existential(value_var_ex), + rty::RefinedTypeVar::Existential(ev) => { + PlaceTypeVar::Existential(ev + self.existentials.len()) + } + rty::RefinedTypeVar::Free(v) => PlaceTypeVar::Var(v), + })); + self.existentials.extend(existentials); + (ty, value_var_ex) + } + + pub fn push_formula(&mut self, formula: impl Into>) { + self.formula.push_conj(formula); + } + + pub fn push_existential(&mut self, sort: chc::Sort) -> rty::ExistentialVarIdx { + self.existentials.push(sort) + } + + pub fn build(self, ty: rty::Type, term: chc::Term) -> PlaceType { + PlaceType { + ty, + existentials: self.existentials, + term, + formula: self.formula, + } + } + + pub fn build_assumption(self) -> Assumption { + Assumption { + existentials: self.existentials, + body: self.formula, + } + } +} + #[derive(Debug, Clone)] pub struct PlaceType { pub ty: rty::Type, @@ -302,82 +371,26 @@ impl PlaceType { } } - pub fn into_assumption(self, term_to_atom: F) -> Assumption - where - F: FnOnce(chc::Term) -> chc::Atom, - { - let PlaceType { - ty: _, - existentials, - term, - mut formula, - } = self; - formula.push_conj(term_to_atom(term)); - Assumption::new(existentials, formula) - } - pub fn deref(self) -> PlaceType { - let PlaceType { - ty: inner_ty, - mut existentials, - term: inner_term, - mut formula, - } = self; + let mut builder = PlaceTypeBuilder::default(); + let (inner_ty, inner_term) = builder.subsume(self); let inner_ty = inner_ty.into_pointer().unwrap(); - let rty::RefinedType { ty, refinement } = *inner_ty.elem; - let rty::Refinement { - existentials: inner_existentials, - body: inner_formula, - } = refinement; - let value_var_ex = existentials.push(ty.to_sort()); + let (ty, value_var_ex) = builder.subsume_rty(*inner_ty.elem); + let term = chc::Term::var(value_var_ex.into()); - formula.push_conj(term.clone().equal_to(inner_ty.kind.deref_term(inner_term))); - formula.push_conj(inner_formula.map_var(|v| match v { - rty::RefinedTypeVar::Value => PlaceTypeVar::Existential(value_var_ex), - rty::RefinedTypeVar::Existential(ev) => { - PlaceTypeVar::Existential(ev + existentials.len()) - } - rty::RefinedTypeVar::Free(v) => PlaceTypeVar::Var(v), - })); - existentials.extend(inner_existentials); - PlaceType { - ty, - existentials, - term, - formula, - } + builder.push_formula(term.clone().equal_to(inner_ty.kind.deref_term(inner_term))); + builder.build(ty, term) } pub fn tuple_proj(self, idx: usize) -> PlaceType { - let PlaceType { - ty: inner_ty, - mut existentials, - term: inner_term, - mut formula, - } = self; + let mut builder = PlaceTypeBuilder::default(); + let (inner_ty, inner_term) = builder.subsume(self); let inner_ty = inner_ty.into_tuple().unwrap(); - let rty::RefinedType { ty, refinement } = inner_ty.elems[idx].clone(); - let rty::Refinement { - existentials: inner_existentials, - body: inner_formula, - } = refinement; - let value_var_ex = existentials.push(ty.to_sort()); + let (ty, value_var_ex) = builder.subsume_rty(inner_ty.elems[idx].clone()); + let term = chc::Term::var(value_var_ex.into()); - formula.push_conj(term.clone().equal_to(inner_term.tuple_proj(idx))); - formula.push_conj(inner_formula.map_var(|v| match v { - rty::RefinedTypeVar::Value => PlaceTypeVar::Existential(value_var_ex), - rty::RefinedTypeVar::Existential(ev) => { - PlaceTypeVar::Existential(ev + existentials.len()) - } - rty::RefinedTypeVar::Free(v) => PlaceTypeVar::Var(v), - })); - existentials.extend(inner_existentials); - PlaceType { - ty, - existentials, - term, - formula, - } + builder.push_formula(term.clone().equal_to(inner_term.tuple_proj(idx))); + builder.build(ty, term) } pub fn downcast( @@ -386,12 +399,8 @@ impl PlaceType { field_idx: FieldIdx, enum_defs: &HashMap, ) -> PlaceType { - let PlaceType { - ty: inner_ty, - mut existentials, - term: inner_term, - mut formula, - } = self; + let mut builder = PlaceTypeBuilder::default(); + let (inner_ty, inner_term) = builder.subsume(self); let inner_ty = inner_ty.into_enum().unwrap(); let def = &enum_defs[&inner_ty.symbol]; let variant = &def.variants[variant_idx]; @@ -401,23 +410,13 @@ impl PlaceType { for field_ty in variant.field_tys.clone() { let mut rty = rty::RefinedType::unrefined(field_ty.vacuous()); rty.instantiate_ty_params(inner_ty.args.clone()); - let rty::RefinedType { ty, refinement } = rty.boxed(); + let (ty, field_ex_var) = builder.subsume_rty(rty.boxed()); - let field_ex_var = existentials.push(ty.to_sort()); field_terms.push(chc::Term::var(field_ex_var.into())); field_tys.push(ty); - - formula.push_conj(refinement.body.map_var(|v| match v { - rty::RefinedTypeVar::Value => PlaceTypeVar::Existential(field_ex_var), - rty::RefinedTypeVar::Existential(ev) => { - PlaceTypeVar::Existential(ev + existentials.len()) - } - rty::RefinedTypeVar::Free(v) => PlaceTypeVar::Var(v), - })); - existentials.extend(refinement.existentials); } - formula.push_conj( + builder.push_formula( chc::Term::datatype_ctor( def.name.clone(), inner_ty.arg_sorts(), @@ -429,113 +428,15 @@ impl PlaceType { let ty = field_tys[field_idx.index()].clone(); let term = field_terms[field_idx.index()].clone(); - PlaceType { - ty, - existentials, - term, - formula, - } + builder.build(ty, term) } pub fn boxed(self) -> PlaceType { - let PlaceType { - ty: inner_ty, - existentials, - term: inner_term, - formula, - } = self; + let mut builder = PlaceTypeBuilder::default(); + let (inner_ty, inner_term) = builder.subsume(self); let term = chc::Term::box_(inner_term); let ty = rty::PointerType::own(inner_ty).into(); - PlaceType { - ty, - existentials, - term, - formula, - } - } - - pub fn replace(self, f: F) -> PlaceType - where - F: FnOnce( - rty::Type, - chc::Term, - ) -> (rty::Type, chc::Term), - { - let PlaceType { - ty, - existentials, - term, - formula, - } = self; - let (ty, term) = f(ty, term); - PlaceType { - ty, - existentials, - term, - formula, - } - } - - pub fn merge_into_assumption(self, other: PlaceType, f: F) -> Assumption - where - F: FnOnce(chc::Term, chc::Term) -> chc::Atom, - { - let PlaceType { - ty: _ty1, - mut existentials, - term: term1, - mut formula, - } = self; - let PlaceType { - ty: _ty2, - existentials: evs2, - term: term2, - formula: formula2, - } = other; - let atom = f( - term1, - term2.map_var(|v| v.shift_existential(existentials.len())), - ); - formula.push_conj(formula2.map_var(|v| v.shift_existential(existentials.len()))); - formula.push_conj(atom); - existentials.extend(evs2); - Assumption::new(existentials, formula) - } - - pub fn merge(self, other: PlaceType, f: F) -> PlaceType - where - F: FnOnce( - (rty::Type, chc::Term), - (rty::Type, chc::Term), - ) -> (rty::Type, chc::Term), - { - let PlaceType { - ty: ty1, - mut existentials, - term: term1, - mut formula, - } = self; - let PlaceType { - ty: ty2, - existentials: evs2, - term: term2, - formula: formula2, - } = other; - let (ty, term) = f( - (ty1, term1), - ( - ty2, - term2.map_var(|v| v.shift_existential(existentials.len())), - ), - ); - formula.push_conj(formula2.map_var(|v| v.shift_existential(existentials.len()))); - existentials.extend(evs2); - PlaceType { - ty, - existentials, - term, - formula, - } + builder.build(ty, term) } pub fn mut_with_proph_term(self, proph: chc::Term) -> PlaceType { @@ -544,58 +445,29 @@ impl PlaceType { } pub fn mut_with(self, proph: PlaceType) -> PlaceType { - self.merge(proph, |(ty1, term1), (_, term2)| - // TODO: check ty1 = ty2 - (rty::PointerType::mut_to(ty1).into(), chc::Term::mut_(term1, term2))) - } - - pub fn aggregate(ptys: I, make_ty: F, make_term: G) -> PlaceType - where - I: IntoIterator, - F: FnOnce(Vec>) -> rty::Type, - G: FnOnce(Vec>) -> chc::Term, - { - #[derive(Default)] - struct State { - tys: Vec>, - terms: Vec>, - existentials: IndexVec, - formula: chc::Body, - } - let State { - tys, - terms, - existentials, - formula, - } = ptys - .into_iter() - .fold(Default::default(), |mut st: State, ty| { - let PlaceType { - ty, - existentials, - term, - formula, - } = ty; - st.tys.push(ty); - st.terms - .push(term.map_var(|v| v.shift_existential(st.existentials.len()))); - st.formula - .push_conj(formula.map_var(|v| v.shift_existential(st.existentials.len()))); - st.existentials.extend(existentials); - st - }); - let ty = make_ty(tys); - let term = make_term(terms); - PlaceType { - ty, - existentials, - term, - formula, + let mut builder = PlaceTypeBuilder::default(); + let (ty1, term1) = builder.subsume(self); + let (_ty2, term2) = builder.subsume(proph); + // TODO: check ty1 = ty2 + let ty = rty::PointerType::mut_to(ty1).into(); + let term = chc::Term::mut_(term1, term2); + builder.build(ty, term) + } + + pub fn tuple(ptys: Vec) -> PlaceType { + let mut builder = PlaceTypeBuilder::default(); + let mut tys = Vec::new(); + let mut terms = Vec::new(); + + for ty in ptys { + let (ty, term) = builder.subsume(ty); + tys.push(ty); + terms.push(term); } - } - pub fn tuple(tys: Vec) -> PlaceType { - PlaceType::aggregate(tys, |tys| rty::TupleType::new(tys).into(), chc::Term::tuple) + let ty = rty::TupleType::new(tys); + let term = chc::Term::tuple(terms); + builder.build(ty.into(), term) } pub fn enum_( @@ -604,50 +476,27 @@ impl PlaceType { discr: TempVarIdx, field_tys: Vec, ) -> PlaceType { - #[derive(Default)] - struct State { - existentials: IndexVec, - terms: Vec>, - formula: chc::Body, + let mut builder = PlaceTypeBuilder::default(); + let mut terms = Vec::new(); + + for ty in field_tys { + let (_, term) = builder.subsume(ty); + terms.push(term); } - let State { - mut existentials, - terms, - mut formula, - } = field_tys - .into_iter() - .fold(Default::default(), |mut st: State, ty| { - let PlaceType { - ty: _, - existentials, - term, - formula, - } = ty; - st.terms - .push(term.map_var(|v| v.shift_existential(st.existentials.len()))); - st.formula - .push_conj(formula.map_var(|v| v.shift_existential(st.existentials.len()))); - st.existentials.extend(existentials); - st - }); + let ty: rty::Type<_> = enum_ty.clone().into(); - let value_var_ev = existentials.push(ty.to_sort()); + let value_var_ev = builder.push_existential(ty.to_sort()); let term = chc::Term::var(value_var_ev.into()); let mut pred_args = terms; pred_args.push(chc::Term::var(value_var_ev.into())); - formula.push_conj(chc::Atom::new(matcher_pred, pred_args)); - formula.push_conj( + builder.push_formula(chc::Atom::new(matcher_pred, pred_args)); + builder.push_formula( chc::Term::var(discr.into()).equal_to(chc::Term::datatype_discr( enum_ty.symbol.clone(), chc::Term::var(value_var_ev.into()), )), ); - PlaceType { - ty, - existentials, - term, - formula, - } + builder.build(ty, term) } } @@ -1252,10 +1101,10 @@ impl Env { fn dropping_assumption(&mut self, path: &Path) -> Assumption { let ty = self.path_type(path); if ty.ty.is_mut() { - ty.into_assumption(|term| { - let term = term.map_var(Into::into); - term.clone().mut_final().equal_to(term.mut_current()) - }) + let mut builder = PlaceTypeBuilder::default(); + let (_, term) = builder.subsume(ty); + builder.push_formula(term.clone().mut_final().equal_to(term.mut_current())); + builder.build_assumption() } else if ty.ty.is_own() { self.dropping_assumption(&path.clone().deref()) } else if let Some(tty) = ty.ty.as_tuple() { From 24d307c5ae0f034205bceb5518499ffba1e65628 Mon Sep 17 00:00:00 2001 From: coord_e Date: Thu, 28 Aug 2025 22:30:59 +0900 Subject: [PATCH 05/75] Add documentation to src/annot.rs --- src/annot.rs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/annot.rs b/src/annot.rs index 0cd6ca4..d00d3a8 100644 --- a/src/annot.rs +++ b/src/annot.rs @@ -1,3 +1,13 @@ +//! A parser for refinement type and formula annotations. +//! +//! This module provides a parser for `#[thrust::...]` attributes in Thrust. The parser is +//! parameterized by the [`Resolver`] trait, which abstracts over the resolution of variable +//! names. This allows the parser to be used in different contexts with different sets of +//! available variables. +//! +//! The main entry point is [`AnnotParser`], which parses a [`TokenStream`] into a +//! [`rty::RefinedType`] or a [`chc::Formula`]. + use rustc_ast::token::{BinOpToken, Delimiter, LitKind, Token, TokenKind}; use rustc_ast::tokenstream::{RefTokenTreeCursor, Spacing, TokenStream, TokenTree}; use rustc_index::IndexVec; @@ -7,6 +17,9 @@ use crate::chc; use crate::pretty::PrettyDisplayExt as _; use crate::rty; +/// A formula in an annotation. +/// +/// This can be either a logical formula or the special value `auto` which tells Thrust to infer it. #[derive(Debug, Clone)] pub enum AnnotFormula { Auto, @@ -19,6 +32,7 @@ impl AnnotFormula { } } +/// A trait for resolving variables in annotations to their logical representation and their sorts. pub trait Resolver { type Output; fn resolve(&self, ident: Ident) -> Option<(Self::Output, chc::Sort)>; @@ -56,6 +70,7 @@ pub trait ResolverExt: Resolver { impl ResolverExt for T where T: Resolver {} +/// An error that can occur when parsing an attribute. #[derive(Debug, Clone, thiserror::Error)] pub enum ParseAttrError { #[error("unexpected end of input (expected {expected:?})")] @@ -96,6 +111,7 @@ impl ParseAttrError { type Result = std::result::Result; +/// A parser for refinement type annotations and formula annotations. struct Parser<'a, T> { resolver: T, cursor: RefTokenTreeCursor<'a>, @@ -610,6 +626,7 @@ where } } +/// A [`Resolver`] implementation for resolving specific variable as [`rty::RefinedTypeVar::Value`]. struct RefinementResolver<'a, T> { resolver: Box + 'a>, self_: Option<(Ident, chc::Sort)>, @@ -643,6 +660,7 @@ impl<'a, T> RefinementResolver<'a, T> { } } +/// A [`Resolver`] that maps the output of another [`Resolver`]. pub struct MappedResolver<'a, T, F> { resolver: Box + 'a>, map: F, @@ -669,6 +687,9 @@ impl<'a, T, F> MappedResolver<'a, T, F> { } } +/// A [`Resolver`] that stacks multiple [`Resolver`]s. +/// +/// This resolver tries to resolve an identifier by querying a list of resolvers in order. pub struct StackedResolver<'a, T> { resolvers: Vec + 'a>>, } @@ -698,6 +719,7 @@ impl<'a, T> StackedResolver<'a, T> { } } +/// A parser for annotations. #[derive(Debug, Clone)] pub struct AnnotParser { resolver: T, From ac1e098f1af2435eebe01e8d1e023d4a3c01e511 Mon Sep 17 00:00:00 2001 From: coord_e Date: Fri, 29 Aug 2025 01:15:00 +0900 Subject: [PATCH 06/75] Add documentation to src/chc.rs --- src/chc.rs | 77 ++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 75 insertions(+), 2 deletions(-) diff --git a/src/chc.rs b/src/chc.rs index f3eb058..8a3309f 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -1,4 +1,5 @@ -/// Multi-sorted CHC system with tuples. +//! A multi-sorted CHC system with tuples. + use pretty::{termcolor, Pretty}; use rustc_index::IndexVec; @@ -17,6 +18,7 @@ pub use debug::DebugInfo; pub use solver::{CheckSatError, Config}; pub use unbox::unbox; +/// A name of a datatype. #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct DatatypeSymbol { inner: String, @@ -43,6 +45,10 @@ impl DatatypeSymbol { } } +/// A datatype sort. +/// +/// A datatype sort is a sort that is defined by a datatype. It has a symbol and a list of +/// argument sorts. #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct DatatypeSort { symbol: DatatypeSymbol, @@ -74,6 +80,7 @@ impl DatatypeSort { } } +/// A sort is the type of a logical term. #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum Sort { Null, @@ -271,6 +278,10 @@ impl Sort { } rustc_index::newtype_index! { + /// An index representing term-level variable. + /// + /// We manage term-level variables using indices that are unique in each clause. + /// [`Clause`] contains `IndexVec` that manages the indices and the sorts of the variables. #[orderable] #[debug_format = "v{}"] pub struct TermVarIdx { } @@ -297,6 +308,7 @@ impl TermVarIdx { } } +/// A known function symbol. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct Function { name: &'static str, @@ -376,6 +388,7 @@ impl Function { pub const NEG: Function = Function::new("-"); } +/// A logical term. #[derive(Debug, Clone)] pub enum Term { Null, @@ -693,6 +706,11 @@ impl Term { } rustc_index::newtype_index! { + /// An index representing predicate variables. + /// + /// We manage predicate variables using indices that are unique in a CHC system. + /// [`System`] contains `IndexVec` that manages the indices + /// and signatures of the predicate variables. #[debug_format = "p{}"] pub struct PredVarId { } } @@ -720,6 +738,9 @@ impl PredVarId { } } +/// A known predicate. +/// +/// A known predicate is a predicate that has a fixed meaning, such as `true`, `false`, `=`, etc. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct KnownPred { name: &'static str, @@ -804,6 +825,25 @@ impl KnownPred { pub const GREATER_THAN: KnownPred = KnownPred::infix(">"); } +/// A matcher predicate. +/// +/// A matcher predicate is a predicate that relates a value of a datatype with its contents. +/// For example, given the following `enum` datatype: +/// +/// ```rust +/// pub enum X { +/// A(i64), +/// B(bool), +/// } +/// ``` +/// +/// The corresponding matcher predicate is defined as: +/// +/// ```smtlib2 +/// (define-fun matcher_pred ((x0 Int) (x1 Bool) (v X)) Bool (or (= v (X.A x0)) (= v (X.B x1)))) +/// ``` +/// +/// See the implementation in the [`smtlib2`] module for the details. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct MatcherPred { datatype_symbol: DatatypeSymbol, @@ -851,6 +891,7 @@ impl MatcherPred { } } +/// A predicate. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Pred { Known(KnownPred), @@ -942,6 +983,7 @@ impl Pred { } } +/// An atom is a predicate applied to a list of terms. #[derive(Debug, Clone)] pub struct Atom { pub pred: Pred, @@ -1030,6 +1072,11 @@ impl Atom { } } +/// An arbitrary formula built with atoms and logical connectives. +/// +/// While it allows arbitrary [`Atom`] in its `Atom` variant, we only expect atoms with known +/// predicates (i.e., predicates other than `Pred::Var`) to appear in formulas. It is our TODO to +/// enforce this restriction statically. Also see the definition of [`Body`]. #[derive(Debug, Clone)] pub enum Formula { Atom(Atom), @@ -1248,10 +1295,11 @@ impl Formula { } } +/// The body part of a clause, consisting of atoms and a formula. #[derive(Debug, Clone)] pub struct Body { pub atoms: Vec>, - // `formula` doesn't contain PredVar + /// NOTE: This doesn't contain predicate variables. Also see [`Formula`]. pub formula: Formula, } @@ -1391,6 +1439,10 @@ where } } +/// A constrained Horn clause. +/// +/// A constrained Horn clause is a formula of the form `∀vars. body ⇒ head` where `body` is a conjunction of +/// atoms and underlying logical formula, and `head` is an atom. #[derive(Debug, Clone)] pub struct Clause { pub vars: IndexVec, @@ -1441,12 +1493,18 @@ impl Clause { } } +/// A selector for a datatype constructor. +/// +/// A selector is a function that extracts a field from a datatype value. +/// Through currently we don't use selectors to access datatype fields in [`Term`]s, +/// we require the symbol as selector name to emit SMT-LIB2 representation of datatypes. #[derive(Debug, Clone)] pub struct DatatypeSelector { pub symbol: DatatypeSymbol, pub sort: Sort, } +/// A datatype constructor. #[derive(Debug, Clone)] pub struct DatatypeCtor { pub symbol: DatatypeSymbol, @@ -1454,6 +1512,7 @@ pub struct DatatypeCtor { pub discriminant: u32, } +/// A datatype definition. #[derive(Debug, Clone)] pub struct Datatype { pub symbol: DatatypeSymbol, @@ -1462,18 +1521,25 @@ pub struct Datatype { } rustc_index::newtype_index! { + /// An index of [`Clause`]. + /// + /// [`System`] contains `IndexVec` that manages the indices and the clauses. #[debug_format = "c{}"] pub struct ClauseId { } } pub type PredSig = Vec; +/// A predicate variable definition. #[derive(Debug, Clone)] pub struct PredVarDef { pub sig: PredSig, + /// We may attach a debug information to include in the resulting SMT-LIB2 file to indicate the + /// origin of predicate variables. pub debug_info: DebugInfo, } +/// A CHC system. #[derive(Debug, Clone, Default)] pub struct System { pub datatypes: Vec, @@ -1498,6 +1564,13 @@ impl System { smtlib2::System::new(self) } + /// Solves the CHC using an external SMT solver. + /// + /// This method first performs some optimization of the CHC, + /// then formats it to SMT-LIB2, and finally calls the configured CHC solver. + /// The solver and its arguments can be configured using environment + /// variables + /// (see ). pub fn solve(&self) -> Result<(), CheckSatError> { let system = unbox(self.clone()); if let Ok(file) = std::env::var("THRUST_PRETTY_OUTPUT") { From 7c10c7b8387d59c40d6d43fc97deb342c6cade26 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sat, 30 Aug 2025 12:16:19 +0900 Subject: [PATCH 07/75] Add documentation to src/chc/*.rs --- src/chc/clause_builder.rs | 20 ++++++++++++++++++++ src/chc/debug.rs | 8 ++++++++ src/chc/format_context.rs | 15 +++++++++++++++ src/chc/hoice.rs | 7 ++++++- src/chc/smtlib2.rs | 22 ++++++++++++++++++++++ src/chc/solver.rs | 15 +++++++++++++++ src/chc/unbox.rs | 7 +++++++ 7 files changed, 93 insertions(+), 1 deletion(-) diff --git a/src/chc/clause_builder.rs b/src/chc/clause_builder.rs index 0aec55b..1eb78ca 100644 --- a/src/chc/clause_builder.rs +++ b/src/chc/clause_builder.rs @@ -1,3 +1,11 @@ +//! A builder for [`Clause`]s. +//! +//! This module provides [`ClauseBuilder`], a helper for constructing [`Clause`]s. It is +//! particularly useful for managing the universally quantified variables of a clause. It can +//! automatically create fresh [`TermVarIdx`] for variables from other domains (e.g., +//! [`crate::rty::FunctionParamIdx`]), simplifying the generation of clauses from higher-level +//! representations. + use std::any::{Any, TypeId}; use std::collections::HashMap; use std::fmt::Debug; @@ -8,6 +16,7 @@ use rustc_index::IndexVec; use super::{Atom, Body, Clause, DebugInfo, Sort, TermVarIdx}; +/// A convenience trait to represent constraints on variables used in [`ClauseBuilder`] at once. pub trait Var: Eq + Ord + Hash + Copy + Debug + 'static {} impl Var for T {} @@ -54,6 +63,17 @@ impl Hash for dyn Key { } } +/// A builder for a [`Clause`]. +/// +/// [`Clause`] contains a list of universally quantified variables, a head atom, and a body formula. +/// When building the head and body, we usually have some formulas that represents variables using +/// something other than [`TermVarIdx`] (e.g. [`crate::rty::FunctionParamIdx`] or [`crate::refine::Var`]). +/// These variables are usually OK to be universally quantified in the clause, so we want to keep +/// the mapping of them to [`TermVarIdx`] and use it to convert the variables in the formulas +/// during the construction of the clause. +/// +/// Also see [`crate::rty::ClauseBuilderExt`], which provides a higher-level API on top of this +/// to build clauses from [`crate::rty::Refinement`]s. #[derive(Clone, Default)] pub struct ClauseBuilder { vars: IndexVec, diff --git a/src/chc/debug.rs b/src/chc/debug.rs index 0c0816b..e200c09 100644 --- a/src/chc/debug.rs +++ b/src/chc/debug.rs @@ -1,3 +1,10 @@ +//! Attachable debug information for CHC clauses. +//! +//! The [`DebugInfo`] struct captures contextual information (like `tracing` spans) at the time +//! of a clause's creation. This information is then pretty-printed as comments in the +//! generated SMT-LIB2 file, which helps in tracing a clause back to its origin in the +//! Thrust codebase. + #[derive(Debug, Clone)] pub struct Display<'a> { inner: &'a DebugInfo, @@ -19,6 +26,7 @@ impl<'a> std::fmt::Display for Display<'a> { } } +/// A purely informational metadata that can be attached to a clause. #[derive(Debug, Clone, Default)] pub struct DebugInfo { contexts: Vec<(String, String)>, diff --git a/src/chc/format_context.rs b/src/chc/format_context.rs index 97c3a04..ed82811 100644 --- a/src/chc/format_context.rs +++ b/src/chc/format_context.rs @@ -1,7 +1,22 @@ +//! A context for formatting a CHC system into SMT-LIB2. +//! +//! This module provides [`FormatContext`], which is responsible for translating parts of [`chc::System`] +//! into a representation that is compatible with SMT solvers. It handles tasks like +//! monomorphization of polymorphic datatypes and applying solver-specific workarounds. +//! The [`super::smtlib2`] module uses this context to perform the final rendering to the SMT-LIB2 format. + use std::collections::BTreeSet; use crate::chc::{self, hoice::HoiceDatatypeRenamer}; +/// A context for formatting a CHC system. +/// +/// This subsumes a representational difference between [`chc::System`] and resulting SMT-LIB2. +/// - Gives a naming convention of symbols to represent built-in datatypes of [`chc::System`] in SMT-LIB2, +/// - Gives a stringified representation of [`chc::Sort`]s, +/// - Monomorphizes polymorphic datatypes of [`chc::System`] to be compatible with several CHC solvers, +/// - Renames datatypes to be compatible with Hoice (see [`HoiceDatatypeRenamer`]), +/// - etc. #[derive(Debug, Clone)] pub struct FormatContext { renamer: HoiceDatatypeRenamer, diff --git a/src/chc/hoice.rs b/src/chc/hoice.rs index 86b64e7..c7c1cd8 100644 --- a/src/chc/hoice.rs +++ b/src/chc/hoice.rs @@ -1,4 +1,5 @@ -/// hopv/hoice#73 +//! A workaround for a bug in the Hoice CHC solver. + use std::collections::{HashMap, HashSet}; use crate::chc; @@ -53,6 +54,10 @@ impl<'a> SortDatatypes<'a> { } } +/// Rename to ensure the referring datatype name is lexicographically larger than the referred one. +/// +/// Workaround for . Applied indirectly via +/// [`crate::chc::format_context::FormatContext`] when formatting [`crate::chc::System`] as SMT-LIB2. #[derive(Debug, Clone, Default)] pub struct HoiceDatatypeRenamer { prefixes: HashMap, diff --git a/src/chc/smtlib2.rs b/src/chc/smtlib2.rs index 6e9ebfa..c708770 100644 --- a/src/chc/smtlib2.rs +++ b/src/chc/smtlib2.rs @@ -1,5 +1,14 @@ +//! Wrappers around CHC structures to display them in SMT-LIB2 format. +//! +//! The main entry point is the [`System`] wrapper, which takes a [`chc::System`] and provides a +//! [`std::fmt::Display`] implementation that produces a complete SMT-LIB2. +//! It uses [`FormatContext`] to handle the complexities of the conversion, +//! such as naming convention and solver-specific workarounds. +//! The output of this module is what gets passed to the external CHC solver. + use crate::chc::{self, format_context::FormatContext}; +/// A helper struct to display a list of items. #[derive(Debug, Clone)] struct List { open: Option<&'static str>, @@ -79,6 +88,7 @@ impl List { } } +/// A wrapper around a [`chc::Term`] that provides a [`std::fmt::Display`] implementation in SMT-LIB2 format. #[derive(Debug, Clone)] struct Term<'ctx, 'a> { ctx: &'ctx FormatContext, @@ -202,6 +212,7 @@ impl<'ctx, 'a> Term<'ctx, 'a> { } } +/// A wrapper around a [`chc::Atom`] that provides a [`std::fmt::Display`] implementation in SMT-LIB2 format. #[derive(Debug, Clone)] pub struct Atom<'ctx, 'a> { ctx: &'ctx FormatContext, @@ -242,6 +253,7 @@ impl<'ctx, 'a> Atom<'ctx, 'a> { } } +/// A wrapper around a [`chc::Formula`] that provides a [`std::fmt::Display`] implementation in SMT-LIB2 format. #[derive(Debug, Clone)] pub struct Formula<'ctx, 'a> { ctx: &'ctx FormatContext, @@ -278,6 +290,7 @@ impl<'ctx, 'a> Formula<'ctx, 'a> { } } +/// A wrapper around a [`chc::Body`] that provides a [`std::fmt::Display`] implementation in SMT-LIB2 format. #[derive(Debug, Clone)] pub struct Body<'ctx, 'a> { ctx: &'ctx FormatContext, @@ -304,6 +317,7 @@ impl<'ctx, 'a> Body<'ctx, 'a> { } } +/// A wrapper around a [`chc::Clause`] that provides a [`std::fmt::Display`] implementation in SMT-LIB2 format. #[derive(Debug, Clone)] pub struct Clause<'ctx, 'a> { ctx: &'ctx FormatContext, @@ -340,6 +354,7 @@ impl<'ctx, 'a> Clause<'ctx, 'a> { } } +/// A wrapper around a [`chc::DatatypeSelector`] that provides a [`std::fmt::Display`] implementation in SMT-LIB2 format. #[derive(Debug, Clone)] pub struct DatatypeSelector<'ctx, 'a> { ctx: &'ctx FormatContext, @@ -363,6 +378,7 @@ impl<'ctx, 'a> DatatypeSelector<'ctx, 'a> { } } +/// A wrapper around a [`chc::DatatypeCtor`] that provides a [`std::fmt::Display`] implementation in SMT-LIB2 format. #[derive(Debug, Clone)] pub struct DatatypeCtor<'ctx, 'a> { ctx: &'ctx FormatContext, @@ -386,6 +402,7 @@ impl<'ctx, 'a> DatatypeCtor<'ctx, 'a> { } } +/// A wrapper around a slice of [`chc::Datatype`] that provides a [`std::fmt::Display`] implementation in SMT-LIB2 format. #[derive(Debug, Clone)] pub struct Datatypes<'ctx, 'a> { ctx: &'ctx FormatContext, @@ -423,6 +440,8 @@ impl<'ctx, 'a> Datatypes<'ctx, 'a> { } } +/// A wrapper around a [`chc::Datatype`] that provides a [`std::fmt::Display`] implementation for the +/// discriminant function in SMT-LIB2 format. #[derive(Debug, Clone)] pub struct DatatypeDiscrFun<'ctx, 'a> { ctx: &'ctx FormatContext, @@ -458,6 +477,8 @@ impl<'ctx, 'a> DatatypeDiscrFun<'ctx, 'a> { } } +/// A wrapper around a [`chc::Datatype`] that provides a [`std::fmt::Display`] implementation for the +/// matcher predicate in SMT-LIB2 format. #[derive(Debug, Clone)] pub struct MatcherPredFun<'ctx, 'a> { ctx: &'ctx FormatContext, @@ -507,6 +528,7 @@ impl<'ctx, 'a> MatcherPredFun<'ctx, 'a> { } } +/// A wrapper around a [`chc::System`] that provides a [`std::fmt::Display`] implementation in SMT-LIB2 format. #[derive(Debug, Clone)] pub struct System<'a> { ctx: FormatContext, diff --git a/src/chc/solver.rs b/src/chc/solver.rs index 4d8c4ba..5792d75 100644 --- a/src/chc/solver.rs +++ b/src/chc/solver.rs @@ -1,3 +1,10 @@ +//! A generic interface for running external command-line CHC solvers. +//! +//! This module provides the [`Config`] struct for configuring and running an external +//! CHC solver. It supports setting the solver command, arguments, and timeout, and can +//! be configured through environment variables. + +/// An error that can occur when solving a [`crate::chc::System`]. #[derive(Debug, thiserror::Error)] pub enum CheckSatError { #[error("unsat")] @@ -12,6 +19,7 @@ pub enum CheckSatError { Io(#[from] std::io::Error), } +/// A configuration for running a command-line CHC solver. #[derive(Debug, Clone)] pub struct CommandConfig { pub name: String, @@ -84,6 +92,13 @@ impl CommandConfig { } } +/// A configuration for solving a [`crate::chc::System`]. +/// +/// This struct holds the configuration for the solver, including the solver command, its +/// arguments, and a timeout. It can also be configured to run a preprocessor on the SMT-LIB2 +/// file before passing it to the solver. +/// +/// The configuration can be loaded from environment variables using [`Config::from_env`]. #[derive(Debug, Clone)] pub struct Config { pub solver: CommandConfig, diff --git a/src/chc/unbox.rs b/src/chc/unbox.rs index 441c95b..532f623 100644 --- a/src/chc/unbox.rs +++ b/src/chc/unbox.rs @@ -1,3 +1,5 @@ +//! An optimization that removes `Box` sorts and terms from a CHC system. + use super::*; fn unbox_term(term: Term) -> Term { @@ -118,6 +120,11 @@ fn unbox_datatype(datatype: Datatype) -> Datatype { } } +/// Remove all `Box` sorts and `Box`/`BoxCurrent` terms from the system. +/// +/// The box values in Thrust represent an owned pointer, but are logically equivalent to the inner type. +/// This pass removes them to reduce the complexity of the CHCs sent to the solver. +/// This function traverses a [`System`] and removes all `Box` related constructs. pub fn unbox(system: System) -> System { let System { clauses, From 3bcd10c6ef592a371b32be2832a5e854423d9ca2 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sat, 30 Aug 2025 13:21:50 +0900 Subject: [PATCH 08/75] Add documentation to src/pretty.rs --- src/pretty.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/pretty.rs b/src/pretty.rs index 31d1f1f..bb347d8 100644 --- a/src/pretty.rs +++ b/src/pretty.rs @@ -1,7 +1,18 @@ +//! A set of utilities for pretty-printing various data structures. +//! +//! It uses the [`pretty`] crate to provide a flexible and configurable way to format complex +//! data structures for display. The main entry point is the [`PrettyDisplayExt`] trait, +//! which provides a [`PrettyDisplayExt::display`] method that returns a [`Display`] object to +//! turn [`Pretty`] values into [`std::fmt::Display`] that can be used with standard formatting macros. +//! +//! This is primarily used for debugging and logging purposes, to make the output of the tool +//! more readable. + use rustc_index::{IndexSlice, IndexVec}; use pretty::{termcolor, BuildDoc, DocAllocator, DocPtr, Pretty}; +/// A wrapper around a [`crate::rty::FunctionType`] that provides a [`Pretty`] implementation. pub struct FunctionType<'a, FV> { pub params: &'a rustc_index::IndexVec>, @@ -36,6 +47,7 @@ impl<'a, FV> FunctionType<'a, FV> { } } +/// A wrapper around a slice that provides a [`Pretty`] implementation. #[derive(Debug, Clone)] pub struct PrettySlice<'a, T> { slice: &'a [T], @@ -95,6 +107,7 @@ impl PrettySliceExt for IndexVec { } } +/// A wrapper around a type that provides a [`std::fmt::Display`] implementation via [`Pretty`]. #[derive(Debug, Clone)] pub struct Display<'a, T> { value: &'a T, From 73226a9e4eea49c564c70b1724b1cd92261acfa6 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sat, 30 Aug 2025 14:03:01 +0900 Subject: [PATCH 09/75] Fix term order in Add --- src/analyze/basic_block.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 4d01eff..cbb6589 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -161,7 +161,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let (_rhs_ty, rhs_term) = builder.subsume(rhs_ty); match (&lhs_ty, op) { (rty::Type::Int, mir::BinOp::Add) => { - builder.build(lhs_ty, rhs_term.add(lhs_term)) + builder.build(lhs_ty, lhs_term.add(rhs_term)) } (rty::Type::Int, mir::BinOp::Sub) => { builder.build(lhs_ty, lhs_term.sub(rhs_term)) From c42766f824698e6523d842b3fbd2b3774b9dd51d Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 7 Sep 2025 21:54:51 +0900 Subject: [PATCH 10/75] Add documentation to src/rty.rs and src/rty/*.rs --- src/rty.rs | 96 +++++++++++++++++++++++++++++++++++++++ src/rty/clause_builder.rs | 16 +++++++ src/rty/params.rs | 4 ++ src/rty/subtyping.rs | 10 ++++ src/rty/template.rs | 18 ++++++++ 5 files changed, 144 insertions(+) diff --git a/src/rty.rs b/src/rty.rs index c8e6e9e..72fede8 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -1,3 +1,42 @@ +//! Data structures for refinement types. +//! +//! Thrust is a refinement type system, and so its analysis is about giving refinement types to +//! variables and terms. This module defines the data structures for representing refinement +//! types. The central data structure is [`RefinedType`], which is just [`Type`] plus [`Refinement`]. +//! +//! Note that there are two notions of "variables" in this module. First is a type variable which is +//! a polymorphic type parameter which is represented by [`TypeParamIdx`], and forms one of valid type as [`ParamType`]. +//! `T` in `struct S { f: T }` is a type variable. See [`params`] module for the handling of type parameters. +//! Second is a logical variable which is a variable that can appear in logical predicates. +//! `x` and `y` in `{ x: int | x > y }` are logical variables. +//! The actual representation of logical variables varies by context and so it is often parameterized +//! as a type parameter in this module (`T` in `Type`). In [`FunctionType`], logical variables are +//! function parameters (`Type`, see [`crate::refine::Var`]). +//! We have [`RefinedTypeVar`] to denote logical variables in refinement predicates, which +//! accepts existential variables and a variable bound in the refinement type (`x` in `{ x: T | phi +//! }`) in addition to variables `T` from the outer scope. This module also contains [`Closed`] which is +//! used to denote a closed type, so `Type` ensures no logical variables from the outer scope +//! appear in that type. +//! +//! We have distinct types for each variant of [`Type`], such as [`FunctionType`] and +//! [`PointerType`]. [`Type`], [`RefinedType`] and these variant types share some common operations: +//! +//! - `subst_var`: Substitutes logical variables with logical terms. +//! - `map_var`: Maps logical variables to other logical variables. +//! - `free_ty_params`: Collects free type parameters [`TypeParamIdx`] in the type. +//! - `subst_ty_params`: Substitutes type parameters with other types. Since this replaces +//! type parameters with refinement types, [`Type`] does not implement this, and +//! [`RefinedType::subst_ty_params`] handles the substitution logic instead. +//! - `strip_refinement`: Strips the refinements recursively and returns a [`Closed`] type. +//! +//! This module also implements several logics for manipulating refinement types and is extensively used in +//! upstream logic in the [`crate::refine`] and [`crate::analyze`] modules. +//! +//! - [`template`]: Generates "template" refinement types with unknown predicates to be inferred. +//! - [`subtyping`]: Generates CHC constraints [`crate::chc`] from subtyping relations between types. +//! - [`clause_builder`]: Helper to build [`crate::chc::Clause`] from the refinement types. + use std::collections::{HashMap, HashSet}; use pretty::{termcolor, Pretty}; @@ -19,6 +58,11 @@ mod params; pub use params::{TypeParamIdx, TypeParamSubst, TypeParams}; rustc_index::newtype_index! { + /// An index representing function parameter. + /// + /// We manage function parameters using indices that are unique in a function. + /// [`FunctionType`] contains `IndexVec>` + /// that manages the indices and types of the function parameters. #[orderable] #[debug_format = "${}"] pub struct FunctionParamIdx { } @@ -39,6 +83,11 @@ where } } +/// A function type. +/// +/// In Thrust, function types are closed. Because of that, function types, thus its parameters and +/// return type only refer to the parameters of the function itself using [`FunctionParamIdx`] and +/// do not accept other type of variables from the environment. #[derive(Debug, Clone)] pub struct FunctionType { pub params: IndexVec>, @@ -74,6 +123,8 @@ impl FunctionType { } } + /// Because function types are always closed in Thrust, we can convert this into + /// [`Type`]. pub fn into_closed_ty(self) -> Type { Type::Function(self) } @@ -104,6 +155,7 @@ impl FunctionType { } } +/// The kind of a reference, which is either mutable or immutable. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RefKind { Mut, @@ -128,6 +180,7 @@ where } } +/// The kind of a pointer, which is either a reference or an owned pointer. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PointerKind { Ref(RefKind), @@ -167,6 +220,7 @@ impl PointerKind { } } +/// A pointer type. #[derive(Debug, Clone)] pub struct PointerType { pub kind: PointerKind, @@ -275,6 +329,11 @@ impl PointerType { } } +/// A tuple type. +/// +/// Note that the current implementation uses tuples to represent structs. See +/// implementation in `crate::refine::template` module for details. +/// It is our TODO to improve the struct representation. #[derive(Debug, Clone)] pub struct TupleType { pub elems: Vec>, @@ -374,6 +433,7 @@ impl TupleType { } } +/// A definition of an enum variant, found in [`EnumDatatypeDef`]. #[derive(Debug, Clone)] pub struct EnumVariantDef { pub name: chc::DatatypeSymbol, @@ -381,6 +441,7 @@ pub struct EnumVariantDef { pub field_tys: Vec>, } +/// A definition of an enum datatype. #[derive(Debug, Clone)] pub struct EnumDatatypeDef { pub name: chc::DatatypeSymbol, @@ -394,6 +455,9 @@ impl EnumDatatypeDef { } } +/// An enum type. +/// +/// An enum type includes its type arguments and the argument types can refer to outer variables `T`. #[derive(Debug, Clone)] pub struct EnumType { pub symbol: chc::DatatypeSymbol, @@ -495,6 +559,7 @@ impl EnumType { } } +/// A type parameter. #[derive(Debug, Clone)] pub struct ParamType { pub idx: TypeParamIdx, @@ -523,6 +588,7 @@ impl ParamType { } } +/// An underlying type of a refinement type. #[derive(Debug, Clone)] pub enum Type { Int, @@ -803,6 +869,10 @@ impl Type { } } +/// A marker type for closed types. +/// +/// Because the value of `Closed` can never exist, `Type` ensures that no +/// logical variables from outer scope appear in the type. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Closed {} @@ -822,6 +892,11 @@ where } rustc_index::newtype_index! { + /// An index representing existential variables. + /// + /// We manage existential variables using indices that are unique in a [`Formula`]. + /// [`Formula`] contains `IndexVec` that manages the indices + /// and sorts of the existential variables. #[orderable] #[debug_format = "e{}"] pub struct ExistentialVarIdx { } @@ -846,6 +921,15 @@ pub trait ShiftExistential { fn shift_existential(self, offset: usize) -> Self; } +/// A logical variable in a refinement predicate. +/// +/// Given a refinement type `{ v: T | ∃e. φ }`, there are three cases of logical variables +/// occurring in the predicate `φ`: +/// +/// - `RefinedTypeVar::Value`: a variable `v` representing the value of the type +/// - `RefinedTypeVar::Existential`: an existential variable `e` +/// - `RefinedTypeVar::Free`: a variable from the outer scope, such as function parameters or +/// variables bound in the environment #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum RefinedTypeVar { Value, @@ -908,6 +992,9 @@ impl ShiftExistential for RefinedTypeVar { } } +/// A formula, potentially equipped with an existential quantifier. +/// +/// Note: This is not to be confused with [`crate::chc::Formula`] in the [`crate::chc`] module, which is a different notion. #[derive(Debug, Clone)] pub struct Formula { pub existentials: IndexVec, @@ -1051,6 +1138,10 @@ where } } +/// A refinement predicate in a refinement type. +/// +/// This is a [`Formula`], but gives an explicit representation of the kinds of variables that can appear in +/// the refinement predicates. See [`RefinedTypeVar`] for details. pub type Refinement = Formula>; impl Refinement { @@ -1105,6 +1196,10 @@ impl Refinement { } } +/// A helper type to map logical variables in a refinement at once. +/// +/// This is essentially just calling `Refinement::map_var`, but provides a convenient interface to +/// construct the mapping of the variables. #[derive(Debug, Clone)] pub struct Instantiator { value_var: Option, @@ -1140,6 +1235,7 @@ impl Instantiator { } } +/// A refinement type. #[derive(Debug, Clone)] pub struct RefinedType { pub ty: Type, diff --git a/src/rty/clause_builder.rs b/src/rty/clause_builder.rs index f45e948..ce3c54f 100644 --- a/src/rty/clause_builder.rs +++ b/src/rty/clause_builder.rs @@ -1,7 +1,19 @@ +//! Helpers to build [`crate::chc::Clause`] from the refinement types. +//! +//! This module provides an extension trait named [`ClauseBuilderExt`] for [`chc::ClauseBuilder`] +//! that allows it to work with refinement types. It provides a convenient way to generate CHC clauses from +//! [`Refinement`]s by handling the mapping of [`super::RefinedTypeVar`] to [`chc::TermVarIdx`]. +//! This is primarily used to generate clauses from [`super::subtyping`] constraints between refinement types. + use crate::chc; use super::{Refinement, Type}; +/// An extension trait for [`chc::ClauseBuilder`] to work with refinement types. +/// +/// We implement a custom logic to deal with [`Refinement`]s in [`RefinementClauseBuilder`], +/// and this extension trait provides methods to create [`RefinementClauseBuilder`]s while +/// specifying how to handle [`super::RefinedTypeVar::Value`] during the instantiation. pub trait ClauseBuilderExt { fn with_value_var<'a, T>(&'a mut self, ty: &Type) -> RefinementClauseBuilder<'a>; fn with_mapped_value_var(&mut self, v: T) -> RefinementClauseBuilder<'_> @@ -31,6 +43,10 @@ impl ClauseBuilderExt for chc::ClauseBuilder { } } +/// A builder for a CHC clause with a refinement. +/// +/// You can supply [`Refinement`]s as the body and head of the clause directly, and this builder +/// will take care of mapping the variables appropriately. pub struct RefinementClauseBuilder<'a> { builder: &'a mut chc::ClauseBuilder, value_var: Option, diff --git a/src/rty/params.rs b/src/rty/params.rs index 26bd289..17ebc2b 100644 --- a/src/rty/params.rs +++ b/src/rty/params.rs @@ -1,3 +1,5 @@ +//! Data structures for type parameters and substitutions. + use std::collections::BTreeMap; use pretty::{termcolor, Pretty}; @@ -8,6 +10,7 @@ use crate::chc; use super::{Closed, RefinedType}; rustc_index::newtype_index! { + /// An index representing a type parameter. #[orderable] #[debug_format = "T{}"] pub struct TypeParamIdx { } @@ -38,6 +41,7 @@ impl TypeParamIdx { pub type TypeParams = IndexVec>; +/// A substitution for type parameters that maps type parameters to refinement types. #[derive(Debug, Clone)] pub struct TypeParamSubst { subst: BTreeMap>, diff --git a/src/rty/subtyping.rs b/src/rty/subtyping.rs index df3d446..82465a8 100644 --- a/src/rty/subtyping.rs +++ b/src/rty/subtyping.rs @@ -1,8 +1,17 @@ +//! Translation of subtyping relations into CHC constraints. + use crate::chc; use crate::pretty::PrettyDisplayExt; use super::{ClauseBuilderExt as _, Closed, PointerKind, RefKind, RefinedType, Type}; +/// A scope for building clauses. +/// +/// The construction of CHC clauses requires knowledge of the current +/// environment to determine variable sorts and include necessary premises. +/// This trait abstracts the preparation of a [`chc::ClauseBuilder`] to allow an +/// environment defined outside of this module (in Thrust, [`crate::refine::Env`]) +/// to build a [`chc::ClauseBuilder`] equipped with in-scope variables and assumptions. pub trait ClauseScope { fn build_clause(&self) -> chc::ClauseBuilder; } @@ -22,6 +31,7 @@ impl ClauseScope for chc::ClauseBuilder { } } +/// Produces CHC constraints for subtyping relations. pub trait Subtyping { #[must_use] fn relate_sub_type( diff --git a/src/rty/template.rs b/src/rty/template.rs index f2f1112..3d75820 100644 --- a/src/rty/template.rs +++ b/src/rty/template.rs @@ -1,9 +1,22 @@ +//! "Template" refinement types with unknown predicates to be inferred. +//! +//! A [`Template`] is used to generate a [`RefinedType`] with a refinement consisting of a +//! single atom with a fresh predicate variable. The unknown predicate can carry dependencies, +//! which are the arguments to the predicate. When Thrust infers refinement types, it +//! first generates template refinement types with unknown refinements, and then +//! generates constraints on the predicate variables in these templates. + use std::collections::BTreeMap; use crate::chc; use super::{RefinedType, RefinedTypeVar, Type}; +/// A template of a refinement type. +/// +/// This is a refinement type in the form of `{ T | P(x1, ..., xn) }` where `P` is a +/// predicate variable, `T` is a type, and `x1, ..., xn` are the dependencies. The predicate +/// variable is actually allocated when [`Template::into_refined_type`] is called. #[derive(Debug, Clone)] pub struct Template { pred_sig: chc::PredSig, @@ -24,6 +37,11 @@ impl Template { } } +/// A builder for a [`Template`]. +/// +/// Note that we have a convenient mechanism in the [`crate::refine`] module +/// to prepare a [`TemplateBuilder`] with variables in scope, and we usually don't +/// construct a [`TemplateBuilder`] directly. #[derive(Debug, Clone)] pub struct TemplateBuilder { dependencies: BTreeMap, chc::Sort>, From 1a5130e9fc52661f4ecb4d59beb2e1d4ef238ef4 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 7 Sep 2025 21:55:54 +0900 Subject: [PATCH 11/75] Add documentation to src/{analyze,refine,refine/basic_block}.rs --- src/analyze.rs | 8 ++++++++ src/refine.rs | 9 +++++++++ src/refine/basic_block.rs | 9 +++++++-- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/analyze.rs b/src/analyze.rs index 9913f26..898df5b 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -1,3 +1,11 @@ +//! Analysis of Rust MIR to generate a CHC system. +//! +//! The [`Analyzer`] generates subtyping constraints in the form of CHCs ([`chc::System`]). +//! The entry point is [`crate_::Analyzer::run`], followed by [`local_def::Analyzer::run`] +//! and [`basic_block::Analyzer::run`], while accumulating the necessary information in +//! [`Analyzer`]. Once [`chc::System`] is collected for the entire input, it invokes an external +//! CHC solver with the [`Analyzer::solve`] and subsequently reports the result. + use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; diff --git a/src/refine.rs b/src/refine.rs index 60f29f4..0371da1 100644 --- a/src/refine.rs +++ b/src/refine.rs @@ -1,3 +1,12 @@ +//! Core logic of refinement typing. +//! +//! This module includes the definition of the refinement typing environment and the template +//! type generation from MIR types. +//! +//! This module is used by the [`crate::analyze`] module. There is currently no clear boundary between +//! the `analyze` and `refine` modules, so it is a TODO to integrate this into the `analyze` +//! module and remove this one. + mod template; pub use template::{TemplateScope, TemplateTypeGenerator, UnrefinedTypeGenerator}; diff --git a/src/refine/basic_block.rs b/src/refine/basic_block.rs index 57e9930..ac39a1c 100644 --- a/src/refine/basic_block.rs +++ b/src/refine/basic_block.rs @@ -1,3 +1,5 @@ +//! The refinement type for a basic block. + use pretty::{termcolor, Pretty}; use rustc_index::IndexVec; use rustc_middle::mir::Local; @@ -5,8 +7,11 @@ use rustc_middle::ty as mir_ty; use crate::rty; -/// `BasicBlockType` is a special case of `FunctionType` whose parameters are -/// associated with `Local`s. +/// A special case of [`rty::FunctionType`] whose parameters are associated with [`Local`]s. +/// +/// Thrust handles basic blocks as functions, but it needs to associate function +/// parameters with MIR [`Local`]s during its analysis. [`BasicBlockType`] includes this mapping +/// from function parameters to [`Local`]s, along with the underlying function type. #[derive(Debug, Clone)] pub struct BasicBlockType { // TODO: make this completely private by exposing appropriate ctor From 0e6b9079e3a068e9e2690e0ed002cb6cf678bfb5 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 7 Sep 2025 22:29:34 +0900 Subject: [PATCH 12/75] Add documentation of src/analyze/{annot,crate_,did_cache,local_def}.rs --- src/analyze/annot.rs | 10 ++++++++++ src/analyze/crate_.rs | 12 ++++++++++++ src/analyze/did_cache.rs | 6 ++++++ src/analyze/local_def.rs | 6 ++++++ 4 files changed, 34 insertions(+) diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 7d23cfe..91dd209 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -1,3 +1,5 @@ +//! Supporting implementation for parsing Thrust annotations. + use rustc_ast::ast::Attribute; use rustc_ast::tokenstream::TokenStream; use rustc_index::IndexVec; @@ -31,6 +33,10 @@ pub fn callable_path() -> [Symbol; 2] { [Symbol::intern("thrust"), Symbol::intern("callable")] } +/// A [`annot::Resolver`] implementation for resolving function parameters. +/// +/// The parameter names and their sorts needs to be configured via +/// [`ParamResolver::push_param`] before use. #[derive(Debug, Clone, Default)] pub struct ParamResolver { params: IndexVec, @@ -52,6 +58,10 @@ impl ParamResolver { } } +/// A [`annot::Resolver`] implementation for resolving the special identifier `result`. +/// +/// The `result` identifier is used to refer to [`rty::RefinedTypeVar::Value`] in postconditions, denoting +/// the return value of a function. #[derive(Debug, Clone)] pub struct ResultResolver { result_symbol: Symbol, diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 016c709..9a1fa67 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -1,3 +1,5 @@ +//! Analyze a local crate. + use std::collections::HashSet; use rustc_hir::def::DefKind; @@ -12,6 +14,16 @@ use crate::chc; use crate::refine::{self, TemplateTypeGenerator, UnrefinedTypeGenerator}; use crate::rty::{self, ClauseBuilderExt as _}; +/// An implementation of local crate analysis. +/// +/// The entry point is [`Analyzer::run`], which performs the following steps in order: +/// +/// 1. Register enum definitions found in the crate. +/// 2. Give initial refinement types to local function definitions based on their signatures and +/// annotations. This generates template refinement types with predicate variables for parameters and +/// return types that are not known via annotations. +/// 3. Type local function definition bodies via [`super::local_def::Analyzer`] using the refinement types +/// generated in the previous step. pub struct Analyzer<'tcx, 'ctx> { tcx: TyCtxt<'tcx>, ctx: &'ctx mut analyze::Analyzer<'tcx>, diff --git a/src/analyze/did_cache.rs b/src/analyze/did_cache.rs index d671f8a..efeea6b 100644 --- a/src/analyze/did_cache.rs +++ b/src/analyze/did_cache.rs @@ -1,3 +1,5 @@ +//! Retrieves and caches well-known [`DefId`]s. + use std::cell::OnceCell; use rustc_middle::ty::TyCtxt; @@ -10,6 +12,10 @@ struct DefIds { nonnull: OnceCell>, } +/// Retrieves and caches well-known [`DefId`]s. +/// +/// [`DefId`]s of some well-known types can be retrieved as lang items or via the definition of +/// lang items. This struct implements that logic and caches the results. #[derive(Clone)] pub struct DefIdCache<'tcx> { tcx: TyCtxt<'tcx>, diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index b1df237..ef5870e 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -1,3 +1,5 @@ +//! Analyze a local definition. + use std::collections::{HashMap, HashSet}; use rustc_index::bit_set::BitSet; @@ -12,6 +14,10 @@ use crate::pretty::PrettyDisplayExt as _; use crate::refine::{BasicBlockType, TemplateTypeGenerator}; use crate::rty; +/// An implementation of the typing of local definitions. +/// +/// The current implementation only applies to function definitions. The entry point is +/// [`Analyzer::run`], which generates constraints during typing, given the expected type of the function. pub struct Analyzer<'tcx, 'ctx> { ctx: &'ctx mut analyze::Analyzer<'tcx>, tcx: TyCtxt<'tcx>, From e896fdca8b70f24136ef9b95e635670d50637e8f Mon Sep 17 00:00:00 2001 From: coord_e Date: Mon, 15 Sep 2025 12:39:31 +0900 Subject: [PATCH 13/75] Handle parens in annotations properly --- src/annot.rs | 459 ++++++++++++++++++++++++++--------------- tests/ui/fail/annot.rs | 13 ++ tests/ui/pass/annot.rs | 13 ++ 3 files changed, 319 insertions(+), 166 deletions(-) diff --git a/src/annot.rs b/src/annot.rs index d00d3a8..757eb68 100644 --- a/src/annot.rs +++ b/src/annot.rs @@ -84,6 +84,10 @@ pub enum ParseAttrError { UnknownIdent { ident: Ident }, #[error("operator {op:?} cannot be applied to a term of sort {}", .sort.display())] UnsortedOp { op: &'static str, sort: chc::Sort }, + #[error("unexpected term {context}")] + UnexpectedTerm { context: &'static str }, + #[error("unexpected formula {context}")] + UnexpectedFormula { context: &'static str }, } impl ParseAttrError { @@ -107,10 +111,76 @@ impl ParseAttrError { fn unsorted_op(op: &'static str, sort: chc::Sort) -> Self { ParseAttrError::UnsortedOp { op, sort } } + + fn unexpected_term(context: &'static str) -> Self { + ParseAttrError::UnexpectedTerm { context } + } + + fn unexpected_formula(context: &'static str) -> Self { + ParseAttrError::UnexpectedFormula { context } + } } type Result = std::result::Result; +#[derive(Debug, Clone, Copy)] +enum AmbiguousBinOp { + Eq, + Ne, + Ge, + Le, + Gt, + Lt, +} + +#[derive(Debug, Clone)] +enum FormulaOrTerm { + Formula(chc::Formula), + Term(chc::Term, chc::Sort), + BinOp(chc::Term, AmbiguousBinOp, chc::Term), + Not(Box>), +} + +impl FormulaOrTerm { + fn into_formula(self) -> Option> { + let fo = match self { + FormulaOrTerm::Formula(fo) => fo, + FormulaOrTerm::Term { .. } => return None, + FormulaOrTerm::BinOp(lhs, binop, rhs) => { + let pred = match binop { + AmbiguousBinOp::Eq => chc::KnownPred::EQUAL, + AmbiguousBinOp::Ne => chc::KnownPred::NOT_EQUAL, + AmbiguousBinOp::Ge => chc::KnownPred::GREATER_THAN_OR_EQUAL, + AmbiguousBinOp::Le => chc::KnownPred::LESS_THAN_OR_EQUAL, + AmbiguousBinOp::Gt => chc::KnownPred::GREATER_THAN, + AmbiguousBinOp::Lt => chc::KnownPred::LESS_THAN, + }; + chc::Formula::Atom(chc::Atom::new(pred.into(), vec![lhs, rhs])) + } + FormulaOrTerm::Not(formula_or_term) => formula_or_term.into_formula()?.not(), + }; + Some(fo) + } + + fn into_term(self) -> Option<(chc::Term, chc::Sort)> { + let (t, s) = match self { + FormulaOrTerm::Formula { .. } => return None, + FormulaOrTerm::Term(t, s) => (t, s), + FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Eq, rhs) => (lhs.eq(rhs), chc::Sort::bool()), + FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Ne, rhs) => (lhs.ne(rhs), chc::Sort::bool()), + FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Ge, rhs) => (lhs.ge(rhs), chc::Sort::bool()), + FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Le, rhs) => (lhs.le(rhs), chc::Sort::bool()), + FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Gt, rhs) => (lhs.gt(rhs), chc::Sort::bool()), + FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Lt, rhs) => (lhs.lt(rhs), chc::Sort::bool()), + FormulaOrTerm::Not(formula_or_term) => { + let (t, _) = formula_or_term.into_term()?; + (t.not(), chc::Sort::bool()) + } + }; + Some((t, s)) + } +} + /// A parser for refinement type annotations and formula annotations. struct Parser<'a, T> { resolver: T, @@ -184,10 +254,10 @@ where } } - fn parse_term_or_tuple(&mut self) -> Result<(chc::Term, chc::Sort)> { - let mut terms_and_sorts = Vec::new(); + fn parse_formula_or_term_or_tuple(&mut self) -> Result> { + let mut formula_or_terms = Vec::new(); loop { - terms_and_sorts.push(self.parse_term()?); + formula_or_terms.push(self.parse_formula_or_term()?); if let Some(Token { kind: TokenKind::Comma, .. @@ -198,16 +268,27 @@ where break; } } - if terms_and_sorts.len() == 1 { - Ok(terms_and_sorts.pop().unwrap()) + if formula_or_terms.len() == 1 { + Ok(formula_or_terms.pop().unwrap()) } else { - let (terms, sorts) = terms_and_sorts.into_iter().unzip(); - Ok((chc::Term::tuple(terms), chc::Sort::tuple(sorts))) + let mut terms = Vec::new(); + let mut sorts = Vec::new(); + for ft in formula_or_terms { + let (t, s) = ft + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("in tuple"))?; + terms.push(t); + sorts.push(s); + } + Ok(FormulaOrTerm::Term( + chc::Term::tuple(terms), + chc::Sort::tuple(sorts), + )) } } - fn parse_atom_term(&mut self) -> Result<(chc::Term, chc::Sort)> { - let tt = self.next_token_tree("term")?.clone(); + fn parse_atom(&mut self) -> Result> { + let tt = self.next_token_tree("term or formula")?.clone(); let t = match &tt { TokenTree::Token(t, _) => t, @@ -216,20 +297,26 @@ where resolver: self.boxed_resolver(), cursor: s.trees(), }; - let (term, sort) = parser.parse_term_or_tuple()?; + let formula_or_term = parser.parse_formula_or_term_or_tuple()?; parser.end_of_input()?; - return Ok((term, sort)); + return Ok(formula_or_term); } _ => return Err(ParseAttrError::unexpected_token_tree("token", tt)), }; - let (term, sort) = if let Some((ident, _)) = t.ident() { - let (v, sort) = self.resolve(ident)?; - (chc::Term::var(v), sort) + let formula_or_term = if let Some((ident, _)) = t.ident() { + match ident.as_str() { + "true" => FormulaOrTerm::Formula(chc::Formula::top()), + "false" => FormulaOrTerm::Formula(chc::Formula::bottom()), + _ => { + let (v, sort) = self.resolve(ident)?; + FormulaOrTerm::Term(chc::Term::var(v), sort) + } + } } else { match t.kind { TokenKind::Literal(lit) => match lit.kind { - LitKind::Integer => ( + LitKind::Integer => FormulaOrTerm::Term( chc::Term::int(lit.symbol.as_str().parse().unwrap()), chc::Sort::int(), ), @@ -244,11 +331,11 @@ where } }; - Ok((term, sort)) + Ok(formula_or_term) } - fn parse_postfix_term(&mut self) -> Result<(chc::Term, chc::Sort)> { - let (term, sort) = self.parse_atom_term()?; + fn parse_postfix(&mut self) -> Result> { + let formula_or_term = self.parse_atom()?; let mut fields = Vec::new(); while let Some(Token { @@ -269,201 +356,234 @@ where } } + if fields.is_empty() { + return Ok(formula_or_term); + } + + let (term, sort) = formula_or_term + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("before projection"))?; let term = fields.iter().fold(term, |acc, idx| acc.tuple_proj(*idx)); let sort = fields.iter().fold(sort, |acc, idx| acc.tuple_elem(*idx)); - Ok((term, sort)) + Ok(FormulaOrTerm::Term(term, sort)) } - fn parse_prefix_term(&mut self) -> Result<(chc::Term, chc::Sort)> { - let (term, sort) = match self.look_ahead_token(0).map(|t| &t.kind) { - Some(TokenKind::BinOp(BinOpToken::Minus)) => { - self.consume(); - let (operand, sort) = self.parse_postfix_term()?; - (operand.neg(), sort) - } - Some(TokenKind::BinOp(BinOpToken::Star)) => { - self.consume(); - let (operand, sort) = self.parse_postfix_term()?; - match sort { - chc::Sort::Box(inner) => (chc::Term::box_current(operand), *inner), - chc::Sort::Mut(inner) => (chc::Term::mut_current(operand), *inner), - _ => return Err(ParseAttrError::unsorted_op("*", sort)), + fn parse_prefix(&mut self) -> Result> { + let formula_or_term = + match self.look_ahead_token(0).map(|t| &t.kind) { + Some(TokenKind::BinOp(BinOpToken::Minus)) => { + self.consume(); + let (operand, sort) = self.parse_postfix()?.into_term().ok_or_else(|| { + ParseAttrError::unexpected_formula("after unary - operator") + })?; + FormulaOrTerm::Term(operand.neg(), sort) } - } - Some(TokenKind::BinOp(BinOpToken::Caret)) => { - self.consume(); - let (operand, sort) = self.parse_postfix_term()?; - if let chc::Sort::Mut(inner) = sort { - (chc::Term::mut_final(operand), *inner) - } else { - return Err(ParseAttrError::unsorted_op("^", sort)); + Some(TokenKind::BinOp(BinOpToken::Star)) => { + self.consume(); + let (operand, sort) = self.parse_postfix()?.into_term().ok_or_else(|| { + ParseAttrError::unexpected_formula("after unary * operator") + })?; + let (t, s) = match sort { + chc::Sort::Box(inner) => (chc::Term::box_current(operand), *inner), + chc::Sort::Mut(inner) => (chc::Term::mut_current(operand), *inner), + _ => return Err(ParseAttrError::unsorted_op("*", sort)), + }; + FormulaOrTerm::Term(t, s) } + Some(TokenKind::BinOp(BinOpToken::Caret)) => { + self.consume(); + let (operand, sort) = self.parse_postfix()?.into_term().ok_or_else(|| { + ParseAttrError::unexpected_formula("after unary ^ operator") + })?; + if let chc::Sort::Mut(inner) = sort { + FormulaOrTerm::Term(chc::Term::mut_final(operand), *inner) + } else { + return Err(ParseAttrError::unsorted_op("^", sort)); + } + } + Some(TokenKind::Not) => { + self.consume(); + let formula_or_term = self.parse_postfix()?; + FormulaOrTerm::Not(Box::new(formula_or_term)) + } + _ => self.parse_postfix()?, + }; + Ok(formula_or_term) + } + + fn parse_binop_1(&mut self) -> Result> { + let lhs = self.parse_prefix()?; + + let formula_or_term = match self.look_ahead_token(0).map(|t| &t.kind) { + Some(TokenKind::BinOp(BinOpToken::Star)) => { + self.consume(); + let (lhs, _) = lhs + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("before * operator"))?; + let (rhs, _) = self + .parse_prefix()? + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("after * operator"))?; + FormulaOrTerm::Term(lhs.mul(rhs), chc::Sort::int()) } - _ => self.parse_postfix_term()?, + _ => return Ok(lhs), }; - Ok((term, sort)) + + Ok(formula_or_term) } - fn parse_term(&mut self) -> Result<(chc::Term, chc::Sort)> { - let (lhs, lhs_sort) = self.parse_prefix_term()?; + fn parse_binop_2(&mut self) -> Result> { + let lhs = self.parse_binop_1()?; - let (term, sort) = match self.look_ahead_token(0).map(|t| &t.kind) { + let formula_or_term = match self.look_ahead_token(0).map(|t| &t.kind) { Some(TokenKind::BinOp(BinOpToken::Plus)) => { self.consume(); - let (rhs, _) = self.parse_prefix_term()?; - (lhs.add(rhs), chc::Sort::int()) + let (lhs, _) = lhs + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("before + operator"))?; + let (rhs, _) = self + .parse_binop_1()? + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("after + operator"))?; + FormulaOrTerm::Term(lhs.add(rhs), chc::Sort::int()) } Some(TokenKind::BinOp(BinOpToken::Minus)) => { self.consume(); - let (rhs, _) = self.parse_prefix_term()?; - (lhs.sub(rhs), chc::Sort::int()) - } - Some(TokenKind::BinOp(BinOpToken::Star)) => { - self.consume(); - let (rhs, _) = self.parse_prefix_term()?; - (lhs.mul(rhs), chc::Sort::int()) + let (lhs, _) = lhs + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("before - operator"))?; + let (rhs, _) = self + .parse_binop_1()? + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("after - operator"))?; + FormulaOrTerm::Term(lhs.sub(rhs), chc::Sort::int()) } + _ => return Ok(lhs), + }; + + Ok(formula_or_term) + } + + fn parse_binop_3(&mut self) -> Result> { + let lhs = self.parse_binop_2()?; + + let formula_or_term = match self.look_ahead_token(0).map(|t| &t.kind) { Some(TokenKind::EqEq) => { self.consume(); - let (rhs, _) = self.parse_prefix_term()?; - (lhs.eq(rhs), chc::Sort::bool()) + let (lhs, _) = lhs + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("before == operator"))?; + let (rhs, _) = self + .parse_binop_2()? + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("after == operator"))?; + FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Eq, rhs) } Some(TokenKind::Ne) => { self.consume(); - let (rhs, _) = self.parse_prefix_term()?; - (lhs.ne(rhs), chc::Sort::bool()) + let (lhs, _) = lhs + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("before != operator"))?; + let (rhs, _) = self + .parse_binop_2()? + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("after != operator"))?; + FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Ne, rhs) } Some(TokenKind::Ge) => { self.consume(); - let (rhs, _) = self.parse_prefix_term()?; - (lhs.ge(rhs), chc::Sort::bool()) + let (lhs, _) = lhs + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("before >= operator"))?; + let (rhs, _) = self + .parse_binop_2()? + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("after >= operator"))?; + FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Ge, rhs) } Some(TokenKind::Le) => { self.consume(); - let (rhs, _) = self.parse_prefix_term()?; - (lhs.le(rhs), chc::Sort::bool()) + let (lhs, _) = lhs + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("before <= operator"))?; + let (rhs, _) = self + .parse_binop_2()? + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("after <= operator"))?; + FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Le, rhs) } Some(TokenKind::Gt) => { self.consume(); - let (rhs, _) = self.parse_prefix_term()?; - (lhs.gt(rhs), chc::Sort::bool()) + let (lhs, _) = lhs + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("before > operator"))?; + let (rhs, _) = self + .parse_binop_2()? + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("after > operator"))?; + FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Gt, rhs) } Some(TokenKind::Lt) => { self.consume(); - let (rhs, _) = self.parse_prefix_term()?; - (lhs.lt(rhs), chc::Sort::bool()) + let (lhs, _) = lhs + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("before < operator"))?; + let (rhs, _) = self + .parse_binop_2()? + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("after < operator"))?; + FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Lt, rhs) } - _ => return Ok((lhs, lhs_sort)), + _ => return Ok(lhs), }; - Ok((term, sort)) + Ok(formula_or_term) } - fn parse_atom(&mut self) -> Result> { - if let Some((ident, _)) = self.look_ahead_token(0).and_then(|t| t.ident()) { - if ident.as_str() == "true" { - self.consume(); - return Ok(chc::Atom::top()); - } - if ident.as_str() == "false" { - self.consume(); - return Ok(chc::Atom::bottom()); - } - } + fn parse_binop_4(&mut self) -> Result> { + let lhs = self.parse_binop_3()?; - let (lhs, _) = self.parse_prefix_term()?; - let pred = match self.next_token("<=, >=, <, >, == or !=")? { - Token { - kind: TokenKind::EqEq, - .. - } => chc::KnownPred::EQUAL, - Token { - kind: TokenKind::Ne, - .. - } => chc::KnownPred::NOT_EQUAL, - Token { - kind: TokenKind::Ge, - .. - } => chc::KnownPred::GREATER_THAN_OR_EQUAL, - Token { - kind: TokenKind::Le, - .. - } => chc::KnownPred::LESS_THAN_OR_EQUAL, - Token { - kind: TokenKind::Gt, - .. - } => chc::KnownPred::GREATER_THAN, - Token { - kind: TokenKind::Lt, - .. - } => chc::KnownPred::LESS_THAN, - t => { - return Err(ParseAttrError::unexpected_token( - "<=, >=, <, >, == or !=", - t.clone(), - )) + let formula_or_term = match self.look_ahead_token(0).map(|t| &t.kind) { + Some(TokenKind::AndAnd) => { + self.consume(); + let lhs = lhs + .into_formula() + .ok_or_else(|| ParseAttrError::unexpected_term("before && operator"))?; + let rhs = self + .parse_binop_3()? + .into_formula() + .ok_or_else(|| ParseAttrError::unexpected_term("after && operator"))?; + FormulaOrTerm::Formula(lhs.and(rhs)) } + _ => return Ok(lhs), }; - let (rhs, _) = self.parse_term()?; - Ok(chc::Atom::new(pred.into(), vec![lhs, rhs])) + + Ok(formula_or_term) } - fn parse_formula_atom(&mut self) -> Result> { - match self.look_ahead_token_tree(0).cloned() { - Some(TokenTree::Token( - Token { - kind: TokenKind::Not, - .. - }, - _, - )) => { + fn parse_binop_5(&mut self) -> Result> { + let lhs = self.parse_binop_4()?; + + let formula_or_term = match self.look_ahead_token(0).map(|t| &t.kind) { + Some(TokenKind::OrOr) => { self.consume(); - let atom = self.parse_atom()?; - Ok(chc::Formula::Atom(atom).not()) - } - //Some(TokenTree::Delimited(_, _, Delimiter::Parenthesis, s)) => { - // self.consume(); - // let mut parser = Parser { - // resolver: self.boxed_resolver(), - // cursor: s.trees(), - // }; - // let formula = parser.parse_formula()?; - // parser.end_of_input()?; - // Ok(formula) - //} - _ => { - let atom = self.parse_atom()?; - Ok(chc::Formula::Atom(atom)) + let lhs = lhs + .into_formula() + .ok_or_else(|| ParseAttrError::unexpected_term("before || operator"))?; + let rhs = self + .parse_binop_4()? + .into_formula() + .ok_or_else(|| ParseAttrError::unexpected_term("after || operator"))?; + FormulaOrTerm::Formula(lhs.or(rhs)) } - } - } + _ => return Ok(lhs), + }; - fn parse_formula_and(&mut self) -> Result> { - let mut formula = self.parse_formula_atom()?; - while let Some(Token { - kind: TokenKind::AndAnd, - .. - }) = self.look_ahead_token(0) - { - self.consume(); - let next_formula = self.parse_formula_atom()?; - formula = formula.and(next_formula); - } - Ok(formula) + Ok(formula_or_term) } - fn parse_formula(&mut self) -> Result> { - let mut formula = self.parse_formula_and()?; - while let Some(Token { - kind: TokenKind::OrOr, - .. - }) = self.look_ahead_token(0) - { - self.consume(); - let next_formula = self.parse_formula_and()?; - formula = formula.or(next_formula); - } - Ok(formula) + fn parse_formula_or_term(&mut self) -> Result> { + self.parse_binop_5() } fn parse_ty(&mut self) -> Result> { @@ -611,7 +731,10 @@ where if let Some(self_ident) = self_ident { parser.resolver.set_self(self_ident, ty.to_sort()); } - let formula = parser.parse_formula()?; + let formula = parser + .parse_formula_or_term()? + .into_formula() + .ok_or_else(|| ParseAttrError::unexpected_term("in refinement"))?; parser.end_of_input()?; Ok(rty::RefinedType::new(ty, formula.into())) } @@ -622,7 +745,11 @@ where return Ok(AnnotFormula::Auto); } } - self.parse_formula().map(AnnotFormula::Formula) + let formula = self + .parse_formula_or_term()? + .into_formula() + .ok_or_else(|| ParseAttrError::unexpected_term("in annotation"))?; + Ok(AnnotFormula::Formula(formula)) } } diff --git a/tests/ui/fail/annot.rs b/tests/ui/fail/annot.rs index bf5f98e..278524d 100644 --- a/tests/ui/fail/annot.rs +++ b/tests/ui/fail/annot.rs @@ -10,6 +10,19 @@ fn rand_except(x: i64) -> i64 { } } +#[thrust::requires(true)] +#[thrust::ensures((result == 1) || (result == -1) && result == 0)] +fn f(x: i64) -> i64 { + let y = rand_except(x); + if y > x { + 1 + } else if y < x { + -1 + } else { + 0 + } +} + fn main() { assert!(rand_except(1) == 0); } diff --git a/tests/ui/pass/annot.rs b/tests/ui/pass/annot.rs index 5b64773..c2e3e73 100644 --- a/tests/ui/pass/annot.rs +++ b/tests/ui/pass/annot.rs @@ -10,6 +10,19 @@ fn rand_except(x: i64) -> i64 { } } +#[thrust::requires(true)] +#[thrust::ensures((result == 1) || (result == -1))] +fn f(x: i64) -> i64 { + let y = rand_except(x); + if y > x { + 1 + } else if y < x { + -1 + } else { + 0 + } +} + fn main() { assert!(rand_except(1) != 1); } From 317955f23411f0f6e4a8033a1abada8d7b5f8a65 Mon Sep 17 00:00:00 2001 From: coord_e Date: Mon, 15 Sep 2025 13:04:29 +0900 Subject: [PATCH 14/75] Handle associativity --- src/annot.rs | 137 ++++++++++++++++++++++++--------------------------- 1 file changed, 65 insertions(+), 72 deletions(-) diff --git a/src/annot.rs b/src/annot.rs index 757eb68..f559bb8 100644 --- a/src/annot.rs +++ b/src/annot.rs @@ -412,56 +412,55 @@ where } fn parse_binop_1(&mut self) -> Result> { - let lhs = self.parse_prefix()?; + let mut formula_or_term = self.parse_prefix()?; - let formula_or_term = match self.look_ahead_token(0).map(|t| &t.kind) { - Some(TokenKind::BinOp(BinOpToken::Star)) => { - self.consume(); - let (lhs, _) = lhs - .into_term() - .ok_or_else(|| ParseAttrError::unexpected_formula("before * operator"))?; - let (rhs, _) = self - .parse_prefix()? - .into_term() - .ok_or_else(|| ParseAttrError::unexpected_formula("after * operator"))?; - FormulaOrTerm::Term(lhs.mul(rhs), chc::Sort::int()) - } - _ => return Ok(lhs), - }; + while let Some(TokenKind::BinOp(BinOpToken::Star)) = + self.look_ahead_token(0).map(|t| &t.kind) + { + self.consume(); + let (lhs, _) = formula_or_term + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("before * operator"))?; + let (rhs, _) = self + .parse_prefix()? + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("after * operator"))?; + formula_or_term = FormulaOrTerm::Term(lhs.mul(rhs), chc::Sort::int()) + } Ok(formula_or_term) } fn parse_binop_2(&mut self) -> Result> { - let lhs = self.parse_binop_1()?; + let mut formula_or_term = self.parse_binop_1()?; - let formula_or_term = match self.look_ahead_token(0).map(|t| &t.kind) { - Some(TokenKind::BinOp(BinOpToken::Plus)) => { - self.consume(); - let (lhs, _) = lhs - .into_term() - .ok_or_else(|| ParseAttrError::unexpected_formula("before + operator"))?; - let (rhs, _) = self - .parse_binop_1()? - .into_term() - .ok_or_else(|| ParseAttrError::unexpected_formula("after + operator"))?; - FormulaOrTerm::Term(lhs.add(rhs), chc::Sort::int()) - } - Some(TokenKind::BinOp(BinOpToken::Minus)) => { - self.consume(); - let (lhs, _) = lhs - .into_term() - .ok_or_else(|| ParseAttrError::unexpected_formula("before - operator"))?; - let (rhs, _) = self - .parse_binop_1()? - .into_term() - .ok_or_else(|| ParseAttrError::unexpected_formula("after - operator"))?; - FormulaOrTerm::Term(lhs.sub(rhs), chc::Sort::int()) + loop { + match self.look_ahead_token(0).map(|t| &t.kind) { + Some(TokenKind::BinOp(BinOpToken::Plus)) => { + self.consume(); + let (lhs, _) = formula_or_term + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("before + operator"))?; + let (rhs, _) = self + .parse_binop_1()? + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("after + operator"))?; + formula_or_term = FormulaOrTerm::Term(lhs.add(rhs), chc::Sort::int()) + } + Some(TokenKind::BinOp(BinOpToken::Minus)) => { + self.consume(); + let (lhs, _) = formula_or_term + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("before - operator"))?; + let (rhs, _) = self + .parse_binop_1()? + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("after - operator"))?; + formula_or_term = FormulaOrTerm::Term(lhs.sub(rhs), chc::Sort::int()) + } + _ => return Ok(formula_or_term), } - _ => return Ok(lhs), - }; - - Ok(formula_or_term) + } } fn parse_binop_3(&mut self) -> Result> { @@ -541,43 +540,37 @@ where } fn parse_binop_4(&mut self) -> Result> { - let lhs = self.parse_binop_3()?; + let mut formula_or_term = self.parse_binop_3()?; - let formula_or_term = match self.look_ahead_token(0).map(|t| &t.kind) { - Some(TokenKind::AndAnd) => { - self.consume(); - let lhs = lhs - .into_formula() - .ok_or_else(|| ParseAttrError::unexpected_term("before && operator"))?; - let rhs = self - .parse_binop_3()? - .into_formula() - .ok_or_else(|| ParseAttrError::unexpected_term("after && operator"))?; - FormulaOrTerm::Formula(lhs.and(rhs)) - } - _ => return Ok(lhs), - }; + while let Some(TokenKind::AndAnd) = self.look_ahead_token(0).map(|t| &t.kind) { + self.consume(); + let lhs = formula_or_term + .into_formula() + .ok_or_else(|| ParseAttrError::unexpected_term("before && operator"))?; + let rhs = self + .parse_binop_3()? + .into_formula() + .ok_or_else(|| ParseAttrError::unexpected_term("after && operator"))?; + formula_or_term = FormulaOrTerm::Formula(lhs.and(rhs)) + } Ok(formula_or_term) } fn parse_binop_5(&mut self) -> Result> { - let lhs = self.parse_binop_4()?; + let mut formula_or_term = self.parse_binop_4()?; - let formula_or_term = match self.look_ahead_token(0).map(|t| &t.kind) { - Some(TokenKind::OrOr) => { - self.consume(); - let lhs = lhs - .into_formula() - .ok_or_else(|| ParseAttrError::unexpected_term("before || operator"))?; - let rhs = self - .parse_binop_4()? - .into_formula() - .ok_or_else(|| ParseAttrError::unexpected_term("after || operator"))?; - FormulaOrTerm::Formula(lhs.or(rhs)) - } - _ => return Ok(lhs), - }; + while let Some(TokenKind::OrOr) = self.look_ahead_token(0).map(|t| &t.kind) { + self.consume(); + let lhs = formula_or_term + .into_formula() + .ok_or_else(|| ParseAttrError::unexpected_term("before || operator"))?; + let rhs = self + .parse_binop_4()? + .into_formula() + .ok_or_else(|| ParseAttrError::unexpected_term("after || operator"))?; + formula_or_term = FormulaOrTerm::Formula(lhs.or(rhs)) + } Ok(formula_or_term) } From 0db538ce923a9398a04e3c912b5ca0df6516b2bc Mon Sep 17 00:00:00 2001 From: coord_e Date: Mon, 15 Sep 2025 13:36:58 +0900 Subject: [PATCH 15/75] Add GitHub Actions workflow to host docs in GitHub pages --- .github/workflows/docs.yml | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 .github/workflows/docs.yml diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 0000000..6488fff --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,30 @@ +on: + push: + branches: + - main + +permissions: {} + +concurrency: + group: ${{ github.workflow }} + +jobs: + docs: + runs-on: ubuntu-latest + permissions: + contents: read + pages: write + id-token: write + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + steps: + - uses: actions/checkout@v4 + - run: cargo doc --no-deps --document-private-items + - run: echo '' > target/doc/index.html + - uses: actions/configure-pages@v5 + - uses: actions/upload-pages-artifact@v3 + with: + path: 'target/doc' + - id: deployment + uses: actions/deploy-pages@v4 From 8487ab4218171f9c6316fe4e39c2d5362172da3b Mon Sep 17 00:00:00 2001 From: coord_e Date: Tue, 23 Sep 2025 17:15:36 +0900 Subject: [PATCH 16/75] Support enums with lifetime params --- src/analyze/crate_.rs | 58 ++++++++++++++++++++++++++--------- src/rty/params.rs | 4 +++ tests/ui/fail/adt_poly_ref.rs | 14 +++++++++ tests/ui/pass/adt_poly_ref.rs | 14 +++++++++ 4 files changed, 76 insertions(+), 14 deletions(-) create mode 100644 tests/ui/fail/adt_poly_ref.rs create mode 100644 tests/ui/pass/adt_poly_ref.rs diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 9a1fa67..fa7143b 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -251,6 +251,29 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { continue; }; let adt = self.tcx.adt_def(local_def_id); + + // The index of TyKind::ParamTy is based on the every generic parameters in + // the definition, including lifetimes. Given the following definition: + // + // struct X<'a, T> { f: &'a T } + // + // The type of field `f` is &T1 (not T0). However, in Thrust, we ignore lifetime + // parameters and the index of rty::ParamType is based on type parameters only. + // We're building a mapping from the original index to the new index here. + let generics = self.tcx.generics_of(local_def_id); + let mut type_param_mapping: std::collections::HashMap = + Default::default(); + for i in 0..generics.count() { + let generic_param = generics.param_at(i, self.tcx); + match generic_param.kind { + mir_ty::GenericParamDefKind::Lifetime => {} + mir_ty::GenericParamDefKind::Type { .. } => { + type_param_mapping.insert(i, type_param_mapping.len()); + } + mir_ty::GenericParamDefKind::Const { .. } => unimplemented!(), + } + } + let name = refine::datatype_symbol(self.tcx, local_def_id.to_def_id()); let variants: IndexVec<_, _> = adt .variants() @@ -264,7 +287,26 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .iter() .map(|field| { let field_ty = self.tcx.type_of(field.did).instantiate_identity(); - self.ctx.unrefined_ty(field_ty) + + // see the comment above about this mapping + let subst = rty::TypeParamSubst::new( + type_param_mapping + .iter() + .map(|(old, new)| { + let old = rty::TypeParamIdx::from(*old); + let new = + rty::ParamType::new(rty::TypeParamIdx::from(*new)); + (old, rty::RefinedType::unrefined(new.into())) + }) + .collect(), + ); + + // the subst doesn't contain refinements, so it's OK to take ty only + // after substitution + let mut field_rty = + rty::RefinedType::unrefined(self.ctx.unrefined_ty(field_ty)); + field_rty.subst_ty_params(&subst); + field_rty.ty }) .collect(); rty::EnumVariantDef { @@ -275,19 +317,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { }) .collect(); - let ty_params = adt - .all_fields() - .map(|f| self.tcx.type_of(f.did).instantiate_identity()) - .flat_map(|ty| { - if let mir_ty::TyKind::Param(p) = ty.kind() { - Some(p.index as usize) - } else { - None - } - }) - .max() - .map(|max| max + 1) - .unwrap_or(0); + let ty_params = type_param_mapping.len(); tracing::debug!(?local_def_id, ?name, ?ty_params, "ty_params count"); let def = rty::EnumDatatypeDef { diff --git a/src/rty/params.rs b/src/rty/params.rs index 17ebc2b..fa0daa3 100644 --- a/src/rty/params.rs +++ b/src/rty/params.rs @@ -71,6 +71,10 @@ impl std::ops::Index for TypeParamSubst { } impl TypeParamSubst { + pub fn new(subst: BTreeMap>) -> Self { + Self { subst } + } + pub fn singleton(idx: TypeParamIdx, ty: RefinedType) -> Self { let mut subst = BTreeMap::default(); subst.insert(idx, ty); diff --git a/tests/ui/fail/adt_poly_ref.rs b/tests/ui/fail/adt_poly_ref.rs new file mode 100644 index 0000000..8c42d4b --- /dev/null +++ b/tests/ui/fail/adt_poly_ref.rs @@ -0,0 +1,14 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +enum X<'a, T> { + A(&'a T), +} + +fn main() { + let i = 42; + let x = X::A(&i); + match x { + X::A(i) => assert!(*i == 41), + } +} diff --git a/tests/ui/pass/adt_poly_ref.rs b/tests/ui/pass/adt_poly_ref.rs new file mode 100644 index 0000000..f0e5e30 --- /dev/null +++ b/tests/ui/pass/adt_poly_ref.rs @@ -0,0 +1,14 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +enum X<'a, T> { + A(&'a T), +} + +fn main() { + let i = 42; + let x = X::A(&i); + match x { + X::A(i) => assert!(*i == 42), + } +} From acddb410ab4744315bdebd4a275e6897ff217748 Mon Sep 17 00:00:00 2001 From: coord_e Date: Tue, 23 Sep 2025 17:17:51 +0900 Subject: [PATCH 17/75] Rename TypeParams to TypeArgs --- src/analyze/basic_block.rs | 8 ++++---- src/refine/env.rs | 2 +- src/rty.rs | 6 +++--- src/rty/params.rs | 14 +++++++------- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index cbb6589..e189695 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -219,7 +219,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .into_iter() .map(|ty| rty::RefinedType::unrefined(ty.vacuous())); - let params: IndexVec<_, _> = args + let rty_args: IndexVec<_, _> = args .types() .map(|ty| { self.ctx @@ -230,15 +230,15 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { for (field_pty, mut variant_rty) in field_tys.clone().into_iter().zip(variant_rtys) { - variant_rty.instantiate_ty_params(params.clone()); + variant_rty.instantiate_ty_params(rty_args.clone()); let cs = self .env .relate_sub_refined_type(&field_pty.into(), &variant_rty.boxed()); self.ctx.extend_clauses(cs); } - let sort_args: Vec<_> = params.iter().map(|rty| rty.ty.to_sort()).collect(); - let ty = rty::EnumType::new(ty_sym.clone(), params).into(); + let sort_args: Vec<_> = rty_args.iter().map(|rty| rty.ty.to_sort()).collect(); + let ty = rty::EnumType::new(ty_sym.clone(), rty_args).into(); let mut builder = PlaceTypeBuilder::default(); let mut field_terms = Vec::new(); diff --git a/src/refine/env.rs b/src/refine/env.rs index e647765..5569485 100644 --- a/src/refine/env.rs +++ b/src/refine/env.rs @@ -930,7 +930,7 @@ impl Env { .field_tys() .map(|ty| rty::RefinedType::unrefined(ty.clone().vacuous()).boxed()); let got_tys = field_tys.iter().map(|ty| ty.clone().into()); - rty::unify_tys_params(expected_tys, got_tys).into_params(def.ty_params, |_| { + rty::unify_tys_params(expected_tys, got_tys).into_args(def.ty_params, |_| { panic!("var_type: should unify all params") }) }; diff --git a/src/rty.rs b/src/rty.rs index 72fede8..c706897 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -55,7 +55,7 @@ mod subtyping; pub use subtyping::{relate_sub_closed_type, ClauseScope, Subtyping}; mod params; -pub use params::{TypeParamIdx, TypeParamSubst, TypeParams}; +pub use params::{TypeArgs, TypeParamIdx, TypeParamSubst}; rustc_index::newtype_index! { /// An index representing function parameter. @@ -487,7 +487,7 @@ where } impl EnumType { - pub fn new(symbol: chc::DatatypeSymbol, args: TypeParams) -> Self { + pub fn new(symbol: chc::DatatypeSymbol, args: TypeArgs) -> Self { EnumType { symbol, args } } @@ -1372,7 +1372,7 @@ impl RefinedType { } } - pub fn instantiate_ty_params(&mut self, params: TypeParams) + pub fn instantiate_ty_params(&mut self, params: TypeArgs) where FV: chc::Var, { diff --git a/src/rty/params.rs b/src/rty/params.rs index fa0daa3..b57ff55 100644 --- a/src/rty/params.rs +++ b/src/rty/params.rs @@ -39,7 +39,7 @@ impl TypeParamIdx { } } -pub type TypeParams = IndexVec>; +pub type TypeArgs = IndexVec>; /// A substitution for type parameters that maps type parameters to refinement types. #[derive(Debug, Clone)] @@ -55,8 +55,8 @@ impl Default for TypeParamSubst { } } -impl From> for TypeParamSubst { - fn from(params: TypeParams) -> Self { +impl From> for TypeParamSubst { + fn from(params: TypeArgs) -> Self { let subst = params.into_iter_enumerated().collect(); Self { subst } } @@ -98,20 +98,20 @@ impl TypeParamSubst { } } - pub fn into_params(mut self, expected_len: usize, mut default: F) -> TypeParams + pub fn into_args(mut self, expected_len: usize, mut default: F) -> TypeArgs where T: chc::Var, F: FnMut(TypeParamIdx) -> RefinedType, { - let mut params = TypeParams::new(); + let mut args = TypeArgs::new(); for idx in 0..expected_len { let ty = self .subst .remove(&idx.into()) .unwrap_or_else(|| default(idx.into())); - params.push(ty); + args.push(ty); } - params + args } pub fn strip_refinement(self) -> TypeParamSubst { From 4ec987356f363a5f012a2e3d316dd62b1389ec46 Mon Sep 17 00:00:00 2001 From: coord_e Date: Fri, 24 Oct 2025 16:05:24 +0900 Subject: [PATCH 18/75] Refactor construction of type templates --- src/analyze.rs | 12 +- src/analyze/basic_block.rs | 38 +-- src/analyze/crate_.rs | 13 +- src/analyze/local_def.rs | 6 +- src/refine.rs | 2 +- src/refine/env.rs | 3 +- src/refine/template.rs | 509 +++++++++++++++++-------------------- 7 files changed, 276 insertions(+), 307 deletions(-) diff --git a/src/analyze.rs b/src/analyze.rs index 898df5b..3550294 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -123,22 +123,12 @@ pub struct Analyzer<'tcx> { enum_defs: Rc>>, } -impl<'tcx> crate::refine::TemplateTypeGenerator<'tcx> for Analyzer<'tcx> { - fn tcx(&self) -> TyCtxt<'tcx> { - self.tcx - } - +impl<'tcx> crate::refine::TemplateRegistry for Analyzer<'tcx> { fn register_template(&mut self, tmpl: rty::Template) -> rty::RefinedType { tmpl.into_refined_type(|pred_sig| self.generate_pred_var(pred_sig)) } } -impl<'tcx> crate::refine::UnrefinedTypeGenerator<'tcx> for Analyzer<'tcx> { - fn tcx(&self) -> TyCtxt<'tcx> { - self.tcx - } -} - impl<'tcx> Analyzer<'tcx> { pub fn generate_pred_var(&mut self, sig: chc::PredSig) -> chc::PredVarId { self.system diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index e189695..2258f33 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -14,7 +14,7 @@ use crate::chc; use crate::pretty::PrettyDisplayExt as _; use crate::refine::{ self, Assumption, BasicBlockType, Env, PlaceType, PlaceTypeBuilder, PlaceTypeVar, TempVarIdx, - TemplateTypeGenerator, UnrefinedTypeGenerator, Var, + TypeBuilder, Var, }; use crate::rty::{ self, ClauseBuilderExt as _, ClauseScope as _, ShiftExistential as _, Subtyping as _, @@ -222,9 +222,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let rty_args: IndexVec<_, _> = args .types() .map(|ty| { - self.ctx - .build_template_ty_with_scope(&self.env) - .refined_ty(ty) + TypeBuilder::new(self.tcx) + .for_template(&mut self.ctx) + .with_scope(&self.env) + .build_refined(ty) }) .collect(); for (field_pty, mut variant_rty) in @@ -237,7 +238,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.ctx.extend_clauses(cs); } - let sort_args: Vec<_> = rty_args.iter().map(|rty| rty.ty.to_sort()).collect(); + let sort_args: Vec<_> = + rty_args.iter().map(|rty| rty.ty.to_sort()).collect(); let ty = rty::EnumType::new(ty_sym.clone(), rty_args).into(); let mut builder = PlaceTypeBuilder::default(); @@ -433,7 +435,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let func_ty = match func.const_fn_def() { // TODO: move this to well-known defs? Some((def_id, args)) if self.is_box_new(def_id) => { - let inner_ty = self.ctx.build_template_ty().ty(args.type_at(0)).vacuous(); + let inner_ty = TypeBuilder::new(self.tcx) + .for_template(&mut self.ctx) + .build(args.type_at(0)) + .vacuous(); let param = rty::RefinedType::unrefined(inner_ty.clone()); let ret_term = chc::Term::box_(chc::Term::var(rty::FunctionParamIdx::from(0_usize))); @@ -444,7 +449,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { rty::FunctionType::new([param].into_iter().collect(), ret).into() } Some((def_id, args)) if self.is_mem_swap(def_id) => { - let inner_ty = self.ctx.unrefined_ty(args.type_at(0)).vacuous(); + let inner_ty = TypeBuilder::new(self.tcx).build(args.type_at(0)).vacuous(); let param1 = rty::RefinedType::unrefined(rty::PointerType::mut_to(inner_ty.clone()).into()); let param2 = @@ -531,7 +536,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } fn add_prophecy_var(&mut self, statement_index: usize, ty: mir_ty::Ty<'tcx>) { - let ty = self.ctx.unrefined_ty(ty); + let ty = TypeBuilder::new(self.tcx).build(ty); let temp_var = self.env.push_temp_var(ty.vacuous()); self.prophecy_vars.insert(statement_index, temp_var); tracing::debug!(stmt_idx = %statement_index, temp_var = ?temp_var, "add_prophecy_var"); @@ -552,7 +557,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { referent: mir::Place<'tcx>, prophecy_ty: mir_ty::Ty<'tcx>, ) -> rty::RefinedType { - let prophecy_ty = self.ctx.unrefined_ty(prophecy_ty); + let prophecy_ty = TypeBuilder::new(self.tcx).build(prophecy_ty); let prophecy = self.env.push_temp_var(prophecy_ty.vacuous()); let place = self.elaborate_place_for_borrow(&referent); self.env.borrow_place(place, prophecy).into() @@ -664,10 +669,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } let decl = self.local_decls[destination].clone(); - let rty = self - .ctx - .build_template_ty_with_scope(&self.env) - .refined_ty(decl.ty); + let rty = TypeBuilder::new(self.tcx) + .for_template(&mut self.ctx) + .with_scope(&self.env) + .build_refined(decl.ty); self.type_call(func.clone(), args.clone().into_iter().map(|a| a.node), &rty); self.bind_local(destination, rty); } @@ -738,9 +743,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { #[tracing::instrument(skip(self))] fn ret_template(&mut self) -> rty::RefinedType { let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty; - self.ctx - .build_template_ty_with_scope(&self.env) - .refined_ty(ret_ty) + TypeBuilder::new(self.tcx) + .for_template(&mut self.ctx) + .with_scope(&self.env) + .build_refined(ret_ty) } // TODO: remove this diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index fa7143b..727b496 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -11,7 +11,7 @@ use rustc_span::symbol::Ident; use crate::analyze; use crate::annot::{self, AnnotFormula, AnnotParser, ResolverExt as _}; use crate::chc; -use crate::refine::{self, TemplateTypeGenerator, UnrefinedTypeGenerator}; +use crate::refine::{self, TypeBuilder}; use crate::rty::{self, ClauseBuilderExt as _}; /// An implementation of local crate analysis. @@ -132,13 +132,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let mut param_resolver = analyze::annot::ParamResolver::default(); for (input_ident, input_ty) in self.tcx.fn_arg_names(def_id).iter().zip(sig.inputs()) { - let input_ty = self.ctx.unrefined_ty(*input_ty); + let input_ty = TypeBuilder::new(self.tcx).build(*input_ty); param_resolver.push_param(input_ident.name, input_ty.to_sort()); } let mut require_annot = self.extract_require_annot(¶m_resolver, def_id); let mut ensure_annot = { - let output_ty = self.ctx.unrefined_ty(sig.output()); + let output_ty = TypeBuilder::new(self.tcx).build(sig.output()); let resolver = annot::StackedResolver::default() .resolver(analyze::annot::ResultResolver::new(output_ty.to_sort())) .resolver((¶m_resolver).map(rty::RefinedTypeVar::Free)); @@ -175,7 +175,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.trusted.insert(def_id); } - let mut builder = self.ctx.build_function_template_ty(sig); + let mut builder = TypeBuilder::new(self.tcx).for_function_template(&mut self.ctx, sig); if let Some(AnnotFormula::Formula(require)) = require_annot { let formula = require.map_var(|idx| { if idx.index() == sig.inputs().len() - 1 { @@ -303,8 +303,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { // the subst doesn't contain refinements, so it's OK to take ty only // after substitution - let mut field_rty = - rty::RefinedType::unrefined(self.ctx.unrefined_ty(field_ty)); + let mut field_rty = rty::RefinedType::unrefined( + TypeBuilder::new(self.tcx).build(field_ty), + ); field_rty.subst_ty_params(&subst); field_rty.ty }) diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index ef5870e..c1ab72c 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -11,7 +11,7 @@ use rustc_span::def_id::LocalDefId; use crate::analyze; use crate::chc; use crate::pretty::PrettyDisplayExt as _; -use crate::refine::{BasicBlockType, TemplateTypeGenerator}; +use crate::refine::{BasicBlockType, TypeBuilder}; use crate::rty; /// An implementation of the typing of local definitions. @@ -306,7 +306,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } // function return type is basic block return type let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty; - let rty = self.ctx.basic_block_template_ty(live_locals, ret_ty); + let rty = TypeBuilder::new(self.tcx) + .for_template(&mut self.ctx) + .build_basic_block(live_locals, ret_ty); self.ctx.register_basic_block_ty(self.local_def_id, bb, rty); } } diff --git a/src/refine.rs b/src/refine.rs index 0371da1..4736b39 100644 --- a/src/refine.rs +++ b/src/refine.rs @@ -8,7 +8,7 @@ //! module and remove this one. mod template; -pub use template::{TemplateScope, TemplateTypeGenerator, UnrefinedTypeGenerator}; +pub use template::{TemplateRegistry, TemplateScope, TypeBuilder}; mod basic_block; pub use basic_block::BasicBlockType; diff --git a/src/refine/env.rs b/src/refine/env.rs index 5569485..a1edbc1 100644 --- a/src/refine/env.rs +++ b/src/refine/env.rs @@ -558,7 +558,8 @@ impl rty::ClauseScope for Env { } } -impl refine::TemplateScope for Env { +impl refine::TemplateScope for Env { + type Var = Var; fn build_template(&self) -> rty::TemplateBuilder { let mut builder = rty::TemplateBuilder::default(); for (v, sort) in self.dependencies() { diff --git a/src/refine/template.rs b/src/refine/template.rs index b6ae7b8..a2380e0 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -9,54 +9,146 @@ use crate::chc; use crate::refine; use crate::rty; -pub trait TemplateScope { - fn build_template(&self) -> rty::TemplateBuilder; +pub trait TemplateRegistry { + fn register_template(&mut self, tmpl: rty::Template) -> rty::RefinedType; } -impl TemplateScope for &U +impl TemplateRegistry for &mut T where - U: TemplateScope, + T: TemplateRegistry + ?Sized, { - fn build_template(&self) -> rty::TemplateBuilder { - U::build_template(self) + fn register_template(&mut self, tmpl: rty::Template) -> rty::RefinedType { + T::register_template(self, tmpl) + } +} + +#[derive(Clone, Default)] +pub struct EmptyTemplateScope; + +impl TemplateScope for EmptyTemplateScope { + type Var = rty::Closed; + fn build_template(&self) -> rty::TemplateBuilder { + rty::TemplateBuilder::default() } } -impl TemplateScope for rty::TemplateBuilder +pub trait TemplateScope { + type Var: chc::Var; + fn build_template(&self) -> rty::TemplateBuilder; +} + +impl TemplateScope for &T +where + T: TemplateScope, +{ + type Var = T::Var; + fn build_template(&self) -> rty::TemplateBuilder { + T::build_template(self) + } +} + +impl TemplateScope for rty::TemplateBuilder where T: chc::Var, { + type Var = T; fn build_template(&self) -> rty::TemplateBuilder { self.clone() } } -pub trait TemplateTypeGenerator<'tcx> { - fn tcx(&self) -> mir_ty::TyCtxt<'tcx>; - fn register_template(&mut self, tmpl: rty::Template) -> rty::RefinedType; +#[derive(Clone)] +pub struct TypeBuilder<'tcx> { + tcx: mir_ty::TyCtxt<'tcx>, +} - fn build_template_ty_with_scope(&mut self, scope: T) -> TemplateTypeBuilder { - TemplateTypeBuilder { - gen: self, - scope, - _marker: std::marker::PhantomData, +impl<'tcx> TypeBuilder<'tcx> { + pub fn new(tcx: mir_ty::TyCtxt<'tcx>) -> Self { + Self { tcx } + } + + // TODO: consolidate two impls + pub fn build(&self, ty: mir_ty::Ty<'tcx>) -> rty::Type { + match ty.kind() { + mir_ty::TyKind::Bool => rty::Type::bool(), + mir_ty::TyKind::Uint(_) | mir_ty::TyKind::Int(_) => rty::Type::int(), + mir_ty::TyKind::Str => rty::Type::string(), + mir_ty::TyKind::Ref(_, elem_ty, mutbl) => { + let elem_ty = self.build(*elem_ty); + match mutbl { + mir_ty::Mutability::Mut => rty::PointerType::mut_to(elem_ty).into(), + mir_ty::Mutability::Not => rty::PointerType::immut_to(elem_ty).into(), + } + } + mir_ty::TyKind::Tuple(ts) => { + // elaboration: all fields are boxed + let elems = ts + .iter() + .map(|ty| rty::PointerType::own(self.build(ty)).into()) + .collect(); + rty::TupleType::new(elems).into() + } + mir_ty::TyKind::Never => rty::Type::never(), + mir_ty::TyKind::Param(ty) => rty::ParamType::new(ty.index.into()).into(), + mir_ty::TyKind::FnPtr(sig) => { + // TODO: justification for skip_binder + let sig = sig.skip_binder(); + let params = sig + .inputs() + .iter() + .map(|ty| rty::RefinedType::unrefined(self.build(*ty)).vacuous()) + .collect(); + let ret = rty::RefinedType::unrefined(self.build(sig.output())); + rty::FunctionType::new(params, ret.vacuous()).into() + } + mir_ty::TyKind::Adt(def, params) if def.is_box() => { + rty::PointerType::own(self.build(params.type_at(0))).into() + } + mir_ty::TyKind::Adt(def, params) => { + if def.is_enum() { + let sym = refine::datatype_symbol(self.tcx, def.did()); + let args: IndexVec<_, _> = params + .types() + .map(|ty| rty::RefinedType::unrefined(self.build(ty))) + .collect(); + rty::EnumType::new(sym, args).into() + } else if def.is_struct() { + let elem_tys = def + .all_fields() + .map(|field| { + let ty = field.ty(self.tcx, params); + // elaboration: all fields are boxed + rty::PointerType::own(self.build(ty)).into() + }) + .collect(); + rty::TupleType::new(elem_tys).into() + } else { + unimplemented!("unsupported ADT: {:?}", ty); + } + } + kind => unimplemented!("unrefined_ty: {:?}", kind), } } - fn build_template_ty(&mut self) -> TemplateTypeBuilder, V> { + pub fn for_template<'a, R>( + &self, + registry: &'a mut R, + ) -> TemplateTypeBuilder<'tcx, 'a, R, EmptyTemplateScope> { TemplateTypeBuilder { - gen: self, + tcx: self.tcx, + registry, scope: Default::default(), - _marker: std::marker::PhantomData, } } - fn build_function_template_ty( - &mut self, + pub fn for_function_template<'a, R>( + &self, + registry: &'a mut R, sig: mir_ty::FnSig<'tcx>, - ) -> FunctionTemplateTypeBuilder<'_, 'tcx, Self> { + ) -> FunctionTemplateTypeBuilder<'tcx, 'a, R> { FunctionTemplateTypeBuilder { - gen: self, + tcx: self.tcx, + registry, param_tys: sig .inputs() .iter() @@ -71,12 +163,101 @@ pub trait TemplateTypeGenerator<'tcx> { ret_rty: None, } } +} + +pub struct TemplateTypeBuilder<'tcx, 'a, R, S> { + tcx: mir_ty::TyCtxt<'tcx>, + registry: &'a mut R, + scope: S, +} - fn build_basic_block_template_ty( +impl<'tcx, 'a, R, S> TemplateTypeBuilder<'tcx, 'a, R, S> { + pub fn with_scope(self, scope: T) -> TemplateTypeBuilder<'tcx, 'a, R, T> { + TemplateTypeBuilder { + tcx: self.tcx, + registry: self.registry, + scope, + } + } +} + +impl<'tcx, 'a, R, S> TemplateTypeBuilder<'tcx, 'a, R, S> +where + R: TemplateRegistry, + S: TemplateScope, +{ + pub fn build(&mut self, ty: mir_ty::Ty<'tcx>) -> rty::Type { + match ty.kind() { + mir_ty::TyKind::Bool => rty::Type::bool(), + mir_ty::TyKind::Uint(_) | mir_ty::TyKind::Int(_) => rty::Type::int(), + mir_ty::TyKind::Str => rty::Type::string(), + mir_ty::TyKind::Ref(_, elem_ty, mutbl) => { + let elem_ty = self.build(*elem_ty); + match mutbl { + mir_ty::Mutability::Mut => rty::PointerType::mut_to(elem_ty).into(), + mir_ty::Mutability::Not => rty::PointerType::immut_to(elem_ty).into(), + } + } + mir_ty::TyKind::Tuple(ts) => { + // elaboration: all fields are boxed + let elems = ts + .iter() + .map(|ty| rty::PointerType::own(self.build(ty)).into()) + .collect(); + rty::TupleType::new(elems).into() + } + mir_ty::TyKind::Never => rty::Type::never(), + mir_ty::TyKind::Param(ty) => rty::ParamType::new(ty.index.into()).into(), + mir_ty::TyKind::FnPtr(sig) => { + // TODO: justification for skip_binder + let sig = sig.skip_binder(); + let ty = TypeBuilder::new(self.tcx) + .for_function_template(self.registry, sig) + .build(); + rty::Type::function(ty) + } + mir_ty::TyKind::Adt(def, params) if def.is_box() => { + rty::PointerType::own(self.build(params.type_at(0))).into() + } + mir_ty::TyKind::Adt(def, params) => { + if def.is_enum() { + let sym = refine::datatype_symbol(self.tcx, def.did()); + let args: IndexVec<_, _> = + params.types().map(|ty| self.build_refined(ty)).collect(); + rty::EnumType::new(sym, args).into() + } else if def.is_struct() { + let elem_tys = def + .all_fields() + .map(|field| { + let ty = field.ty(self.tcx, params); + // elaboration: all fields are boxed + rty::PointerType::own(self.build(ty)).into() + }) + .collect(); + rty::TupleType::new(elem_tys).into() + } else { + unimplemented!("unsupported ADT: {:?}", ty); + } + } + kind => unimplemented!("ty: {:?}", kind), + } + } + + pub fn build_refined(&mut self, ty: mir_ty::Ty<'tcx>) -> rty::RefinedType { + // TODO: consider building ty with scope + let ty = TypeBuilder::new(self.tcx) + .for_template(self.registry) + .build(ty) + .vacuous(); + let tmpl = self.scope.build_template().build(ty); + self.registry.register_template(tmpl) + } + + pub fn build_basic_block( &mut self, live_locals: I, ret_ty: mir_ty::Ty<'tcx>, - ) -> BasicBlockTemplateTypeBuilder<'_, 'tcx, Self> + ) -> BasicBlockType where I: IntoIterator)>, { @@ -87,51 +268,23 @@ pub trait TemplateTypeGenerator<'tcx> { locals.push((local, ty.mutbl)); tys.push(ty); } - let inner = FunctionTemplateTypeBuilder { - gen: self, + let ty = FunctionTemplateTypeBuilder { + tcx: self.tcx, + registry: self.registry, param_tys: tys, ret_ty, param_rtys: Default::default(), param_refinement: None, ret_rty: None, - }; - BasicBlockTemplateTypeBuilder { inner, locals } - } - - fn basic_block_template_ty( - &mut self, - live_locals: I, - ret_ty: mir_ty::Ty<'tcx>, - ) -> BasicBlockType - where - I: IntoIterator)>, - { - self.build_basic_block_template_ty(live_locals, ret_ty) - .build() - } - - fn function_template_ty(&mut self, sig: mir_ty::FnSig<'tcx>) -> rty::FunctionType { - self.build_function_template_ty(sig).build() - } -} - -impl<'tcx, T> TemplateTypeGenerator<'tcx> for &mut T -where - T: TemplateTypeGenerator<'tcx> + ?Sized, -{ - fn tcx(&self) -> mir_ty::TyCtxt<'tcx> { - T::tcx(self) - } - - fn register_template(&mut self, tmpl: rty::Template) -> rty::RefinedType { - T::register_template(self, tmpl) + } + .build(); + BasicBlockType { ty, locals } } } -#[derive(Debug)] -pub struct FunctionTemplateTypeBuilder<'a, 'tcx, T: ?Sized> { - // can't use T: TemplateTypeGenerator<'tcx> directly because of recursive instantiation - gen: &'a mut T, +pub struct FunctionTemplateTypeBuilder<'tcx, 'a, R> { + tcx: mir_ty::TyCtxt<'tcx>, + registry: &'a mut R, param_tys: Vec>, ret_ty: mir_ty::Ty<'tcx>, param_refinement: Option>, @@ -139,10 +292,7 @@ pub struct FunctionTemplateTypeBuilder<'a, 'tcx, T: ?Sized> { ret_rty: Option>, } -impl<'a, 'tcx, T> FunctionTemplateTypeBuilder<'a, 'tcx, T> -where - T: TemplateTypeGenerator<'tcx> + ?Sized, -{ +impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { pub fn param_refinement( &mut self, refinement: rty::Refinement, @@ -174,7 +324,7 @@ where &mut self, refinement: rty::Refinement, ) -> &mut Self { - let ty = UnrefinedTypeGeneratorWrapper(&mut self.gen).unrefined_ty(self.ret_ty); + let ty = TypeBuilder::new(self.tcx).build(self.ret_ty); self.ret_rty = Some(rty::RefinedType::new(ty.vacuous(), refinement)); self } @@ -183,7 +333,12 @@ where self.ret_rty = Some(rty); self } +} +impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> +where + R: TemplateRegistry, +{ pub fn build(&mut self) -> rty::FunctionType { let mut builder = rty::TemplateBuilder::default(); let mut param_rtys = IndexVec::::new(); @@ -195,19 +350,20 @@ where .unwrap_or_else(|| { if idx == self.param_tys.len() - 1 { if let Some(param_refinement) = &self.param_refinement { - let ty = UnrefinedTypeGeneratorWrapper(&mut self.gen) - .unrefined_ty(param_ty.ty); + let ty = TypeBuilder::new(self.tcx).build(param_ty.ty); rty::RefinedType::new(ty.vacuous(), param_refinement.clone()) } else { - self.gen - .build_template_ty_with_scope(&builder) - .refined_ty(param_ty.ty) + TypeBuilder::new(self.tcx) + .for_template(self.registry) + .with_scope(&builder) + .build_refined(param_ty.ty) } } else { rty::RefinedType::unrefined( - self.gen - .build_template_ty_with_scope(&builder) - .ty(param_ty.ty), + TypeBuilder::new(self.tcx) + .for_template(self.registry) + .with_scope(&builder) + .build(param_ty.ty), ) } }); @@ -227,208 +383,21 @@ where let param_rty = if let Some(param_refinement) = &self.param_refinement { rty::RefinedType::new(rty::Type::unit(), param_refinement.clone()) } else { - let unit_ty = mir_ty::Ty::new_unit(self.gen.tcx()); - self.gen - .build_template_ty_with_scope(&builder) - .refined_ty(unit_ty) + let unit_ty = mir_ty::Ty::new_unit(self.tcx); + TypeBuilder::new(self.tcx) + .for_template(self.registry) + .with_scope(&builder) + .build_refined(unit_ty) }; param_rtys.push(param_rty); } let ret_rty = self.ret_rty.clone().unwrap_or_else(|| { - self.gen - .build_template_ty_with_scope(&builder) - .refined_ty(self.ret_ty) + TypeBuilder::new(self.tcx) + .for_template(self.registry) + .with_scope(&builder) + .build_refined(self.ret_ty) }); rty::FunctionType::new(param_rtys, ret_rty) } } - -#[derive(Debug)] -pub struct BasicBlockTemplateTypeBuilder<'a, 'tcx, T: ?Sized> { - inner: FunctionTemplateTypeBuilder<'a, 'tcx, T>, - locals: IndexVec, -} - -impl<'a, 'tcx, T> BasicBlockTemplateTypeBuilder<'a, 'tcx, T> -where - T: TemplateTypeGenerator<'tcx> + ?Sized, -{ - #[allow(dead_code)] - pub fn param_refinement( - &mut self, - refinement: rty::Refinement, - ) -> &mut Self { - self.inner.param_refinement(refinement); - self - } - - #[allow(dead_code)] - pub fn ret_rty(&mut self, rty: rty::RefinedType) -> &mut Self { - self.inner.ret_rty(rty); - self - } - - pub fn build(&mut self) -> BasicBlockType { - let ty = self.inner.build(); - BasicBlockType { - ty, - locals: self.locals.clone(), - } - } -} - -#[derive(Debug)] -pub struct TemplateTypeBuilder<'a, T: ?Sized, U, V> { - // can't use T: TemplateTypeGenerator<'tcx> directly because of recursive instantiation - gen: &'a mut T, - scope: U, - _marker: std::marker::PhantomData V>, -} - -impl<'a, 'tcx, T, U, V> TemplateTypeBuilder<'a, T, U, V> -where - T: TemplateTypeGenerator<'tcx> + ?Sized, - U: TemplateScope, - V: chc::Var, -{ - pub fn ty(&mut self, ty: mir_ty::Ty<'tcx>) -> rty::Type { - match ty.kind() { - mir_ty::TyKind::Bool => rty::Type::bool(), - mir_ty::TyKind::Uint(_) | mir_ty::TyKind::Int(_) => rty::Type::int(), - mir_ty::TyKind::Str => rty::Type::string(), - mir_ty::TyKind::Ref(_, elem_ty, mutbl) => { - let elem_ty = self.ty(*elem_ty); - match mutbl { - mir_ty::Mutability::Mut => rty::PointerType::mut_to(elem_ty).into(), - mir_ty::Mutability::Not => rty::PointerType::immut_to(elem_ty).into(), - } - } - mir_ty::TyKind::Tuple(ts) => { - // elaboration: all fields are boxed - let elems = ts - .iter() - .map(|ty| rty::PointerType::own(self.ty(ty)).into()) - .collect(); - rty::TupleType::new(elems).into() - } - mir_ty::TyKind::Never => rty::Type::never(), - mir_ty::TyKind::Param(ty) => rty::ParamType::new(ty.index.into()).into(), - mir_ty::TyKind::FnPtr(sig) => { - // TODO: justification for skip_binder - let sig = sig.skip_binder(); - let ty = self.gen.function_template_ty(sig); - rty::Type::function(ty) - } - mir_ty::TyKind::Adt(def, params) if def.is_box() => { - rty::PointerType::own(self.ty(params.type_at(0))).into() - } - mir_ty::TyKind::Adt(def, params) => { - if def.is_enum() { - let sym = refine::datatype_symbol(self.gen.tcx(), def.did()); - let args: IndexVec<_, _> = - params.types().map(|ty| self.refined_ty(ty)).collect(); - rty::EnumType::new(sym, args).into() - } else if def.is_struct() { - let elem_tys = def - .all_fields() - .map(|field| { - let ty = field.ty(self.gen.tcx(), params); - // elaboration: all fields are boxed - rty::PointerType::own(self.ty(ty)).into() - }) - .collect(); - rty::TupleType::new(elem_tys).into() - } else { - unimplemented!("unsupported ADT: {:?}", ty); - } - } - kind => unimplemented!("ty: {:?}", kind), - } - } - - pub fn refined_ty(&mut self, ty: mir_ty::Ty<'tcx>) -> rty::RefinedType { - // TODO: consider building ty with scope - let ty = self.gen.build_template_ty().ty(ty); - let tmpl = self.scope.build_template().build(ty); - self.gen.register_template(tmpl) - } -} - -pub trait UnrefinedTypeGenerator<'tcx> { - fn tcx(&self) -> mir_ty::TyCtxt<'tcx>; - - // TODO: consolidate two defs - fn unrefined_ty(&mut self, ty: mir_ty::Ty<'tcx>) -> rty::Type { - match ty.kind() { - mir_ty::TyKind::Bool => rty::Type::bool(), - mir_ty::TyKind::Uint(_) | mir_ty::TyKind::Int(_) => rty::Type::int(), - mir_ty::TyKind::Str => rty::Type::string(), - mir_ty::TyKind::Ref(_, elem_ty, mutbl) => { - let elem_ty = self.unrefined_ty(*elem_ty); - match mutbl { - mir_ty::Mutability::Mut => rty::PointerType::mut_to(elem_ty).into(), - mir_ty::Mutability::Not => rty::PointerType::immut_to(elem_ty).into(), - } - } - mir_ty::TyKind::Tuple(ts) => { - // elaboration: all fields are boxed - let elems = ts - .iter() - .map(|ty| rty::PointerType::own(self.unrefined_ty(ty)).into()) - .collect(); - rty::TupleType::new(elems).into() - } - mir_ty::TyKind::Never => rty::Type::never(), - mir_ty::TyKind::Param(ty) => rty::ParamType::new(ty.index.into()).into(), - mir_ty::TyKind::FnPtr(sig) => { - // TODO: justification for skip_binder - let sig = sig.skip_binder(); - let params = sig - .inputs() - .iter() - .map(|ty| rty::RefinedType::unrefined(self.unrefined_ty(*ty)).vacuous()) - .collect(); - let ret = rty::RefinedType::unrefined(self.unrefined_ty(sig.output())); - rty::FunctionType::new(params, ret.vacuous()).into() - } - mir_ty::TyKind::Adt(def, params) if def.is_box() => { - rty::PointerType::own(self.unrefined_ty(params.type_at(0))).into() - } - mir_ty::TyKind::Adt(def, params) => { - if def.is_enum() { - let sym = refine::datatype_symbol(self.tcx(), def.did()); - let args: IndexVec<_, _> = params - .types() - .map(|ty| rty::RefinedType::unrefined(self.unrefined_ty(ty))) - .collect(); - rty::EnumType::new(sym, args).into() - } else if def.is_struct() { - let elem_tys = def - .all_fields() - .map(|field| { - let ty = field.ty(self.tcx(), params); - // elaboration: all fields are boxed - rty::PointerType::own(self.unrefined_ty(ty)).into() - }) - .collect(); - rty::TupleType::new(elem_tys).into() - } else { - unimplemented!("unsupported ADT: {:?}", ty); - } - } - kind => unimplemented!("unrefined_ty: {:?}", kind), - } - } -} - -struct UnrefinedTypeGeneratorWrapper(T); - -impl<'tcx, T> UnrefinedTypeGenerator<'tcx> for UnrefinedTypeGeneratorWrapper -where - T: TemplateTypeGenerator<'tcx>, -{ - fn tcx(&self) -> mir_ty::TyCtxt<'tcx> { - self.0.tcx() - } -} From 69d34556691a4b47138ad74789a4791545c3a941 Mon Sep 17 00:00:00 2001 From: coord_e Date: Fri, 24 Oct 2025 17:21:14 +0900 Subject: [PATCH 19/75] Handle parameter shifting in TypeBuilder --- src/analyze/basic_block.rs | 20 ++++++---- src/analyze/crate_.rs | 61 +++++++---------------------- src/analyze/local_def.rs | 2 +- src/refine/template.rs | 80 +++++++++++++++++++++++++------------- 4 files changed, 82 insertions(+), 81 deletions(-) diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 2258f33..7b346ce 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -56,6 +56,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.ctx.basic_block_ty(self.local_def_id, bb) } + fn type_builder(&self) -> TypeBuilder<'tcx> { + TypeBuilder::new(self.tcx, self.local_def_id.to_def_id()) + } + fn bind_local(&mut self, local: Local, rty: rty::RefinedType) { let rty = if self.is_mut_local(local) { // elaboration: @@ -222,7 +226,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let rty_args: IndexVec<_, _> = args .types() .map(|ty| { - TypeBuilder::new(self.tcx) + self.type_builder() .for_template(&mut self.ctx) .with_scope(&self.env) .build_refined(ty) @@ -435,7 +439,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let func_ty = match func.const_fn_def() { // TODO: move this to well-known defs? Some((def_id, args)) if self.is_box_new(def_id) => { - let inner_ty = TypeBuilder::new(self.tcx) + let inner_ty = self + .type_builder() .for_template(&mut self.ctx) .build(args.type_at(0)) .vacuous(); @@ -449,7 +454,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { rty::FunctionType::new([param].into_iter().collect(), ret).into() } Some((def_id, args)) if self.is_mem_swap(def_id) => { - let inner_ty = TypeBuilder::new(self.tcx).build(args.type_at(0)).vacuous(); + let inner_ty = self.type_builder().build(args.type_at(0)).vacuous(); let param1 = rty::RefinedType::unrefined(rty::PointerType::mut_to(inner_ty.clone()).into()); let param2 = @@ -536,7 +541,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } fn add_prophecy_var(&mut self, statement_index: usize, ty: mir_ty::Ty<'tcx>) { - let ty = TypeBuilder::new(self.tcx).build(ty); + let ty = self.type_builder().build(ty); let temp_var = self.env.push_temp_var(ty.vacuous()); self.prophecy_vars.insert(statement_index, temp_var); tracing::debug!(stmt_idx = %statement_index, temp_var = ?temp_var, "add_prophecy_var"); @@ -557,7 +562,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { referent: mir::Place<'tcx>, prophecy_ty: mir_ty::Ty<'tcx>, ) -> rty::RefinedType { - let prophecy_ty = TypeBuilder::new(self.tcx).build(prophecy_ty); + let prophecy_ty = self.type_builder().build(prophecy_ty); let prophecy = self.env.push_temp_var(prophecy_ty.vacuous()); let place = self.elaborate_place_for_borrow(&referent); self.env.borrow_place(place, prophecy).into() @@ -669,7 +674,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } let decl = self.local_decls[destination].clone(); - let rty = TypeBuilder::new(self.tcx) + let rty = self + .type_builder() .for_template(&mut self.ctx) .with_scope(&self.env) .build_refined(decl.ty); @@ -743,7 +749,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { #[tracing::instrument(skip(self))] fn ret_template(&mut self) -> rty::RefinedType { let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty; - TypeBuilder::new(self.tcx) + self.type_builder() .for_template(&mut self.ctx) .with_scope(&self.env) .build_refined(ret_ty) diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 727b496..54a02d1 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -132,13 +132,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let mut param_resolver = analyze::annot::ParamResolver::default(); for (input_ident, input_ty) in self.tcx.fn_arg_names(def_id).iter().zip(sig.inputs()) { - let input_ty = TypeBuilder::new(self.tcx).build(*input_ty); + let input_ty = TypeBuilder::new(self.tcx, def_id).build(*input_ty); param_resolver.push_param(input_ident.name, input_ty.to_sort()); } let mut require_annot = self.extract_require_annot(¶m_resolver, def_id); let mut ensure_annot = { - let output_ty = TypeBuilder::new(self.tcx).build(sig.output()); + let output_ty = TypeBuilder::new(self.tcx, def_id).build(sig.output()); let resolver = annot::StackedResolver::default() .resolver(analyze::annot::ResultResolver::new(output_ty.to_sort())) .resolver((¶m_resolver).map(rty::RefinedTypeVar::Free)); @@ -175,7 +175,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.trusted.insert(def_id); } - let mut builder = TypeBuilder::new(self.tcx).for_function_template(&mut self.ctx, sig); + let mut builder = + TypeBuilder::new(self.tcx, def_id).for_function_template(&mut self.ctx, sig); if let Some(AnnotFormula::Formula(require)) = require_annot { let formula = require.map_var(|idx| { if idx.index() == sig.inputs().len() - 1 { @@ -252,28 +253,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { }; let adt = self.tcx.adt_def(local_def_id); - // The index of TyKind::ParamTy is based on the every generic parameters in - // the definition, including lifetimes. Given the following definition: - // - // struct X<'a, T> { f: &'a T } - // - // The type of field `f` is &T1 (not T0). However, in Thrust, we ignore lifetime - // parameters and the index of rty::ParamType is based on type parameters only. - // We're building a mapping from the original index to the new index here. - let generics = self.tcx.generics_of(local_def_id); - let mut type_param_mapping: std::collections::HashMap = - Default::default(); - for i in 0..generics.count() { - let generic_param = generics.param_at(i, self.tcx); - match generic_param.kind { - mir_ty::GenericParamDefKind::Lifetime => {} - mir_ty::GenericParamDefKind::Type { .. } => { - type_param_mapping.insert(i, type_param_mapping.len()); - } - mir_ty::GenericParamDefKind::Const { .. } => unimplemented!(), - } - } - let name = refine::datatype_symbol(self.tcx, local_def_id.to_def_id()); let variants: IndexVec<_, _> = adt .variants() @@ -287,27 +266,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .iter() .map(|field| { let field_ty = self.tcx.type_of(field.did).instantiate_identity(); - - // see the comment above about this mapping - let subst = rty::TypeParamSubst::new( - type_param_mapping - .iter() - .map(|(old, new)| { - let old = rty::TypeParamIdx::from(*old); - let new = - rty::ParamType::new(rty::TypeParamIdx::from(*new)); - (old, rty::RefinedType::unrefined(new.into())) - }) - .collect(), - ); - - // the subst doesn't contain refinements, so it's OK to take ty only - // after substitution - let mut field_rty = rty::RefinedType::unrefined( - TypeBuilder::new(self.tcx).build(field_ty), - ); - field_rty.subst_ty_params(&subst); - field_rty.ty + TypeBuilder::new(self.tcx, local_def_id.to_def_id()).build(field_ty) }) .collect(); rty::EnumVariantDef { @@ -318,7 +277,15 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { }) .collect(); - let ty_params = type_param_mapping.len(); + let generics = self.tcx.generics_of(local_def_id); + let ty_params = (0..generics.count()) + .filter(|idx| { + matches!( + generics.param_at(*idx, self.tcx).kind, + mir_ty::GenericParamDefKind::Type { .. } + ) + }) + .count(); tracing::debug!(?local_def_id, ?name, ?ty_params, "ty_params count"); let def = rty::EnumDatatypeDef { diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index c1ab72c..7e7c737 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -306,7 +306,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } // function return type is basic block return type let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty; - let rty = TypeBuilder::new(self.tcx) + let rty = TypeBuilder::new(self.tcx, self.local_def_id.to_def_id()) .for_template(&mut self.ctx) .build_basic_block(live_locals, ret_ty); self.ctx.register_basic_block_ty(self.local_def_id, bb, rty); diff --git a/src/refine/template.rs b/src/refine/template.rs index a2380e0..1614b38 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use rustc_index::IndexVec; use rustc_middle::mir::{Local, Mutability}; use rustc_middle::ty as mir_ty; +use rustc_span::def_id::DefId; use super::basic_block::BasicBlockType; use crate::chc; @@ -60,11 +61,43 @@ where #[derive(Clone)] pub struct TypeBuilder<'tcx> { tcx: mir_ty::TyCtxt<'tcx>, + type_param_mapping: HashMap, } impl<'tcx> TypeBuilder<'tcx> { - pub fn new(tcx: mir_ty::TyCtxt<'tcx>) -> Self { - Self { tcx } + pub fn new(tcx: mir_ty::TyCtxt<'tcx>, def_id: DefId) -> Self { + // The index of TyKind::ParamTy is based on the every generic parameters in + // the definition, including lifetimes. Given the following definition: + // + // struct X<'a, T> { f: &'a T } + // + // The type of field `f` is &T1 (not T0). However, in Thrust, we ignore lifetime + // parameters and the index of rty::ParamType is based on type parameters only. + // We're building a mapping from the original index to the new index here. + let generics = tcx.generics_of(def_id); + let mut type_param_mapping: HashMap = Default::default(); + for i in 0..generics.count() { + let generic_param = generics.param_at(i, tcx); + match generic_param.kind { + mir_ty::GenericParamDefKind::Lifetime => {} + mir_ty::GenericParamDefKind::Type { .. } => { + type_param_mapping.insert(i as u32, type_param_mapping.len().into()); + } + mir_ty::GenericParamDefKind::Const { .. } => unimplemented!(), + } + } + Self { + tcx, + type_param_mapping, + } + } + + fn translate_param_type(&self, ty: &mir_ty::ParamTy) -> rty::ParamType { + let index = *self + .type_param_mapping + .get(&ty.index) + .expect("unknown type param idx"); + rty::ParamType::new(index) } // TODO: consolidate two impls @@ -89,7 +122,7 @@ impl<'tcx> TypeBuilder<'tcx> { rty::TupleType::new(elems).into() } mir_ty::TyKind::Never => rty::Type::never(), - mir_ty::TyKind::Param(ty) => rty::ParamType::new(ty.index.into()).into(), + mir_ty::TyKind::Param(ty) => self.translate_param_type(ty).into(), mir_ty::TyKind::FnPtr(sig) => { // TODO: justification for skip_binder let sig = sig.skip_binder(); @@ -135,7 +168,7 @@ impl<'tcx> TypeBuilder<'tcx> { registry: &'a mut R, ) -> TemplateTypeBuilder<'tcx, 'a, R, EmptyTemplateScope> { TemplateTypeBuilder { - tcx: self.tcx, + inner: self.clone(), registry, scope: Default::default(), } @@ -147,7 +180,7 @@ impl<'tcx> TypeBuilder<'tcx> { sig: mir_ty::FnSig<'tcx>, ) -> FunctionTemplateTypeBuilder<'tcx, 'a, R> { FunctionTemplateTypeBuilder { - tcx: self.tcx, + inner: self.clone(), registry, param_tys: sig .inputs() @@ -166,7 +199,7 @@ impl<'tcx> TypeBuilder<'tcx> { } pub struct TemplateTypeBuilder<'tcx, 'a, R, S> { - tcx: mir_ty::TyCtxt<'tcx>, + inner: TypeBuilder<'tcx>, registry: &'a mut R, scope: S, } @@ -174,7 +207,7 @@ pub struct TemplateTypeBuilder<'tcx, 'a, R, S> { impl<'tcx, 'a, R, S> TemplateTypeBuilder<'tcx, 'a, R, S> { pub fn with_scope(self, scope: T) -> TemplateTypeBuilder<'tcx, 'a, R, T> { TemplateTypeBuilder { - tcx: self.tcx, + inner: self.inner, registry: self.registry, scope, } @@ -207,13 +240,11 @@ where rty::TupleType::new(elems).into() } mir_ty::TyKind::Never => rty::Type::never(), - mir_ty::TyKind::Param(ty) => rty::ParamType::new(ty.index.into()).into(), + mir_ty::TyKind::Param(ty) => self.inner.translate_param_type(ty).into(), mir_ty::TyKind::FnPtr(sig) => { // TODO: justification for skip_binder let sig = sig.skip_binder(); - let ty = TypeBuilder::new(self.tcx) - .for_function_template(self.registry, sig) - .build(); + let ty = self.inner.for_function_template(self.registry, sig).build(); rty::Type::function(ty) } mir_ty::TyKind::Adt(def, params) if def.is_box() => { @@ -221,7 +252,7 @@ where } mir_ty::TyKind::Adt(def, params) => { if def.is_enum() { - let sym = refine::datatype_symbol(self.tcx, def.did()); + let sym = refine::datatype_symbol(self.inner.tcx, def.did()); let args: IndexVec<_, _> = params.types().map(|ty| self.build_refined(ty)).collect(); rty::EnumType::new(sym, args).into() @@ -229,7 +260,7 @@ where let elem_tys = def .all_fields() .map(|field| { - let ty = field.ty(self.tcx, params); + let ty = field.ty(self.inner.tcx, params); // elaboration: all fields are boxed rty::PointerType::own(self.build(ty)).into() }) @@ -245,10 +276,7 @@ where pub fn build_refined(&mut self, ty: mir_ty::Ty<'tcx>) -> rty::RefinedType { // TODO: consider building ty with scope - let ty = TypeBuilder::new(self.tcx) - .for_template(self.registry) - .build(ty) - .vacuous(); + let ty = self.inner.for_template(self.registry).build(ty).vacuous(); let tmpl = self.scope.build_template().build(ty); self.registry.register_template(tmpl) } @@ -269,7 +297,7 @@ where tys.push(ty); } let ty = FunctionTemplateTypeBuilder { - tcx: self.tcx, + inner: self.inner.clone(), registry: self.registry, param_tys: tys, ret_ty, @@ -283,7 +311,7 @@ where } pub struct FunctionTemplateTypeBuilder<'tcx, 'a, R> { - tcx: mir_ty::TyCtxt<'tcx>, + inner: TypeBuilder<'tcx>, registry: &'a mut R, param_tys: Vec>, ret_ty: mir_ty::Ty<'tcx>, @@ -324,7 +352,7 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { &mut self, refinement: rty::Refinement, ) -> &mut Self { - let ty = TypeBuilder::new(self.tcx).build(self.ret_ty); + let ty = self.inner.build(self.ret_ty); self.ret_rty = Some(rty::RefinedType::new(ty.vacuous(), refinement)); self } @@ -350,17 +378,17 @@ where .unwrap_or_else(|| { if idx == self.param_tys.len() - 1 { if let Some(param_refinement) = &self.param_refinement { - let ty = TypeBuilder::new(self.tcx).build(param_ty.ty); + let ty = self.inner.build(param_ty.ty); rty::RefinedType::new(ty.vacuous(), param_refinement.clone()) } else { - TypeBuilder::new(self.tcx) + self.inner .for_template(self.registry) .with_scope(&builder) .build_refined(param_ty.ty) } } else { rty::RefinedType::unrefined( - TypeBuilder::new(self.tcx) + self.inner .for_template(self.registry) .with_scope(&builder) .build(param_ty.ty), @@ -383,8 +411,8 @@ where let param_rty = if let Some(param_refinement) = &self.param_refinement { rty::RefinedType::new(rty::Type::unit(), param_refinement.clone()) } else { - let unit_ty = mir_ty::Ty::new_unit(self.tcx); - TypeBuilder::new(self.tcx) + let unit_ty = mir_ty::Ty::new_unit(self.inner.tcx); + self.inner .for_template(self.registry) .with_scope(&builder) .build_refined(unit_ty) @@ -393,7 +421,7 @@ where } let ret_rty = self.ret_rty.clone().unwrap_or_else(|| { - TypeBuilder::new(self.tcx) + self.inner .for_template(self.registry) .with_scope(&builder) .build_refined(self.ret_ty) From df2a843d2a6c293949872fa1a7783ea4d3e67b13 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sat, 25 Oct 2025 11:24:19 +0900 Subject: [PATCH 20/75] Enhance docs --- src/refine/template.rs | 38 +++++++++++++++++++++++++------------- src/rty/params.rs | 14 ++++++++++++++ 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/src/refine/template.rs b/src/refine/template.rs index 1614b38..bec3f8a 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -23,6 +23,7 @@ where } } +/// [`TemplateScope`] with no variables in scope. #[derive(Clone, Default)] pub struct EmptyTemplateScope; @@ -58,43 +59,45 @@ where } } +/// Translates [`mir_ty::Ty`] to [`rty::Type`]. +/// +/// This struct implements a translation from Rust MIR types to Thrust types. +/// Thrust types may contain refinement predicates which do not exist in MIR types, +/// and [`TypeBuilder`] solely builds types with null refinement (true) in +/// [`TypeBuilder::build`]. This also provides [`TypeBuilder::for_template`] to build +/// refinement types by filling unknown predicates with templates with predicate variables. #[derive(Clone)] pub struct TypeBuilder<'tcx> { tcx: mir_ty::TyCtxt<'tcx>, - type_param_mapping: HashMap, + /// Maps index in [`mir_ty::ParamTy`] to [`rty::TypeParamIdx`]. + /// These indices may differ because we skip lifetime parameters. + /// See [`rty::TypeParamIdx`] for more details. + param_idx_mapping: HashMap, } impl<'tcx> TypeBuilder<'tcx> { pub fn new(tcx: mir_ty::TyCtxt<'tcx>, def_id: DefId) -> Self { - // The index of TyKind::ParamTy is based on the every generic parameters in - // the definition, including lifetimes. Given the following definition: - // - // struct X<'a, T> { f: &'a T } - // - // The type of field `f` is &T1 (not T0). However, in Thrust, we ignore lifetime - // parameters and the index of rty::ParamType is based on type parameters only. - // We're building a mapping from the original index to the new index here. let generics = tcx.generics_of(def_id); - let mut type_param_mapping: HashMap = Default::default(); + let mut param_idx_mapping: HashMap = Default::default(); for i in 0..generics.count() { let generic_param = generics.param_at(i, tcx); match generic_param.kind { mir_ty::GenericParamDefKind::Lifetime => {} mir_ty::GenericParamDefKind::Type { .. } => { - type_param_mapping.insert(i as u32, type_param_mapping.len().into()); + param_idx_mapping.insert(i as u32, param_idx_mapping.len().into()); } mir_ty::GenericParamDefKind::Const { .. } => unimplemented!(), } } Self { tcx, - type_param_mapping, + param_idx_mapping, } } fn translate_param_type(&self, ty: &mir_ty::ParamTy) -> rty::ParamType { let index = *self - .type_param_mapping + .param_idx_mapping .get(&ty.index) .expect("unknown type param idx"); rty::ParamType::new(index) @@ -198,8 +201,16 @@ impl<'tcx> TypeBuilder<'tcx> { } } +/// Translates [`mir_ty::Ty`] to [`rty::Type`] using templates for refinements. +/// +/// [`rty::Template`] is a refinement type in the form of `{ T | P(x1, ..., xn) }` where `P` is a +/// predicate variable. When constructing a template, we need to know which variables can affect the +/// predicate of the template (dependencies, `x1, ..., xn`), and they are provided by the +/// [`TemplateScope`]. No variables are in scope by default and you can provide a scope using +/// [`TemplateTypeBuilder::with_scope`]. pub struct TemplateTypeBuilder<'tcx, 'a, R, S> { inner: TypeBuilder<'tcx>, + // XXX: this can't be simply `R` because monomorphization instantiates types recursively registry: &'a mut R, scope: S, } @@ -310,6 +321,7 @@ where } } +/// A builder for function template types. pub struct FunctionTemplateTypeBuilder<'tcx, 'a, R> { inner: TypeBuilder<'tcx>, registry: &'a mut R, diff --git a/src/rty/params.rs b/src/rty/params.rs index b57ff55..6414f6a 100644 --- a/src/rty/params.rs +++ b/src/rty/params.rs @@ -11,6 +11,20 @@ use super::{Closed, RefinedType}; rustc_index::newtype_index! { /// An index representing a type parameter. + /// + /// ## Note on indexing of type parameters + /// + /// The index of [`rustc_middle::ty::ParamTy`] is based on all generic parameters in + /// the definition, including lifetimes. Given the following definition: + /// + /// ```rust + /// struct X<'a, T> { f: &'a T } + /// ``` + /// + /// The type of field `f` is `&T1` (not `&T0`) in MIR. However, in Thrust, we ignore lifetime + /// parameters and the index of [`rty::ParamType`](super::ParamType) is based on type parameters only, giving `f` + /// the type `&T0`. [`TypeBuilder`](crate::refine::TypeBuilder) takes care of this difference when translating MIR + /// types to Thrust types. #[orderable] #[debug_format = "T{}"] pub struct TypeParamIdx { } From bb6118332a61071bc7f75946c8f4c1193f70043c Mon Sep 17 00:00:00 2001 From: coord_e Date: Fri, 24 Oct 2025 17:54:08 +0900 Subject: [PATCH 21/75] Enable to handle annotated polymorphic function --- src/analyze/basic_block.rs | 43 +++++++++++++++++++--------------- src/analyze/crate_.rs | 18 ++++++++++++-- src/analyze/local_def.rs | 12 +++++++++- src/refine/template.rs | 32 +++++++++++++++++++++---- tests/ui/fail/fn_poly_annot.rs | 11 +++++++++ tests/ui/pass/fn_poly_annot.rs | 11 +++++++++ 6 files changed, 101 insertions(+), 26 deletions(-) create mode 100644 tests/ui/fail/fn_poly_annot.rs create mode 100644 tests/ui/pass/fn_poly_annot.rs diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 7b346ce..05b6c1b 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -33,6 +33,7 @@ pub struct Analyzer<'tcx, 'ctx> { basic_block: BasicBlock, body: Cow<'tcx, Body<'tcx>>, + type_builder: TypeBuilder<'tcx>, env: Env, local_decls: IndexVec>, // TODO: remove this @@ -56,10 +57,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.ctx.basic_block_ty(self.local_def_id, bb) } - fn type_builder(&self) -> TypeBuilder<'tcx> { - TypeBuilder::new(self.tcx, self.local_def_id.to_def_id()) - } - fn bind_local(&mut self, local: Local, rty: rty::RefinedType) { let rty = if self.is_mut_local(local) { // elaboration: @@ -226,7 +223,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let rty_args: IndexVec<_, _> = args .types() .map(|ty| { - self.type_builder() + self.type_builder .for_template(&mut self.ctx) .with_scope(&self.env) .build_refined(ty) @@ -440,7 +437,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { // TODO: move this to well-known defs? Some((def_id, args)) if self.is_box_new(def_id) => { let inner_ty = self - .type_builder() + .type_builder .for_template(&mut self.ctx) .build(args.type_at(0)) .vacuous(); @@ -454,7 +451,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { rty::FunctionType::new([param].into_iter().collect(), ret).into() } Some((def_id, args)) if self.is_mem_swap(def_id) => { - let inner_ty = self.type_builder().build(args.type_at(0)).vacuous(); + let inner_ty = self.type_builder.build(args.type_at(0)).vacuous(); let param1 = rty::RefinedType::unrefined(rty::PointerType::mut_to(inner_ty.clone()).into()); let param2 = @@ -472,15 +469,16 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { rty::FunctionType::new([param1, param2].into_iter().collect(), ret).into() } Some((def_id, args)) => { - if !args.is_empty() { - tracing::warn!(?args, ?def_id, "generic args ignored"); + if args.consts().next().is_some() { + tracing::warn!(?args, ?def_id, "const generic args ignored"); } - self.ctx - .def_ty(def_id) - .expect("unknown def") - .ty - .clone() - .vacuous() + let ty_args = args + .types() + .map(|ty| rty::RefinedType::unrefined(self.type_builder.build(ty))) + .collect(); + let mut def_ty = self.ctx.def_ty(def_id).expect("unknown def").clone(); + def_ty.instantiate_ty_params(ty_args); + def_ty.ty.vacuous() } _ => self.operand_type(func.clone()).ty, }; @@ -541,7 +539,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } fn add_prophecy_var(&mut self, statement_index: usize, ty: mir_ty::Ty<'tcx>) { - let ty = self.type_builder().build(ty); + let ty = self.type_builder.build(ty); let temp_var = self.env.push_temp_var(ty.vacuous()); self.prophecy_vars.insert(statement_index, temp_var); tracing::debug!(stmt_idx = %statement_index, temp_var = ?temp_var, "add_prophecy_var"); @@ -562,7 +560,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { referent: mir::Place<'tcx>, prophecy_ty: mir_ty::Ty<'tcx>, ) -> rty::RefinedType { - let prophecy_ty = self.type_builder().build(prophecy_ty); + let prophecy_ty = self.type_builder.build(prophecy_ty); let prophecy = self.env.push_temp_var(prophecy_ty.vacuous()); let place = self.elaborate_place_for_borrow(&referent); self.env.borrow_place(place, prophecy).into() @@ -675,7 +673,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let decl = self.local_decls[destination].clone(); let rty = self - .type_builder() + .type_builder .for_template(&mut self.ctx) .with_scope(&self.env) .build_refined(decl.ty); @@ -749,7 +747,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { #[tracing::instrument(skip(self))] fn ret_template(&mut self) -> rty::RefinedType { let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty; - self.type_builder() + self.type_builder .for_template(&mut self.ctx) .with_scope(&self.env) .build_refined(ret_ty) @@ -955,6 +953,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let env = ctx.new_env(); let local_decls = body.local_decls.clone(); let prophecy_vars = Default::default(); + let type_builder = TypeBuilder::new(tcx, local_def_id.to_def_id()); Self { ctx, tcx, @@ -962,6 +961,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { drop_points, basic_block, body, + type_builder, env, local_decls, prophecy_vars, @@ -989,6 +989,11 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self } + pub fn type_builder(&mut self, type_builder: TypeBuilder<'tcx>) -> &mut Self { + self.type_builder = type_builder; + self + } + pub fn run(&mut self, expected: &BasicBlockType) { let span = tracing::info_span!("bb", bb = ?self.basic_block); let _guard = span.enter(); diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 54a02d1..3e0fb69 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -211,8 +211,22 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { tracing::info!(?local_def_id, "trusted"); continue; } - let expected = self.ctx.def_ty(local_def_id.to_def_id()).unwrap().clone(); - self.ctx.local_def_analyzer(*local_def_id).run(&expected); + // check polymorphic function def by replacing type params with some opaque type + let type_builder = TypeBuilder::new(self.tcx, local_def_id.to_def_id()) + .with_param_mapper(|_| rty::Type::int()); + let mut expected = self.ctx.def_ty(local_def_id.to_def_id()).unwrap().clone(); + let subst = rty::TypeParamSubst::new( + expected + .free_ty_params() + .into_iter() + .map(|ty_param| (ty_param, rty::RefinedType::unrefined(rty::Type::int()))) + .collect(), + ); + expected.subst_ty_params(&subst); + self.ctx + .local_def_analyzer(*local_def_id) + .type_builder(type_builder) + .run(&expected); } } diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 7e7c737..d0c4aa7 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -26,6 +26,7 @@ pub struct Analyzer<'tcx, 'ctx> { body: Body<'tcx>, drop_points: HashMap, + type_builder: TypeBuilder<'tcx>, } impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { @@ -306,7 +307,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } // function return type is basic block return type let ret_ty = self.body.local_decls[mir::RETURN_PLACE].ty; - let rty = TypeBuilder::new(self.tcx, self.local_def_id.to_def_id()) + let rty = self + .type_builder .for_template(&mut self.ctx) .build_basic_block(live_locals, ret_ty); self.ctx.register_basic_block_ty(self.local_def_id, bb, rty); @@ -321,6 +323,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .basic_block_analyzer(self.local_def_id, bb) .body(self.body.clone()) .drop_points(drop_points) + .type_builder(self.type_builder.clone()) .run(&rty); } } @@ -426,15 +429,22 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let tcx = ctx.tcx; let body = tcx.optimized_mir(local_def_id.to_def_id()).clone(); let drop_points = Default::default(); + let type_builder = TypeBuilder::new(tcx, local_def_id.to_def_id()); Self { ctx, tcx, local_def_id, body, drop_points, + type_builder, } } + pub fn type_builder(&mut self, type_builder: TypeBuilder<'tcx>) -> &mut Self { + self.type_builder = type_builder; + self + } + pub fn run(&mut self, expected: &rty::RefinedType) { let span = tracing::info_span!("def", def = %self.tcx.def_path_str(self.local_def_id)); let _guard = span.enter(); diff --git a/src/refine/template.rs b/src/refine/template.rs index bec3f8a..2ba150a 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -59,6 +59,19 @@ where } } +trait ParamTypeMapper { + fn map_param_ty(&self, ty: rty::ParamType) -> rty::Type; +} + +impl ParamTypeMapper for F +where + F: Fn(rty::ParamType) -> rty::Type, +{ + fn map_param_ty(&self, ty: rty::ParamType) -> rty::Type { + self(ty) + } +} + /// Translates [`mir_ty::Ty`] to [`rty::Type`]. /// /// This struct implements a translation from Rust MIR types to Thrust types. @@ -73,6 +86,7 @@ pub struct TypeBuilder<'tcx> { /// These indices may differ because we skip lifetime parameters. /// See [`rty::TypeParamIdx`] for more details. param_idx_mapping: HashMap, + param_type_mapper: std::rc::Rc, } impl<'tcx> TypeBuilder<'tcx> { @@ -92,15 +106,25 @@ impl<'tcx> TypeBuilder<'tcx> { Self { tcx, param_idx_mapping, + param_type_mapper: std::rc::Rc::new(|ty: rty::ParamType| ty.into()), } } - fn translate_param_type(&self, ty: &mir_ty::ParamTy) -> rty::ParamType { + pub fn with_param_mapper(mut self, mapper: F) -> Self + where + F: Fn(rty::ParamType) -> rty::Type + 'static, + { + self.param_type_mapper = std::rc::Rc::new(mapper); + self + } + + fn translate_param_type(&self, ty: &mir_ty::ParamTy) -> rty::Type { let index = *self .param_idx_mapping .get(&ty.index) .expect("unknown type param idx"); - rty::ParamType::new(index) + let param_ty = rty::ParamType::new(index); + self.param_type_mapper.map_param_ty(param_ty) } // TODO: consolidate two impls @@ -125,7 +149,7 @@ impl<'tcx> TypeBuilder<'tcx> { rty::TupleType::new(elems).into() } mir_ty::TyKind::Never => rty::Type::never(), - mir_ty::TyKind::Param(ty) => self.translate_param_type(ty).into(), + mir_ty::TyKind::Param(ty) => self.translate_param_type(ty), mir_ty::TyKind::FnPtr(sig) => { // TODO: justification for skip_binder let sig = sig.skip_binder(); @@ -251,7 +275,7 @@ where rty::TupleType::new(elems).into() } mir_ty::TyKind::Never => rty::Type::never(), - mir_ty::TyKind::Param(ty) => self.inner.translate_param_type(ty).into(), + mir_ty::TyKind::Param(ty) => self.inner.translate_param_type(ty).vacuous(), mir_ty::TyKind::FnPtr(sig) => { // TODO: justification for skip_binder let sig = sig.skip_binder(); diff --git a/tests/ui/fail/fn_poly_annot.rs b/tests/ui/fail/fn_poly_annot.rs new file mode 100644 index 0000000..2458a98 --- /dev/null +++ b/tests/ui/fail/fn_poly_annot.rs @@ -0,0 +1,11 @@ +//@error-in-other-file: Unsat + +#[thrust::requires(true)] +#[thrust::ensures(result != x.0)] +fn left(x: (T, U)) -> T { + x.0 +} + +fn main() { + assert!(left((42, 0)) == 42); +} diff --git a/tests/ui/pass/fn_poly_annot.rs b/tests/ui/pass/fn_poly_annot.rs new file mode 100644 index 0000000..3176816 --- /dev/null +++ b/tests/ui/pass/fn_poly_annot.rs @@ -0,0 +1,11 @@ +//@check-pass + +#[thrust::requires(true)] +#[thrust::ensures(result == x.0)] +fn left(x: (T, U)) -> T { + x.0 +} + +fn main() { + assert!(left((42, 0)) == 42); +} From 91a957c6e937171afa8f5d1f7d11db972e71e14c Mon Sep 17 00:00:00 2001 From: coord_e Date: Fri, 24 Oct 2025 23:52:19 +0900 Subject: [PATCH 22/75] Enable to handle unannotated polymorphic function --- src/analyze.rs | 83 +++++++++++++++++++++++++++++++++++--- src/analyze/basic_block.rs | 30 ++++++-------- src/analyze/crate_.rs | 33 ++++++++++++--- src/refine/template.rs | 18 ++++++++- tests/ui/fail/fn_poly.rs | 9 +++++ tests/ui/pass/fn_poly.rs | 9 +++++ 6 files changed, 152 insertions(+), 30 deletions(-) create mode 100644 tests/ui/fail/fn_poly.rs create mode 100644 tests/ui/pass/fn_poly.rs diff --git a/src/analyze.rs b/src/analyze.rs index 3550294..fe8ca00 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -11,13 +11,14 @@ use std::collections::HashMap; use std::rc::Rc; use rustc_hir::lang_items::LangItem; +use rustc_index::IndexVec; use rustc_middle::mir::{self, BasicBlock, Local}; use rustc_middle::ty::{self as mir_ty, TyCtxt}; use rustc_span::def_id::{DefId, LocalDefId}; use crate::chc; use crate::pretty::PrettyDisplayExt as _; -use crate::refine::{self, BasicBlockType}; +use crate::refine::{self, BasicBlockType, TypeBuilder}; use crate::rty; mod annot; @@ -103,6 +104,17 @@ impl<'tcx> ReplacePlacesVisitor<'tcx> { } } +#[derive(Debug, Clone)] +struct DeferredDefTy<'tcx> { + cache: Rc>, rty::RefinedType>>>, +} + +#[derive(Debug, Clone)] +enum DefTy<'tcx> { + Concrete(rty::RefinedType), + Deferred(DeferredDefTy<'tcx>), +} + #[derive(Clone)] pub struct Analyzer<'tcx> { tcx: TyCtxt<'tcx>, @@ -112,7 +124,7 @@ pub struct Analyzer<'tcx> { /// currently contains only local-def templates, /// but will be extended to contain externally known def's refinement types /// (at least for every defs referenced by local def bodies) - defs: HashMap, + defs: HashMap>, /// Resulting CHC system. system: Rc>, @@ -207,11 +219,72 @@ impl<'tcx> Analyzer<'tcx> { pub fn register_def(&mut self, def_id: DefId, rty: rty::RefinedType) { tracing::info!(def_id = ?def_id, rty = %rty.display(), "register_def"); - self.defs.insert(def_id, rty); + self.defs.insert(def_id, DefTy::Concrete(rty)); + } + + pub fn register_deferred_def(&mut self, def_id: DefId) { + tracing::info!(def_id = ?def_id, "register_deferred_def"); + self.defs.insert( + def_id, + DefTy::Deferred(DeferredDefTy { + cache: Rc::new(RefCell::new(HashMap::new())), + }), + ); + } + + pub fn concrete_def_ty(&self, def_id: DefId) -> Option<&rty::RefinedType> { + self.defs.get(&def_id).and_then(|def_ty| match def_ty { + DefTy::Concrete(rty) => Some(rty), + DefTy::Deferred(_) => None, + }) } - pub fn def_ty(&self, def_id: DefId) -> Option<&rty::RefinedType> { - self.defs.get(&def_id) + pub fn def_ty_with_args( + &mut self, + def_id: DefId, + args: mir_ty::GenericArgsRef<'tcx>, + ) -> Option { + let type_builder = TypeBuilder::new(self.tcx, def_id); + let rty_args: IndexVec<_, _> = args.types().map(|ty| type_builder.build(ty)).collect(); + + let deferred_ty = match self.defs.get(&def_id)? { + DefTy::Concrete(rty) => { + let mut def_ty = rty.clone(); + def_ty.instantiate_ty_params( + rty_args + .clone() + .into_iter() + .map(rty::RefinedType::unrefined) + .collect(), + ); + return Some(def_ty); + } + DefTy::Deferred(deferred) => deferred, + }; + + let ty_args: Vec<_> = args.types().collect(); + let deferred_ty_cache = Rc::clone(&deferred_ty.cache); // to cut reference to allow &mut self + if let Some(rty) = deferred_ty_cache.borrow().get(&ty_args) { + return Some(rty.clone()); + } + let local_def_id = def_id.as_local()?; + + let sig = self + .tcx + .fn_sig(def_id) + .instantiate(self.tcx, args) + .skip_binder(); + let expected = self + .crate_analyzer() + .fn_def_ty_with_sig(local_def_id.to_def_id(), sig) + .unwrap(); + self.local_def_analyzer(local_def_id) + .type_builder(type_builder.with_param_mapper(move |ty| rty_args[ty.idx].clone())) + .run(&expected); + deferred_ty_cache + .borrow_mut() + .insert(ty_args, expected.clone()); + Some(expected) } pub fn register_basic_block_ty( diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 05b6c1b..5719406 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -263,12 +263,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { _ty, ) => { let func_ty = match operand.const_fn_def() { - Some((def_id, args)) => { - if !args.is_empty() { - tracing::warn!(?args, ?def_id, "generic args ignored"); - } - self.ctx.def_ty(def_id).expect("unknown def").ty.clone() - } + Some((def_id, args)) => self + .ctx + .def_ty_with_args(def_id, args) + .expect("unknown def") + .ty + .clone(), _ => unimplemented!(), }; PlaceType::with_ty_and_term(func_ty.vacuous(), chc::Term::null()) @@ -468,18 +468,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let ret = rty::RefinedType::new(rty::Type::unit(), ret_formula.into()); rty::FunctionType::new([param1, param2].into_iter().collect(), ret).into() } - Some((def_id, args)) => { - if args.consts().next().is_some() { - tracing::warn!(?args, ?def_id, "const generic args ignored"); - } - let ty_args = args - .types() - .map(|ty| rty::RefinedType::unrefined(self.type_builder.build(ty))) - .collect(); - let mut def_ty = self.ctx.def_ty(def_id).expect("unknown def").clone(); - def_ty.instantiate_ty_params(ty_args); - def_ty.ty.vacuous() - } + Some((def_id, args)) => self + .ctx + .def_ty_with_args(def_id, args) + .expect("unknown def") + .ty + .vacuous(), _ => self.operand_type(func.clone()).ty, }; let expected_args: IndexVec<_, _> = args diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 3e0fb69..389647b 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -128,8 +128,19 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { #[tracing::instrument(skip(self), fields(def_id = %self.tcx.def_path_str(def_id)))] fn refine_fn_def(&mut self, def_id: DefId) { let sig = self.tcx.fn_sig(def_id); - let sig = sig.instantiate_identity().skip_binder(); // TODO: is it OK? + let sig = sig.instantiate_identity().skip_binder(); + if let Some(rty) = self.fn_def_ty_with_sig(def_id, sig) { + self.ctx.register_def(def_id, rty); + } else { + self.ctx.register_deferred_def(def_id); + } + } + pub fn fn_def_ty_with_sig( + &mut self, + def_id: DefId, + sig: mir_ty::FnSig<'tcx>, + ) -> Option { let mut param_resolver = analyze::annot::ParamResolver::default(); for (input_ident, input_ty) in self.tcx.fn_arg_names(def_id).iter().zip(sig.inputs()) { let input_ty = TypeBuilder::new(self.tcx, def_id).build(*input_ty); @@ -198,8 +209,14 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { if let Some(ret_rty) = ret_annot { builder.ret_rty(ret_rty); } - let rty = rty::RefinedType::unrefined(builder.build().into()); - self.ctx.register_def(def_id, rty); + + // can't generate template with type parameter... + use mir_ty::TypeVisitableExt as _; + if builder.would_contain_template() && sig.has_param() { + None + } else { + Some(rty::RefinedType::unrefined(builder.build().into())) + } } fn analyze_local_defs(&mut self) { @@ -211,10 +228,16 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { tracing::info!(?local_def_id, "trusted"); continue; } + let Some(expected) = self.ctx.concrete_def_ty(local_def_id.to_def_id()) else { + // when the local_def_id is deferred it would be skipped + continue; + }; + // check polymorphic function def by replacing type params with some opaque type + // (and this is no-op if the function is mono) let type_builder = TypeBuilder::new(self.tcx, local_def_id.to_def_id()) .with_param_mapper(|_| rty::Type::int()); - let mut expected = self.ctx.def_ty(local_def_id.to_def_id()).unwrap().clone(); + let mut expected = expected.clone(); let subst = rty::TypeParamSubst::new( expected .free_ty_params() @@ -236,7 +259,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { // TODO: replace code here with relate_* in Env + Refine context (created with empty env) let entry_ty = self .ctx - .def_ty(def_id) + .concrete_def_ty(def_id) .unwrap() .ty .as_function() diff --git a/src/refine/template.rs b/src/refine/template.rs index 2ba150a..c859419 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -83,9 +83,12 @@ where pub struct TypeBuilder<'tcx> { tcx: mir_ty::TyCtxt<'tcx>, /// Maps index in [`mir_ty::ParamTy`] to [`rty::TypeParamIdx`]. - /// These indices may differ because we skip lifetime parameters. + /// These indices may differ because we skip lifetime parameters and they always need to be + /// mapped when we translate a [`mir_ty::ParamTy`] to [`rty::ParamType`]. /// See [`rty::TypeParamIdx`] for more details. param_idx_mapping: HashMap, + /// Optionally also want to further map rty::ParamType to other rty::Type before generating + /// templates. This is no-op by default. param_type_mapper: std::rc::Rc, } @@ -100,7 +103,7 @@ impl<'tcx> TypeBuilder<'tcx> { mir_ty::GenericParamDefKind::Type { .. } => { param_idx_mapping.insert(i as u32, param_idx_mapping.len().into()); } - mir_ty::GenericParamDefKind::Const { .. } => unimplemented!(), + mir_ty::GenericParamDefKind::Const { .. } => {} } } Self { @@ -397,6 +400,17 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { self.ret_rty = Some(rty); self } + + pub fn would_contain_template(&self) -> bool { + if self.param_tys.is_empty() { + return self.ret_rty.is_none(); + } + + let last_param_idx = rty::FunctionParamIdx::from(self.param_tys.len() - 1); + let param_annotated = + self.param_refinement.is_some() || self.param_rtys.contains_key(&last_param_idx); + self.ret_rty.is_none() || !param_annotated + } } impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> diff --git a/tests/ui/fail/fn_poly.rs b/tests/ui/fail/fn_poly.rs new file mode 100644 index 0000000..15351dd --- /dev/null +++ b/tests/ui/fail/fn_poly.rs @@ -0,0 +1,9 @@ +//@error-in-other-file: Unsat + +fn left(x: (T, U)) -> T { + x.0 +} + +fn main() { + assert!(left((42, 0)) == 0); +} diff --git a/tests/ui/pass/fn_poly.rs b/tests/ui/pass/fn_poly.rs new file mode 100644 index 0000000..4a8e678 --- /dev/null +++ b/tests/ui/pass/fn_poly.rs @@ -0,0 +1,9 @@ +//@check-pass + +fn left(x: (T, U)) -> T { + x.0 +} + +fn main() { + assert!(left((42, 0)) == 42); +} From dd67486c40b2a2337c44afc5f4a1b2f4a1a6df98 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 9 Nov 2025 23:10:46 +0900 Subject: [PATCH 23/75] Implement Eq and Hash for Type --- src/chc.rs | 8 ++++---- src/rty.rs | 20 ++++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/chc.rs b/src/chc.rs index 8a3309f..6f5db32 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -389,7 +389,7 @@ impl Function { } /// A logical term. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Term { Null, Var(V), @@ -984,7 +984,7 @@ impl Pred { } /// An atom is a predicate applied to a list of terms. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Atom { pub pred: Pred, pub args: Vec>, @@ -1077,7 +1077,7 @@ impl Atom { /// While it allows arbitrary [`Atom`] in its `Atom` variant, we only expect atoms with known /// predicates (i.e., predicates other than `Pred::Var`) to appear in formulas. It is our TODO to /// enforce this restriction statically. Also see the definition of [`Body`]. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Formula { Atom(Atom), Not(Box>), @@ -1296,7 +1296,7 @@ impl Formula { } /// The body part of a clause, consisting of atoms and a formula. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Body { pub atoms: Vec>, /// NOTE: This doesn't contain predicate variables. Also see [`Formula`]. diff --git a/src/rty.rs b/src/rty.rs index c706897..eacf26e 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -88,7 +88,7 @@ where /// In Thrust, function types are closed. Because of that, function types, thus its parameters and /// return type only refer to the parameters of the function itself using [`FunctionParamIdx`] and /// do not accept other type of variables from the environment. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct FunctionType { pub params: IndexVec>, pub ret: Box>, @@ -156,7 +156,7 @@ impl FunctionType { } /// The kind of a reference, which is either mutable or immutable. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum RefKind { Mut, Immut, @@ -181,7 +181,7 @@ where } /// The kind of a pointer, which is either a reference or an owned pointer. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum PointerKind { Ref(RefKind), Own, @@ -221,7 +221,7 @@ impl PointerKind { } /// A pointer type. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct PointerType { pub kind: PointerKind, pub elem: Box>, @@ -334,7 +334,7 @@ impl PointerType { /// Note that the current implementation uses tuples to represent structs. See /// implementation in `crate::refine::template` module for details. /// It is our TODO to improve the struct representation. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct TupleType { pub elems: Vec>, } @@ -458,7 +458,7 @@ impl EnumDatatypeDef { /// An enum type. /// /// An enum type includes its type arguments and the argument types can refer to outer variables `T`. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct EnumType { pub symbol: chc::DatatypeSymbol, pub args: IndexVec>, @@ -560,7 +560,7 @@ impl EnumType { } /// A type parameter. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ParamType { pub idx: TypeParamIdx, } @@ -589,7 +589,7 @@ impl ParamType { } /// An underlying type of a refinement type. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Type { Int, Bool, @@ -995,7 +995,7 @@ impl ShiftExistential for RefinedTypeVar { /// A formula, potentially equipped with an existential quantifier. /// /// Note: This is not to be confused with [`crate::chc::Formula`] in the [`crate::chc`] module, which is a different notion. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Formula { pub existentials: IndexVec, pub body: chc::Body, @@ -1236,7 +1236,7 @@ impl Instantiator { } /// A refinement type. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct RefinedType { pub ty: Type, pub refinement: Refinement, From 23f5e522b9840d7b7067bac690c25c876a66c772 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 9 Nov 2025 23:11:30 +0900 Subject: [PATCH 24/75] Rename TypeArgs to RefinedTypeArgs --- src/rty.rs | 6 +++--- src/rty/params.rs | 19 +++++++++++++++---- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/rty.rs b/src/rty.rs index eacf26e..ce6ef5e 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -55,7 +55,7 @@ mod subtyping; pub use subtyping::{relate_sub_closed_type, ClauseScope, Subtyping}; mod params; -pub use params::{TypeArgs, TypeParamIdx, TypeParamSubst}; +pub use params::{RefinedTypeArgs, TypeArgs, TypeParamIdx, TypeParamSubst}; rustc_index::newtype_index! { /// An index representing function parameter. @@ -487,7 +487,7 @@ where } impl EnumType { - pub fn new(symbol: chc::DatatypeSymbol, args: TypeArgs) -> Self { + pub fn new(symbol: chc::DatatypeSymbol, args: RefinedTypeArgs) -> Self { EnumType { symbol, args } } @@ -1372,7 +1372,7 @@ impl RefinedType { } } - pub fn instantiate_ty_params(&mut self, params: TypeArgs) + pub fn instantiate_ty_params(&mut self, params: RefinedTypeArgs) where FV: chc::Var, { diff --git a/src/rty/params.rs b/src/rty/params.rs index 6414f6a..b76b6d1 100644 --- a/src/rty/params.rs +++ b/src/rty/params.rs @@ -7,7 +7,7 @@ use rustc_index::IndexVec; use crate::chc; -use super::{Closed, RefinedType}; +use super::{Closed, RefinedType, Type}; rustc_index::newtype_index! { /// An index representing a type parameter. @@ -53,7 +53,8 @@ impl TypeParamIdx { } } -pub type TypeArgs = IndexVec>; +pub type RefinedTypeArgs = IndexVec>; +pub type TypeArgs = IndexVec>; /// A substitution for type parameters that maps type parameters to refinement types. #[derive(Debug, Clone)] @@ -71,6 +72,16 @@ impl Default for TypeParamSubst { impl From> for TypeParamSubst { fn from(params: TypeArgs) -> Self { + let subst = params + .into_iter_enumerated() + .map(|(idx, ty)| (idx, RefinedType::unrefined(ty))) + .collect(); + Self { subst } + } +} + +impl From> for TypeParamSubst { + fn from(params: RefinedTypeArgs) -> Self { let subst = params.into_iter_enumerated().collect(); Self { subst } } @@ -112,12 +123,12 @@ impl TypeParamSubst { } } - pub fn into_args(mut self, expected_len: usize, mut default: F) -> TypeArgs + pub fn into_args(mut self, expected_len: usize, mut default: F) -> RefinedTypeArgs where T: chc::Var, F: FnMut(TypeParamIdx) -> RefinedType, { - let mut args = TypeArgs::new(); + let mut args = RefinedTypeArgs::new(); for idx in 0..expected_len { let ty = self .subst From b4b1c9b7223d0174d92c0173bf4f9c23e5d5e96f Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 9 Nov 2025 11:50:17 +0900 Subject: [PATCH 25/75] Fixes to allow nested polymorphic calls --- src/analyze.rs | 46 +++---- src/analyze/basic_block.rs | 29 +++-- src/analyze/crate_.rs | 191 +++------------------------ src/analyze/local_def.rs | 209 ++++++++++++++++++++++++++++++ tests/ui/fail/adt_poly_fn_poly.rs | 39 ++++++ tests/ui/pass/adt_poly_fn_poly.rs | 39 ++++++ 6 files changed, 339 insertions(+), 214 deletions(-) create mode 100644 tests/ui/fail/adt_poly_fn_poly.rs create mode 100644 tests/ui/pass/adt_poly_fn_poly.rs diff --git a/src/analyze.rs b/src/analyze.rs index fe8ca00..a1ebd34 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -11,7 +11,6 @@ use std::collections::HashMap; use std::rc::Rc; use rustc_hir::lang_items::LangItem; -use rustc_index::IndexVec; use rustc_middle::mir::{self, BasicBlock, Local}; use rustc_middle::ty::{self as mir_ty, TyCtxt}; use rustc_span::def_id::{DefId, LocalDefId}; @@ -105,14 +104,14 @@ impl<'tcx> ReplacePlacesVisitor<'tcx> { } #[derive(Debug, Clone)] -struct DeferredDefTy<'tcx> { - cache: Rc>, rty::RefinedType>>>, +struct DeferredDefTy { + cache: Rc>>, } #[derive(Debug, Clone)] -enum DefTy<'tcx> { +enum DefTy { Concrete(rty::RefinedType), - Deferred(DeferredDefTy<'tcx>), + Deferred(DeferredDefTy), } #[derive(Clone)] @@ -124,7 +123,7 @@ pub struct Analyzer<'tcx> { /// currently contains only local-def templates, /// but will be extended to contain externally known def's refinement types /// (at least for every defs referenced by local def bodies) - defs: HashMap>, + defs: HashMap, /// Resulting CHC system. system: Rc>, @@ -242,11 +241,8 @@ impl<'tcx> Analyzer<'tcx> { pub fn def_ty_with_args( &mut self, def_id: DefId, - args: mir_ty::GenericArgsRef<'tcx>, + rty_args: rty::TypeArgs, ) -> Option { - let type_builder = TypeBuilder::new(self.tcx, def_id); - let rty_args: IndexVec<_, _> = args.types().map(|ty| type_builder.build(ty)).collect(); - let deferred_ty = match self.defs.get(&def_id)? { DefTy::Concrete(rty) => { let mut def_ty = rty.clone(); @@ -262,28 +258,24 @@ impl<'tcx> Analyzer<'tcx> { DefTy::Deferred(deferred) => deferred, }; - let ty_args: Vec<_> = args.types().collect(); let deferred_ty_cache = Rc::clone(&deferred_ty.cache); // to cut reference to allow &mut self - if let Some(rty) = deferred_ty_cache.borrow().get(&ty_args) { + if let Some(rty) = deferred_ty_cache.borrow().get(&rty_args) { return Some(rty.clone()); } - let local_def_id = def_id.as_local()?; - - let sig = self - .tcx - .fn_sig(def_id) - .instantiate(self.tcx, args) - .skip_binder(); - let expected = self - .crate_analyzer() - .fn_def_ty_with_sig(local_def_id.to_def_id(), sig) - .unwrap(); - self.local_def_analyzer(local_def_id) - .type_builder(type_builder.with_param_mapper(move |ty| rty_args[ty.idx].clone())) - .run(&expected); + + let type_builder = TypeBuilder::new(self.tcx, def_id).with_param_mapper({ + let rty_args = rty_args.clone(); + move |ty: rty::ParamType| rty_args[ty.idx].clone() + }); + let mut analyzer = self.local_def_analyzer(def_id.as_local()?); + analyzer.type_builder(type_builder); + + let expected = analyzer.expected_ty(); deferred_ty_cache .borrow_mut() - .insert(ty_args, expected.clone()); + .insert(rty_args, expected.clone()); + + analyzer.run(&expected); Some(expected) } diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 5719406..9f6ee49 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -263,12 +263,15 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { _ty, ) => { let func_ty = match operand.const_fn_def() { - Some((def_id, args)) => self - .ctx - .def_ty_with_args(def_id, args) - .expect("unknown def") - .ty - .clone(), + Some((def_id, args)) => { + let rty_args: IndexVec<_, _> = + args.types().map(|ty| self.type_builder.build(ty)).collect(); + self.ctx + .def_ty_with_args(def_id, rty_args) + .expect("unknown def") + .ty + .clone() + } _ => unimplemented!(), }; PlaceType::with_ty_and_term(func_ty.vacuous(), chc::Term::null()) @@ -468,12 +471,14 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let ret = rty::RefinedType::new(rty::Type::unit(), ret_formula.into()); rty::FunctionType::new([param1, param2].into_iter().collect(), ret).into() } - Some((def_id, args)) => self - .ctx - .def_ty_with_args(def_id, args) - .expect("unknown def") - .ty - .vacuous(), + Some((def_id, args)) => { + let rty_args = args.types().map(|ty| self.type_builder.build(ty)).collect(); + self.ctx + .def_ty_with_args(def_id, rty_args) + .expect("unknown def") + .ty + .vacuous() + } _ => self.operand_type(func.clone()).ty, }; let expected_args: IndexVec<_, _> = args diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 389647b..e15ca49 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -5,11 +5,9 @@ use std::collections::HashSet; use rustc_hir::def::DefKind; use rustc_index::IndexVec; use rustc_middle::ty::{self as mir_ty, TyCtxt}; -use rustc_span::def_id::DefId; -use rustc_span::symbol::Ident; +use rustc_span::def_id::{DefId, LocalDefId}; use crate::analyze; -use crate::annot::{self, AnnotFormula, AnnotParser, ResolverExt as _}; use crate::chc; use crate::refine::{self, TypeBuilder}; use crate::rty::{self, ClauseBuilderExt as _}; @@ -34,188 +32,31 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { fn refine_local_defs(&mut self) { for local_def_id in self.tcx.mir_keys(()) { if self.tcx.def_kind(*local_def_id).is_fn_like() { - self.refine_fn_def(local_def_id.to_def_id()); + self.refine_fn_def(*local_def_id); } } } - fn extract_require_annot( - &self, - resolver: T, - def_id: DefId, - ) -> Option> - where - T: annot::Resolver, - { - let mut require_annot = None; - for attrs in self - .tcx - .get_attrs_by_path(def_id, &analyze::annot::requires_path()) - { - if require_annot.is_some() { - unimplemented!(); - } - let ts = analyze::annot::extract_annot_tokens(attrs.clone()); - let require = AnnotParser::new(&resolver).parse_formula(ts).unwrap(); - require_annot = Some(require); - } - require_annot - } + #[tracing::instrument(skip(self), fields(def_id = %self.tcx.def_path_str(local_def_id)))] + fn refine_fn_def(&mut self, local_def_id: LocalDefId) { + let mut analyzer = self.ctx.local_def_analyzer(local_def_id); - fn extract_ensure_annot(&self, resolver: T, def_id: DefId) -> Option> - where - T: annot::Resolver, - { - let mut ensure_annot = None; - for attrs in self - .tcx - .get_attrs_by_path(def_id, &analyze::annot::ensures_path()) - { - if ensure_annot.is_some() { - unimplemented!(); - } - let ts = analyze::annot::extract_annot_tokens(attrs.clone()); - let ensure = AnnotParser::new(&resolver).parse_formula(ts).unwrap(); - ensure_annot = Some(ensure); + if analyzer.is_annotated_as_trusted() { + assert!(analyzer.is_fully_annotated()); + self.trusted.insert(local_def_id.to_def_id()); } - ensure_annot - } - fn extract_param_annots( - &self, - resolver: T, - def_id: DefId, - ) -> Vec<(Ident, rty::RefinedType)> - where - T: annot::Resolver, - { - let mut param_annots = Vec::new(); - for attrs in self + let sig = self .tcx - .get_attrs_by_path(def_id, &analyze::annot::param_path()) - { - let ts = analyze::annot::extract_annot_tokens(attrs.clone()); - let (ident, ts) = analyze::annot::split_param(&ts); - let param = AnnotParser::new(&resolver).parse_rty(ts).unwrap(); - param_annots.push((ident, param)); - } - param_annots - } - - fn extract_ret_annot( - &self, - resolver: T, - def_id: DefId, - ) -> Option> - where - T: annot::Resolver, - { - let mut ret_annot = None; - for attrs in self - .tcx - .get_attrs_by_path(def_id, &analyze::annot::ret_path()) - { - if ret_annot.is_some() { - unimplemented!(); - } - let ts = analyze::annot::extract_annot_tokens(attrs.clone()); - let ret = AnnotParser::new(&resolver).parse_rty(ts).unwrap(); - ret_annot = Some(ret); - } - ret_annot - } - - #[tracing::instrument(skip(self), fields(def_id = %self.tcx.def_path_str(def_id)))] - fn refine_fn_def(&mut self, def_id: DefId) { - let sig = self.tcx.fn_sig(def_id); - let sig = sig.instantiate_identity().skip_binder(); - if let Some(rty) = self.fn_def_ty_with_sig(def_id, sig) { - self.ctx.register_def(def_id, rty); - } else { - self.ctx.register_deferred_def(def_id); - } - } - - pub fn fn_def_ty_with_sig( - &mut self, - def_id: DefId, - sig: mir_ty::FnSig<'tcx>, - ) -> Option { - let mut param_resolver = analyze::annot::ParamResolver::default(); - for (input_ident, input_ty) in self.tcx.fn_arg_names(def_id).iter().zip(sig.inputs()) { - let input_ty = TypeBuilder::new(self.tcx, def_id).build(*input_ty); - param_resolver.push_param(input_ident.name, input_ty.to_sort()); - } - - let mut require_annot = self.extract_require_annot(¶m_resolver, def_id); - let mut ensure_annot = { - let output_ty = TypeBuilder::new(self.tcx, def_id).build(sig.output()); - let resolver = annot::StackedResolver::default() - .resolver(analyze::annot::ResultResolver::new(output_ty.to_sort())) - .resolver((¶m_resolver).map(rty::RefinedTypeVar::Free)); - self.extract_ensure_annot(resolver, def_id) - }; - let param_annots = self.extract_param_annots(¶m_resolver, def_id); - let ret_annot = self.extract_ret_annot(¶m_resolver, def_id); - - if self - .tcx - .get_attrs_by_path(def_id, &analyze::annot::callable_path()) - .next() - .is_some() - { - if require_annot.is_some() || ensure_annot.is_some() { - unimplemented!(); - } - - require_annot = Some(AnnotFormula::top()); - ensure_annot = Some(AnnotFormula::top()); - } - - assert!(require_annot.is_none() || param_annots.is_empty()); - assert!(ensure_annot.is_none() || ret_annot.is_none()); - - if self - .tcx - .get_attrs_by_path(def_id, &analyze::annot::trusted_path()) - .next() - .is_some() - { - assert!(require_annot.is_some() || !param_annots.is_empty()); - assert!(ensure_annot.is_some() || ret_annot.is_some()); - self.trusted.insert(def_id); - } - - let mut builder = - TypeBuilder::new(self.tcx, def_id).for_function_template(&mut self.ctx, sig); - if let Some(AnnotFormula::Formula(require)) = require_annot { - let formula = require.map_var(|idx| { - if idx.index() == sig.inputs().len() - 1 { - rty::RefinedTypeVar::Value - } else { - rty::RefinedTypeVar::Free(idx) - } - }); - builder.param_refinement(formula.into()); - } - if let Some(AnnotFormula::Formula(ensure)) = ensure_annot { - builder.ret_refinement(ensure.into()); - } - for (ident, annot_rty) in param_annots { - use annot::Resolver as _; - let (idx, _) = param_resolver.resolve(ident).expect("unknown param"); - builder.param_rty(idx, annot_rty); - } - if let Some(ret_rty) = ret_annot { - builder.ret_rty(ret_rty); - } - - // can't generate template with type parameter... + .fn_sig(local_def_id) + .instantiate_identity() + .skip_binder(); use mir_ty::TypeVisitableExt as _; - if builder.would_contain_template() && sig.has_param() { - None + if sig.has_param() && !analyzer.is_fully_annotated() { + self.ctx.register_deferred_def(local_def_id.to_def_id()); } else { - Some(rty::RefinedType::unrefined(builder.build().into())) + let expected = analyzer.expected_ty(); + self.ctx.register_def(local_def_id.to_def_id(), expected); } } diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index d0c4aa7..58e4b8b 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -7,8 +7,10 @@ use rustc_index::IndexVec; use rustc_middle::mir::{self, BasicBlock, Body, Local}; use rustc_middle::ty::{self as mir_ty, TyCtxt, TypeAndMut}; use rustc_span::def_id::LocalDefId; +use rustc_span::symbol::Ident; use crate::analyze; +use crate::annot::{self, AnnotFormula, AnnotParser, ResolverExt as _}; use crate::chc; use crate::pretty::PrettyDisplayExt as _; use crate::refine::{BasicBlockType, TypeBuilder}; @@ -30,6 +32,213 @@ pub struct Analyzer<'tcx, 'ctx> { } impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { + fn extract_require_annot(&self, resolver: T) -> Option> + where + T: annot::Resolver, + { + let mut require_annot = None; + for attrs in self.tcx.get_attrs_by_path( + self.local_def_id.to_def_id(), + &analyze::annot::requires_path(), + ) { + if require_annot.is_some() { + unimplemented!(); + } + let ts = analyze::annot::extract_annot_tokens(attrs.clone()); + let require = AnnotParser::new(&resolver).parse_formula(ts).unwrap(); + require_annot = Some(require); + } + require_annot + } + + fn extract_ensure_annot(&self, resolver: T) -> Option> + where + T: annot::Resolver, + { + let mut ensure_annot = None; + for attrs in self.tcx.get_attrs_by_path( + self.local_def_id.to_def_id(), + &analyze::annot::ensures_path(), + ) { + if ensure_annot.is_some() { + unimplemented!(); + } + let ts = analyze::annot::extract_annot_tokens(attrs.clone()); + let ensure = AnnotParser::new(&resolver).parse_formula(ts).unwrap(); + ensure_annot = Some(ensure); + } + ensure_annot + } + + fn extract_param_annots(&self, resolver: T) -> Vec<(Ident, rty::RefinedType)> + where + T: annot::Resolver, + { + let mut param_annots = Vec::new(); + for attrs in self + .tcx + .get_attrs_by_path(self.local_def_id.to_def_id(), &analyze::annot::param_path()) + { + let ts = analyze::annot::extract_annot_tokens(attrs.clone()); + let (ident, ts) = analyze::annot::split_param(&ts); + let param = AnnotParser::new(&resolver).parse_rty(ts).unwrap(); + param_annots.push((ident, param)); + } + param_annots + } + + fn extract_ret_annot(&self, resolver: T) -> Option> + where + T: annot::Resolver, + { + let mut ret_annot = None; + for attrs in self + .tcx + .get_attrs_by_path(self.local_def_id.to_def_id(), &analyze::annot::ret_path()) + { + if ret_annot.is_some() { + unimplemented!(); + } + let ts = analyze::annot::extract_annot_tokens(attrs.clone()); + let ret = AnnotParser::new(&resolver).parse_rty(ts).unwrap(); + ret_annot = Some(ret); + } + ret_annot + } + + pub fn is_annotated_as_trusted(&self) -> bool { + self.tcx + .get_attrs_by_path( + self.local_def_id.to_def_id(), + &analyze::annot::trusted_path(), + ) + .next() + .is_some() + } + + pub fn is_annotated_as_callable(&self) -> bool { + self.tcx + .get_attrs_by_path( + self.local_def_id.to_def_id(), + &analyze::annot::callable_path(), + ) + .next() + .is_some() + } + + // TODO: unify this logic with extraction functions above + pub fn is_fully_annotated(&self) -> bool { + let has_require = self + .tcx + .get_attrs_by_path( + self.local_def_id.to_def_id(), + &analyze::annot::requires_path(), + ) + .next() + .is_some(); + let has_ensure = self + .tcx + .get_attrs_by_path( + self.local_def_id.to_def_id(), + &analyze::annot::ensures_path(), + ) + .next() + .is_some(); + let annotated_params: Vec<_> = self + .tcx + .get_attrs_by_path(self.local_def_id.to_def_id(), &analyze::annot::param_path()) + .map(|attrs| { + let ts = analyze::annot::extract_annot_tokens(attrs.clone()); + let (ident, _) = analyze::annot::split_param(&ts); + ident + }) + .collect(); + let has_ret = self + .tcx + .get_attrs_by_path(self.local_def_id.to_def_id(), &analyze::annot::ret_path()) + .next() + .is_some(); + + let arg_names = self.tcx.fn_arg_names(self.local_def_id.to_def_id()); + let all_params_annotated = arg_names + .iter() + .all(|ident| annotated_params.contains(ident)); + self.is_annotated_as_callable() + || (has_require && has_ensure) + || (all_params_annotated && has_ret) + } + + pub fn expected_ty(&mut self) -> rty::RefinedType { + let sig = self.tcx.fn_sig(self.local_def_id); + let sig = sig.instantiate_identity().skip_binder(); + + let mut param_resolver = analyze::annot::ParamResolver::default(); + for (input_ident, input_ty) in self + .tcx + .fn_arg_names(self.local_def_id.to_def_id()) + .iter() + .zip(sig.inputs()) + { + let input_ty = self.type_builder.build(*input_ty); + param_resolver.push_param(input_ident.name, input_ty.to_sort()); + } + + let mut require_annot = self.extract_require_annot(¶m_resolver); + let mut ensure_annot = { + let output_ty = self.type_builder.build(sig.output()); + let resolver = annot::StackedResolver::default() + .resolver(analyze::annot::ResultResolver::new(output_ty.to_sort())) + .resolver((¶m_resolver).map(rty::RefinedTypeVar::Free)); + self.extract_ensure_annot(resolver) + }; + let param_annots = self.extract_param_annots(¶m_resolver); + let ret_annot = self.extract_ret_annot(¶m_resolver); + + if self.is_annotated_as_callable() { + if require_annot.is_some() || ensure_annot.is_some() { + unimplemented!(); + } + if !param_annots.is_empty() || ret_annot.is_some() { + unimplemented!(); + } + + require_annot = Some(AnnotFormula::top()); + ensure_annot = Some(AnnotFormula::top()); + } + + assert!(require_annot.is_none() || param_annots.is_empty()); + assert!(ensure_annot.is_none() || ret_annot.is_none()); + + let mut builder = self.type_builder.for_function_template(&mut self.ctx, sig); + if let Some(AnnotFormula::Formula(require)) = require_annot { + let formula = require.map_var(|idx| { + if idx.index() == sig.inputs().len() - 1 { + rty::RefinedTypeVar::Value + } else { + rty::RefinedTypeVar::Free(idx) + } + }); + builder.param_refinement(formula.into()); + } + if let Some(AnnotFormula::Formula(ensure)) = ensure_annot { + builder.ret_refinement(ensure.into()); + } + for (ident, annot_rty) in param_annots { + use annot::Resolver as _; + let (idx, _) = param_resolver.resolve(ident).expect("unknown param"); + builder.param_rty(idx, annot_rty); + } + if let Some(ret_rty) = ret_annot { + builder.ret_rty(ret_rty); + } + + // Note that we do not expect predicate variables to be generated here + // when type params are still present in the type. Callers should ensure either + // - type params are fully instantiated, or + // - the function is fully annotated + rty::RefinedType::unrefined(builder.build().into()) + } + fn is_mut_param(&self, param_idx: rty::FunctionParamIdx) -> bool { let param_local = analyze::local_of_function_param(param_idx); self.body.local_decls[param_local].mutability.is_mut() diff --git a/tests/ui/fail/adt_poly_fn_poly.rs b/tests/ui/fail/adt_poly_fn_poly.rs new file mode 100644 index 0000000..899f875 --- /dev/null +++ b/tests/ui/fail/adt_poly_fn_poly.rs @@ -0,0 +1,39 @@ +//@error-in-other-file: Unsat + +pub enum X { + A(T), + B(T), +} + +#[thrust::trusted] +#[thrust::requires(true)] +#[thrust::ensures(true)] +fn rand() -> X { unimplemented!() } + +fn is_a(x: &X) -> bool { + match x { + X::A(_) => true, + X::B(_) => false, + } +} + +fn inv(x: X) -> X { + match x { + X::A(i) => X::B(i), + X::B(i) => X::A(i), + } +} + +fn rand_a() -> X { + let x = rand(); + if !is_a(&x) { loop {} } + x +} + +#[thrust::callable] +fn check() { + let x = rand_a::(); + assert!(is_a(&inv(x))); +} + +fn main() {} diff --git a/tests/ui/pass/adt_poly_fn_poly.rs b/tests/ui/pass/adt_poly_fn_poly.rs new file mode 100644 index 0000000..d3b91f4 --- /dev/null +++ b/tests/ui/pass/adt_poly_fn_poly.rs @@ -0,0 +1,39 @@ +//@check-pass + +pub enum X { + A(T), + B(T), +} + +#[thrust::trusted] +#[thrust::requires(true)] +#[thrust::ensures(true)] +fn rand() -> X { unimplemented!() } + +fn is_a(x: &X) -> bool { + match x { + X::A(_) => true, + X::B(_) => false, + } +} + +fn inv(x: X) -> X { + match x { + X::A(i) => X::B(i), + X::B(i) => X::A(i), + } +} + +fn rand_a() -> X { + let x = rand(); + if !is_a(&x) { loop {} } + x +} + +#[thrust::callable] +fn check() { + let x = rand_a::(); + assert!(!is_a(&inv(x))); +} + +fn main() {} From cd004d2b72e480d665c489414e81d115f74a19cd Mon Sep 17 00:00:00 2001 From: coord_e Date: Mon, 10 Nov 2025 00:03:57 +0900 Subject: [PATCH 26/75] Add many more test cases (thanks Claude) --- tests/ui/fail/fn_poly_annot_complex.rs | 12 ++++++++++++ tests/ui/fail/fn_poly_annot_multi_inst.rs | 15 +++++++++++++++ tests/ui/fail/fn_poly_annot_nested.rs | 17 +++++++++++++++++ tests/ui/fail/fn_poly_annot_recursive.rs | 17 +++++++++++++++++ tests/ui/fail/fn_poly_annot_ref.rs | 13 +++++++++++++ tests/ui/fail/fn_poly_annot_stronger.rs | 13 +++++++++++++ tests/ui/fail/fn_poly_double_nested.rs | 21 +++++++++++++++++++++ tests/ui/fail/fn_poly_multiple_calls.rs | 13 +++++++++++++ tests/ui/fail/fn_poly_mut_ref.rs | 16 ++++++++++++++++ tests/ui/fail/fn_poly_nested_calls.rs | 17 +++++++++++++++++ tests/ui/fail/fn_poly_param_order.rs | 18 ++++++++++++++++++ tests/ui/fail/fn_poly_recursive.rs | 22 ++++++++++++++++++++++ tests/ui/fail/fn_poly_ref.rs | 15 +++++++++++++++ tests/ui/fail/fn_poly_unused_param.rs | 15 +++++++++++++++ tests/ui/pass/fn_poly_annot_complex.rs | 14 ++++++++++++++ tests/ui/pass/fn_poly_annot_multi_inst.rs | 18 ++++++++++++++++++ tests/ui/pass/fn_poly_annot_nested.rs | 17 +++++++++++++++++ tests/ui/pass/fn_poly_annot_recursive.rs | 17 +++++++++++++++++ tests/ui/pass/fn_poly_annot_ref.rs | 13 +++++++++++++ tests/ui/pass/fn_poly_annot_stronger.rs | 13 +++++++++++++ tests/ui/pass/fn_poly_double_nested.rs | 22 ++++++++++++++++++++++ tests/ui/pass/fn_poly_multiple_calls.rs | 15 +++++++++++++++ tests/ui/pass/fn_poly_mut_ref.rs | 16 ++++++++++++++++ tests/ui/pass/fn_poly_nested_calls.rs | 18 ++++++++++++++++++ tests/ui/pass/fn_poly_param_order.rs | 20 ++++++++++++++++++++ tests/ui/pass/fn_poly_recursive.rs | 23 +++++++++++++++++++++++ tests/ui/pass/fn_poly_ref.rs | 15 +++++++++++++++ tests/ui/pass/fn_poly_unused_param.rs | 15 +++++++++++++++ 28 files changed, 460 insertions(+) create mode 100644 tests/ui/fail/fn_poly_annot_complex.rs create mode 100644 tests/ui/fail/fn_poly_annot_multi_inst.rs create mode 100644 tests/ui/fail/fn_poly_annot_nested.rs create mode 100644 tests/ui/fail/fn_poly_annot_recursive.rs create mode 100644 tests/ui/fail/fn_poly_annot_ref.rs create mode 100644 tests/ui/fail/fn_poly_annot_stronger.rs create mode 100644 tests/ui/fail/fn_poly_double_nested.rs create mode 100644 tests/ui/fail/fn_poly_multiple_calls.rs create mode 100644 tests/ui/fail/fn_poly_mut_ref.rs create mode 100644 tests/ui/fail/fn_poly_nested_calls.rs create mode 100644 tests/ui/fail/fn_poly_param_order.rs create mode 100644 tests/ui/fail/fn_poly_recursive.rs create mode 100644 tests/ui/fail/fn_poly_ref.rs create mode 100644 tests/ui/fail/fn_poly_unused_param.rs create mode 100644 tests/ui/pass/fn_poly_annot_complex.rs create mode 100644 tests/ui/pass/fn_poly_annot_multi_inst.rs create mode 100644 tests/ui/pass/fn_poly_annot_nested.rs create mode 100644 tests/ui/pass/fn_poly_annot_recursive.rs create mode 100644 tests/ui/pass/fn_poly_annot_ref.rs create mode 100644 tests/ui/pass/fn_poly_annot_stronger.rs create mode 100644 tests/ui/pass/fn_poly_double_nested.rs create mode 100644 tests/ui/pass/fn_poly_multiple_calls.rs create mode 100644 tests/ui/pass/fn_poly_mut_ref.rs create mode 100644 tests/ui/pass/fn_poly_nested_calls.rs create mode 100644 tests/ui/pass/fn_poly_param_order.rs create mode 100644 tests/ui/pass/fn_poly_recursive.rs create mode 100644 tests/ui/pass/fn_poly_ref.rs create mode 100644 tests/ui/pass/fn_poly_unused_param.rs diff --git a/tests/ui/fail/fn_poly_annot_complex.rs b/tests/ui/fail/fn_poly_annot_complex.rs new file mode 100644 index 0000000..e37f5b8 --- /dev/null +++ b/tests/ui/fail/fn_poly_annot_complex.rs @@ -0,0 +1,12 @@ +//@error-in-other-file: Unsat + +#[thrust::requires((x.0 > 0) && (x.1 > 0))] +#[thrust::ensures((result.0 == x.1) && (result.1 == x.0))] +fn swap_positive(x: (i32, i32, T, U)) -> (i32, i32, U, T) { + (x.1, x.0, x.3, x.2) +} + +fn main() { + let result = swap_positive((-5, 10, true, 42)); + assert!(result.0 == 10); +} diff --git a/tests/ui/fail/fn_poly_annot_multi_inst.rs b/tests/ui/fail/fn_poly_annot_multi_inst.rs new file mode 100644 index 0000000..cf8331f --- /dev/null +++ b/tests/ui/fail/fn_poly_annot_multi_inst.rs @@ -0,0 +1,15 @@ +//@error-in-other-file: Unsat + +#[thrust::requires(true)] +#[thrust::ensures(result == x)] +fn id(x: T) -> T { + x +} + +fn main() { + let a = id(42); + assert!(a == 42); + + let b = id(true); + assert!(b == false); +} diff --git a/tests/ui/fail/fn_poly_annot_nested.rs b/tests/ui/fail/fn_poly_annot_nested.rs new file mode 100644 index 0000000..e1e5015 --- /dev/null +++ b/tests/ui/fail/fn_poly_annot_nested.rs @@ -0,0 +1,17 @@ +//@error-in-other-file: Unsat + +#[thrust::requires(true)] +#[thrust::ensures(result == x)] +fn id(x: T) -> T { + x +} + +#[thrust::requires(true)] +#[thrust::ensures(result != x)] +fn apply_twice(x: T) -> T { + id(id(x)) +} + +fn main() { + assert!(apply_twice(42) == 42); +} diff --git a/tests/ui/fail/fn_poly_annot_recursive.rs b/tests/ui/fail/fn_poly_annot_recursive.rs new file mode 100644 index 0000000..aa40960 --- /dev/null +++ b/tests/ui/fail/fn_poly_annot_recursive.rs @@ -0,0 +1,17 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +#[thrust::requires(n >= 0)] +#[thrust::ensures(result == value)] +fn repeat(n: i32, value: T) -> T { + if n == 0 { + value + } else { + repeat(n - 1, value) + } +} + +fn main() { + let result = repeat(-1, 42); + assert!(result == 42); +} diff --git a/tests/ui/fail/fn_poly_annot_ref.rs b/tests/ui/fail/fn_poly_annot_ref.rs new file mode 100644 index 0000000..10ba6b5 --- /dev/null +++ b/tests/ui/fail/fn_poly_annot_ref.rs @@ -0,0 +1,13 @@ +//@error-in-other-file: Unsat + +#[thrust::requires(true)] +#[thrust::ensures(result != x)] +fn id_ref(x: &T) -> &T { + x +} + +fn main() { + let val = 42; + let r = id_ref(&val); + assert!(*r == 42); +} diff --git a/tests/ui/fail/fn_poly_annot_stronger.rs b/tests/ui/fail/fn_poly_annot_stronger.rs new file mode 100644 index 0000000..90738a8 --- /dev/null +++ b/tests/ui/fail/fn_poly_annot_stronger.rs @@ -0,0 +1,13 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +#[thrust::requires(x > 0)] +#[thrust::ensures((result == x) && (result > 0))] +fn pass_positive(x: i32, _dummy: T) -> i32 { + x +} + +fn main() { + let result = pass_positive(-5, true); + assert!(result == -5); +} diff --git a/tests/ui/fail/fn_poly_double_nested.rs b/tests/ui/fail/fn_poly_double_nested.rs new file mode 100644 index 0000000..15040c5 --- /dev/null +++ b/tests/ui/fail/fn_poly_double_nested.rs @@ -0,0 +1,21 @@ +//@error-in-other-file: Unsat + +fn id(x: T) -> T { + x +} + +fn apply_twice(x: T) -> T { + id(id(x)) +} + +fn apply_thrice(x: T) -> T { + apply_twice(id(x)) +} + +fn apply_four(x: T) -> T { + apply_twice(apply_twice(x)) +} + +fn main() { + assert!(apply_four(42) == 43); +} diff --git a/tests/ui/fail/fn_poly_multiple_calls.rs b/tests/ui/fail/fn_poly_multiple_calls.rs new file mode 100644 index 0000000..f510eab --- /dev/null +++ b/tests/ui/fail/fn_poly_multiple_calls.rs @@ -0,0 +1,13 @@ +//@error-in-other-file: Unsat + +fn first(pair: (T, U)) -> T { + pair.0 +} + +fn main() { + let x = first((42, true)); + let y = first((true, 100)); + + assert!(x == 42); + assert!(y == false); +} diff --git a/tests/ui/fail/fn_poly_mut_ref.rs b/tests/ui/fail/fn_poly_mut_ref.rs new file mode 100644 index 0000000..1ab6ff9 --- /dev/null +++ b/tests/ui/fail/fn_poly_mut_ref.rs @@ -0,0 +1,16 @@ +//@error-in-other-file: Unsat + +fn update(x: &mut T, new_val: T) { + *x = new_val; +} + +fn chain_update(x: &mut T, temp: T, final_val: T) { + update(x, temp); + update(x, final_val); +} + +fn main() { + let mut val = 42; + chain_update(&mut val, 100, 200); + assert!(val == 42); +} diff --git a/tests/ui/fail/fn_poly_nested_calls.rs b/tests/ui/fail/fn_poly_nested_calls.rs new file mode 100644 index 0000000..609f136 --- /dev/null +++ b/tests/ui/fail/fn_poly_nested_calls.rs @@ -0,0 +1,17 @@ +//@error-in-other-file: Unsat + +fn id(x: T) -> T { + x +} + +fn apply_twice(x: T) -> T { + id(id(x)) +} + +fn apply_thrice(x: T) -> T { + id(apply_twice(x)) +} + +fn main() { + assert!(apply_thrice(42) == 43); +} diff --git a/tests/ui/fail/fn_poly_param_order.rs b/tests/ui/fail/fn_poly_param_order.rs new file mode 100644 index 0000000..93af596 --- /dev/null +++ b/tests/ui/fail/fn_poly_param_order.rs @@ -0,0 +1,18 @@ +//@error-in-other-file: Unsat + +fn select(a: T, b: U, c: V, which: i32) -> T { + if which == 0 { + a + } else { + a + } +} + +fn rotate(triple: (A, B, C)) -> (B, C, A) { + (triple.1, triple.2, triple.0) +} + +fn main() { + let x = rotate((1, true, 42)); + assert!(x.0 == false); +} diff --git a/tests/ui/fail/fn_poly_recursive.rs b/tests/ui/fail/fn_poly_recursive.rs new file mode 100644 index 0000000..8b2aaba --- /dev/null +++ b/tests/ui/fail/fn_poly_recursive.rs @@ -0,0 +1,22 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +fn repeat(n: i32, value: T) -> T { + if n <= 1 { + value + } else { + repeat(n - 1, repeat(1, value)) + } +} + +fn identity_loop(depth: i32, x: T) -> T { + if depth == 0 { + x + } else { + identity_loop(depth - 1, identity_loop(0, x)) + } +} + +fn main() { + assert!(repeat(5, 42) == 43); +} diff --git a/tests/ui/fail/fn_poly_ref.rs b/tests/ui/fail/fn_poly_ref.rs new file mode 100644 index 0000000..671f8b0 --- /dev/null +++ b/tests/ui/fail/fn_poly_ref.rs @@ -0,0 +1,15 @@ +//@error-in-other-file: Unsat + +fn identity_ref(x: &T) -> &T { + x +} + +fn chain_ref(x: &T) -> &T { + identity_ref(identity_ref(x)) +} + +fn main() { + let val = 42; + let r = chain_ref(&val); + assert!(*r == 43); +} diff --git a/tests/ui/fail/fn_poly_unused_param.rs b/tests/ui/fail/fn_poly_unused_param.rs new file mode 100644 index 0000000..bde86f1 --- /dev/null +++ b/tests/ui/fail/fn_poly_unused_param.rs @@ -0,0 +1,15 @@ +//@error-in-other-file: Unsat + +fn project_first(triple: (T, U, V)) -> T { + triple.0 +} + +fn chain(x: A, _phantom_b: B, _phantom_c: C) -> A { + x +} + +fn main() { + let x = project_first((42, true, 100)); + let y = chain(x, (1, 2), false); + assert!(y == 43); +} diff --git a/tests/ui/pass/fn_poly_annot_complex.rs b/tests/ui/pass/fn_poly_annot_complex.rs new file mode 100644 index 0000000..79d6251 --- /dev/null +++ b/tests/ui/pass/fn_poly_annot_complex.rs @@ -0,0 +1,14 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +#[thrust::requires((n > 0) && (m > 0))] +#[thrust::ensures((result.0 == m) && (result.1 == n))] +fn swap_pair(n: i32, m: i32, _phantom: T) -> (i32, i32) { + (m, n) +} + +fn main() { + let result = swap_pair(5, 10, true); + assert!(result.0 == 10); + assert!(result.1 == 5); +} diff --git a/tests/ui/pass/fn_poly_annot_multi_inst.rs b/tests/ui/pass/fn_poly_annot_multi_inst.rs new file mode 100644 index 0000000..372cd66 --- /dev/null +++ b/tests/ui/pass/fn_poly_annot_multi_inst.rs @@ -0,0 +1,18 @@ +//@check-pass + +#[thrust::requires(true)] +#[thrust::ensures(result == x)] +fn id(x: T) -> T { + x +} + +fn main() { + let a = id(42); + assert!(a == 42); + + let b = id(true); + assert!(b == true); + + let c = id((1, 2)); + assert!(c.0 == 1); +} diff --git a/tests/ui/pass/fn_poly_annot_nested.rs b/tests/ui/pass/fn_poly_annot_nested.rs new file mode 100644 index 0000000..927a393 --- /dev/null +++ b/tests/ui/pass/fn_poly_annot_nested.rs @@ -0,0 +1,17 @@ +//@check-pass + +#[thrust::requires(true)] +#[thrust::ensures(result == x)] +fn id(x: T) -> T { + x +} + +#[thrust::requires(true)] +#[thrust::ensures(result == x)] +fn apply_twice(x: T) -> T { + id(id(x)) +} + +fn main() { + assert!(apply_twice(42) == 42); +} diff --git a/tests/ui/pass/fn_poly_annot_recursive.rs b/tests/ui/pass/fn_poly_annot_recursive.rs new file mode 100644 index 0000000..f7dc255 --- /dev/null +++ b/tests/ui/pass/fn_poly_annot_recursive.rs @@ -0,0 +1,17 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +#[thrust::requires(n >= 0)] +#[thrust::ensures(result == value)] +fn repeat(n: i32, value: T) -> T { + if n == 0 { + value + } else { + repeat(n - 1, value) + } +} + +fn main() { + let result = repeat(5, 42); + assert!(result == 42); +} diff --git a/tests/ui/pass/fn_poly_annot_ref.rs b/tests/ui/pass/fn_poly_annot_ref.rs new file mode 100644 index 0000000..adb6a14 --- /dev/null +++ b/tests/ui/pass/fn_poly_annot_ref.rs @@ -0,0 +1,13 @@ +//@check-pass + +#[thrust::requires(true)] +#[thrust::ensures(result == x)] +fn id_ref(x: &T) -> &T { + x +} + +fn main() { + let val = 42; + let r = id_ref(&val); + assert!(*r == 42); +} diff --git a/tests/ui/pass/fn_poly_annot_stronger.rs b/tests/ui/pass/fn_poly_annot_stronger.rs new file mode 100644 index 0000000..c3ea370 --- /dev/null +++ b/tests/ui/pass/fn_poly_annot_stronger.rs @@ -0,0 +1,13 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +#[thrust::requires(x > 0)] +#[thrust::ensures((result == x) && (result > 0))] +fn pass_positive(x: i32, _dummy: T) -> i32 { + x +} + +fn main() { + let result = pass_positive(42, true); + assert!(result == 42); +} diff --git a/tests/ui/pass/fn_poly_double_nested.rs b/tests/ui/pass/fn_poly_double_nested.rs new file mode 100644 index 0000000..5c9b054 --- /dev/null +++ b/tests/ui/pass/fn_poly_double_nested.rs @@ -0,0 +1,22 @@ +//@check-pass + +fn id(x: T) -> T { + x +} + +fn apply_twice(x: T) -> T { + id(id(x)) +} + +fn apply_thrice(x: T) -> T { + apply_twice(id(x)) +} + +fn apply_four(x: T) -> T { + apply_twice(apply_twice(x)) +} + +fn main() { + assert!(apply_four(42) == 42); + assert!(apply_thrice(true) == true); +} diff --git a/tests/ui/pass/fn_poly_multiple_calls.rs b/tests/ui/pass/fn_poly_multiple_calls.rs new file mode 100644 index 0000000..aa83ba1 --- /dev/null +++ b/tests/ui/pass/fn_poly_multiple_calls.rs @@ -0,0 +1,15 @@ +//@check-pass + +fn first(pair: (T, U)) -> T { + pair.0 +} + +fn main() { + let x = first((42, true)); + let y = first((true, 100)); + let z = first(((1, 2), 3)); + + assert!(x == 42); + assert!(y == true); + assert!(z.0 == 1); +} diff --git a/tests/ui/pass/fn_poly_mut_ref.rs b/tests/ui/pass/fn_poly_mut_ref.rs new file mode 100644 index 0000000..4298a64 --- /dev/null +++ b/tests/ui/pass/fn_poly_mut_ref.rs @@ -0,0 +1,16 @@ +//@check-pass + +fn update(x: &mut T, new_val: T) { + *x = new_val; +} + +fn chain_update(x: &mut T, temp: T, final_val: T) { + update(x, temp); + update(x, final_val); +} + +fn main() { + let mut val = 42; + chain_update(&mut val, 100, 200); + assert!(val == 200); +} diff --git a/tests/ui/pass/fn_poly_nested_calls.rs b/tests/ui/pass/fn_poly_nested_calls.rs new file mode 100644 index 0000000..b2bef0b --- /dev/null +++ b/tests/ui/pass/fn_poly_nested_calls.rs @@ -0,0 +1,18 @@ +//@check-pass + +fn id(x: T) -> T { + x +} + +fn apply_twice(x: T) -> T { + id(id(x)) +} + +fn apply_thrice(x: T) -> T { + id(apply_twice(x)) +} + +fn main() { + assert!(apply_thrice(42) == 42); + assert!(apply_twice(true) == true); +} diff --git a/tests/ui/pass/fn_poly_param_order.rs b/tests/ui/pass/fn_poly_param_order.rs new file mode 100644 index 0000000..41191de --- /dev/null +++ b/tests/ui/pass/fn_poly_param_order.rs @@ -0,0 +1,20 @@ +//@check-pass + +fn select(a: T, b: U, c: V, which: i32) -> T { + if which == 0 { + a + } else { + a + } +} + +fn rotate(triple: (A, B, C)) -> (B, C, A) { + (triple.1, triple.2, triple.0) +} + +fn main() { + let x = rotate((1, true, 42)); + assert!(x.0 == true); + assert!(x.1 == 42); + assert!(x.2 == 1); +} diff --git a/tests/ui/pass/fn_poly_recursive.rs b/tests/ui/pass/fn_poly_recursive.rs new file mode 100644 index 0000000..cfc7e2e --- /dev/null +++ b/tests/ui/pass/fn_poly_recursive.rs @@ -0,0 +1,23 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +fn repeat(n: i32, value: T) -> T { + if n <= 1 { + value + } else { + repeat(n - 1, repeat(1, value)) + } +} + +fn identity_loop(depth: i32, x: T) -> T { + if depth == 0 { + x + } else { + identity_loop(depth - 1, identity_loop(0, x)) + } +} + +fn main() { + assert!(repeat(5, 42) == 42); + assert!(identity_loop(3, true) == true); +} diff --git a/tests/ui/pass/fn_poly_ref.rs b/tests/ui/pass/fn_poly_ref.rs new file mode 100644 index 0000000..fae27ae --- /dev/null +++ b/tests/ui/pass/fn_poly_ref.rs @@ -0,0 +1,15 @@ +//@check-pass + +fn identity_ref(x: &T) -> &T { + x +} + +fn chain_ref(x: &T) -> &T { + identity_ref(identity_ref(x)) +} + +fn main() { + let val = 42; + let r = chain_ref(&val); + assert!(*r == 42); +} diff --git a/tests/ui/pass/fn_poly_unused_param.rs b/tests/ui/pass/fn_poly_unused_param.rs new file mode 100644 index 0000000..e2ec90e --- /dev/null +++ b/tests/ui/pass/fn_poly_unused_param.rs @@ -0,0 +1,15 @@ +//@check-pass + +fn project_first(triple: (T, U, V)) -> T { + triple.0 +} + +fn chain(x: A, _phantom_b: B, _phantom_c: C) -> A { + x +} + +fn main() { + let x = project_first((42, true, 100)); + let y = chain(x, (1, 2), false); + assert!(y == 42); +} From 2d4e198b8c160d68a5cc77268ad99685619f1382 Mon Sep 17 00:00:00 2001 From: coord_e Date: Tue, 9 Sep 2025 23:35:56 +0900 Subject: [PATCH 27/75] Include existential quantification in Formula --- src/chc.rs | 46 +++++++++++++++++++++++++++++++++++---- src/chc/format_context.rs | 1 + src/chc/smtlib2.rs | 9 ++++++++ src/chc/unbox.rs | 7 ++++++ 4 files changed, 59 insertions(+), 4 deletions(-) diff --git a/src/chc.rs b/src/chc.rs index 6f5db32..c2bd695 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -406,6 +406,8 @@ pub enum Term { TupleProj(Box>, usize), DatatypeCtor(DatatypeSort, DatatypeSymbol, Vec>), DatatypeDiscr(DatatypeSymbol, Box>), + /// Used in [`Formula`] to represent existentially quantified variables appearing in annotations. + FormulaExistentialVar(Sort, String), } impl<'a, 'b, D, V> Pretty<'a, D, termcolor::ColorSpec> for &'b Term @@ -475,6 +477,7 @@ where Term::DatatypeDiscr(_, t) => allocator .text("discriminant") .append(t.pretty(allocator).parens()), + Term::FormulaExistentialVar(_, name) => allocator.text(name.clone()), } } } @@ -521,6 +524,7 @@ impl Term { args.into_iter().map(|t| t.subst_var(&mut f)).collect(), ), Term::DatatypeDiscr(d_sym, t) => Term::DatatypeDiscr(d_sym, Box::new(t.subst_var(f))), + Term::FormulaExistentialVar(sort, name) => Term::FormulaExistentialVar(sort, name), } } @@ -562,15 +566,18 @@ impl Term { Term::TupleProj(t, i) => t.sort(var_sort).tuple_elem(*i), Term::DatatypeCtor(sort, _, _) => sort.clone().into(), Term::DatatypeDiscr(_, _) => Sort::int(), + Term::FormulaExistentialVar(sort, _) => sort.clone(), } } fn fv_impl(&self) -> Box + '_> { match self { Term::Var(v) => Box::new(std::iter::once(v)), - Term::Null | Term::Bool(_) | Term::Int(_) | Term::String(_) => { - Box::new(std::iter::empty()) - } + Term::Null + | Term::Bool(_) + | Term::Int(_) + | Term::String(_) + | Term::FormulaExistentialVar { .. } => Box::new(std::iter::empty()), Term::Box(t) => t.fv_impl(), Term::Mut(t1, t2) => Box::new(t1.fv_impl().chain(t2.fv_impl())), Term::BoxCurrent(t) => t.fv_impl(), @@ -1083,6 +1090,7 @@ pub enum Formula { Not(Box>), And(Vec>), Or(Vec>), + Exists(Vec<(String, Sort)>, Box>), } impl Default for Formula { @@ -1124,6 +1132,25 @@ where ); inner.group() } + Formula::Exists(vars, fo) => { + let vars = allocator.intersperse( + vars.iter().map(|(name, sort)| { + allocator + .text(name.clone()) + .append(allocator.text(":")) + .append(allocator.text(" ")) + .append(sort.pretty(allocator)) + }), + allocator.text(", ").append(allocator.line()), + ); + allocator + .text("∃") + .append(vars) + .append(allocator.text(".")) + .append(allocator.line()) + .append(fo.pretty(allocator).nest(2)) + .group() + } } } } @@ -1139,7 +1166,9 @@ impl Formula { D::Doc: Clone, { match self { - Formula::And(_) | Formula::Or(_) => self.pretty(allocator).parens(), + Formula::And(_) | Formula::Or(_) | Formula::Exists { .. } => { + self.pretty(allocator).parens() + } _ => self.pretty(allocator), } } @@ -1158,6 +1187,7 @@ impl Formula { Formula::Not(fo) => fo.is_bottom(), Formula::And(fs) => fs.iter().all(Formula::is_top), Formula::Or(fs) => fs.iter().any(Formula::is_top), + Formula::Exists(_, fo) => fo.is_top(), } } @@ -1167,6 +1197,7 @@ impl Formula { Formula::Not(fo) => fo.is_top(), Formula::And(fs) => fs.iter().any(Formula::is_bottom), Formula::Or(fs) => fs.iter().all(Formula::is_bottom), + Formula::Exists(_, fo) => fo.is_bottom(), } } @@ -1210,6 +1241,7 @@ impl Formula { Formula::And(fs.into_iter().map(|fo| fo.subst_var(&mut f)).collect()) } Formula::Or(fs) => Formula::Or(fs.into_iter().map(|fo| fo.subst_var(&mut f)).collect()), + Formula::Exists(vars, fo) => Formula::Exists(vars, Box::new(fo.subst_var(f))), } } @@ -1224,6 +1256,7 @@ impl Formula { Formula::Not(fo) => Formula::Not(Box::new(fo.map_var(&mut f))), Formula::And(fs) => Formula::And(fs.into_iter().map(|fo| fo.map_var(&mut f)).collect()), Formula::Or(fs) => Formula::Or(fs.into_iter().map(|fo| fo.map_var(&mut f)).collect()), + Formula::Exists(vars, fo) => Formula::Exists(vars, Box::new(fo.map_var(f))), } } @@ -1237,6 +1270,7 @@ impl Formula { Formula::Not(fo) => Box::new(fo.fv()), Formula::And(fs) => Box::new(fs.iter().flat_map(Formula::fv)), Formula::Or(fs) => Box::new(fs.iter().flat_map(Formula::fv)), + Formula::Exists(_, fo) => Box::new(fo.fv()), } } @@ -1250,6 +1284,7 @@ impl Formula { Formula::Not(fo) => Box::new(fo.iter_atoms()), Formula::And(fs) => Box::new(fs.iter().flat_map(Formula::iter_atoms)), Formula::Or(fs) => Box::new(fs.iter().flat_map(Formula::iter_atoms)), + Formula::Exists(_, fo) => Box::new(fo.iter_atoms()), } } @@ -1291,6 +1326,9 @@ impl Formula { *self = fs.pop().unwrap(); } } + Formula::Exists(_, fo) => { + fo.simplify(); + } } } } diff --git a/src/chc/format_context.rs b/src/chc/format_context.rs index ed82811..2123315 100644 --- a/src/chc/format_context.rs +++ b/src/chc/format_context.rs @@ -57,6 +57,7 @@ fn term_sorts(clause: &chc::Clause, t: &chc::Term, sorts: &mut BTreeSet term_sorts(clause, t, sorts), + chc::Term::FormulaExistentialVar(_, _) => {} } } diff --git a/src/chc/smtlib2.rs b/src/chc/smtlib2.rs index c708770..3cef75e 100644 --- a/src/chc/smtlib2.rs +++ b/src/chc/smtlib2.rs @@ -202,6 +202,7 @@ impl<'ctx, 'a> std::fmt::Display for Term<'ctx, 'a> { Term::new(self.ctx, self.clause, t) ) } + chc::Term::FormulaExistentialVar(_, name) => write!(f, "{}", name), } } } @@ -280,6 +281,14 @@ impl<'ctx, 'a> std::fmt::Display for Formula<'ctx, 'a> { let fs = List::open(fs.iter().map(|fo| Formula::new(self.ctx, self.clause, fo))); write!(f, "(or {})", fs) } + chc::Formula::Exists(vars, fo) => { + let vars = + List::closed(vars.iter().map(|(v, s)| { + List::closed([v.to_string(), self.ctx.fmt_sort(s).to_string()]) + })); + let fo = Formula::new(self.ctx, self.clause, fo); + write!(f, "(exists {vars} {fo})") + } } } } diff --git a/src/chc/unbox.rs b/src/chc/unbox.rs index 532f623..ffc4600 100644 --- a/src/chc/unbox.rs +++ b/src/chc/unbox.rs @@ -17,6 +17,9 @@ fn unbox_term(term: Term) -> Term { Term::DatatypeCtor(s1, s2, args.into_iter().map(unbox_term).collect()) } Term::DatatypeDiscr(sym, arg) => Term::DatatypeDiscr(sym, Box::new(unbox_term(*arg))), + Term::FormulaExistentialVar(sort, name) => { + Term::FormulaExistentialVar(unbox_sort(sort), name) + } } } @@ -52,6 +55,10 @@ fn unbox_formula(formula: Formula) -> Formula { Formula::Not(fo) => Formula::Not(Box::new(unbox_formula(*fo))), Formula::And(fs) => Formula::And(fs.into_iter().map(unbox_formula).collect()), Formula::Or(fs) => Formula::Or(fs.into_iter().map(unbox_formula).collect()), + Formula::Exists(vars, fo) => { + let vars = vars.into_iter().map(|(v, s)| (v, unbox_sort(s))).collect(); + Formula::Exists(vars, Box::new(unbox_formula(*fo))) + } } } From 4060e06e54b560d19d8e692e7d64905bec8cae63 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 14 Sep 2025 23:34:20 +0900 Subject: [PATCH 28/75] Parse existentials --- src/annot.rs | 119 +++++++++++++++++++++++++++++++++++++++++++++++++-- src/chc.rs | 4 ++ 2 files changed, 119 insertions(+), 4 deletions(-) diff --git a/src/annot.rs b/src/annot.rs index f559bb8..730b75e 100644 --- a/src/annot.rs +++ b/src/annot.rs @@ -12,6 +12,7 @@ use rustc_ast::token::{BinOpToken, Delimiter, LitKind, Token, TokenKind}; use rustc_ast::tokenstream::{RefTokenTreeCursor, Spacing, TokenStream, TokenTree}; use rustc_index::IndexVec; use rustc_span::symbol::Ident; +use std::collections::HashMap; use crate::chc; use crate::pretty::PrettyDisplayExt as _; @@ -185,6 +186,7 @@ impl FormulaOrTerm { struct Parser<'a, T> { resolver: T, cursor: RefTokenTreeCursor<'a>, + formula_existentials: HashMap, } impl<'a, T> Parser<'a, T> @@ -232,6 +234,15 @@ where } } + fn expect_next_ident(&mut self) -> Result { + let t = self.next_token("ident")?; + if let Some((ident, _)) = t.ident() { + Ok(ident) + } else { + Err(ParseAttrError::unexpected_token("ident", t.clone())) + } + } + fn consume(&mut self) { self.cursor.next().unwrap(); } @@ -296,6 +307,7 @@ where let mut parser = Parser { resolver: self.boxed_resolver(), cursor: s.trees(), + formula_existentials: self.formula_existentials.clone(), }; let formula_or_term = parser.parse_formula_or_term_or_tuple()?; parser.end_of_input()?; @@ -305,9 +317,17 @@ where }; let formula_or_term = if let Some((ident, _)) = t.ident() { - match ident.as_str() { - "true" => FormulaOrTerm::Formula(chc::Formula::top()), - "false" => FormulaOrTerm::Formula(chc::Formula::bottom()), + match ( + ident.as_str(), + self.formula_existentials.get(ident.name.as_str()), + ) { + ("true", _) => FormulaOrTerm::Formula(chc::Formula::top()), + ("false", _) => FormulaOrTerm::Formula(chc::Formula::bottom()), + (_, Some(sort)) => { + let var = + chc::Term::FormulaExistentialVar(sort.clone(), ident.name.to_string()); + FormulaOrTerm::Term(var, sort.clone()) + } _ => { let (v, sort) = self.resolve(ident)?; FormulaOrTerm::Term(chc::Term::var(v), sort) @@ -575,8 +595,94 @@ where Ok(formula_or_term) } + fn parse_exists(&mut self) -> Result> { + match self.look_ahead_token(0) { + Some(Token { + kind: TokenKind::Ident(sym, _), + .. + }) if sym.as_str() == "exists" => { + self.consume(); + let mut vars = Vec::new(); + loop { + let ident = self.expect_next_ident()?; + self.expect_next_token(TokenKind::Colon, ":")?; + let sort = self.parse_sort()?; + vars.push((ident.name.to_string(), sort)); + match self.next_token(". or ,")? { + Token { + kind: TokenKind::Comma, + .. + } => {} + Token { + kind: TokenKind::Dot, + .. + } => break, + t => return Err(ParseAttrError::unexpected_token(". or ,", t.clone())), + } + } + self.formula_existentials.extend(vars.iter().cloned()); + let formula = self + .parse_formula_or_term()? + .into_formula() + .ok_or_else(|| ParseAttrError::unexpected_term("in exists formula"))?; + for (name, _) in &vars { + self.formula_existentials.remove(name); + } + Ok(FormulaOrTerm::Formula(chc::Formula::exists(vars, formula))) + } + _ => self.parse_binop_5(), + } + } + fn parse_formula_or_term(&mut self) -> Result> { - self.parse_binop_5() + self.parse_exists() + } + + fn parse_sort(&mut self) -> Result { + let tt = self.next_token_tree("sort")?.clone(); + match tt { + TokenTree::Token( + Token { + kind: TokenKind::Ident(sym, _), + .. + }, + _, + ) => { + let sort = match sym.as_str() { + "bool" => chc::Sort::bool(), + "int" => chc::Sort::int(), + "string" => unimplemented!(), + "null" => chc::Sort::null(), + "fn" => unimplemented!(), + _ => unimplemented!(), + }; + Ok(sort) + } + TokenTree::Delimited(_, _, Delimiter::Parenthesis, ts) => { + let mut parser = Parser { + resolver: self.boxed_resolver(), + cursor: ts.trees(), + formula_existentials: self.formula_existentials.clone(), + }; + let mut sorts = Vec::new(); + loop { + sorts.push(parser.parse_sort()?); + match parser.look_ahead_token(0) { + Some(Token { + kind: TokenKind::Comma, + .. + }) => { + parser.consume(); + } + None => break, + Some(t) => return Err(ParseAttrError::unexpected_token(",", t.clone())), + } + } + parser.end_of_input()?; + Ok(chc::Sort::tuple(sorts)) + } + t => Err(ParseAttrError::unexpected_token_tree("sort", t.clone())), + } } fn parse_ty(&mut self) -> Result> { @@ -662,6 +768,7 @@ where let mut parser = Parser { resolver: self.boxed_resolver(), cursor: ts.trees(), + formula_existentials: self.formula_existentials.clone(), }; let mut rtys = Vec::new(); loop { @@ -697,6 +804,7 @@ where let mut parser = Parser { resolver: self.boxed_resolver(), cursor: ts.trees(), + formula_existentials: self.formula_existentials.clone(), }; let self_ident = if matches!( parser.look_ahead_token(1), @@ -720,6 +828,7 @@ where let mut parser = Parser { resolver: RefinementResolver::new(self.boxed_resolver()), cursor: parser.cursor, + formula_existentials: self.formula_existentials.clone(), }; if let Some(self_ident) = self_ident { parser.resolver.set_self(self_ident, ty.to_sort()); @@ -859,6 +968,7 @@ where let mut parser = Parser { resolver: &self.resolver, cursor: ts.trees(), + formula_existentials: Default::default(), }; let rty = parser.parse_rty()?; parser.end_of_input()?; @@ -869,6 +979,7 @@ where let mut parser = Parser { resolver: &self.resolver, cursor: ts.trees(), + formula_existentials: Default::default(), }; let formula = parser.parse_annot_formula()?; parser.end_of_input()?; diff --git a/src/chc.rs b/src/chc.rs index c2bd695..fbd301b 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -1228,6 +1228,10 @@ impl Formula { } } + pub fn exists(vars: Vec<(String, Sort)>, body: Self) -> Self { + Formula::Exists(vars, Box::new(body)) + } + pub fn subst_var(self, f: F) -> Formula where F: FnMut(V) -> Term, From 7270ca421b0de5d555b100edd17539739a589691 Mon Sep 17 00:00:00 2001 From: coord_e Date: Thu, 11 Dec 2025 01:07:29 +0900 Subject: [PATCH 29/75] Test parsing existentials --- .github/workflows/ci.yml | 24 ++++++++++++++++++++++++ tests/ui/fail/annot_exists.rs | 16 ++++++++++++++++ tests/ui/pass/annot_exists.rs | 16 ++++++++++++++++ 3 files changed, 56 insertions(+) create mode 100644 tests/ui/fail/annot_exists.rs create mode 100644 tests/ui/pass/annot_exists.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7740817..a3c8103 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,9 +24,33 @@ jobs: runs-on: ubuntu-latest permissions: contents: read + env: + COAR_IMAGE: ghcr.io/hiroshi-unno/coar@sha256:73144ed27a02b163d1a71b41b58f3b5414f12e91326015600cfdca64ff19f011 steps: - uses: actions/checkout@v4 - uses: ./.github/actions/setup-z3 + - name: Setup thrust-pcsat-wrapper + run: | + docker pull "$COAR_IMAGE" + + cat <<"EOF" > thrust-pcsat-wrapper + #!/bin/bash + + smt2=$(mktemp -p . --suffix .smt2) + trap "rm -f $smt2" EXIT + cp "$1" "$smt2" + out=$( + docker run --rm -v "$PWD:/mnt" -w /root/coar "$COAR_IMAGE" \ + main.exe -c ./config/solver/pcsat_tbq_ar.json -p pcsp "/mnt/$smt2" + ) + exit_code=$? + echo "${out%,*}" + exit "$exit_code" + EOF + chmod +x thrust-pcsat-wrapper + + mkdir -p ~/.local/bin + mv thrust-pcsat-wrapper ~/.local/bin/thrust-pcsat-wrapper - run: rustup show - uses: Swatinem/rust-cache@v2 - run: cargo test diff --git a/tests/ui/fail/annot_exists.rs b/tests/ui/fail/annot_exists.rs new file mode 100644 index 0000000..888b2ec --- /dev/null +++ b/tests/ui/fail/annot_exists.rs @@ -0,0 +1,16 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off +//@rustc-env: THRUST_SOLVER=thrust-pcsat-wrapper + +#[thrust::trusted] +#[thrust::callable] +fn rand() -> i32 { unimplemented!() } + +#[thrust::requires(true)] +#[thrust::ensures(exists x:int. result == 2 * x)] +fn f() -> i32 { + let x = rand(); + x + x + x +} + +fn main() {} diff --git a/tests/ui/pass/annot_exists.rs b/tests/ui/pass/annot_exists.rs new file mode 100644 index 0000000..95151e9 --- /dev/null +++ b/tests/ui/pass/annot_exists.rs @@ -0,0 +1,16 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off +//@rustc-env: THRUST_SOLVER=thrust-pcsat-wrapper + +#[thrust::trusted] +#[thrust::callable] +fn rand() -> i32 { unimplemented!() } + +#[thrust::requires(true)] +#[thrust::ensures(exists x:int. result == 2 * x)] +fn f() -> i32 { + let x = rand(); + x + x +} + +fn main() {} From e4db83815b769fd3b51e815cfd454975da482046 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 23 Nov 2025 21:43:28 +0900 Subject: [PATCH 30/75] Enable to handle some promoted constants --- src/analyze/basic_block.rs | 102 ++++++++++++++++++++++++++++++++++++- src/refine/env.rs | 53 ++++--------------- 2 files changed, 110 insertions(+), 45 deletions(-) diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 9f6ee49..b8982c5 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -125,11 +125,111 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { clauses } + fn const_bytes_ty( + &self, + ty: mir_ty::Ty<'tcx>, + alloc: mir::interpret::ConstAllocation, + range: std::ops::Range, + ) -> PlaceType { + let alloc = alloc.inner(); + let bytes = alloc.inspect_with_uninit_and_ptr_outside_interpreter(range); + match ty.kind() { + mir_ty::TyKind::Str => { + let content = std::str::from_utf8(bytes).unwrap(); + PlaceType::with_ty_and_term( + rty::Type::string(), + chc::Term::string(content.to_owned()), + ) + } + mir_ty::TyKind::Bool => { + PlaceType::with_ty_and_term(rty::Type::bool(), chc::Term::bool(bytes[0] != 0)) + } + mir_ty::TyKind::Int(_) => { + // TODO: see target endianness + let val = match bytes.len() { + 1 => i8::from_ne_bytes(bytes.try_into().unwrap()) as i64, + 2 => i16::from_ne_bytes(bytes.try_into().unwrap()) as i64, + 4 => i32::from_ne_bytes(bytes.try_into().unwrap()) as i64, + 8 => i64::from_ne_bytes(bytes.try_into().unwrap()), + _ => unimplemented!("const int bytes len: {}", bytes.len()), + }; + PlaceType::with_ty_and_term(rty::Type::int(), chc::Term::int(val)) + } + _ => unimplemented!("const bytes ty: {:?}", ty), + } + } + + fn const_value_ty(&self, val: &mir::ConstValue<'tcx>, ty: &mir_ty::Ty<'tcx>) -> PlaceType { + use mir::{interpret::Scalar, ConstValue, Mutability}; + match (ty.kind(), val) { + (mir_ty::TyKind::Int(_), ConstValue::Scalar(Scalar::Int(val))) => { + let val = val.try_to_int(val.size()).unwrap(); + PlaceType::with_ty_and_term( + rty::Type::int(), + chc::Term::int(val.try_into().unwrap()), + ) + } + (mir_ty::TyKind::Bool, ConstValue::Scalar(Scalar::Int(val))) => { + PlaceType::with_ty_and_term( + rty::Type::bool(), + chc::Term::bool(val.try_to_bool().unwrap()), + ) + } + ( + mir_ty::TyKind::Ref(_, elem, Mutability::Not), + ConstValue::Scalar(Scalar::Ptr(ptr, _)), + ) => { + // Pointer::into_parts is OK for CtfeProvenance + // in a later version of rustc it has prov_and_relative_offset that ensures this + let (prov, offset) = ptr.into_parts(); + let global_alloc = self.tcx.global_alloc(prov.alloc_id()); + match global_alloc { + mir::interpret::GlobalAlloc::Memory(alloc) => { + let layout = self + .tcx + .layout_of(mir_ty::ParamEnv::reveal_all().and(*elem)) + .unwrap(); + let size = layout.size; + let range = + offset.bytes() as usize..(offset.bytes() + size.bytes()) as usize; + self.const_bytes_ty(*elem, alloc, range).immut() + } + _ => unimplemented!("const ptr alloc: {:?}", global_alloc), + } + } + (mir_ty::TyKind::Ref(_, elem, Mutability::Not), ConstValue::Slice { data, meta }) => { + let end = (*meta).try_into().unwrap(); + self.const_bytes_ty(*elem, *data, 0..end).immut() + } + _ => unimplemented!("const: {:?}, ty: {:?}", val, ty), + } + } + + fn const_ty(&self, const_: &mir::Const<'tcx>) -> PlaceType { + match const_ { + mir::Const::Val(val, ty) => self.const_value_ty(val, ty), + mir::Const::Unevaluated(unevaluated, ty) => { + // since all constants are immutable in current setup, + // it should be okay to evaluate them here on-the-fly + let param_env = self.tcx.param_env(self.local_def_id); + let val = self + .tcx + .const_eval_resolve(param_env, *unevaluated, None) + .unwrap(); + self.const_value_ty(&val, ty) + } + _ => unimplemented!("const: {:?}", const_), + } + } + fn operand_type(&self, mut operand: Operand<'tcx>) -> PlaceType { if let Operand::Copy(p) | Operand::Move(p) = &mut operand { *p = self.elaborate_place(p); } - let ty = self.env.operand_type(operand.clone()); + let ty = match &operand { + Operand::Copy(place) | Operand::Move(place) => self.env.place_type(*place), + Operand::Constant(operand) => self.const_ty(&operand.const_), + }; tracing::debug!(operand = ?operand, ty = %ty.display(), "operand_type"); ty } diff --git a/src/refine/env.rs b/src/refine/env.rs index a1edbc1..f18dadb 100644 --- a/src/refine/env.rs +++ b/src/refine/env.rs @@ -2,8 +2,7 @@ use std::collections::{BTreeMap, HashMap}; use pretty::{termcolor, Pretty}; use rustc_index::IndexVec; -use rustc_middle::mir::{self, Local, Operand, Place, PlaceElem}; -use rustc_middle::ty as mir_ty; +use rustc_middle::mir::{Local, Place, PlaceElem}; use rustc_target::abi::{FieldIdx, VariantIdx}; use crate::chc; @@ -454,6 +453,14 @@ impl PlaceType { builder.build(ty, term) } + pub fn immut(self) -> PlaceType { + let mut builder = PlaceTypeBuilder::default(); + let (inner_ty, inner_term) = builder.subsume(self); + let ty = rty::PointerType::immut_to(inner_ty).into(); + let term = chc::Term::box_(inner_term); + builder.build(ty, term) + } + pub fn tuple(ptys: Vec) -> PlaceType { let mut builder = PlaceTypeBuilder::default(); let mut tys = Vec::new(); @@ -951,48 +958,6 @@ impl Env { self.var_type(local.into()) } - pub fn operand_type(&self, operand: Operand<'_>) -> PlaceType { - use mir::{interpret::Scalar, Const, ConstValue, Mutability}; - match operand { - Operand::Copy(place) | Operand::Move(place) => self.place_type(place), - Operand::Constant(operand) => { - let Const::Val(val, ty) = operand.const_ else { - unimplemented!("const: {:?}", operand.const_); - }; - match (ty.kind(), val) { - (mir_ty::TyKind::Int(_), ConstValue::Scalar(Scalar::Int(val))) => { - let val = val.try_to_int(val.size()).unwrap(); - PlaceType::with_ty_and_term( - rty::Type::int(), - chc::Term::int(val.try_into().unwrap()), - ) - } - (mir_ty::TyKind::Bool, ConstValue::Scalar(Scalar::Int(val))) => { - PlaceType::with_ty_and_term( - rty::Type::bool(), - chc::Term::bool(val.try_to_bool().unwrap()), - ) - } - ( - mir_ty::TyKind::Ref(_, elem, Mutability::Not), - ConstValue::Slice { data, meta }, - ) if matches!(elem.kind(), mir_ty::TyKind::Str) => { - let end = meta.try_into().unwrap(); - let content = data - .inner() - .inspect_with_uninit_and_ptr_outside_interpreter(0..end); - let content = std::str::from_utf8(content).unwrap(); - PlaceType::with_ty_and_term( - rty::PointerType::immut_to(rty::Type::string()).into(), - chc::Term::box_(chc::Term::string(content.to_owned())), - ) - } - _ => unimplemented!("const: {:?}, ty: {:?}", val, ty), - } - } - } - } - fn borrow_var(&mut self, var: Var, prophecy: TempVarIdx) -> PlaceType { match *self.flow_binding(var).expect("borrowing unbound var") { FlowBinding::Box(x) => { From d84871555a737bc67129a36b35e0d9784b4faacd Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 14 Dec 2025 17:38:22 +0900 Subject: [PATCH 31/75] Instantiate body instead of using ParamTypeMapper --- src/analyze.rs | 58 +++++++++++++++++++++++++++----------- src/analyze/basic_block.rs | 29 ++++++++----------- src/analyze/crate_.rs | 53 ++++++++++++++++++++++++++++++++-- src/analyze/local_def.rs | 11 ++++++-- src/rty.rs | 2 +- 5 files changed, 115 insertions(+), 38 deletions(-) diff --git a/src/analyze.rs b/src/analyze.rs index a1ebd34..dc06cbe 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -104,14 +104,14 @@ impl<'tcx> ReplacePlacesVisitor<'tcx> { } #[derive(Debug, Clone)] -struct DeferredDefTy { - cache: Rc>>, +struct DeferredDefTy<'tcx> { + cache: Rc, rty::RefinedType>>>, } #[derive(Debug, Clone)] -enum DefTy { +enum DefTy<'tcx> { Concrete(rty::RefinedType), - Deferred(DeferredDefTy), + Deferred(DeferredDefTy<'tcx>), } #[derive(Clone)] @@ -123,7 +123,7 @@ pub struct Analyzer<'tcx> { /// currently contains only local-def templates, /// but will be extended to contain externally known def's refinement types /// (at least for every defs referenced by local def bodies) - defs: HashMap, + defs: HashMap>, /// Resulting CHC system. system: Rc>, @@ -241,15 +241,17 @@ impl<'tcx> Analyzer<'tcx> { pub fn def_ty_with_args( &mut self, def_id: DefId, - rty_args: rty::TypeArgs, + generic_args: mir_ty::GenericArgsRef<'tcx>, ) -> Option { + let type_builder = TypeBuilder::new(self.tcx, def_id); + let deferred_ty = match self.defs.get(&def_id)? { DefTy::Concrete(rty) => { let mut def_ty = rty.clone(); def_ty.instantiate_ty_params( - rty_args - .clone() - .into_iter() + generic_args + .types() + .map(|ty| type_builder.build(ty)) .map(rty::RefinedType::unrefined) .collect(), ); @@ -259,21 +261,19 @@ impl<'tcx> Analyzer<'tcx> { }; let deferred_ty_cache = Rc::clone(&deferred_ty.cache); // to cut reference to allow &mut self - if let Some(rty) = deferred_ty_cache.borrow().get(&rty_args) { + if let Some(rty) = deferred_ty_cache.borrow().get(&generic_args) { return Some(rty.clone()); } - let type_builder = TypeBuilder::new(self.tcx, def_id).with_param_mapper({ - let rty_args = rty_args.clone(); - move |ty: rty::ParamType| rty_args[ty.idx].clone() - }); let mut analyzer = self.local_def_analyzer(def_id.as_local()?); - analyzer.type_builder(type_builder); + analyzer + .type_builder(type_builder) + .generic_args(generic_args); let expected = analyzer.expected_ty(); deferred_ty_cache .borrow_mut() - .insert(rty_args, expected.clone()); + .insert(generic_args, expected.clone()); analyzer.run(&expected); Some(expected) @@ -340,4 +340,30 @@ impl<'tcx> Analyzer<'tcx> { self.tcx.dcx().err(format!("verification error: {:?}", err)); } } + + /// Computes the signature of the local function. + /// + /// This is a drop-in replacement of `self.tcx.fn_sig(local_def_id).instantiate_identity().skip_binder()`, + /// but extracts parameter and return types directly from the given `body` to obtain a signature that + /// reflects potential type instantiations happened after `optimized_mir`. + pub fn local_fn_sig_with_body( + &self, + local_def_id: LocalDefId, + body: &mir::Body<'tcx>, + ) -> mir_ty::FnSig<'tcx> { + let ty = self.tcx.type_of(local_def_id).instantiate_identity(); + let sig = if let mir_ty::TyKind::Closure(_, substs) = ty.kind() { + substs.as_closure().sig().skip_binder() + } else { + ty.fn_sig(self.tcx).skip_binder() + }; + + self.tcx.mk_fn_sig( + body.args_iter().map(|arg| body.local_decls[arg].ty), + body.return_ty(), + sig.c_variadic, + sig.unsafety, + sig.abi, + ) + } } diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 9f6ee49..5719406 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -263,15 +263,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { _ty, ) => { let func_ty = match operand.const_fn_def() { - Some((def_id, args)) => { - let rty_args: IndexVec<_, _> = - args.types().map(|ty| self.type_builder.build(ty)).collect(); - self.ctx - .def_ty_with_args(def_id, rty_args) - .expect("unknown def") - .ty - .clone() - } + Some((def_id, args)) => self + .ctx + .def_ty_with_args(def_id, args) + .expect("unknown def") + .ty + .clone(), _ => unimplemented!(), }; PlaceType::with_ty_and_term(func_ty.vacuous(), chc::Term::null()) @@ -471,14 +468,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let ret = rty::RefinedType::new(rty::Type::unit(), ret_formula.into()); rty::FunctionType::new([param1, param2].into_iter().collect(), ret).into() } - Some((def_id, args)) => { - let rty_args = args.types().map(|ty| self.type_builder.build(ty)).collect(); - self.ctx - .def_ty_with_args(def_id, rty_args) - .expect("unknown def") - .ty - .vacuous() - } + Some((def_id, args)) => self + .ctx + .def_ty_with_args(def_id, args) + .expect("unknown def") + .ty + .vacuous(), _ => self.operand_type(func.clone()).ty, }; let expected_args: IndexVec<_, _> = args diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index e15ca49..218fecc 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -76,8 +76,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { // check polymorphic function def by replacing type params with some opaque type // (and this is no-op if the function is mono) - let type_builder = TypeBuilder::new(self.tcx, local_def_id.to_def_id()) - .with_param_mapper(|_| rty::Type::int()); + let type_builder = TypeBuilder::new(self.tcx, local_def_id.to_def_id()); let mut expected = expected.clone(); let subst = rty::TypeParamSubst::new( expected @@ -87,13 +86,63 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .collect(), ); expected.subst_ty_params(&subst); + let generic_args = self.placeholder_generic_args(*local_def_id); self.ctx .local_def_analyzer(*local_def_id) .type_builder(type_builder) + .generic_args(generic_args) .run(&expected); } } + fn placeholder_generic_args(&self, local_def_id: LocalDefId) -> mir_ty::GenericArgsRef<'tcx> { + let mut constrained_params = HashSet::new(); + let predicates = self.tcx.predicates_of(local_def_id); + let sized_trait = self.tcx.lang_items().sized_trait().unwrap(); + for (clause, _) in predicates.predicates { + let mir_ty::ClauseKind::Trait(pred) = clause.kind().skip_binder() else { + continue; + }; + if pred.def_id() == sized_trait { + continue; + }; + for arg in pred.trait_ref.args.iter().flat_map(|ty| ty.walk()) { + let Some(ty) = arg.as_type() else { + continue; + }; + let mir_ty::TyKind::Param(param_ty) = ty.kind() else { + continue; + }; + constrained_params.insert(param_ty.index); + } + } + + let mut args: Vec> = Vec::new(); + + let generics = self.tcx.generics_of(local_def_id); + for idx in 0..generics.count() { + let param = generics.param_at(idx, self.tcx); + let arg = match param.kind { + mir_ty::GenericParamDefKind::Type { .. } => { + if constrained_params.contains(¶m.index) { + panic!( + "unable to check generic function with constrained type parameter: {}", + self.tcx.def_path_str(local_def_id) + ); + } + self.tcx.types.i32.into() + } + mir_ty::GenericParamDefKind::Const { .. } => { + unimplemented!() + } + mir_ty::GenericParamDefKind::Lifetime { .. } => self.tcx.lifetimes.re_erased.into(), + }; + args.push(arg); + } + + self.tcx.mk_args(&args) + } + fn assert_callable_entry(&mut self) { if let Some((def_id, _)) = self.tcx.entry_fn(()) { // we want to assert entry function is safe to execute without any assumption diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 58e4b8b..b13cc63 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -169,8 +169,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } pub fn expected_ty(&mut self) -> rty::RefinedType { - let sig = self.tcx.fn_sig(self.local_def_id); - let sig = sig.instantiate_identity().skip_binder(); + let sig = self + .ctx + .local_fn_sig_with_body(self.local_def_id, &self.body); let mut param_resolver = analyze::annot::ParamResolver::default(); for (input_ident, input_ty) in self @@ -654,6 +655,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self } + pub fn generic_args(&mut self, generic_args: mir_ty::GenericArgsRef<'tcx>) -> &mut Self { + self.body = + mir_ty::EarlyBinder::bind(self.body.clone()).instantiate(self.tcx, generic_args); + self + } + pub fn run(&mut self, expected: &rty::RefinedType) { let span = tracing::info_span!("def", def = %self.tcx.def_path_str(self.local_def_id)); let _guard = span.enter(); diff --git a/src/rty.rs b/src/rty.rs index ce6ef5e..1636dfa 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -55,7 +55,7 @@ mod subtyping; pub use subtyping::{relate_sub_closed_type, ClauseScope, Subtyping}; mod params; -pub use params::{RefinedTypeArgs, TypeArgs, TypeParamIdx, TypeParamSubst}; +pub use params::{RefinedTypeArgs, TypeParamIdx, TypeParamSubst}; rustc_index::newtype_index! { /// An index representing function parameter. From daba7b7b158088e3c3a18d6fbca0237c416943ad Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 14 Dec 2025 13:43:08 +0900 Subject: [PATCH 32/75] Remove ParamTypeMapper and stop relaying TypeBuilder --- src/analyze.rs | 8 +++----- src/analyze/basic_block.rs | 5 ----- src/analyze/crate_.rs | 2 -- src/analyze/local_def.rs | 6 ------ src/refine/template.rs | 39 +------------------------------------- 5 files changed, 4 insertions(+), 56 deletions(-) diff --git a/src/analyze.rs b/src/analyze.rs index dc06cbe..fff56b4 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -243,10 +243,10 @@ impl<'tcx> Analyzer<'tcx> { def_id: DefId, generic_args: mir_ty::GenericArgsRef<'tcx>, ) -> Option { - let type_builder = TypeBuilder::new(self.tcx, def_id); - let deferred_ty = match self.defs.get(&def_id)? { DefTy::Concrete(rty) => { + let type_builder = TypeBuilder::new(self.tcx, def_id); + let mut def_ty = rty.clone(); def_ty.instantiate_ty_params( generic_args @@ -266,9 +266,7 @@ impl<'tcx> Analyzer<'tcx> { } let mut analyzer = self.local_def_analyzer(def_id.as_local()?); - analyzer - .type_builder(type_builder) - .generic_args(generic_args); + analyzer.generic_args(generic_args); let expected = analyzer.expected_ty(); deferred_ty_cache diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 5719406..81f7705 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -983,11 +983,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self } - pub fn type_builder(&mut self, type_builder: TypeBuilder<'tcx>) -> &mut Self { - self.type_builder = type_builder; - self - } - pub fn run(&mut self, expected: &BasicBlockType) { let span = tracing::info_span!("bb", bb = ?self.basic_block); let _guard = span.enter(); diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 218fecc..d6e343d 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -76,7 +76,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { // check polymorphic function def by replacing type params with some opaque type // (and this is no-op if the function is mono) - let type_builder = TypeBuilder::new(self.tcx, local_def_id.to_def_id()); let mut expected = expected.clone(); let subst = rty::TypeParamSubst::new( expected @@ -89,7 +88,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let generic_args = self.placeholder_generic_args(*local_def_id); self.ctx .local_def_analyzer(*local_def_id) - .type_builder(type_builder) .generic_args(generic_args) .run(&expected); } diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index b13cc63..50a4397 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -533,7 +533,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .basic_block_analyzer(self.local_def_id, bb) .body(self.body.clone()) .drop_points(drop_points) - .type_builder(self.type_builder.clone()) .run(&rty); } } @@ -650,11 +649,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } } - pub fn type_builder(&mut self, type_builder: TypeBuilder<'tcx>) -> &mut Self { - self.type_builder = type_builder; - self - } - pub fn generic_args(&mut self, generic_args: mir_ty::GenericArgsRef<'tcx>) -> &mut Self { self.body = mir_ty::EarlyBinder::bind(self.body.clone()).instantiate(self.tcx, generic_args); diff --git a/src/refine/template.rs b/src/refine/template.rs index c859419..446c450 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -59,19 +59,6 @@ where } } -trait ParamTypeMapper { - fn map_param_ty(&self, ty: rty::ParamType) -> rty::Type; -} - -impl ParamTypeMapper for F -where - F: Fn(rty::ParamType) -> rty::Type, -{ - fn map_param_ty(&self, ty: rty::ParamType) -> rty::Type { - self(ty) - } -} - /// Translates [`mir_ty::Ty`] to [`rty::Type`]. /// /// This struct implements a translation from Rust MIR types to Thrust types. @@ -87,9 +74,6 @@ pub struct TypeBuilder<'tcx> { /// mapped when we translate a [`mir_ty::ParamTy`] to [`rty::ParamType`]. /// See [`rty::TypeParamIdx`] for more details. param_idx_mapping: HashMap, - /// Optionally also want to further map rty::ParamType to other rty::Type before generating - /// templates. This is no-op by default. - param_type_mapper: std::rc::Rc, } impl<'tcx> TypeBuilder<'tcx> { @@ -109,25 +93,15 @@ impl<'tcx> TypeBuilder<'tcx> { Self { tcx, param_idx_mapping, - param_type_mapper: std::rc::Rc::new(|ty: rty::ParamType| ty.into()), } } - pub fn with_param_mapper(mut self, mapper: F) -> Self - where - F: Fn(rty::ParamType) -> rty::Type + 'static, - { - self.param_type_mapper = std::rc::Rc::new(mapper); - self - } - fn translate_param_type(&self, ty: &mir_ty::ParamTy) -> rty::Type { let index = *self .param_idx_mapping .get(&ty.index) .expect("unknown type param idx"); - let param_ty = rty::ParamType::new(index); - self.param_type_mapper.map_param_ty(param_ty) + rty::ParamType::new(index).into() } // TODO: consolidate two impls @@ -400,17 +374,6 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { self.ret_rty = Some(rty); self } - - pub fn would_contain_template(&self) -> bool { - if self.param_tys.is_empty() { - return self.ret_rty.is_none(); - } - - let last_param_idx = rty::FunctionParamIdx::from(self.param_tys.len() - 1); - let param_annotated = - self.param_refinement.is_some() || self.param_rtys.contains_key(&last_param_idx); - self.ret_rty.is_none() || !param_annotated - } } impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> From ed03484569c1bdfd05e34880b06948291d8bc332 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 14 Dec 2025 17:42:05 +0900 Subject: [PATCH 33/75] Revert "Implement Eq and Hash for Type" This reverts commit dd67486c40b2a2337c44afc5f4a1b2f4a1a6df98. --- src/chc.rs | 8 ++++---- src/rty.rs | 20 ++++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/chc.rs b/src/chc.rs index 6f5db32..8a3309f 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -389,7 +389,7 @@ impl Function { } /// A logical term. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub enum Term { Null, Var(V), @@ -984,7 +984,7 @@ impl Pred { } /// An atom is a predicate applied to a list of terms. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct Atom { pub pred: Pred, pub args: Vec>, @@ -1077,7 +1077,7 @@ impl Atom { /// While it allows arbitrary [`Atom`] in its `Atom` variant, we only expect atoms with known /// predicates (i.e., predicates other than `Pred::Var`) to appear in formulas. It is our TODO to /// enforce this restriction statically. Also see the definition of [`Body`]. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub enum Formula { Atom(Atom), Not(Box>), @@ -1296,7 +1296,7 @@ impl Formula { } /// The body part of a clause, consisting of atoms and a formula. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct Body { pub atoms: Vec>, /// NOTE: This doesn't contain predicate variables. Also see [`Formula`]. diff --git a/src/rty.rs b/src/rty.rs index 1636dfa..64b8dfb 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -88,7 +88,7 @@ where /// In Thrust, function types are closed. Because of that, function types, thus its parameters and /// return type only refer to the parameters of the function itself using [`FunctionParamIdx`] and /// do not accept other type of variables from the environment. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct FunctionType { pub params: IndexVec>, pub ret: Box>, @@ -156,7 +156,7 @@ impl FunctionType { } /// The kind of a reference, which is either mutable or immutable. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RefKind { Mut, Immut, @@ -181,7 +181,7 @@ where } /// The kind of a pointer, which is either a reference or an owned pointer. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PointerKind { Ref(RefKind), Own, @@ -221,7 +221,7 @@ impl PointerKind { } /// A pointer type. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct PointerType { pub kind: PointerKind, pub elem: Box>, @@ -334,7 +334,7 @@ impl PointerType { /// Note that the current implementation uses tuples to represent structs. See /// implementation in `crate::refine::template` module for details. /// It is our TODO to improve the struct representation. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct TupleType { pub elems: Vec>, } @@ -458,7 +458,7 @@ impl EnumDatatypeDef { /// An enum type. /// /// An enum type includes its type arguments and the argument types can refer to outer variables `T`. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct EnumType { pub symbol: chc::DatatypeSymbol, pub args: IndexVec>, @@ -560,7 +560,7 @@ impl EnumType { } /// A type parameter. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct ParamType { pub idx: TypeParamIdx, } @@ -589,7 +589,7 @@ impl ParamType { } /// An underlying type of a refinement type. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub enum Type { Int, Bool, @@ -995,7 +995,7 @@ impl ShiftExistential for RefinedTypeVar { /// A formula, potentially equipped with an existential quantifier. /// /// Note: This is not to be confused with [`crate::chc::Formula`] in the [`crate::chc`] module, which is a different notion. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct Formula { pub existentials: IndexVec, pub body: chc::Body, @@ -1236,7 +1236,7 @@ impl Instantiator { } /// A refinement type. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] pub struct RefinedType { pub ty: Type, pub refinement: Refinement, From 4c9e5d81b70ce317c6255dcb57395ae8bc32fc25 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 23 Nov 2025 21:44:19 +0900 Subject: [PATCH 34/75] Resolve trait method DefId in type_call --- src/analyze/basic_block.rs | 25 +++++++++++++++++++------ tests/ui/fail/trait.rs | 22 ++++++++++++++++++++++ tests/ui/fail/trait_param.rs | 26 ++++++++++++++++++++++++++ tests/ui/pass/trait.rs | 22 ++++++++++++++++++++++ tests/ui/pass/trait_param.rs | 26 ++++++++++++++++++++++++++ 5 files changed, 115 insertions(+), 6 deletions(-) create mode 100644 tests/ui/fail/trait.rs create mode 100644 tests/ui/fail/trait_param.rs create mode 100644 tests/ui/pass/trait.rs create mode 100644 tests/ui/pass/trait_param.rs diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 8b21ca2..3038675 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -568,12 +568,25 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let ret = rty::RefinedType::new(rty::Type::unit(), ret_formula.into()); rty::FunctionType::new([param1, param2].into_iter().collect(), ret).into() } - Some((def_id, args)) => self - .ctx - .def_ty_with_args(def_id, args) - .expect("unknown def") - .ty - .vacuous(), + Some((def_id, args)) => { + let param_env = self.tcx.param_env(self.local_def_id); + let instance = + mir_ty::Instance::resolve(self.tcx, param_env, def_id, args).unwrap(); + let resolved_def_id = if let Some(instance) = instance { + instance.def_id() + } else { + def_id + }; + if def_id != resolved_def_id { + tracing::info!(?def_id, ?resolved_def_id, "resolve",); + } + + self.ctx + .def_ty_with_args(resolved_def_id, args) + .expect("unknown def") + .ty + .vacuous() + } _ => self.operand_type(func.clone()).ty, }; let expected_args: IndexVec<_, _> = args diff --git a/tests/ui/fail/trait.rs b/tests/ui/fail/trait.rs new file mode 100644 index 0000000..8cd3881 --- /dev/null +++ b/tests/ui/fail/trait.rs @@ -0,0 +1,22 @@ +//@error-in-other-file: Unsat + +trait BoolLike { + fn truthy(&self) -> bool; +} + +impl BoolLike for bool { + fn truthy(&self) -> bool { + *self + } +} + +impl BoolLike for i32 { + fn truthy(&self) -> bool { + *self != 0 + } +} + +fn main() { + assert!(1_i32.truthy()); + assert!(false.truthy()); +} diff --git a/tests/ui/fail/trait_param.rs b/tests/ui/fail/trait_param.rs new file mode 100644 index 0000000..b0d6081 --- /dev/null +++ b/tests/ui/fail/trait_param.rs @@ -0,0 +1,26 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +trait BoolLike { + fn truthy(&self) -> bool; +} + +impl BoolLike for bool { + fn truthy(&self) -> bool { + *self + } +} + +impl BoolLike for i32 { + fn truthy(&self) -> bool { + *self != 0 + } +} + +fn falsy(value: T) -> bool { + value.truthy() +} + +fn main() { + assert!(falsy(0_i32)); +} diff --git a/tests/ui/pass/trait.rs b/tests/ui/pass/trait.rs new file mode 100644 index 0000000..35dc65e --- /dev/null +++ b/tests/ui/pass/trait.rs @@ -0,0 +1,22 @@ +//@check-pass + +trait BoolLike { + fn truthy(&self) -> bool; +} + +impl BoolLike for bool { + fn truthy(&self) -> bool { + *self + } +} + +impl BoolLike for i32 { + fn truthy(&self) -> bool { + *self != 0 + } +} + +fn main() { + assert!(1_i32.truthy()); + assert!(true.truthy()); +} diff --git a/tests/ui/pass/trait_param.rs b/tests/ui/pass/trait_param.rs new file mode 100644 index 0000000..7cc4179 --- /dev/null +++ b/tests/ui/pass/trait_param.rs @@ -0,0 +1,26 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +trait BoolLike { + fn truthy(&self) -> bool; +} + +impl BoolLike for bool { + fn truthy(&self) -> bool { + *self + } +} + +impl BoolLike for i32 { + fn truthy(&self) -> bool { + *self != 0 + } +} + +fn falsy(value: T) -> bool { + !value.truthy() +} + +fn main() { + assert!(falsy(0_i32)); +} From b03b58c7b8d76e6a97c392e55ead37215fdf40f5 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 23 Nov 2025 21:45:39 +0900 Subject: [PATCH 35/75] Enable to type lifted closure functions --- src/analyze.rs | 12 +++++++++++- src/analyze/crate_.rs | 7 ++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/analyze.rs b/src/analyze.rs index fff56b4..e237641 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -341,7 +341,7 @@ impl<'tcx> Analyzer<'tcx> { /// Computes the signature of the local function. /// - /// This is a drop-in replacement of `self.tcx.fn_sig(local_def_id).instantiate_identity().skip_binder()`, + /// This works like `self.tcx.fn_sig(local_def_id).instantiate_identity().skip_binder()`, /// but extracts parameter and return types directly from the given `body` to obtain a signature that /// reflects potential type instantiations happened after `optimized_mir`. pub fn local_fn_sig_with_body( @@ -364,4 +364,14 @@ impl<'tcx> Analyzer<'tcx> { sig.abi, ) } + + /// Computes the signature of the local function. + /// + /// This works like `self.tcx.fn_sig(local_def_id).instantiate_identity().skip_binder()`, + /// but extracts parameter and return types directly from [`mir::Body`] to obtain a signature that + /// reflects the actual type of lifted closure functions. + pub fn local_fn_sig(&self, local_def_id: LocalDefId) -> mir_ty::FnSig<'tcx> { + let body = self.tcx.optimized_mir(local_def_id); + self.local_fn_sig_with_body(local_def_id, body) + } } diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index d6e343d..2a17b11 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -39,6 +39,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { #[tracing::instrument(skip(self), fields(def_id = %self.tcx.def_path_str(local_def_id)))] fn refine_fn_def(&mut self, local_def_id: LocalDefId) { + let sig = self.ctx.local_fn_sig(local_def_id); + let mut analyzer = self.ctx.local_def_analyzer(local_def_id); if analyzer.is_annotated_as_trusted() { @@ -46,11 +48,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.trusted.insert(local_def_id.to_def_id()); } - let sig = self - .tcx - .fn_sig(local_def_id) - .instantiate_identity() - .skip_binder(); use mir_ty::TypeVisitableExt as _; if sig.has_param() && !analyzer.is_fully_annotated() { self.ctx.register_deferred_def(local_def_id.to_def_id()); From 47bf162334beb8e6c613ad762c1b4a219c1b2ea5 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 23 Nov 2025 21:46:12 +0900 Subject: [PATCH 36/75] Enable to type closures --- src/refine/template.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/refine/template.rs b/src/refine/template.rs index 446c450..c7f480c 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -163,6 +163,7 @@ impl<'tcx> TypeBuilder<'tcx> { unimplemented!("unsupported ADT: {:?}", ty); } } + mir_ty::TyKind::Closure(_, args) => self.build(args.as_closure().tupled_upvars_ty()), kind => unimplemented!("unrefined_ty: {:?}", kind), } } @@ -282,6 +283,7 @@ where unimplemented!("unsupported ADT: {:?}", ty); } } + mir_ty::TyKind::Closure(_, args) => self.build(args.as_closure().tupled_upvars_ty()), kind => unimplemented!("ty: {:?}", kind), } } From 8c533e0ddfcba8b426d794e7849649c168509b78 Mon Sep 17 00:00:00 2001 From: coord_e Date: Mon, 24 Nov 2025 00:30:18 +0900 Subject: [PATCH 37/75] Add rty::FunctionAbi and attach it to rty::FunctionType --- src/refine/template.rs | 10 ++++++- src/rty.rs | 63 ++++++++++++++++++++++++++++++++++++------ 2 files changed, 64 insertions(+), 9 deletions(-) diff --git a/src/refine/template.rs b/src/refine/template.rs index c7f480c..b57dae5 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -184,6 +184,11 @@ impl<'tcx> TypeBuilder<'tcx> { registry: &'a mut R, sig: mir_ty::FnSig<'tcx>, ) -> FunctionTemplateTypeBuilder<'tcx, 'a, R> { + let abi = match sig.abi { + rustc_target::spec::abi::Abi::Rust => rty::FunctionAbi::Rust, + rustc_target::spec::abi::Abi::RustCall => rty::FunctionAbi::RustCall, + _ => unimplemented!("unsupported function ABI: {:?}", sig.abi), + }; FunctionTemplateTypeBuilder { inner: self.clone(), registry, @@ -199,6 +204,7 @@ impl<'tcx> TypeBuilder<'tcx> { param_rtys: Default::default(), param_refinement: None, ret_rty: None, + abi, } } } @@ -318,6 +324,7 @@ where param_rtys: Default::default(), param_refinement: None, ret_rty: None, + abi: Default::default(), } .build(); BasicBlockType { ty, locals } @@ -333,6 +340,7 @@ pub struct FunctionTemplateTypeBuilder<'tcx, 'a, R> { param_refinement: Option>, param_rtys: HashMap>, ret_rty: Option>, + abi: rty::FunctionAbi, } impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> { @@ -441,6 +449,6 @@ where .with_scope(&builder) .build_refined(self.ret_ty) }); - rty::FunctionType::new(param_rtys, ret_rty) + rty::FunctionType::new(param_rtys, ret_rty).with_abi(self.abi) } } diff --git a/src/rty.rs b/src/rty.rs index 64b8dfb..b780d44 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -83,6 +83,36 @@ where } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +pub enum FunctionAbi { + #[default] + Rust, + RustCall, +} + +impl std::fmt::Display for FunctionAbi { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str(self.name()) + } +} + +impl FunctionAbi { + pub fn name(&self) -> &'static str { + match self { + FunctionAbi::Rust => "rust", + FunctionAbi::RustCall => "rust-call", + } + } + + pub fn is_rust(&self) -> bool { + matches!(self, FunctionAbi::Rust) + } + + pub fn is_rust_call(&self) -> bool { + matches!(self, FunctionAbi::RustCall) + } +} + /// A function type. /// /// In Thrust, function types are closed. Because of that, function types, thus its parameters and @@ -92,6 +122,7 @@ where pub struct FunctionType { pub params: IndexVec>, pub ret: Box>, + pub abi: FunctionAbi, } impl<'a, 'b, D> Pretty<'a, D, termcolor::ColorSpec> for &'b FunctionType @@ -100,15 +131,25 @@ where D::Doc: Clone, { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, termcolor::ColorSpec> { + let abi = match self.abi { + FunctionAbi::Rust => allocator.nil(), + abi => allocator + .text("extern") + .append(allocator.space()) + .append(allocator.as_string(abi)) + .append(allocator.space()), + }; let separator = allocator.text(",").append(allocator.line()); - allocator - .intersperse(self.params.iter().map(|ty| ty.pretty(allocator)), separator) - .parens() - .append(allocator.space()) - .append(allocator.text("→")) - .append(allocator.line()) - .append(self.ret.pretty(allocator)) - .group() + abi.append( + allocator + .intersperse(self.params.iter().map(|ty| ty.pretty(allocator)), separator) + .parens(), + ) + .append(allocator.space()) + .append(allocator.text("→")) + .append(allocator.line()) + .append(self.ret.pretty(allocator)) + .group() } } @@ -120,9 +161,15 @@ impl FunctionType { FunctionType { params, ret: Box::new(ret), + abi: FunctionAbi::Rust, } } + pub fn with_abi(mut self, abi: FunctionAbi) -> Self { + self.abi = abi; + self + } + /// Because function types are always closed in Thrust, we can convert this into /// [`Type`]. pub fn into_closed_ty(self) -> Type { From 53aeadd245c0fb60a8be7ff82ecd22c5903e8943 Mon Sep 17 00:00:00 2001 From: coord_e Date: Mon, 24 Nov 2025 00:30:44 +0900 Subject: [PATCH 38/75] Enable to type closure calls --- src/analyze/basic_block.rs | 47 +++++++++++++++++++++++++++++++++--- src/chc.rs | 4 +++ src/rty.rs | 26 ++++++++++++++++++++ tests/ui/fail/closure_mut.rs | 12 +++++++++ tests/ui/pass/closure_mut.rs | 12 +++++++++ 5 files changed, 97 insertions(+), 4 deletions(-) create mode 100644 tests/ui/fail/closure_mut.rs create mode 100644 tests/ui/pass/closure_mut.rs diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 3038675..4c88566 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -84,16 +84,55 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { ) -> Vec { let mut clauses = Vec::new(); - if expected_args.is_empty() { - // elaboration: we need at least one predicate variable in parameter (see mir_function_ty_impl) - expected_args.push(rty::RefinedType::unrefined(rty::Type::unit()).vacuous()); - } tracing::debug!( got = %got.display(), expected = %crate::pretty::FunctionType::new(&expected_args, &expected_ret).display(), "fn_sub_type" ); + match got.abi { + rty::FunctionAbi::Rust => { + if expected_args.is_empty() { + // elaboration: we need at least one predicate variable in parameter (see mir_function_ty_impl) + expected_args.push(rty::RefinedType::unrefined(rty::Type::unit()).vacuous()); + } + } + rty::FunctionAbi::RustCall => { + // &Closure, { v: (own i32, own bool) | v = (<0>, ) } + // => + // &Closure, { v: i32 | (, _) = (<0>, ) }, { v: bool | (_, ) = (<0>, ) } + + let rty::RefinedType { ty, mut refinement } = + expected_args.pop().expect("rust-call last arg"); + let ty = ty.into_tuple().expect("rust-call last arg is tuple"); + let mut replacement_tuple = Vec::new(); // will be (, _) or (_, ) + for elem in &ty.elems { + let existential = refinement.existentials.push(elem.ty.to_sort()); + replacement_tuple.push(chc::Term::var(rty::RefinedTypeVar::Existential( + existential, + ))); + } + + for (i, elem) in ty.elems.into_iter().enumerate() { + // all tuple elements are boxed during the translation to rty::Type + let mut param_ty = elem.deref(); + param_ty + .refinement + .push_conj(refinement.clone().subst_value_var(|| { + let mut value_elems = replacement_tuple.clone(); + value_elems[i] = chc::Term::var(rty::RefinedTypeVar::Value).boxed(); + chc::Term::tuple(value_elems) + })); + expected_args.push(param_ty); + } + + tracing::info!( + expected = %crate::pretty::FunctionType::new(&expected_args, &expected_ret).display(), + "rust-call expanded", + ); + } + } + // TODO: check sty and length is equal let mut builder = self.env.build_clause(); for (param_idx, param_rty) in got.params.iter_enumerated() { diff --git a/src/chc.rs b/src/chc.rs index a4f70af..a5f046e 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -623,6 +623,10 @@ impl Term { Term::Mut(Box::new(t1), Box::new(t2)) } + pub fn boxed(self) -> Self { + Term::Box(Box::new(self)) + } + pub fn box_current(self) -> Self { Term::BoxCurrent(Box::new(self)) } diff --git a/src/rty.rs b/src/rty.rs index b780d44..34cfd6d 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -1351,6 +1351,32 @@ impl RefinedType { RefinedType { ty, refinement } } + /// Returns a dereferenced type of the immutable reference or owned pointer. + /// + /// e.g. `{ v: Box | φ } --> { v: T | φ[box v/v] }` + pub fn deref(self) -> Self { + let RefinedType { + ty, + refinement: outer_refinement, + } = self; + let inner_ty = ty.into_pointer().expect("invalid deref"); + if inner_ty.is_mut() { + // losing info about proph + panic!("invalid deref"); + } + let RefinedType { + ty: inner_ty, + refinement: mut inner_refinement, + } = *inner_ty.elem; + inner_refinement.push_conj( + outer_refinement.subst_value_var(|| chc::Term::var(RefinedTypeVar::Value).boxed()), + ); + RefinedType { + ty: inner_ty, + refinement: inner_refinement, + } + } + pub fn subst_var(self, mut f: F) -> RefinedType where F: FnMut(FV) -> chc::Term, diff --git a/tests/ui/fail/closure_mut.rs b/tests/ui/fail/closure_mut.rs new file mode 100644 index 0000000..0b0a8ea --- /dev/null +++ b/tests/ui/fail/closure_mut.rs @@ -0,0 +1,12 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +fn main() { + let mut x = 1; + let mut incr = |by: i32| { + x += by; + }; + incr(5); + incr(5); + assert!(x == 10); +} diff --git a/tests/ui/pass/closure_mut.rs b/tests/ui/pass/closure_mut.rs new file mode 100644 index 0000000..d50c074 --- /dev/null +++ b/tests/ui/pass/closure_mut.rs @@ -0,0 +1,12 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +fn main() { + let mut x = 1; + let mut incr = |by: i32| { + x += by; + }; + incr(5); + incr(5); + assert!(x == 11); +} From c5abc9d574660ade41fa0fa45a346c5e3c1ee40a Mon Sep 17 00:00:00 2001 From: coord_e Date: Mon, 24 Nov 2025 01:12:39 +0900 Subject: [PATCH 39/75] Enable to type unit via ZeroSized --- src/analyze/basic_block.rs | 6 ++++++ tests/ui/fail/closure_mut_0.rs | 12 ++++++++++++ tests/ui/fail/closure_param.rs | 14 ++++++++++++++ tests/ui/pass/closure_mut_0.rs | 11 +++++++++++ tests/ui/pass/closure_param.rs | 14 ++++++++++++++ 5 files changed, 57 insertions(+) create mode 100644 tests/ui/fail/closure_mut_0.rs create mode 100644 tests/ui/fail/closure_param.rs create mode 100644 tests/ui/pass/closure_mut_0.rs create mode 100644 tests/ui/pass/closure_param.rs diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 4c88566..372e481 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -214,6 +214,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { chc::Term::bool(val.try_to_bool().unwrap()), ) } + (mir_ty::TyKind::Tuple(tys), _) if tys.is_empty() => { + PlaceType::with_ty_and_term(rty::Type::unit(), chc::Term::tuple(vec![])) + } + (mir_ty::TyKind::Closure(_, args), _) if args.as_closure().upvar_tys().is_empty() => { + PlaceType::with_ty_and_term(rty::Type::unit(), chc::Term::tuple(vec![])) + } ( mir_ty::TyKind::Ref(_, elem, Mutability::Not), ConstValue::Scalar(Scalar::Ptr(ptr, _)), diff --git a/tests/ui/fail/closure_mut_0.rs b/tests/ui/fail/closure_mut_0.rs new file mode 100644 index 0000000..a08196d --- /dev/null +++ b/tests/ui/fail/closure_mut_0.rs @@ -0,0 +1,12 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +fn main() { + let mut x = 1; + x += 1; + let mut incr = || { + x += 1; + }; + incr(); + assert!(x == 2); +} diff --git a/tests/ui/fail/closure_param.rs b/tests/ui/fail/closure_param.rs new file mode 100644 index 0000000..bd9dc23 --- /dev/null +++ b/tests/ui/fail/closure_param.rs @@ -0,0 +1,14 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +fn take_fn T>(f: F) -> T { + f(42) +} + +fn main() { + let y = take_fn(|x| { + assert!(x == 41); + x + 1 + }); + assert!(y == 42); +} diff --git a/tests/ui/pass/closure_mut_0.rs b/tests/ui/pass/closure_mut_0.rs new file mode 100644 index 0000000..5f39a0f --- /dev/null +++ b/tests/ui/pass/closure_mut_0.rs @@ -0,0 +1,11 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +fn main() { + let mut x = 1; + let mut incr = || { + x += 1; + }; + incr(); + assert!(x == 2); +} diff --git a/tests/ui/pass/closure_param.rs b/tests/ui/pass/closure_param.rs new file mode 100644 index 0000000..0188ca9 --- /dev/null +++ b/tests/ui/pass/closure_param.rs @@ -0,0 +1,14 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +fn take_fn T>(f: F) -> T { + f(41) +} + +fn main() { + let y = take_fn(|x| { + assert!(x == 41); + x + 1 + }); + assert!(y == 42); +} From c6d49c81776636185ebf8fcdee14546949adbab0 Mon Sep 17 00:00:00 2001 From: coord_e Date: Mon, 24 Nov 2025 01:31:18 +0900 Subject: [PATCH 40/75] Ensure bb fn type params are sorted by locals --- src/refine/template.rs | 5 ++++- tests/ui/fail/closure_no_capture.rs | 9 +++++++++ tests/ui/pass/closure_no_capture.rs | 9 +++++++++ 3 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 tests/ui/fail/closure_no_capture.rs create mode 100644 tests/ui/pass/closure_no_capture.rs diff --git a/src/refine/template.rs b/src/refine/template.rs index b57dae5..da8e4b6 100644 --- a/src/refine/template.rs +++ b/src/refine/template.rs @@ -309,9 +309,12 @@ where where I: IntoIterator)>, { + // this is necessary for local_def::Analyzer::elaborate_unused_args + let mut live_locals: Vec<_> = live_locals.into_iter().collect(); + live_locals.sort_by_key(|(local, _)| *local); + let mut locals = IndexVec::::new(); let mut tys = Vec::new(); - // TODO: avoid two iteration and assumption of FunctionParamIdx match between locals and ty for (local, ty) in live_locals { locals.push((local, ty.mutbl)); tys.push(ty); diff --git a/tests/ui/fail/closure_no_capture.rs b/tests/ui/fail/closure_no_capture.rs new file mode 100644 index 0000000..edabc34 --- /dev/null +++ b/tests/ui/fail/closure_no_capture.rs @@ -0,0 +1,9 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +fn main() { + let incr = |x| { + x + 1 + }; + assert!(incr(2) == 2); +} diff --git a/tests/ui/pass/closure_no_capture.rs b/tests/ui/pass/closure_no_capture.rs new file mode 100644 index 0000000..02f38aa --- /dev/null +++ b/tests/ui/pass/closure_no_capture.rs @@ -0,0 +1,9 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +fn main() { + let incr = |x| { + x + 1 + }; + assert!(incr(1) == 2); +} From 4ee1d9aeacd899f58cfe3b8b92ccdef7587528c7 Mon Sep 17 00:00:00 2001 From: coord_e Date: Tue, 30 Dec 2025 11:48:56 +0900 Subject: [PATCH 41/75] Enable to add guard to Atom --- src/chc.rs | 77 ++++++++++++++++++++++++++++++++++++++-------- src/chc/smtlib2.rs | 7 +++++ src/chc/unbox.rs | 5 +-- 3 files changed, 74 insertions(+), 15 deletions(-) diff --git a/src/chc.rs b/src/chc.rs index a5f046e..a2c4b30 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -997,6 +997,12 @@ impl Pred { /// An atom is a predicate applied to a list of terms. #[derive(Debug, Clone)] pub struct Atom { + /// With `guard`, this represents `guard => pred(args)`. + /// + /// As long as there is no pvar in the `guard`, it forms a valid CHC. However, in that case, + /// it becomes odd to call this an `Atom`... It is our TODO to clean this up by either + /// getting rid of the `guard` or renaming `Atom`. + pub guard: Option>>, pub pred: Pred, pub args: Vec>, } @@ -1008,7 +1014,12 @@ where D::Doc: Clone, { fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, termcolor::ColorSpec> { - if self.pred.is_infix() { + let guard = if let Some(guard) = &self.guard { + guard.pretty(allocator).append(allocator.text(" ⇒")) + } else { + allocator.nil() + }; + let atom = if self.pred.is_infix() { self.args[0] .pretty_atom(allocator) .append(allocator.line()) @@ -1027,35 +1038,50 @@ where } else { p.append(allocator.line()).append(inner.nest(2)).group() } - } + }; + guard.append(allocator.line()).append(atom).group() } } impl Atom { pub fn new(pred: Pred, args: Vec>) -> Self { - Atom { pred, args } + Atom { + guard: None, + pred, + args, + } } - pub fn top() -> Self { + pub fn with_guard(guard: Formula, pred: Pred, args: Vec>) -> Self { Atom { - pred: KnownPred::TOP.into(), - args: vec![], + guard: Some(Box::new(guard)), + pred, + args, } } + pub fn top() -> Self { + Atom::new(KnownPred::TOP.into(), vec![]) + } + pub fn bottom() -> Self { - Atom { - pred: KnownPred::BOTTOM.into(), - args: vec![], - } + Atom::new(KnownPred::BOTTOM.into(), vec![]) } pub fn is_top(&self) -> bool { - self.pred.is_top() + if let Some(guard) = &self.guard { + guard.is_bottom() || self.pred.is_top() + } else { + self.pred.is_top() + } } pub fn is_bottom(&self) -> bool { - self.pred.is_bottom() + if let Some(guard) = &self.guard { + guard.is_top() && self.pred.is_bottom() + } else { + self.pred.is_bottom() + } } pub fn subst_var(self, mut f: F) -> Atom @@ -1063,6 +1089,7 @@ impl Atom { F: FnMut(V) -> Term, { Atom { + guard: self.guard.map(|fo| Box::new(fo.subst_var(&mut f))), pred: self.pred, args: self.args.into_iter().map(|t| t.subst_var(&mut f)).collect(), } @@ -1073,13 +1100,37 @@ impl Atom { F: FnMut(V) -> W, { Atom { + guard: self.guard.map(|fo| Box::new(fo.map_var(&mut f))), pred: self.pred, args: self.args.into_iter().map(|t| t.map_var(&mut f)).collect(), } } pub fn fv(&self) -> impl Iterator { - self.args.iter().flat_map(|t| t.fv()) + let guard_fvs: Box> = if let Some(guard) = &self.guard { + Box::new(guard.fv()) + } else { + Box::new(std::iter::empty()) + }; + self.args.iter().flat_map(|t| t.fv()).chain(guard_fvs) + } + + pub fn guarded(self, new_guard: Formula) -> Atom { + let Atom { + guard: self_guard, + pred, + args, + } = self; + let guard = if let Some(self_guard) = self_guard { + self_guard.and(new_guard) + } else { + new_guard + }; + Atom { + guard: Some(Box::new(guard)), + pred, + args, + } } } diff --git a/src/chc/smtlib2.rs b/src/chc/smtlib2.rs index 3cef75e..167d108 100644 --- a/src/chc/smtlib2.rs +++ b/src/chc/smtlib2.rs @@ -223,6 +223,10 @@ pub struct Atom<'ctx, 'a> { impl<'ctx, 'a> std::fmt::Display for Atom<'ctx, 'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(guard) = &self.inner.guard { + let guard = Formula::new(self.ctx, self.clause, guard); + write!(f, "(=> {} ", guard)?; + } if self.inner.pred.is_negative() { write!(f, "(not ")?; } @@ -244,6 +248,9 @@ impl<'ctx, 'a> std::fmt::Display for Atom<'ctx, 'a> { if self.inner.pred.is_negative() { write!(f, ")")?; } + if self.inner.guard.is_some() { + write!(f, ")")?; + } Ok(()) } } diff --git a/src/chc/unbox.rs b/src/chc/unbox.rs index ffc4600..667ab39 100644 --- a/src/chc/unbox.rs +++ b/src/chc/unbox.rs @@ -24,9 +24,10 @@ fn unbox_term(term: Term) -> Term { } fn unbox_atom(atom: Atom) -> Atom { - let Atom { pred, args } = atom; + let Atom { guard, pred, args } = atom; + let guard = guard.map(|fo| Box::new(unbox_formula(*fo))); let args = args.into_iter().map(unbox_term).collect(); - Atom { pred, args } + Atom { guard, pred, args } } fn unbox_datatype_sort(sort: DatatypeSort) -> DatatypeSort { From 933bd100cd3b0ee9e1452aa6b688927060a9a2bb Mon Sep 17 00:00:00 2001 From: coord_e Date: Tue, 30 Dec 2025 12:46:28 +0900 Subject: [PATCH 42/75] Add guard of discriminant when expanding variant field types --- src/chc.rs | 16 +++++++++++++++ src/refine/env.rs | 18 +++++++++++------ src/rty.rs | 22 +++++++++++++++++++++ tests/ui/fail/adt_variant_without_params.rs | 13 ++++++++++++ tests/ui/pass/adt_variant_without_params.rs | 13 ++++++++++++ 5 files changed, 76 insertions(+), 6 deletions(-) create mode 100644 tests/ui/fail/adt_variant_without_params.rs create mode 100644 tests/ui/pass/adt_variant_without_params.rs diff --git a/src/chc.rs b/src/chc.rs index a2c4b30..5543de4 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -1507,6 +1507,22 @@ impl Body { } } +impl Body +where + V: Var, +{ + pub fn guarded(self, guard: Formula) -> Body { + let Body { atoms, formula } = self; + Body { + atoms: atoms + .into_iter() + .map(|a| a.guarded(guard.clone())) + .collect(), + formula: guard.not().or(formula), + } + } +} + impl<'a, 'b, D, V> Pretty<'a, D, termcolor::ColorSpec> for &'b Body where V: Var, diff --git a/src/refine/env.rs b/src/refine/env.rs index f18dadb..3663d81 100644 --- a/src/refine/env.rs +++ b/src/refine/env.rs @@ -730,6 +730,12 @@ impl Env { let def = self.enum_defs[&ty.symbol].clone(); let matcher_pred = chc::MatcherPred::new(ty.symbol.clone(), ty.arg_sorts()); + let discr_var = self + .temp_vars + .push(TempVarBinding::Type(rty::RefinedType::unrefined( + rty::Type::int(), + ))); + let mut variants = IndexVec::new(); for variant_def in &def.variants { let mut fields = IndexVec::new(); @@ -738,7 +744,12 @@ impl Env { fields.push(x); let mut field_ty = rty::RefinedType::unrefined(field_ty.clone().vacuous()); field_ty.instantiate_ty_params(ty.args.clone()); - self.bind_impl(x.into(), field_ty.boxed(), depth); + let guarded_field_ty = field_ty.guarded( + chc::Term::var(discr_var.into()) + .equal_to(chc::Term::int(variant_def.discr as i64)) + .into(), + ); + self.bind_impl(x.into(), guarded_field_ty.boxed(), depth); } variants.push(FlowBindingVariant { fields }); } @@ -773,11 +784,6 @@ impl Env { assumption .body .push_conj(chc::Atom::new(matcher_pred.into(), pred_args)); - let discr_var = self - .temp_vars - .push(TempVarBinding::Type(rty::RefinedType::unrefined( - rty::Type::int(), - ))); assumption .body .push_conj( diff --git a/src/rty.rs b/src/rty.rs index 34cfd6d..4c3d2a1 100644 --- a/src/rty.rs +++ b/src/rty.rs @@ -1172,6 +1172,19 @@ impl Formula { } } +impl Formula +where + V: chc::Var, +{ + pub fn guarded(self, guard: chc::Formula) -> Self { + let Formula { existentials, body } = self; + Formula { + existentials, + body: body.guarded(guard), + } + } +} + impl Formula where V: ShiftExistential, @@ -1351,6 +1364,15 @@ impl RefinedType { RefinedType { ty, refinement } } + pub fn guarded(self, guard: chc::Formula) -> Self + where + FV: chc::Var, + { + let RefinedType { ty, refinement } = self; + let refinement = refinement.guarded(guard.map_var(RefinedTypeVar::Free)); + RefinedType { ty, refinement } + } + /// Returns a dereferenced type of the immutable reference or owned pointer. /// /// e.g. `{ v: Box | φ } --> { v: T | φ[box v/v] }` diff --git a/tests/ui/fail/adt_variant_without_params.rs b/tests/ui/fail/adt_variant_without_params.rs new file mode 100644 index 0000000..bd8c500 --- /dev/null +++ b/tests/ui/fail/adt_variant_without_params.rs @@ -0,0 +1,13 @@ +//@error-in-other-file: Unsat + +enum X { + None1, + None2, + Some(T), +} + +fn main() { + let mut opt: X = X::None1; + opt = X::None2; + assert!(matches!(opt, X::None1)); +} diff --git a/tests/ui/pass/adt_variant_without_params.rs b/tests/ui/pass/adt_variant_without_params.rs new file mode 100644 index 0000000..15bd1e4 --- /dev/null +++ b/tests/ui/pass/adt_variant_without_params.rs @@ -0,0 +1,13 @@ +//@check-pass + +enum X { + None1, + None2, + Some(T), +} + +fn main() { + let mut opt: X = X::None1; + opt = X::None2; + assert!(matches!(opt, X::None2)); +} From 70c87053ab35e948151a323cfb6fbfe1ade7c756 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 28 Dec 2025 15:03:38 +0900 Subject: [PATCH 43/75] Register enum_defs on-demand --- src/analyze.rs | 69 ++++++++++++++++++++++++++++++-------- src/analyze/basic_block.rs | 21 ++++++------ src/analyze/crate_.rs | 46 +------------------------ 3 files changed, 67 insertions(+), 69 deletions(-) diff --git a/src/analyze.rs b/src/analyze.rs index e237641..53bfadb 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -11,6 +11,7 @@ use std::collections::HashMap; use std::rc::Rc; use rustc_hir::lang_items::LangItem; +use rustc_index::IndexVec; use rustc_middle::mir::{self, BasicBlock, Local}; use rustc_middle::ty::{self as mir_ty, TyCtxt}; use rustc_span::def_id::{DefId, LocalDefId}; @@ -174,7 +175,58 @@ impl<'tcx> Analyzer<'tcx> { } } - pub fn register_enum_def(&mut self, def_id: DefId, enum_def: rty::EnumDatatypeDef) { + fn build_enum_def(&self, def_id: DefId) -> rty::EnumDatatypeDef { + let adt = self.tcx.adt_def(def_id); + + let name = refine::datatype_symbol(self.tcx, def_id); + let variants: IndexVec<_, _> = adt + .variants() + .iter() + .map(|variant| { + let name = refine::datatype_symbol(self.tcx, variant.def_id); + // TODO: consider using TyCtxt::tag_for_variant + let discr = resolve_discr(self.tcx, variant.discr); + let field_tys = variant + .fields + .iter() + .map(|field| { + let field_ty = self.tcx.type_of(field.did).instantiate_identity(); + TypeBuilder::new(self.tcx, def_id).build(field_ty) + }) + .collect(); + rty::EnumVariantDef { + name, + discr, + field_tys, + } + }) + .collect(); + + let generics = self.tcx.generics_of(def_id); + let ty_params = (0..generics.count()) + .filter(|idx| { + matches!( + generics.param_at(*idx, self.tcx).kind, + mir_ty::GenericParamDefKind::Type { .. } + ) + }) + .count(); + tracing::debug!(?def_id, ?name, ?ty_params, "ty_params count"); + + rty::EnumDatatypeDef { + name, + ty_params, + variants, + } + } + + pub fn get_or_register_enum_def(&self, def_id: DefId) -> rty::EnumDatatypeDef { + let mut enum_defs = self.enum_defs.borrow_mut(); + if let Some(enum_def) = enum_defs.get(def_id) { + return enum_def.clone(); + } + + let enum_def = self.build_enum_def(def_id); tracing::debug!(def_id = ?def_id, enum_def = ?enum_def, "register_enum_def"); let ctors = enum_def .variants @@ -199,21 +251,10 @@ impl<'tcx> Analyzer<'tcx> { params: enum_def.ty_params, ctors, }; - self.enum_defs.borrow_mut().insert(def_id, enum_def); + enum_defs.insert(def_id, enum_def.clone()); self.system.borrow_mut().datatypes.push(datatype); - } - pub fn find_enum_variant( - &self, - ty_sym: &chc::DatatypeSymbol, - v_sym: &chc::DatatypeSymbol, - ) -> Option { - self.enum_defs - .borrow() - .iter() - .find(|(_, d)| &d.name == ty_sym) - .and_then(|(_, d)| d.variants.iter().find(|v| &v.name == v_sym)) - .cloned() + enum_def } pub fn register_def(&mut self, def_id: DefId, rty: rty::RefinedType) { diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 372e481..c381e1d 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -350,16 +350,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .map(|operand| self.operand_type(operand).boxed()) .collect(); match *kind { - mir::AggregateKind::Adt(did, variant_id, args, _, _) + mir::AggregateKind::Adt(did, variant_idx, args, _, _) if self.tcx.def_kind(did) == DefKind::Enum => { - let adt = self.tcx.adt_def(did); - let ty_sym = refine::datatype_symbol(self.tcx, did); - let variant = adt.variant(variant_id); - let v_sym = refine::datatype_symbol(self.tcx, variant.def_id); - - let enum_variant_def = self.ctx.find_enum_variant(&ty_sym, &v_sym).unwrap(); - let variant_rtys = enum_variant_def + let enum_def = self.ctx.get_or_register_enum_def(did); + let variant_def = &enum_def.variants[variant_idx]; + let variant_rtys = variant_def .field_tys .clone() .into_iter() @@ -386,7 +382,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let sort_args: Vec<_> = rty_args.iter().map(|rty| rty.ty.to_sort()).collect(); - let ty = rty::EnumType::new(ty_sym.clone(), rty_args).into(); + let ty = rty::EnumType::new(enum_def.name.clone(), rty_args).into(); let mut builder = PlaceTypeBuilder::default(); let mut field_terms = Vec::new(); @@ -396,7 +392,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } builder.build( ty, - chc::Term::datatype_ctor(ty_sym, sort_args, v_sym, field_terms), + chc::Term::datatype_ctor( + enum_def.name, + sort_args, + variant_def.name.clone(), + field_terms, + ), ) } _ => PlaceType::tuple(field_tys), diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 2a17b11..5467630 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -3,13 +3,11 @@ use std::collections::HashSet; use rustc_hir::def::DefKind; -use rustc_index::IndexVec; use rustc_middle::ty::{self as mir_ty, TyCtxt}; use rustc_span::def_id::{DefId, LocalDefId}; use crate::analyze; use crate::chc; -use crate::refine::{self, TypeBuilder}; use crate::rty::{self, ClauseBuilderExt as _}; /// An implementation of local crate analysis. @@ -173,49 +171,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let DefKind::Enum = self.tcx.def_kind(local_def_id) else { continue; }; - let adt = self.tcx.adt_def(local_def_id); - - let name = refine::datatype_symbol(self.tcx, local_def_id.to_def_id()); - let variants: IndexVec<_, _> = adt - .variants() - .iter() - .map(|variant| { - let name = refine::datatype_symbol(self.tcx, variant.def_id); - // TODO: consider using TyCtxt::tag_for_variant - let discr = analyze::resolve_discr(self.tcx, variant.discr); - let field_tys = variant - .fields - .iter() - .map(|field| { - let field_ty = self.tcx.type_of(field.did).instantiate_identity(); - TypeBuilder::new(self.tcx, local_def_id.to_def_id()).build(field_ty) - }) - .collect(); - rty::EnumVariantDef { - name, - discr, - field_tys, - } - }) - .collect(); - - let generics = self.tcx.generics_of(local_def_id); - let ty_params = (0..generics.count()) - .filter(|idx| { - matches!( - generics.param_at(*idx, self.tcx).kind, - mir_ty::GenericParamDefKind::Type { .. } - ) - }) - .count(); - tracing::debug!(?local_def_id, ?name, ?ty_params, "ty_params count"); - - let def = rty::EnumDatatypeDef { - name, - ty_params, - variants, - }; - self.ctx.register_enum_def(local_def_id.to_def_id(), def); + self.ctx.register_enum_def(local_def_id.to_def_id()); } } } From 50b3f239ab3f91916522d809f1e6c2572a76a759 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 28 Dec 2025 15:40:44 +0900 Subject: [PATCH 44/75] Use latest enum_defs from Env --- src/analyze.rs | 39 ++++++++++++++++++++------- src/analyze/basic_block.rs | 9 ++----- src/refine.rs | 4 ++- src/refine/env.rs | 53 +++++++++++++++++++++++++++---------- tests/ui/fail/option_mut.rs | 10 +++++++ tests/ui/pass/option_mut.rs | 10 +++++++ 6 files changed, 94 insertions(+), 31 deletions(-) create mode 100644 tests/ui/fail/option_mut.rs create mode 100644 tests/ui/pass/option_mut.rs diff --git a/src/analyze.rs b/src/analyze.rs index 53bfadb..1f4853f 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -115,6 +115,33 @@ enum DefTy<'tcx> { Deferred(DeferredDefTy<'tcx>), } +#[derive(Debug, Clone, Default)] +pub struct EnumDefs { + defs: HashMap, +} + +impl EnumDefs { + pub fn find_by_name(&self, name: &chc::DatatypeSymbol) -> Option<&rty::EnumDatatypeDef> { + self.defs.values().find(|def| &def.name == name) + } + + pub fn get(&self, def_id: DefId) -> Option<&rty::EnumDatatypeDef> { + self.defs.get(&def_id) + } + + pub fn insert(&mut self, def_id: DefId, def: rty::EnumDatatypeDef) { + self.defs.insert(def_id, def); + } +} + +impl refine::EnumDefProvider for Rc> { + fn enum_def(&self, name: &chc::DatatypeSymbol) -> rty::EnumDatatypeDef { + self.borrow().find_by_name(name).unwrap().clone() + } +} + +pub type Env = refine::Env>>; + #[derive(Clone)] pub struct Analyzer<'tcx> { tcx: TyCtxt<'tcx>, @@ -132,7 +159,7 @@ pub struct Analyzer<'tcx> { basic_blocks: HashMap>, def_ids: did_cache::DefIdCache<'tcx>, - enum_defs: Rc>>, + enum_defs: Rc>, } impl<'tcx> crate::refine::TemplateRegistry for Analyzer<'tcx> { @@ -345,14 +372,8 @@ impl<'tcx> Analyzer<'tcx> { self.register_def(panic_def_id, rty::RefinedType::unrefined(panic_ty.into())); } - pub fn new_env(&self) -> refine::Env { - let defs = self - .enum_defs - .borrow() - .values() - .map(|def| (def.name.clone(), def.clone())) - .collect(); - refine::Env::new(defs) + pub fn new_env(&self) -> Env { + refine::Env::new(Rc::clone(&self.enum_defs)) } pub fn crate_analyzer(&mut self) -> crate_::Analyzer<'tcx, '_> { diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index c381e1d..6053489 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -34,7 +34,7 @@ pub struct Analyzer<'tcx, 'ctx> { body: Cow<'tcx, Body<'tcx>>, type_builder: TypeBuilder<'tcx>, - env: Env, + env: analyze::Env, local_decls: IndexVec>, // TODO: remove this prophecy_vars: HashMap, @@ -968,7 +968,7 @@ impl UnbindAtoms { self.existentials.extend(var_ty.existentials); } - pub fn unbind(mut self, env: &Env, ty: rty::RefinedType) -> rty::RefinedType { + pub fn unbind(mut self, env: &analyze::Env, ty: rty::RefinedType) -> rty::RefinedType { let rty::RefinedType { ty: src_ty, refinement, @@ -1137,11 +1137,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self } - pub fn env(&mut self, env: Env) -> &mut Self { - self.env = env; - self - } - pub fn run(&mut self, expected: &BasicBlockType) { let span = tracing::info_span!("bb", bb = ?self.basic_block); let _guard = span.enter(); diff --git a/src/refine.rs b/src/refine.rs index 4736b39..7ef3886 100644 --- a/src/refine.rs +++ b/src/refine.rs @@ -14,7 +14,9 @@ mod basic_block; pub use basic_block::BasicBlockType; mod env; -pub use env::{Assumption, Env, PlaceType, PlaceTypeBuilder, PlaceTypeVar, TempVarIdx, Var}; +pub use env::{ + Assumption, EnumDefProvider, Env, PlaceType, PlaceTypeBuilder, PlaceTypeVar, TempVarIdx, Var, +}; use crate::chc::DatatypeSymbol; use rustc_middle::ty as mir_ty; diff --git a/src/refine/env.rs b/src/refine/env.rs index 3663d81..e8b0cb7 100644 --- a/src/refine/env.rs +++ b/src/refine/env.rs @@ -290,6 +290,16 @@ impl PlaceTypeBuilder { } } +pub trait EnumDefProvider { + fn enum_def(&self, name: &chc::DatatypeSymbol) -> rty::EnumDatatypeDef; +} + +impl<'a, T: EnumDefProvider> EnumDefProvider for &'a T { + fn enum_def(&self, name: &chc::DatatypeSymbol) -> rty::EnumDatatypeDef { + ::enum_def(self, name) + } +} + #[derive(Debug, Clone)] pub struct PlaceType { pub ty: rty::Type, @@ -392,16 +402,19 @@ impl PlaceType { builder.build(ty, term) } - pub fn downcast( + pub fn downcast( self, variant_idx: VariantIdx, field_idx: FieldIdx, - enum_defs: &HashMap, - ) -> PlaceType { + enum_defs: T, + ) -> PlaceType + where + T: EnumDefProvider, + { let mut builder = PlaceTypeBuilder::default(); let (inner_ty, inner_term) = builder.subsume(self); let inner_ty = inner_ty.into_enum().unwrap(); - let def = &enum_defs[&inner_ty.symbol]; + let def = enum_defs.enum_def(&inner_ty.symbol); let variant = &def.variants[variant_idx]; let mut field_terms = Vec::new(); @@ -510,18 +523,21 @@ impl PlaceType { pub type Assumption = rty::Formula; #[derive(Debug, Clone)] -pub struct Env { +pub struct Env { locals: BTreeMap>, flow_locals: BTreeMap, temp_vars: IndexVec, assumptions: Vec, - enum_defs: HashMap, + enum_defs: T, enum_expansion_depth_limit: usize, } -impl rty::ClauseScope for Env { +impl rty::ClauseScope for Env +where + T: EnumDefProvider, +{ fn build_clause(&self) -> chc::ClauseBuilder { let mut builder = chc::ClauseBuilder::default(); for (v, sort) in self.dependencies() { @@ -565,7 +581,10 @@ impl rty::ClauseScope for Env { } } -impl refine::TemplateScope for Env { +impl refine::TemplateScope for Env +where + T: EnumDefProvider, +{ type Var = Var; fn build_template(&self) -> rty::TemplateBuilder { let mut builder = rty::TemplateBuilder::default(); @@ -576,8 +595,11 @@ impl refine::TemplateScope for Env { } } -impl Env { - pub fn new(enum_defs: HashMap) -> Self { +impl Env +where + T: EnumDefProvider, +{ + pub fn new(enum_defs: T) -> Self { Env { locals: Default::default(), flow_locals: Default::default(), @@ -727,7 +749,7 @@ impl Env { assert_eq!(temp, self.temp_vars.push(TempVarBinding::Flow(dummy))); } - let def = self.enum_defs[&ty.symbol].clone(); + let def = self.enum_defs.enum_def(&ty.symbol); let matcher_pred = chc::MatcherPred::new(ty.symbol.clone(), ty.arg_sorts()); let discr_var = self @@ -939,7 +961,7 @@ impl Env { .collect(); let arg_rtys = { - let def = &self.enum_defs[sym]; + let def = self.enum_defs.enum_def(sym); let expected_tys = def .field_tys() .map(|ty| rty::RefinedType::unrefined(ty.clone().vacuous()).boxed()); @@ -1052,7 +1074,10 @@ impl Path { } } -impl Env { +impl Env +where + T: EnumDefProvider, +{ fn path_type(&self, path: &Path) -> PlaceType { match path { Path::PlaceTy(pty) => pty.clone(), @@ -1084,7 +1109,7 @@ impl Env { .map(|i| self.dropping_assumption(&path.clone().tuple_proj(i))) .collect() } else if let Some(ety) = ty.ty.as_enum() { - let enum_def = self.enum_defs[&ety.symbol].clone(); + let enum_def = self.enum_defs.enum_def(&ety.symbol); let matcher_pred = chc::MatcherPred::new(ety.symbol.clone(), ety.arg_sorts()); let PlaceType { diff --git a/tests/ui/fail/option_mut.rs b/tests/ui/fail/option_mut.rs new file mode 100644 index 0000000..827bc9c --- /dev/null +++ b/tests/ui/fail/option_mut.rs @@ -0,0 +1,10 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +fn main() { + let mut m: Option = Some(1); + if let Some(i) = &mut m { + *i += 2; + } + assert!(matches!(m, Some(1))); +} diff --git a/tests/ui/pass/option_mut.rs b/tests/ui/pass/option_mut.rs new file mode 100644 index 0000000..37f7121 --- /dev/null +++ b/tests/ui/pass/option_mut.rs @@ -0,0 +1,10 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +fn main() { + let mut m: Option = Some(1); + if let Some(i) = &mut m { + *i += 2; + } + assert!(matches!(m, Some(3))); +} From fade2d5e64ecb3dc51445405aabf6667ae3e98fe Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 28 Dec 2025 19:04:41 +0900 Subject: [PATCH 45/75] Add more test cases --- tests/ui/fail/option_inc.rs | 17 +++++++++++++++++ tests/ui/fail/option_loop.rs | 14 ++++++++++++++ tests/ui/fail/result_mut.rs | 18 ++++++++++++++++++ tests/ui/fail/result_struct.rs | 22 ++++++++++++++++++++++ tests/ui/pass/option_inc.rs | 17 +++++++++++++++++ tests/ui/pass/option_loop.rs | 14 ++++++++++++++ tests/ui/pass/result_mut.rs | 18 ++++++++++++++++++ tests/ui/pass/result_struct.rs | 23 +++++++++++++++++++++++ 8 files changed, 143 insertions(+) create mode 100644 tests/ui/fail/option_inc.rs create mode 100644 tests/ui/fail/option_loop.rs create mode 100644 tests/ui/fail/result_mut.rs create mode 100644 tests/ui/fail/result_struct.rs create mode 100644 tests/ui/pass/option_inc.rs create mode 100644 tests/ui/pass/option_loop.rs create mode 100644 tests/ui/pass/result_mut.rs create mode 100644 tests/ui/pass/result_struct.rs diff --git a/tests/ui/fail/option_inc.rs b/tests/ui/fail/option_inc.rs new file mode 100644 index 0000000..4e98b06 --- /dev/null +++ b/tests/ui/fail/option_inc.rs @@ -0,0 +1,17 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +fn maybe_inc(x: i32, do_it: bool) -> Option { + if do_it { + Some(x + 1) + } else { + None + } +} + +fn main() { + let res = maybe_inc(10, true); + if let Some(v) = res { + assert!(v == 12); + } +} diff --git a/tests/ui/fail/option_loop.rs b/tests/ui/fail/option_loop.rs new file mode 100644 index 0000000..a010af4 --- /dev/null +++ b/tests/ui/fail/option_loop.rs @@ -0,0 +1,14 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +fn main() { + let mut opt = Some(5); + while let Some(x) = opt { + if x > 0 { + opt = Some(x - 1); + } else { + opt = None; + } + } + assert!(matches!(opt, Some(0))); +} diff --git a/tests/ui/fail/result_mut.rs b/tests/ui/fail/result_mut.rs new file mode 100644 index 0000000..7baa46b --- /dev/null +++ b/tests/ui/fail/result_mut.rs @@ -0,0 +1,18 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +fn mutate_res(r: &mut Result) { + match r { + Ok(v) => *v += 1, + Err(e) => *e -= 1, + } +} + +fn main() { + let mut r = Ok(10); + mutate_res(&mut r); + match r { + Ok(v) => assert!(v == 10), + Err(_) => unreachable!(), + } +} diff --git a/tests/ui/fail/result_struct.rs b/tests/ui/fail/result_struct.rs new file mode 100644 index 0000000..cb16420 --- /dev/null +++ b/tests/ui/fail/result_struct.rs @@ -0,0 +1,22 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +struct Point { + x: i32, + y: i32, +} + +fn make_point(x: i32, y: i32) -> Result { + if x >= 0 && y >= 0 { + Ok(Point { x, y }) + } else { + Err(()) + } +} + +fn main() { + let p = make_point(1, 2); + if let Ok(pt) = p { + assert!(pt.x > 1); + } +} diff --git a/tests/ui/pass/option_inc.rs b/tests/ui/pass/option_inc.rs new file mode 100644 index 0000000..aa0530b --- /dev/null +++ b/tests/ui/pass/option_inc.rs @@ -0,0 +1,17 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +fn maybe_inc(x: i32, do_it: bool) -> Option { + if do_it { + Some(x + 1) + } else { + None + } +} + +fn main() { + let res = maybe_inc(10, true); + if let Some(v) = res { + assert!(v == 11); + } +} diff --git a/tests/ui/pass/option_loop.rs b/tests/ui/pass/option_loop.rs new file mode 100644 index 0000000..ea3697f --- /dev/null +++ b/tests/ui/pass/option_loop.rs @@ -0,0 +1,14 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +fn main() { + let mut opt = Some(5); + while let Some(x) = opt { + if x > 0 { + opt = Some(x - 1); + } else { + opt = None; + } + } + assert!(matches!(opt, None)); +} diff --git a/tests/ui/pass/result_mut.rs b/tests/ui/pass/result_mut.rs new file mode 100644 index 0000000..1a5e218 --- /dev/null +++ b/tests/ui/pass/result_mut.rs @@ -0,0 +1,18 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +fn mutate_res(r: &mut Result) { + match r { + Ok(v) => *v += 1, + Err(e) => *e -= 1, + } +} + +fn main() { + let mut r = Ok(10); + mutate_res(&mut r); + match r { + Ok(v) => assert!(v == 11), + Err(_) => unreachable!(), + } +} diff --git a/tests/ui/pass/result_struct.rs b/tests/ui/pass/result_struct.rs new file mode 100644 index 0000000..6e80f99 --- /dev/null +++ b/tests/ui/pass/result_struct.rs @@ -0,0 +1,23 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +struct Point { + x: i32, + y: i32, +} + +fn make_point(x: i32, y: i32) -> Result { + if x >= 0 && y >= 0 { + Ok(Point { x, y }) + } else { + Err(()) + } +} + +fn main() { + let p = make_point(1, 2); + if let Ok(pt) = p { + assert!(pt.x >= 0); + assert!(pt.y >= 0); + } +} From 39954dfbcb0e1ac0cc02a1a8bc87b4848f6b1b62 Mon Sep 17 00:00:00 2001 From: coord_e Date: Tue, 30 Dec 2025 01:22:55 +0900 Subject: [PATCH 46/75] Populate enum_defs before analyzing BB --- src/analyze/basic_block.rs | 30 ++++++++++++++++++++++++++++-- src/analyze/crate_.rs | 11 ----------- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/analyze/basic_block.rs b/src/analyze/basic_block.rs index 6053489..86fa58f 100644 --- a/src/analyze/basic_block.rs +++ b/src/analyze/basic_block.rs @@ -13,8 +13,8 @@ use crate::analyze; use crate::chc; use crate::pretty::PrettyDisplayExt as _; use crate::refine::{ - self, Assumption, BasicBlockType, Env, PlaceType, PlaceTypeBuilder, PlaceTypeVar, TempVarIdx, - TypeBuilder, Var, + Assumption, BasicBlockType, PlaceType, PlaceTypeBuilder, PlaceTypeVar, TempVarIdx, TypeBuilder, + Var, }; use crate::rty::{ self, ClauseBuilderExt as _, ClauseScope as _, ShiftExistential as _, Subtyping as _, @@ -925,6 +925,31 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } } } + + fn register_enum_defs(&mut self) { + for local_decl in &self.local_decls { + use mir_ty::{TypeSuperVisitable as _, TypeVisitable as _}; + #[derive(Default)] + struct EnumCollector { + enums: std::collections::HashSet, + } + impl<'tcx> mir_ty::TypeVisitor> for EnumCollector { + fn visit_ty(&mut self, ty: mir_ty::Ty<'tcx>) { + if let mir_ty::TyKind::Adt(adt_def, _) = ty.kind() { + if adt_def.is_enum() { + self.enums.insert(adt_def.did()); + } + } + ty.super_visit_with(self); + } + } + let mut visitor = EnumCollector::default(); + local_decl.ty.visit_with(&mut visitor); + for def_id in visitor.enums { + self.ctx.get_or_register_enum_def(def_id); + } + } + } } /// Turns [`rty::RefinedType`] into [`rty::RefinedType`]. @@ -1140,6 +1165,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { pub fn run(&mut self, expected: &BasicBlockType) { let span = tracing::info_span!("bb", bb = ?self.basic_block); let _guard = span.enter(); + self.register_enum_defs(); let params = expected.as_ref().params.clone(); self.bind_locals(¶ms); diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 5467630..d707578 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -2,7 +2,6 @@ use std::collections::HashSet; -use rustc_hir::def::DefKind; use rustc_middle::ty::{self as mir_ty, TyCtxt}; use rustc_span::def_id::{DefId, LocalDefId}; @@ -165,15 +164,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } } } - - fn register_enum_defs(&mut self) { - for local_def_id in self.tcx.iter_local_def_id() { - let DefKind::Enum = self.tcx.def_kind(local_def_id) else { - continue; - }; - self.ctx.register_enum_def(local_def_id.to_def_id()); - } - } } impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { @@ -187,7 +177,6 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let span = tracing::debug_span!("crate", krate = %self.tcx.crate_name(rustc_span::def_id::LOCAL_CRATE)); let _guard = span.enter(); - self.register_enum_defs(); self.refine_local_defs(); self.analyze_local_defs(); self.assert_callable_entry(); From 04b379c70814bc7d3510736ca00e4a4184338036 Mon Sep 17 00:00:00 2001 From: coord_e Date: Tue, 30 Dec 2025 01:23:13 +0900 Subject: [PATCH 47/75] Don't attach value formula when ty is singleton --- src/refine/env.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/refine/env.rs b/src/refine/env.rs index e8b0cb7..59e5024 100644 --- a/src/refine/env.rs +++ b/src/refine/env.rs @@ -317,9 +317,11 @@ impl From for rty::RefinedType { formula, } = ty; let mut body = formula.map_var(Into::into); - body.push_conj( - chc::Term::var(rty::RefinedTypeVar::Value).equal_to(term.map_var(Into::into)), - ); + if !ty.to_sort().is_singleton() { + body.push_conj( + chc::Term::var(rty::RefinedTypeVar::Value).equal_to(term.map_var(Into::into)), + ); + } let refinement = rty::Refinement::new(existentials, body); rty::RefinedType::new(ty, refinement) } From 2ca04c664b16203cb4080655624e3c3c43db03a7 Mon Sep 17 00:00:00 2001 From: coord_e Date: Tue, 30 Dec 2025 11:02:53 +0900 Subject: [PATCH 48/75] Fix missing unbox --- src/chc/unbox.rs | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/src/chc/unbox.rs b/src/chc/unbox.rs index 667ab39..5be1240 100644 --- a/src/chc/unbox.rs +++ b/src/chc/unbox.rs @@ -13,9 +13,11 @@ fn unbox_term(term: Term) -> Term { Term::App(fun, args) => Term::App(fun, args.into_iter().map(unbox_term).collect()), Term::Tuple(ts) => Term::Tuple(ts.into_iter().map(unbox_term).collect()), Term::TupleProj(t, i) => Term::TupleProj(Box::new(unbox_term(*t)), i), - Term::DatatypeCtor(s1, s2, args) => { - Term::DatatypeCtor(s1, s2, args.into_iter().map(unbox_term).collect()) - } + Term::DatatypeCtor(s1, s2, args) => Term::DatatypeCtor( + unbox_datatype_sort(s1), + s2, + args.into_iter().map(unbox_term).collect(), + ), Term::DatatypeDiscr(sym, arg) => Term::DatatypeDiscr(sym, Box::new(unbox_term(*arg))), Term::FormulaExistentialVar(sort, name) => { Term::FormulaExistentialVar(unbox_sort(sort), name) @@ -23,9 +25,30 @@ fn unbox_term(term: Term) -> Term { } } +fn unbox_matcher_pred(pred: MatcherPred) -> Pred { + let MatcherPred { + datatype_symbol, + datatype_args, + } = pred; + let datatype_args = datatype_args.into_iter().map(unbox_sort).collect(); + Pred::Matcher(MatcherPred { + datatype_symbol, + datatype_args, + }) +} + +fn unbox_pred(pred: Pred) -> Pred { + match pred { + Pred::Known(pred) => Pred::Known(pred), + Pred::Var(pred) => Pred::Var(pred), + Pred::Matcher(pred) => unbox_matcher_pred(pred), + } +} + fn unbox_atom(atom: Atom) -> Atom { let Atom { guard, pred, args } = atom; let guard = guard.map(|fo| Box::new(unbox_formula(*fo))); + let pred = unbox_pred(pred); let args = args.into_iter().map(unbox_term).collect(); Atom { guard, pred, args } } From 5101b46207eeec83571f1e59b50952b86cc17275 Mon Sep 17 00:00:00 2001 From: coord_e Date: Sun, 14 Dec 2025 17:20:52 +0900 Subject: [PATCH 49/75] Enable to parse enum constructors in annotations --- src/annot.rs | 188 ++++++++++++++++++++++++++--- tests/ui/fail/annot_enum_simple.rs | 20 +++ tests/ui/pass/annot_enum_simple.rs | 20 +++ 3 files changed, 214 insertions(+), 14 deletions(-) create mode 100644 tests/ui/fail/annot_enum_simple.rs create mode 100644 tests/ui/pass/annot_enum_simple.rs diff --git a/src/annot.rs b/src/annot.rs index 730b75e..bb4c118 100644 --- a/src/annot.rs +++ b/src/annot.rs @@ -33,6 +33,62 @@ impl AnnotFormula { } } +/// A path in an annotation. +/// +/// Handling of paths in Thrust annotations is intentionally limited now. We plan to re-implement +/// annotation parsing while letting rustc handle path resolution in the future. +#[derive(Debug, Clone)] +pub struct AnnotPath { + pub segments: Vec, +} + +impl AnnotPath { + /// Convert the path to a datatype constructor term with the given arguments. + pub fn to_datatype_ctor(&self, ctor_args: Vec>) -> (chc::Term, chc::Sort) { + let mut segments = self.segments.clone(); + + let ctor = segments.pop().unwrap(); + if !ctor.generic_args.is_empty() { + unimplemented!("generic arguments in datatype constructor"); + } + let Some(ty_last_segment) = segments.last_mut() else { + unimplemented!("single segment path"); + }; + let generic_args: Vec<_> = ty_last_segment.generic_args.drain(..).collect(); + let ty_path_idents: Vec<_> = segments + .into_iter() + .map(|segment| { + if !segment.generic_args.is_empty() { + unimplemented!("generic arguments in datatype constructor"); + } + segment.ident.to_string() + }) + .collect(); + // see refine::datatype_symbol + let d_sym = chc::DatatypeSymbol::new(ty_path_idents.join(".")); + let v_sym = chc::DatatypeSymbol::new(format!("{}.{}", d_sym, ctor.ident.as_str())); + let term = chc::Term::datatype_ctor(d_sym.clone(), generic_args.clone(), v_sym, ctor_args); + let sort = chc::Sort::datatype(d_sym, generic_args); + (term, sort) + } + + /// If the path consists of a single segment without generic arguments, return that identifier. + pub fn single_segment_ident(&self) -> Option<&Ident> { + if self.segments.len() == 1 && self.segments[0].generic_args.is_empty() { + Some(&self.segments[0].ident) + } else { + None + } + } +} + +/// A segment of a path in an annotation. +#[derive(Debug, Clone)] +pub struct AnnotPathSegment { + pub ident: Ident, + pub generic_args: Vec, +} + /// A trait for resolving variables in annotations to their logical representation and their sorts. pub trait Resolver { type Output; @@ -298,6 +354,88 @@ where } } + fn parse_path_tail(&mut self, head: Ident) -> Result { + let mut segments: Vec = Vec::new(); + segments.push(AnnotPathSegment { + ident: head, + generic_args: Vec::new(), + }); + while let Some(Token { + kind: TokenKind::ModSep, + .. + }) = self.look_ahead_token(0) + { + self.consume(); + match self.next_token("ident or <")? { + t @ Token { + kind: TokenKind::Lt, + .. + } => { + if segments.is_empty() { + return Err(ParseAttrError::unexpected_token( + "path segment before <", + t.clone(), + )); + } + let mut generic_args = Vec::new(); + loop { + let sort = self.parse_sort()?; + generic_args.push(sort); + match self.next_token(", or >")? { + Token { + kind: TokenKind::Comma, + .. + } => {} + Token { + kind: TokenKind::Gt, + .. + } => break, + t => return Err(ParseAttrError::unexpected_token(", or >", t.clone())), + } + } + segments.last_mut().unwrap().generic_args = generic_args; + } + t @ Token { + kind: TokenKind::Ident(_, _), + .. + } => { + let (ident, _) = t.ident().unwrap(); + segments.push(AnnotPathSegment { + ident, + generic_args: Vec::new(), + }); + } + t => return Err(ParseAttrError::unexpected_token("ident or <", t.clone())), + } + } + Ok(AnnotPath { segments }) + } + + fn parse_datatype_ctor_args(&mut self) -> Result>> { + if self.look_ahead_token(0).is_none() { + return Ok(Vec::new()); + } + + let mut terms = Vec::new(); + loop { + let formula_or_term = self.parse_formula_or_term()?; + let (t, _) = formula_or_term.into_term().ok_or_else(|| { + ParseAttrError::unexpected_formula("in datatype constructor arguments") + })?; + terms.push(t); + if let Some(Token { + kind: TokenKind::Comma, + .. + }) = self.look_ahead_token(0) + { + self.consume(); + } else { + break; + } + } + Ok(terms) + } + fn parse_atom(&mut self) -> Result> { let tt = self.next_token_tree("term or formula")?.clone(); @@ -317,21 +455,43 @@ where }; let formula_or_term = if let Some((ident, _)) = t.ident() { - match ( - ident.as_str(), - self.formula_existentials.get(ident.name.as_str()), - ) { - ("true", _) => FormulaOrTerm::Formula(chc::Formula::top()), - ("false", _) => FormulaOrTerm::Formula(chc::Formula::bottom()), - (_, Some(sort)) => { - let var = - chc::Term::FormulaExistentialVar(sort.clone(), ident.name.to_string()); - FormulaOrTerm::Term(var, sort.clone()) - } - _ => { - let (v, sort) = self.resolve(ident)?; - FormulaOrTerm::Term(chc::Term::var(v), sort) + let path = self.parse_path_tail(ident)?; + if let Some(ident) = path.single_segment_ident() { + match ( + ident.as_str(), + self.formula_existentials.get(ident.name.as_str()), + ) { + ("true", _) => FormulaOrTerm::Formula(chc::Formula::top()), + ("false", _) => FormulaOrTerm::Formula(chc::Formula::bottom()), + (_, Some(sort)) => { + let var = + chc::Term::FormulaExistentialVar(sort.clone(), ident.name.to_string()); + FormulaOrTerm::Term(var, sort.clone()) + } + _ => { + let (v, sort) = self.resolve(*ident)?; + FormulaOrTerm::Term(chc::Term::var(v), sort) + } } + } else { + let next_tt = self + .next_token_tree("arguments for datatype constructor")? + .clone(); + let TokenTree::Delimited(_, _, Delimiter::Parenthesis, s) = next_tt else { + return Err(ParseAttrError::unexpected_token_tree( + "arguments for datatype constructor", + next_tt.clone(), + )); + }; + let mut parser = Parser { + resolver: self.boxed_resolver(), + cursor: s.trees(), + formula_existentials: self.formula_existentials.clone(), + }; + let args = parser.parse_datatype_ctor_args()?; + parser.end_of_input()?; + let (term, sort) = path.to_datatype_ctor(args); + FormulaOrTerm::Term(term, sort) } } else { match t.kind { diff --git a/tests/ui/fail/annot_enum_simple.rs b/tests/ui/fail/annot_enum_simple.rs new file mode 100644 index 0000000..b0f9077 --- /dev/null +++ b/tests/ui/fail/annot_enum_simple.rs @@ -0,0 +1,20 @@ +//@error-in-other-file: Unsat + +pub enum X { + A(i64), + B(bool), +} + +#[thrust::requires(x == X::A(1))] +#[thrust::ensures(true)] +fn test(x: X) { + if let X::A(i) = x { + assert!(i == 2); + } else { + loop {} + } +} + +fn main() { + test(X::A(1)); +} diff --git a/tests/ui/pass/annot_enum_simple.rs b/tests/ui/pass/annot_enum_simple.rs new file mode 100644 index 0000000..70eed4c --- /dev/null +++ b/tests/ui/pass/annot_enum_simple.rs @@ -0,0 +1,20 @@ +//@check-pass + +pub enum X { + A(i64), + B(bool), +} + +#[thrust::requires(x == X::A(1))] +#[thrust::ensures(true)] +fn test(x: X) { + if let X::A(i) = x { + assert!(i == 1); + } else { + loop {} + } +} + +fn main() { + test(X::A(1)); +} From a48be415fbc2abd3ec84b08cde817eafeb07959c Mon Sep 17 00:00:00 2001 From: coord_e Date: Mon, 24 Nov 2025 13:03:30 +0900 Subject: [PATCH 50/75] Implement #[thrust::extern_spec_fn] --- src/analyze/annot.rs | 4 +++ src/analyze/crate_.rs | 12 ++++++- src/analyze/local_def.rs | 54 ++++++++++++++++++++++++++++++- tests/ui/fail/extern_spec_take.rs | 14 ++++++++ tests/ui/pass/extern_spec_take.rs | 15 +++++++++ 5 files changed, 97 insertions(+), 2 deletions(-) create mode 100644 tests/ui/fail/extern_spec_take.rs create mode 100644 tests/ui/pass/extern_spec_take.rs diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 91dd209..2dbb9ea 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -33,6 +33,10 @@ pub fn callable_path() -> [Symbol; 2] { [Symbol::intern("thrust"), Symbol::intern("callable")] } +pub fn extern_spec_fn_path() -> [Symbol; 2] { + [Symbol::intern("thrust"), Symbol::intern("extern_spec_fn")] +} + /// A [`annot::Resolver`] implementation for resolving function parameters. /// /// The parameter names and their sorts needs to be configured via diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 2a17b11..7e3d420 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -48,12 +48,22 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.trusted.insert(local_def_id.to_def_id()); } + if analyzer.is_annotated_as_extern_spec_fn() { + assert!(analyzer.is_fully_annotated()); + self.trusted.insert(local_def_id.to_def_id()); + } + use mir_ty::TypeVisitableExt as _; if sig.has_param() && !analyzer.is_fully_annotated() { self.ctx.register_deferred_def(local_def_id.to_def_id()); } else { let expected = analyzer.expected_ty(); - self.ctx.register_def(local_def_id.to_def_id(), expected); + let target_def_id = if analyzer.is_annotated_as_extern_spec_fn() { + analyzer.extern_spec_fn_target_def_id() + } else { + local_def_id.to_def_id() + }; + self.ctx.register_def(target_def_id, expected); } } diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 50a4397..d556ef0 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -6,7 +6,7 @@ use rustc_index::bit_set::BitSet; use rustc_index::IndexVec; use rustc_middle::mir::{self, BasicBlock, Body, Local}; use rustc_middle::ty::{self as mir_ty, TyCtxt, TypeAndMut}; -use rustc_span::def_id::LocalDefId; +use rustc_span::def_id::{DefId, LocalDefId}; use rustc_span::symbol::Ident; use crate::analyze; @@ -126,6 +126,16 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .is_some() } + pub fn is_annotated_as_extern_spec_fn(&self) -> bool { + self.tcx + .get_attrs_by_path( + self.local_def_id.to_def_id(), + &analyze::annot::extern_spec_fn_path(), + ) + .next() + .is_some() + } + // TODO: unify this logic with extraction functions above pub fn is_fully_annotated(&self) -> bool { let has_require = self @@ -240,6 +250,48 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { rty::RefinedType::unrefined(builder.build().into()) } + /// Extract the target DefId from `#[thrust::extern_spec_fn]` function. + pub fn extern_spec_fn_target_def_id(&self) -> DefId { + struct ExtractDefId<'tcx> { + tcx: TyCtxt<'tcx>, + outer_def_id: LocalDefId, + inner_def_id: Option, + } + + impl<'tcx> rustc_hir::intravisit::Visitor<'tcx> for ExtractDefId<'tcx> { + type NestedFilter = rustc_middle::hir::nested_filter::OnlyBodies; + + fn nested_visit_map(&mut self) -> Self::Map { + self.tcx.hir() + } + + fn visit_qpath( + &mut self, + qpath: &rustc_hir::QPath<'tcx>, + hir_id: rustc_hir::HirId, + _span: rustc_span::Span, + ) { + let typeck_result = self.tcx.typeck(self.outer_def_id); + if let rustc_hir::def::Res::Def(_, def_id) = typeck_result.qpath_res(qpath, hir_id) + { + assert!(self.inner_def_id.is_none(), "invalid extern_spec_fn"); + self.inner_def_id = Some(def_id); + } + } + } + + use rustc_hir::intravisit::Visitor as _; + let mut visitor = ExtractDefId { + tcx: self.tcx, + outer_def_id: self.local_def_id, + inner_def_id: None, + }; + if let rustc_hir::Node::Item(item) = self.tcx.hir_node_by_def_id(self.local_def_id) { + visitor.visit_item(item); + } + visitor.inner_def_id.expect("invalid extern_spec_fn") + } + fn is_mut_param(&self, param_idx: rty::FunctionParamIdx) -> bool { let param_local = analyze::local_of_function_param(param_idx); self.body.local_decls[param_local].mutability.is_mut() diff --git a/tests/ui/fail/extern_spec_take.rs b/tests/ui/fail/extern_spec_take.rs new file mode 100644 index 0000000..5687edd --- /dev/null +++ b/tests/ui/fail/extern_spec_take.rs @@ -0,0 +1,14 @@ +//@error-in-other-file: Unsat + +#[thrust::extern_spec_fn] +#[thrust::requires(true)] +#[thrust::ensures(result == *dest && ^dest == 0)] +fn _extern_spec_take(dest: &mut i32) -> i32 { + std::mem::take(dest) +} + +fn main() { + let mut x = 42; + let old = std::mem::take(&mut x); + assert!(x == 42); +} diff --git a/tests/ui/pass/extern_spec_take.rs b/tests/ui/pass/extern_spec_take.rs new file mode 100644 index 0000000..a72df6d --- /dev/null +++ b/tests/ui/pass/extern_spec_take.rs @@ -0,0 +1,15 @@ +//@check-pass + +#[thrust::extern_spec_fn] +#[thrust::requires(true)] +#[thrust::ensures(result == *dest && ^dest == 0)] +fn _extern_spec_take(dest: &mut i32) -> i32 { + std::mem::take(dest) +} + +fn main() { + let mut x = 42; + let old = std::mem::take(&mut x); + assert!(old == 42); + assert!(x == 0); +} From 8089bfb299ca654098c12c1ecc582a43c017975e Mon Sep 17 00:00:00 2001 From: coord_e Date: Tue, 30 Dec 2025 15:43:47 +0900 Subject: [PATCH 51/75] Parse and in annotations --- src/annot.rs | 25 +++++++++++++++++++++++++ tests/ui/fail/annot_box_term.rs | 17 +++++++++++++++++ tests/ui/fail/annot_mut_term.rs | 13 +++++++++++++ tests/ui/pass/annot_box_term.rs | 17 +++++++++++++++++ tests/ui/pass/annot_mut_term.rs | 13 +++++++++++++ 5 files changed, 85 insertions(+) create mode 100644 tests/ui/fail/annot_box_term.rs create mode 100644 tests/ui/fail/annot_mut_term.rs create mode 100644 tests/ui/pass/annot_box_term.rs create mode 100644 tests/ui/pass/annot_mut_term.rs diff --git a/src/annot.rs b/src/annot.rs index bb4c118..42266ac 100644 --- a/src/annot.rs +++ b/src/annot.rs @@ -502,6 +502,31 @@ where ), _ => unimplemented!(), }, + TokenKind::Lt => { + let (t1, s1) = self + .parse_binop_2()? + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("in box/mut term"))?; + + match self.next_token("> or ,")? { + Token { + kind: TokenKind::Gt, + .. + } => FormulaOrTerm::Term(chc::Term::box_(t1), chc::Sort::box_(s1)), + Token { + kind: TokenKind::Comma, + .. + } => { + let (t2, _s2) = self + .parse_binop_2()? + .into_term() + .ok_or_else(|| ParseAttrError::unexpected_formula("in mut term"))?; + self.expect_next_token(TokenKind::Gt, ">")?; + FormulaOrTerm::Term(chc::Term::mut_(t1, t2), chc::Sort::mut_(s1)) + } + t => return Err(ParseAttrError::unexpected_token("> or ,", t.clone())), + } + } _ => { return Err(ParseAttrError::unexpected_token( "identifier, or literal", diff --git a/tests/ui/fail/annot_box_term.rs b/tests/ui/fail/annot_box_term.rs new file mode 100644 index 0000000..cdb3eb6 --- /dev/null +++ b/tests/ui/fail/annot_box_term.rs @@ -0,0 +1,17 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +#[thrust::sig(fn(x: int) -> {r: Box | r == })] +fn box_create(x: i64) -> Box { + Box::new(x) +} + +#[thrust::requires(b == )] +fn box_consume(b: Box, v: i64) { + assert!(*b == v); +} + +fn main() { + let b = box_create(42); + box_consume(b, 10); +} diff --git a/tests/ui/fail/annot_mut_term.rs b/tests/ui/fail/annot_mut_term.rs new file mode 100644 index 0000000..f1037dd --- /dev/null +++ b/tests/ui/fail/annot_mut_term.rs @@ -0,0 +1,13 @@ +//@error-in-other-file: Unsat + +#[thrust::requires(true)] +#[thrust::ensures(x == <*x, y>)] +fn f(x: &mut i64, y: i64) { + *x = y; +} + +fn main() { + let mut a = 1; + f(&mut a, 2); + assert!(a == 1); +} diff --git a/tests/ui/pass/annot_box_term.rs b/tests/ui/pass/annot_box_term.rs new file mode 100644 index 0000000..95d4775 --- /dev/null +++ b/tests/ui/pass/annot_box_term.rs @@ -0,0 +1,17 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +#[thrust::sig(fn(x: int) -> {r: Box | r == })] +fn box_create(x: i64) -> Box { + Box::new(x) +} + +#[thrust::requires(b == )] +fn box_consume(b: Box, v: i64) { + assert!(*b == v); +} + +fn main() { + let b = box_create(42); + box_consume(b, 42); +} diff --git a/tests/ui/pass/annot_mut_term.rs b/tests/ui/pass/annot_mut_term.rs new file mode 100644 index 0000000..172a1df --- /dev/null +++ b/tests/ui/pass/annot_mut_term.rs @@ -0,0 +1,13 @@ +//@check-pass + +#[thrust::requires(true)] +#[thrust::ensures(x == <*x, y>)] +fn f(x: &mut i64, y: i64) { + *x = y; +} + +fn main() { + let mut a = 1; + f(&mut a, 2); + assert!(a == 2); +} From e0e13f34d1d99b039d8d4382336fa4479b441f0e Mon Sep 17 00:00:00 2001 From: coord_e Date: Tue, 30 Dec 2025 17:14:56 +0900 Subject: [PATCH 52/75] Parse sort params in ad-hoc way --- src/annot.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/annot.rs b/src/annot.rs index 42266ac..a1cabe4 100644 --- a/src/annot.rs +++ b/src/annot.rs @@ -839,7 +839,16 @@ where "string" => unimplemented!(), "null" => chc::Sort::null(), "fn" => unimplemented!(), - _ => unimplemented!(), + name => { + // TODO: ad-hoc + if let Some(i) = + name.strip_prefix("T").and_then(|s| s.parse::().ok()) + { + chc::Sort::param(i) + } else { + unimplemented!(); + } + } }; Ok(sort) } From c3a550aacf98eeb226fc4d7a51fe7a491827d96d Mon Sep 17 00:00:00 2001 From: coord_e Date: Wed, 31 Dec 2025 16:35:18 +0900 Subject: [PATCH 53/75] Parse boolean literals in annotation --- src/annot.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/annot.rs b/src/annot.rs index a1cabe4..fe55c4a 100644 --- a/src/annot.rs +++ b/src/annot.rs @@ -196,6 +196,7 @@ enum FormulaOrTerm { Term(chc::Term, chc::Sort), BinOp(chc::Term, AmbiguousBinOp, chc::Term), Not(Box>), + Literal(bool), } impl FormulaOrTerm { @@ -215,6 +216,13 @@ impl FormulaOrTerm { chc::Formula::Atom(chc::Atom::new(pred.into(), vec![lhs, rhs])) } FormulaOrTerm::Not(formula_or_term) => formula_or_term.into_formula()?.not(), + FormulaOrTerm::Literal(b) => { + if b { + chc::Formula::top() + } else { + chc::Formula::bottom() + } + } }; Some(fo) } @@ -233,6 +241,7 @@ impl FormulaOrTerm { let (t, _) = formula_or_term.into_term()?; (t.not(), chc::Sort::bool()) } + FormulaOrTerm::Literal(b) => (chc::Term::bool(b), chc::Sort::bool()), }; Some((t, s)) } @@ -461,8 +470,8 @@ where ident.as_str(), self.formula_existentials.get(ident.name.as_str()), ) { - ("true", _) => FormulaOrTerm::Formula(chc::Formula::top()), - ("false", _) => FormulaOrTerm::Formula(chc::Formula::bottom()), + ("true", _) => FormulaOrTerm::Literal(true), + ("false", _) => FormulaOrTerm::Literal(false), (_, Some(sort)) => { let var = chc::Term::FormulaExistentialVar(sort.clone(), ident.name.to_string()); From aec64940867cc180e0e530dcd9a107acebc4657a Mon Sep 17 00:00:00 2001 From: coord_e Date: Wed, 31 Dec 2025 16:35:36 +0900 Subject: [PATCH 54/75] Parse pointer sorts in annotation --- src/annot.rs | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/annot.rs b/src/annot.rs index fe55c4a..0e6bcd2 100644 --- a/src/annot.rs +++ b/src/annot.rs @@ -835,6 +835,26 @@ where fn parse_sort(&mut self) -> Result { let tt = self.next_token_tree("sort")?.clone(); match tt { + TokenTree::Token( + Token { + kind: TokenKind::BinOp(BinOpToken::And), + .. + }, + _, + ) => match self.look_ahead_token(0) { + Some(Token { + kind: TokenKind::Ident(sym, _), + .. + }) if sym.as_str() == "mut" => { + self.consume(); + let inner_sort = self.parse_sort()?; + Ok(chc::Sort::mut_(inner_sort)) + } + _ => { + let inner_sort = self.parse_sort()?; + Ok(chc::Sort::box_(inner_sort)) + } + }, TokenTree::Token( Token { kind: TokenKind::Ident(sym, _), From 53a77c6a3281a40362e47c2ef285ba7eda395b5f Mon Sep 17 00:00:00 2001 From: coord_e Date: Wed, 31 Dec 2025 16:41:49 +0900 Subject: [PATCH 55/75] fixup! Parse sort params in ad-hoc way --- src/annot.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/annot.rs b/src/annot.rs index 0e6bcd2..128c289 100644 --- a/src/annot.rs +++ b/src/annot.rs @@ -871,7 +871,7 @@ where name => { // TODO: ad-hoc if let Some(i) = - name.strip_prefix("T").and_then(|s| s.parse::().ok()) + name.strip_prefix('T').and_then(|s| s.parse::().ok()) { chc::Sort::param(i) } else { From 7678fac921c391dae0948b401a3dc9a939132fcb Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Fri, 9 Jan 2026 13:08:23 +0000 Subject: [PATCH 56/75] add: test for raw_define attribute --- tests/ui/pass/annot_raw_define.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 tests/ui/pass/annot_raw_define.rs diff --git a/tests/ui/pass/annot_raw_define.rs b/tests/ui/pass/annot_raw_define.rs new file mode 100644 index 0000000..ed8ee54 --- /dev/null +++ b/tests/ui/pass/annot_raw_define.rs @@ -0,0 +1,24 @@ +//@check-pass +//@compile-flags: -Adead_code -C debug-assertions=off + +// Insert definitions written in SMT-LIB2 format into .smt file directly. +// This feature is intended for debug or experiment purpose. +#![thrust::raw_define("(define-fun is_double ((x Int) (doubled_x Int)) Bool + (=( + (* (x 2)) + doubled_x + )) +)")] + +#[thrust::requires(true)] +#[thrust::ensures(result == 2 * x)] +// #[thrust::ensures(is_double(x, result))] +fn double(x: i64) -> i64 { + x + x +} + +fn main() { + let a = 3; + assert!(double(a) == 6); + assert!(is_double(a, double(a))); +} From c5b9b5d9c81b1c92b997f818ad4eda9a91c20678 Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Sat, 10 Jan 2026 07:42:25 +0000 Subject: [PATCH 57/75] add: RawDefinition for System --- src/chc.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/chc.rs b/src/chc.rs index 5543de4..eefdc4f 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -1606,6 +1606,14 @@ impl Clause { } } +/// A definition specified using #![thrust::define_raw()] +/// +/// Those will be directly inserted into the generated SMT-LIB2 file. +#[derive(Debug, Clone)] +pub struct RawDefinition { + pub definition: String, +} + /// A selector for a datatype constructor. /// /// A selector is a function that extracts a field from a datatype value. @@ -1655,6 +1663,7 @@ pub struct PredVarDef { /// A CHC system. #[derive(Debug, Clone, Default)] pub struct System { + pub raw_definitions: Vec, pub datatypes: Vec, pub clauses: IndexVec, pub pred_vars: IndexVec, @@ -1665,6 +1674,10 @@ impl System { self.pred_vars.push(PredVarDef { sig, debug_info }) } + pub fn push_raw_definition(&mut self, raw_definition: RawDefinition) { + self.raw_definitions.push(raw_definition) + } + pub fn push_clause(&mut self, clause: Clause) -> Option { if clause.is_nop() { return None; From 373f7d9d55a4cf9afc9f51988cb00c596978793e Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Sat, 10 Jan 2026 07:43:18 +0000 Subject: [PATCH 58/75] add: formatting for RawDefinition --- src/chc/format_context.rs | 8 +++++++- src/chc/smtlib2.rs | 29 +++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/src/chc/format_context.rs b/src/chc/format_context.rs index 2123315..1c75215 100644 --- a/src/chc/format_context.rs +++ b/src/chc/format_context.rs @@ -21,6 +21,7 @@ use crate::chc::{self, hoice::HoiceDatatypeRenamer}; pub struct FormatContext { renamer: HoiceDatatypeRenamer, datatypes: Vec, + raw_definitions: Vec, } // FIXME: this is obviously ineffective and should be replaced @@ -273,13 +274,18 @@ impl FormatContext { .filter(|d| d.params == 0) .collect(); let renamer = HoiceDatatypeRenamer::new(&datatypes); - FormatContext { renamer, datatypes } + let raw_definitions = system.raw_definitions.clone(); + FormatContext { renamer, datatypes, raw_definitions } } pub fn datatypes(&self) -> &[chc::Datatype] { &self.datatypes } + pub fn raw_definitions(&self) -> &[chc::RawDefinition] { + &self.raw_definitions + } + pub fn box_ctor(&self, sort: &chc::Sort) -> impl std::fmt::Display { let ss = Sort::new(sort).sorts(); format!("box{ss}") diff --git a/src/chc/smtlib2.rs b/src/chc/smtlib2.rs index 167d108..0e0f2e9 100644 --- a/src/chc/smtlib2.rs +++ b/src/chc/smtlib2.rs @@ -370,6 +370,30 @@ impl<'ctx, 'a> Clause<'ctx, 'a> { } } +/// A wrapper around a [`chc::RawDefinition`] that provides a [`std::fmt::Display`] implementation in SMT-LIB2 format. +#[derive(Debug, Clone)] +pub struct RawDefinition<'a> { + inner: &'a chc::RawDefinition, +} + +impl<'a> std::fmt::Display for RawDefinition<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + self.inner.definition, + ) + } +} + +impl<'a> RawDefinition<'a> { + pub fn new(inner: &'a chc::RawDefinition) -> Self { + Self { + inner + } + } +} + /// A wrapper around a [`chc::DatatypeSelector`] that provides a [`std::fmt::Display`] implementation in SMT-LIB2 format. #[derive(Debug, Clone)] pub struct DatatypeSelector<'ctx, 'a> { @@ -555,6 +579,11 @@ impl<'a> std::fmt::Display for System<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { writeln!(f, "(set-logic HORN)\n")?; + // insert definition from #![thrust::define_chc()] here + for raw_def in self.ctx.raw_definitions() { + writeln!(f, "{}\n", RawDefinition::new(raw_def))?; + } + writeln!(f, "{}\n", Datatypes::new(&self.ctx, self.ctx.datatypes()))?; for datatype in self.ctx.datatypes() { writeln!(f, "{}", DatatypeDiscrFun::new(&self.ctx, datatype))?; From 6c58785665425c4d400bfeb263f0ec11db93e4f2 Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Sun, 11 Jan 2026 16:07:25 +0000 Subject: [PATCH 59/75] add: raw_define path --- src/analyze/annot.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 2dbb9ea..30e70d3 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -37,6 +37,10 @@ pub fn extern_spec_fn_path() -> [Symbol; 2] { [Symbol::intern("thrust"), Symbol::intern("extern_spec_fn")] } +pub fn raw_define_path() -> [Symbol; 2] { + [Symbol::intern("thrust"), Symbol::intern("raw_define")] +} + /// A [`annot::Resolver`] implementation for resolving function parameters. /// /// The parameter names and their sorts needs to be configured via From 4c1bfdef2d7eaab2db8e54cd6e951a1632ae39b9 Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Sun, 11 Jan 2026 16:07:50 +0000 Subject: [PATCH 60/75] add: parse raw definitions --- src/annot.rs | 39 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/src/annot.rs b/src/annot.rs index 128c289..6541fee 100644 --- a/src/annot.rs +++ b/src/annot.rs @@ -8,7 +8,7 @@ //! The main entry point is [`AnnotParser`], which parses a [`TokenStream`] into a //! [`rty::RefinedType`] or a [`chc::Formula`]. -use rustc_ast::token::{BinOpToken, Delimiter, LitKind, Token, TokenKind}; +use rustc_ast::token::{BinOpToken, Delimiter, LitKind, Lit, Token, TokenKind}; use rustc_ast::tokenstream::{RefTokenTreeCursor, Spacing, TokenStream, TokenTree}; use rustc_index::IndexVec; use rustc_span::symbol::Ident; @@ -1076,6 +1076,32 @@ where .ok_or_else(|| ParseAttrError::unexpected_term("in annotation"))?; Ok(AnnotFormula::Formula(formula)) } + + pub fn parse_annot_raw_definition(&mut self) -> Result { + let t = self.next_token("raw CHC definition")?; + + match t { + Token { + kind: TokenKind::Literal( + Lit { kind, symbol, .. } + ), + .. + } => { + match kind { + LitKind::Str => { + let definition = symbol.to_string(); + Ok(chc::RawDefinition{ definition }) + }, + _ => Err(ParseAttrError::unexpected_token( + "string literal", t.clone() + )) + } + }, + _ => Err(ParseAttrError::unexpected_token( + "string literal", t.clone() + )) + } + } } /// A [`Resolver`] implementation for resolving specific variable as [`rty::RefinedTypeVar::Value`]. @@ -1208,4 +1234,15 @@ where parser.end_of_input()?; Ok(formula) } + + pub fn parse_raw_definition(&self, ts: TokenStream) -> Result { + let mut parser = Parser { + resolver: &self.resolver, + cursor: ts.trees(), + formula_existentials: Default::default(), + }; + let raw_definition = parser.parse_annot_raw_definition()?; + parser.end_of_input()?; + Ok(raw_definition) + } } From 4016ac4293e1108a0d90d80ab2a49349fc1b022a Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Sun, 11 Jan 2026 16:45:20 +0000 Subject: [PATCH 61/75] fix: invalid SMT-LIB2 format --- tests/ui/pass/annot_raw_define.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/ui/pass/annot_raw_define.rs b/tests/ui/pass/annot_raw_define.rs index ed8ee54..c1c47a7 100644 --- a/tests/ui/pass/annot_raw_define.rs +++ b/tests/ui/pass/annot_raw_define.rs @@ -3,11 +3,12 @@ // Insert definitions written in SMT-LIB2 format into .smt file directly. // This feature is intended for debug or experiment purpose. +#![feature(custom_inner_attributes)] #![thrust::raw_define("(define-fun is_double ((x Int) (doubled_x Int)) Bool - (=( - (* (x 2)) + (= + (* x 2) doubled_x - )) + ) )")] #[thrust::requires(true)] @@ -20,5 +21,5 @@ fn double(x: i64) -> i64 { fn main() { let a = 3; assert!(double(a) == 6); - assert!(is_double(a, double(a))); + // assert!(is_double(a, double(a))); } From 15e86002323e05dcace40f640bb187eac7d2fa9b Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Sun, 11 Jan 2026 16:46:15 +0000 Subject: [PATCH 62/75] add: analyze inner-attribute #[raw_define()] for the given crate --- src/analyze/crate_.rs | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 9dd85f9..c71021e 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -2,12 +2,14 @@ use std::collections::HashSet; +use rustc_hir::def_id::CRATE_DEF_ID; use rustc_middle::ty::{self as mir_ty, TyCtxt}; use rustc_span::def_id::{DefId, LocalDefId}; use crate::analyze; use crate::chc; use crate::rty::{self, ClauseBuilderExt as _}; +use crate::annot; /// An implementation of local crate analysis. /// @@ -26,6 +28,31 @@ pub struct Analyzer<'tcx, 'ctx> { } impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { + // fn is_annotated_as_raw_define(&self) -> bool { + // self.tcx + // .get_attrs_by_path( + // CRATE_DEF_ID.to_def_id(), + // &analyze::annot::raw_define_path(), + // ) + // .next() + // .is_some() + // } + + fn analyze_raw_define_annot(&mut self) { + for attrs in self.tcx.get_attrs_by_path( + CRATE_DEF_ID.to_def_id(), + &analyze::annot::raw_define_path(), + ) { + let ts = analyze::annot::extract_annot_tokens(attrs.clone()); + let parser = annot::AnnotParser::new( + // TODO: this resolver is not actually used. + analyze::annot::ParamResolver::default() + ); + let raw_definition = parser.parse_raw_definition(ts).unwrap(); + self.ctx.system.borrow_mut().push_raw_definition(raw_definition); + } + } + fn refine_local_defs(&mut self) { for local_def_id in self.tcx.mir_keys(()) { if self.tcx.def_kind(*local_def_id).is_fn_like() { @@ -187,6 +214,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let span = tracing::debug_span!("crate", krate = %self.tcx.crate_name(rustc_span::def_id::LOCAL_CRATE)); let _guard = span.enter(); + self.analyze_raw_define_annot(); self.refine_local_defs(); self.analyze_local_defs(); self.assert_callable_entry(); From 38580ddb8de69030249bc85378835ccef50de30c Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Sun, 11 Jan 2026 16:48:01 +0000 Subject: [PATCH 63/75] fix: error relate to new raw_definitions field of System --- src/chc/unbox.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/chc/unbox.rs b/src/chc/unbox.rs index 5be1240..40c2e6a 100644 --- a/src/chc/unbox.rs +++ b/src/chc/unbox.rs @@ -161,6 +161,7 @@ pub fn unbox(system: System) -> System { clauses, pred_vars, datatypes, + raw_definitions, } = system; let datatypes = datatypes.into_iter().map(unbox_datatype).collect(); let clauses = clauses.into_iter().map(unbox_clause).collect(); @@ -169,5 +170,6 @@ pub fn unbox(system: System) -> System { clauses, pred_vars, datatypes, + raw_definitions, } } From 047954d52d247b86734c83ef57b887ce4afb0e5f Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Mon, 12 Jan 2026 02:16:41 +0900 Subject: [PATCH 64/75] remove: unused code --- src/analyze/crate_.rs | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index c71021e..4172f94 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -28,16 +28,6 @@ pub struct Analyzer<'tcx, 'ctx> { } impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { - // fn is_annotated_as_raw_define(&self) -> bool { - // self.tcx - // .get_attrs_by_path( - // CRATE_DEF_ID.to_def_id(), - // &analyze::annot::raw_define_path(), - // ) - // .next() - // .is_some() - // } - fn analyze_raw_define_annot(&mut self) { for attrs in self.tcx.get_attrs_by_path( CRATE_DEF_ID.to_def_id(), From 5a904bd801f835b51fa25c9dcefb0ecf0e6171a0 Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Mon, 12 Jan 2026 05:43:54 +0000 Subject: [PATCH 65/75] add: positiive test for multiple raw_define annotations --- tests/ui/pass/annot_raw_define_multi.rs | 33 +++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 tests/ui/pass/annot_raw_define_multi.rs diff --git a/tests/ui/pass/annot_raw_define_multi.rs b/tests/ui/pass/annot_raw_define_multi.rs new file mode 100644 index 0000000..ae9e0eb --- /dev/null +++ b/tests/ui/pass/annot_raw_define_multi.rs @@ -0,0 +1,33 @@ +//@check-pass +//@compile-flags: -Adead_code -C debug-assertions=off + +// Insert definitions written in SMT-LIB2 format into .smt file directly. +// This feature is intended for debug or experiment purpose. +#![feature(custom_inner_attributes)] +#![thrust::raw_define("(define-fun is_double ((x Int) (doubled_x Int)) Bool + (= + (* x 2) + doubled_x + ) +)")] + +// multiple raw definitions can be inserted. +#![thrust::raw_define("(define-fun is_triple ((x Int) (tripled_x Int)) Bool + (= + (* x 3) + tripled_x + ) +)")] + +#[thrust::requires(true)] +#[thrust::ensures(result == 2 * x)] +// #[thrust::ensures(is_double(x, result))] +fn double(x: i64) -> i64 { + x + x +} + +fn main() { + let a = 3; + assert!(double(a) == 6); + // assert!(is_double(a, double(a))); +} From a96ddf69aea3ce848083bcaac2bf08854d234fcf Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Mon, 12 Jan 2026 05:48:47 +0000 Subject: [PATCH 66/75] add: negative tests for raw_define --- tests/ui/fail/annot_raw_define.rs | 20 +++++++++++++++++++ .../fail/annot_raw_define_without_params.rs | 20 +++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 tests/ui/fail/annot_raw_define.rs create mode 100644 tests/ui/fail/annot_raw_define_without_params.rs diff --git a/tests/ui/fail/annot_raw_define.rs b/tests/ui/fail/annot_raw_define.rs new file mode 100644 index 0000000..346a158 --- /dev/null +++ b/tests/ui/fail/annot_raw_define.rs @@ -0,0 +1,20 @@ +//@error-in-other-file: UnexpectedToken +//@compile-flags: -Adead_code -C debug-assertions=off + +// Insert definitions written in SMT-LIB2 format into .smt file directly. +// This feature is intended for debug or experiment purpose. +#![feature(custom_inner_attributes)] +#![thrust::raw_define(true)] // argument must be single string literal + +#[thrust::requires(true)] +#[thrust::ensures(result == 2 * x)] +// #[thrust::ensures(is_double(x, result))] +fn double(x: i64) -> i64 { + x + x +} + +fn main() { + let a = 3; + assert!(double(a) == 6); + // assert!(is_double(a, double(a))); +} diff --git a/tests/ui/fail/annot_raw_define_without_params.rs b/tests/ui/fail/annot_raw_define_without_params.rs new file mode 100644 index 0000000..d6683a8 --- /dev/null +++ b/tests/ui/fail/annot_raw_define_without_params.rs @@ -0,0 +1,20 @@ +//@error-in-other-file: invalid attribute +//@compile-flags: -Adead_code -C debug-assertions=off + +// Insert definitions written in SMT-LIB2 format into .smt file directly. +// This feature is intended for debug or experiment purpose. +#![feature(custom_inner_attributes)] +#![thrust::raw_define] // argument must be single string literal + +#[thrust::requires(true)] +#[thrust::ensures(result == 2 * x)] +// #[thrust::ensures(is_double(x, result))] +fn double(x: i64) -> i64 { + x + x +} + +fn main() { + let a = 3; + assert!(double(a) == 6); + // assert!(is_double(a, double(a))); +} From 96a6e3745494ca6e8badbf03f9d8877a3377653d Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Thu, 25 Dec 2025 08:24:37 +0000 Subject: [PATCH 67/75] add: test for annotations of predicates --- tests/ui/pass/annot_preds.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 tests/ui/pass/annot_preds.rs diff --git a/tests/ui/pass/annot_preds.rs b/tests/ui/pass/annot_preds.rs new file mode 100644 index 0000000..79bf978 --- /dev/null +++ b/tests/ui/pass/annot_preds.rs @@ -0,0 +1,24 @@ +//@check-pass +//@compile-flags: -Adead_code -C debug-assertions=off + +#[thrust::predicate] +fn is_double(x: i64, doubled_x: i64) -> bool { + x * 2 == doubled_x + // "(=( + // (* (x 2)) + // doubled_x + // ))" +} + +#[thrust::requires(true)] +#[thrust::ensures(result == 2 * x)] +// #[thrust::ensures(is_double(x, result))] +fn double(x: i64) -> i64 { + x + x +} + +fn main() { + let a = 3; + assert!(double(a) == 6); + assert!(is_double(a, double(a))); +} From cff43b79721970daeab4ed76559add7f6e0747a3 Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Mon, 12 Jan 2026 16:10:37 +0900 Subject: [PATCH 68/75] Merge main --- src/analyze/local_def.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index d556ef0..61cb80d 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -135,7 +135,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .next() .is_some() } - + // TODO: unify this logic with extraction functions above pub fn is_fully_annotated(&self) -> bool { let has_require = self From f864b8b929037ba284c2b729fc97196b96f9bf5e Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Fri, 9 Jan 2026 05:58:01 +0000 Subject: [PATCH 69/75] add: definition for user-defined predicates in CHC --- src/chc.rs | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/src/chc.rs b/src/chc.rs index 5543de4..972cdb3 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -902,12 +902,87 @@ impl MatcherPred { } } +// TODO: DatatypeSymbolをほぼそのままコピーする形になっているので、エイリアスなどで共通化すべき? +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct UserDefinedPredSymbol { + inner: String, +} + +impl std::fmt::Display for UserDefinedPredSymbol { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.inner.fmt(f) + } +} + +impl<'a, 'b, D> Pretty<'a, D, termcolor::ColorSpec> for &'b UserDefinedPredSymbol +where + D: pretty::DocAllocator<'a, termcolor::ColorSpec>, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, termcolor::ColorSpec> { + allocator.text(self.inner.clone()) + } +} + +impl UserDefinedPredSymbol { + pub fn new(inner: String) -> Self { + Self { inner } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct UserDefinedPred { + symbol: UserDefinedPredSymbol, + args: Vec, +} + +impl<'a> std::fmt::Display for UserDefinedPred { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.symbol.inner) + } +} + +impl<'a, 'b, D> Pretty<'a, D, termcolor::ColorSpec> for &'b UserDefinedPred +where + D: pretty::DocAllocator<'a, termcolor::ColorSpec>, + D::Doc: Clone, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, termcolor::ColorSpec> { + let args = allocator.intersperse( + self.args.iter().map(|a| a.pretty(allocator)), + allocator.text(", "), + ); + allocator + .text("user_defined_pred") + .append( + allocator + .text(self.symbol.inner.clone()) + .append(args.angles()) + .angles(), + ) + .group() + } +} + +impl UserDefinedPred { + pub fn new(symbol: UserDefinedPredSymbol, args: Vec) -> Self { + Self { + symbol, + args, + } + } + + pub fn name(&self) -> &str { + &self.symbol.inner + } +} + /// A predicate. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Pred { Known(KnownPred), Var(PredVarId), Matcher(MatcherPred), + UserDefined(UserDefinedPredSymbol), } impl std::fmt::Display for Pred { @@ -916,6 +991,7 @@ impl std::fmt::Display for Pred { Pred::Known(p) => p.fmt(f), Pred::Var(p) => p.fmt(f), Pred::Matcher(p) => p.fmt(f), + Pred::UserDefined(p) => p.fmt(f), } } } @@ -930,6 +1006,7 @@ where Pred::Known(p) => p.pretty(allocator), Pred::Var(p) => p.pretty(allocator), Pred::Matcher(p) => p.pretty(allocator), + Pred::UserDefined(p) => p.pretty(allocator), } } } @@ -958,6 +1035,7 @@ impl Pred { Pred::Known(p) => p.name().into(), Pred::Var(p) => p.to_string().into(), Pred::Matcher(p) => p.name().into(), + Pred::UserDefined(p) => p.inner.clone().into(), } } @@ -966,6 +1044,7 @@ impl Pred { Pred::Known(p) => p.is_negative(), Pred::Var(_) => false, Pred::Matcher(_) => false, + Pred::UserDefined(_) => false, } } @@ -974,6 +1053,7 @@ impl Pred { Pred::Known(p) => p.is_infix(), Pred::Var(_) => false, Pred::Matcher(_) => false, + Pred::UserDefined(_) => false, } } @@ -982,6 +1062,7 @@ impl Pred { Pred::Known(p) => p.is_top(), Pred::Var(_) => false, Pred::Matcher(_) => false, + Pred::UserDefined(_) => false, } } @@ -990,6 +1071,7 @@ impl Pred { Pred::Known(p) => p.is_bottom(), Pred::Var(_) => false, Pred::Matcher(_) => false, + Pred::UserDefined(_) => false, } } } From 56bb7e15ff4663c94e84210893c70b5931430b07 Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Fri, 9 Jan 2026 06:17:40 +0000 Subject: [PATCH 70/75] add: an implementation for unboxing user-defined predicates --- src/chc/unbox.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/chc/unbox.rs b/src/chc/unbox.rs index 5be1240..0c3308d 100644 --- a/src/chc/unbox.rs +++ b/src/chc/unbox.rs @@ -42,6 +42,7 @@ fn unbox_pred(pred: Pred) -> Pred { Pred::Known(pred) => Pred::Known(pred), Pred::Var(pred) => Pred::Var(pred), Pred::Matcher(pred) => unbox_matcher_pred(pred), + Pred::UserDefined(pred) => Pred::UserDefined(pred), } } From c9fb50b499dc0829bdb9becb0e77b5c26488770f Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Mon, 12 Jan 2026 17:27:05 +0900 Subject: [PATCH 71/75] add: parse single-path identifier followed by parenthesized arguments as a user-defined predicate call --- src/annot.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/annot.rs b/src/annot.rs index 128c289..dacf709 100644 --- a/src/annot.rs +++ b/src/annot.rs @@ -478,6 +478,17 @@ where FormulaOrTerm::Term(var, sort.clone()) } _ => { + if let Some(TokenTree::Delimited(_, _, Delimiter::Parenthesis, _args)) = self.look_ahead_token_tree(1) { + self.consume(); + let pred_symbol = chc::UserDefinedPredSymbol::new(ident.name.to_string()); + let pred = chc::Pred::UserDefined(pred_symbol); + + let args = self.parse_datatype_ctor_args()?; + + let atom = chc::Atom::new(pred, args); + let formula = chc::Formula::Atom(atom); + return Ok(FormulaOrTerm::Formula(formula)); + } let (v, sort) = self.resolve(*ident)?; FormulaOrTerm::Term(chc::Term::var(v), sort) } From 0da1bd9dbe92bbdcc67cacd0f08ff8767b780542 Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Mon, 12 Jan 2026 19:00:11 +0900 Subject: [PATCH 72/75] fix: wrong implementation of parser for predicate call arguments --- src/annot.rs | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/annot.rs b/src/annot.rs index 76f0608..94fecf8 100644 --- a/src/annot.rs +++ b/src/annot.rs @@ -477,14 +477,23 @@ where chc::Term::FormulaExistentialVar(sort.clone(), ident.name.to_string()); FormulaOrTerm::Term(var, sort.clone()) } - _ => { - if let Some(TokenTree::Delimited(_, _, Delimiter::Parenthesis, _args)) = self.look_ahead_token_tree(1) { + _ => { + let next_tt = self.look_ahead_token_tree(0); + + if let Some(TokenTree::Delimited(_, _, Delimiter::Parenthesis, args)) = next_tt { + let args = args.clone(); self.consume(); + let pred_symbol = chc::UserDefinedPredSymbol::new(ident.name.to_string()); let pred = chc::Pred::UserDefined(pred_symbol); - let args = self.parse_datatype_ctor_args()?; - + let mut parser = Parser { + resolver: self.boxed_resolver(), + cursor: args.trees(), + formula_existentials: self.formula_existentials.clone(), + }; + let args = parser.parse_datatype_ctor_args()?; + let atom = chc::Atom::new(pred, args); let formula = chc::Formula::Atom(atom); return Ok(FormulaOrTerm::Formula(formula)); From 4cd0a346625c603e89cd2e2d3606b561eb3905ac Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Mon, 12 Jan 2026 19:04:55 +0900 Subject: [PATCH 73/75] change: use raw_define to define user-defined predicates for now --- tests/ui/pass/annot_preds.rs | 24 ------------------------ tests/ui/pass/annot_preds_raw_define.rs | 25 +++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 24 deletions(-) delete mode 100644 tests/ui/pass/annot_preds.rs create mode 100644 tests/ui/pass/annot_preds_raw_define.rs diff --git a/tests/ui/pass/annot_preds.rs b/tests/ui/pass/annot_preds.rs deleted file mode 100644 index 79bf978..0000000 --- a/tests/ui/pass/annot_preds.rs +++ /dev/null @@ -1,24 +0,0 @@ -//@check-pass -//@compile-flags: -Adead_code -C debug-assertions=off - -#[thrust::predicate] -fn is_double(x: i64, doubled_x: i64) -> bool { - x * 2 == doubled_x - // "(=( - // (* (x 2)) - // doubled_x - // ))" -} - -#[thrust::requires(true)] -#[thrust::ensures(result == 2 * x)] -// #[thrust::ensures(is_double(x, result))] -fn double(x: i64) -> i64 { - x + x -} - -fn main() { - let a = 3; - assert!(double(a) == 6); - assert!(is_double(a, double(a))); -} diff --git a/tests/ui/pass/annot_preds_raw_define.rs b/tests/ui/pass/annot_preds_raw_define.rs new file mode 100644 index 0000000..bfcde39 --- /dev/null +++ b/tests/ui/pass/annot_preds_raw_define.rs @@ -0,0 +1,25 @@ +//@check-pass +//@compile-flags: -Adead_code -C debug-assertions=off + +// Insert definitions written in SMT-LIB2 format into .smt file directly. +// This feature is intended for debug or experiment purpose. +#![feature(custom_inner_attributes)] +#![thrust::raw_define("(define-fun is_double ((x Int) (doubled_x Int)) Bool + (= + (* x 2) + doubled_x + ) +)")] + +#[thrust::requires(true)] +// #[thrust::ensures(result == 2 * x)] +#[thrust::ensures(is_double(x, result))] +fn double(x: i64) -> i64 { + x + x +} + +fn main() { + let a = 3; + assert!(double(a) == 6); + // assert!(is_double(a, double(a))); +} From f1646242b564936741366d79532f484a0ff2637f Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Mon, 12 Jan 2026 19:19:18 +0900 Subject: [PATCH 74/75] fix: translate comment into English --- src/chc.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chc.rs b/src/chc.rs index 1c4dfc9..5000563 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -902,7 +902,7 @@ impl MatcherPred { } } -// TODO: DatatypeSymbolをほぼそのままコピーする形になっているので、エイリアスなどで共通化すべき? +// TODO: This struct is almost copy of `DatatypeSymbol`. Two traits maight be unified with aliases. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct UserDefinedPredSymbol { inner: String, From 45a5f09c7b8286971c59f27b090a53a583ec9b1d Mon Sep 17 00:00:00 2001 From: coeff-aij <175928954+coeff-aij@users.noreply.github.com> Date: Mon, 12 Jan 2026 20:18:13 +0900 Subject: [PATCH 75/75] add: more test for user-defined predicate calls --- tests/ui/pass/annot_preds_raw_define_multi.rs | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 tests/ui/pass/annot_preds_raw_define_multi.rs diff --git a/tests/ui/pass/annot_preds_raw_define_multi.rs b/tests/ui/pass/annot_preds_raw_define_multi.rs new file mode 100644 index 0000000..c0419f9 --- /dev/null +++ b/tests/ui/pass/annot_preds_raw_define_multi.rs @@ -0,0 +1,36 @@ +//@check-pass +//@compile-flags: -Adead_code -C debug-assertions=off + +#![feature(custom_inner_attributes)] +#![thrust::raw_define("(define-fun is_double ((x Int) (doubled_x Int)) Bool + (= + (* x 2) + doubled_x + ) +)")] + +#![thrust::raw_define("(define-fun is_triple ((x Int) (tripled_x Int)) Bool + (= + (* x 3) + tripled_x + ) +)")] + +#[thrust::requires(true)] +#[thrust::ensures(is_double(x, result))] +fn double(x: i64) -> i64 { + x + x +} + +#[thrust::require(is_double(x, doubled_x))] +#[thrust::ensures(is_triple(x, result))] +fn triple(x: i64, doubled_x: i64) -> i64 { + x + doubled_x +} + +fn main() { + let a = 3; + let double_a = double(a); + assert!(double_a == 6); + assert!(triple(a, double_a) == 9); +}