From 1b6f240e8e3d73b0e3c87a80067a878815676db6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABlle=20Huisman?= Date: Tue, 30 Dec 2025 13:30:03 +0100 Subject: [PATCH] fix: improve type heuristic and related tests --- .../fortifier-macros/src/validate/type.rs | 565 +++++++++++++----- 1 file changed, 428 insertions(+), 137 deletions(-) diff --git a/packages/fortifier-macros/src/validate/type.rs b/packages/fortifier-macros/src/validate/type.rs index d73b3b5..b95de34 100644 --- a/packages/fortifier-macros/src/validate/type.rs +++ b/packages/fortifier-macros/src/validate/type.rs @@ -167,7 +167,10 @@ fn path_to_string(path: &Path) -> String { fn is_validate_path(path: &Path) -> bool { let path_string = path_to_string(path); - path_string == "Validate" || path_string == "fortifier::Validate" + path_string == "Validate" + || path_string == "ValidateWithContext" + || path_string == "fortifier::Validate" + || path_string == "fortifier::ValidateWithContext" } fn should_validate_generic_argument( @@ -274,7 +277,7 @@ fn should_validate_path(generics: &Generics, path: &Path) -> Option Option { Known(T), Unknown, } +impl KnownOrUnknown { + pub fn map(self, f: F) -> KnownOrUnknown + where + F: FnOnce(T) -> U, + { + match self { + KnownOrUnknown::Known(value) => KnownOrUnknown::Known(f(value)), + KnownOrUnknown::Unknown => KnownOrUnknown::Unknown, + } + } +} + pub fn should_validate_type( generics: &Generics, r#type: &Type, ) -> Option> { match r#type { - Type::Array(r#type) => should_validate_type(generics, &r#type.elem), + Type::Array(r#type) => { + should_validate_type(generics, &r#type.elem).map(|error_type| { + error_type.map(|error_type| quote!(::fortifier::IndexedValidationError<#error_type>)) + }) + }, Type::BareFn(_) => None, Type::Group(r#type) => should_validate_type(generics, &r#type.elem), - Type::ImplTrait(r#type) => r#type.bounds.iter().any( - |bound| matches!(bound, TypeParamBound::Trait(bound) if is_validate_path(&bound.path)), - ).then_some(KnownOrUnknown::Unknown), + Type::ImplTrait(r#type) => { + r#type.bounds + .iter() + .any(|bound| matches!(bound, TypeParamBound::Trait(bound) if is_validate_path(&bound.path))) + .then_some(KnownOrUnknown::Unknown) + }, Type::Infer(_) => Some(KnownOrUnknown::Unknown), Type::Macro(_) => Some(KnownOrUnknown::Unknown), Type::Never(_) => None, @@ -305,7 +328,11 @@ pub fn should_validate_type( Type::Path(r#type) => should_validate_path(generics, &r#type.path), Type::Ptr(r#type) => should_validate_type(generics, &r#type.elem), Type::Reference(r#type) => should_validate_type(generics,&r#type.elem), - Type::Slice(r#type) => should_validate_type(generics, &r#type.elem), + Type::Slice(r#type) => { + should_validate_type(generics, &r#type.elem).map(|error_type| { + error_type.map(|error_type| quote!(::fortifier::IndexedValidationError<#error_type>)) + }) + }, Type::TraitObject(r#type) => should_validate_type_param_bounds(r#type.bounds.iter()), Type::Tuple(r#type) => { (!r#type.elems.is_empty() && r#type.elems.iter().all(|r#type| should_validate_type(generics, r#type).is_some())).then_some(KnownOrUnknown::Unknown) @@ -321,155 +348,419 @@ mod tests { use quote::quote; use syn::{GenericParam, Generics, punctuated::Punctuated}; + use crate::validate::r#type::KnownOrUnknown; + use super::should_validate_type; - fn validate(tokens: TokenStream) -> bool { - should_validate_type( - &Generics::default(), - &syn::parse2(tokens).expect("valid type"), - ) - .is_some() + fn validate(tokens: TokenStream) -> Option> { + validate_with_generics(tokens, Generics::default()) } - fn validate_with_generics(tokens: TokenStream, generics: Generics) -> bool { - should_validate_type(&generics, &syn::parse2(tokens).expect("valid type")).is_some() + fn validate_with_generics( + tokens: TokenStream, + generics: Generics, + ) -> Option> { + should_validate_type(&generics, &syn::parse2(tokens).expect("valid type")) + .map(|value| value.map(|value| value.to_string().replace(' ', ""))) } #[test] fn should_validate() { - assert!(validate(quote!(&T))); - assert!(validate(quote!(T))); - - assert!(validate(quote!((T, T)))); - assert!(validate(quote!((A, B, C)))); - - assert!(validate(quote!([T]))); - assert!(validate(quote!([T; 3]))); - assert!(validate(quote!([&T]))); - assert!(validate(quote!([&T; 3]))); - assert!(validate(quote!(&[T]))); - assert!(validate(quote!(&[T; 3]))); - - assert!(validate(quote!(Arc))); - assert!(validate(quote!(BTreeSet))); - assert!(validate(quote!(BTreeMap))); - assert!(validate(quote!(HashSet))); - assert!(validate(quote!(HashMap))); - assert!(validate(quote!(LinkedList))); - assert!(validate(quote!(Option))); - assert!(validate(quote!(Option>))); - assert!(validate(quote!(Rc))); - assert!(validate(quote!(Vec))); - assert!(validate(quote!(VecDeque))); - - assert!(validate(quote!(impl Validate))); - assert!(validate(quote!(impl fortifier::Validate))); - assert!(validate(quote!(dyn Validate))); - assert!(validate(quote!(dyn ::fortifier::Validate))); + // TODO: Keyed error types. + + assert_eq!( + validate(quote!(&T)), + Some(KnownOrUnknown::Known("TValidationError".to_owned())) + ); + assert_eq!( + validate(quote!(T)), + Some(KnownOrUnknown::Known("TValidationError".to_owned())) + ); + assert_eq!( + validate(quote!(T)), + Some(KnownOrUnknown::Known("TValidationError".to_owned())) + ); + assert_eq!( + validate(quote!(T)), + Some(KnownOrUnknown::Known("TValidationError".to_owned())) + ); + + assert_eq!(validate(quote!((T, T))), Some(KnownOrUnknown::Unknown)); + assert_eq!(validate(quote!((A, B, C))), Some(KnownOrUnknown::Unknown)); + + assert_eq!( + validate(quote!([T])), + Some(KnownOrUnknown::Known( + "::fortifier::IndexedValidationError".to_owned() + )) + ); + assert_eq!( + validate(quote!([T; 3])), + Some(KnownOrUnknown::Known( + "::fortifier::IndexedValidationError".to_owned() + )) + ); + assert_eq!( + validate(quote!([&T])), + Some(KnownOrUnknown::Known( + "::fortifier::IndexedValidationError".to_owned() + )) + ); + assert_eq!( + validate(quote!([&T; 3])), + Some(KnownOrUnknown::Known( + "::fortifier::IndexedValidationError".to_owned() + )) + ); + assert_eq!( + validate(quote!(&[T])), + Some(KnownOrUnknown::Known( + "::fortifier::IndexedValidationError".to_owned() + )) + ); + assert_eq!( + validate(quote!(&[T; 3])), + Some(KnownOrUnknown::Known( + "::fortifier::IndexedValidationError".to_owned() + )) + ); + + assert_eq!( + validate(quote!(Arc)), + Some(KnownOrUnknown::Known("TValidationError".to_owned())) + ); + assert_eq!( + validate(quote!(BTreeSet)), + Some(KnownOrUnknown::Known( + "::fortifier::IndexedValidationError".to_owned() + )) + ); + // assert_eq!( + // validate(quote!(BTreeMap)), + // Some(KnownOrUnknown::Known( + // "::fortifier::KeyedValidationError".to_owned() + // )) + // ); + assert_eq!( + validate(quote!(IndexSet)), + Some(KnownOrUnknown::Known( + "::fortifier::IndexedValidationError".to_owned() + )) + ); + // assert_eq!( + // validate(quote!(IndexMap)), + // Some(KnownOrUnknown::Known( + // "::fortifier::KeyedValidationError".to_owned() + // )) + // ); + assert_eq!( + validate(quote!(HashSet)), + Some(KnownOrUnknown::Known( + "::fortifier::IndexedValidationError".to_owned() + )) + ); + // assert_eq!( + // validate(quote!(HashMap)), + // Some(KnownOrUnknown::Known( + // "::fortifier::KeyedValidationError".to_owned() + // )) + // ); + assert_eq!( + validate(quote!(LinkedList)), + Some(KnownOrUnknown::Known( + "::fortifier::IndexedValidationError".to_owned() + )) + ); + assert_eq!( + validate(quote!(Option)), + Some(KnownOrUnknown::Known("TValidationError".to_owned())) + ); + assert_eq!( + validate(quote!(Option>)), + Some(KnownOrUnknown::Known("TValidationError".to_owned())) + ); + assert_eq!( + validate(quote!(Rc)), + Some(KnownOrUnknown::Known("TValidationError".to_owned())) + ); + assert_eq!( + validate(quote!(Vec)), + Some(KnownOrUnknown::Known( + "::fortifier::IndexedValidationError".to_owned() + )) + ); + assert_eq!( + validate(quote!(VecDeque)), + Some(KnownOrUnknown::Known( + "::fortifier::IndexedValidationError".to_owned() + )) + ); + + assert_eq!( + validate(quote!(impl Validate)), + Some(KnownOrUnknown::Unknown) + ); + assert_eq!( + validate(quote!(impl ValidateWithContext)), + Some(KnownOrUnknown::Unknown) + ); + assert_eq!( + validate(quote!(impl ValidateWithContext)), + Some(KnownOrUnknown::Unknown) + ); + assert_eq!( + validate(quote!(impl fortifier::Validate)), + Some(KnownOrUnknown::Unknown) + ); + assert_eq!( + validate(quote!(impl fortifier::ValidateWithContext)), + Some(KnownOrUnknown::Unknown) + ); + assert_eq!( + validate(quote!(impl fortifier::ValidateWithContext)), + Some(KnownOrUnknown::Unknown) + ); + assert_eq!( + validate(quote!(dyn Validate)), + Some(KnownOrUnknown::Unknown) + ); + assert_eq!( + validate(quote!(dyn ValidateWithContext)), + Some(KnownOrUnknown::Unknown) + ); + assert_eq!( + validate(quote!(dyn ValidateWithContext)), + Some(KnownOrUnknown::Unknown) + ); + assert_eq!( + validate(quote!(dyn ::fortifier::Validate)), + Some(KnownOrUnknown::Unknown) + ); + assert_eq!( + validate(quote!(dyn ::fortifier::ValidateWithContext)), + Some(KnownOrUnknown::Unknown) + ); + assert_eq!( + validate(quote!(dyn ::fortifier::ValidateWithContext)), + Some(KnownOrUnknown::Unknown) + ); } #[test] fn should_not_validate() { - assert!(!validate(quote!(bool))); - assert!(!validate(quote!(i8))); - assert!(!validate(quote!(i16))); - assert!(!validate(quote!(i32))); - assert!(!validate(quote!(i64))); - assert!(!validate(quote!(i128))); - assert!(!validate(quote!(isize))); - assert!(!validate(quote!(u8))); - assert!(!validate(quote!(u16))); - assert!(!validate(quote!(u32))); - assert!(!validate(quote!(u64))); - assert!(!validate(quote!(u128))); - assert!(!validate(quote!(usize))); - assert!(!validate(quote!(f32))); - assert!(!validate(quote!(f64))); - assert!(!validate(quote!(char))); - assert!(!validate(quote!(&str))); - assert!(!validate(quote!(String))); - - assert!(!validate(quote!(()))); - assert!(!validate(quote!((bool, bool)))); - assert!(!validate(quote!((usize, usize, usize)))); - assert!(!validate(quote!((usize, &str)))); - - assert!(!validate(quote!([isize]))); - assert!(!validate(quote!([&str; 3]))); - assert!(!validate(quote!(&[isize]))); - assert!(!validate(quote!(&[&str; 3]))); - - assert!(!validate(quote!(Arc<&str>))); - assert!(!validate(quote!(BTreeSet))); - assert!(!validate(quote!(BTreeMap))); - assert!(!validate(quote!(HashSet<&str>))); - assert!(!validate(quote!(HashMap<&str, &str>))); - assert!(!validate(quote!(LinkedList))); - assert!(!validate(quote!(Option))); - assert!(!validate(quote!(Option>))); - assert!(!validate(quote!(Rc<&str>))); - assert!(!validate(quote!(Vec))); - assert!(!validate(quote!(VecDeque))); - - assert!(!validate(quote!(impl Serialize))); - assert!(!validate(quote!(dyn Serialize))); + assert_eq!(validate(quote!(bool)), None); + assert_eq!(validate(quote!(i8)), None); + assert_eq!(validate(quote!(i16)), None); + assert_eq!(validate(quote!(i32)), None); + assert_eq!(validate(quote!(i64)), None); + assert_eq!(validate(quote!(i128)), None); + assert_eq!(validate(quote!(isize)), None); + assert_eq!(validate(quote!(u8)), None); + assert_eq!(validate(quote!(u16)), None); + assert_eq!(validate(quote!(u32)), None); + assert_eq!(validate(quote!(u64)), None); + assert_eq!(validate(quote!(u128)), None); + assert_eq!(validate(quote!(usize)), None); + assert_eq!(validate(quote!(f32)), None); + assert_eq!(validate(quote!(f64)), None); + assert_eq!(validate(quote!(char)), None); + assert_eq!(validate(quote!(&str)), None); + assert_eq!(validate(quote!(String)), None); + + assert_eq!(validate(quote!(())), None); + assert_eq!(validate(quote!((bool, bool))), None); + assert_eq!(validate(quote!((usize, usize, usize))), None); + assert_eq!(validate(quote!((usize, &str))), None); + + assert_eq!(validate(quote!([isize])), None); + assert_eq!(validate(quote!([&str; 3])), None); + assert_eq!(validate(quote!(&[isize])), None); + assert_eq!(validate(quote!(&[&str; 3])), None); + + assert_eq!(validate(quote!(Arc<&str>)), None); + assert_eq!(validate(quote!(BTreeSet)), None); + assert_eq!(validate(quote!(BTreeMap)), None); + assert_eq!(validate(quote!(IndexSet<&str>)), None); + assert_eq!(validate(quote!(IndexMap<&str, &str>)), None); + assert_eq!(validate(quote!(HashSet<&str>)), None); + assert_eq!(validate(quote!(HashMap<&str, &str>)), None); + assert_eq!(validate(quote!(LinkedList)), None); + assert_eq!(validate(quote!(Option)), None); + assert_eq!(validate(quote!(Option>)), None); + assert_eq!(validate(quote!(Rc<&str>)), None); + assert_eq!(validate(quote!(Vec)), None); + assert_eq!(validate(quote!(VecDeque)), None); + + assert_eq!(validate(quote!(impl Serialize)), None); + assert_eq!(validate(quote!(dyn Serialize)), None); } #[test] fn should_validate_with_generics() { - assert!(validate_with_generics( - quote!(T), - Generics { - lt_token: Default::default(), - params: Punctuated::from_iter([ - syn::parse2::(quote!(T: Validate)).expect("valid generic param") - ]), - gt_token: Default::default(), - where_clause: None - } - )); - - assert!(validate_with_generics( - quote!(T), - Generics { - lt_token: Default::default(), - params: Punctuated::from_iter([ - syn::parse2::(quote!(T)).expect("valid generic param") - ]), - gt_token: Default::default(), - where_clause: Some( - syn::parse2(quote!(where T: Validate)).expect("valid where clause") - ) - } - )); + // TODO: Output error type as `::Error` if possible. + + assert_eq!( + validate_with_generics( + quote!(T), + Generics { + lt_token: Default::default(), + params: Punctuated::from_iter([syn::parse2::( + quote!(T: Validate) + ) + .expect("valid generic param")]), + gt_token: Default::default(), + where_clause: None + } + ), + Some(KnownOrUnknown::Unknown) + ); + assert_eq!( + validate_with_generics( + quote!([T]), + Generics { + lt_token: Default::default(), + params: Punctuated::from_iter([syn::parse2::( + quote!(T: Validate) + ) + .expect("valid generic param")]), + gt_token: Default::default(), + where_clause: None + } + ), + Some(KnownOrUnknown::Unknown) + ); + assert_eq!( + validate_with_generics( + quote!(T), + Generics { + lt_token: Default::default(), + params: Punctuated::from_iter([syn::parse2::( + quote!(T: ValidateWithContext) + ) + .expect("valid generic param")]), + gt_token: Default::default(), + where_clause: None + } + ), + Some(KnownOrUnknown::Unknown) + ); + assert_eq!( + validate_with_generics( + quote!(T), + Generics { + lt_token: Default::default(), + params: Punctuated::from_iter([syn::parse2::( + quote!(T: ValidateWithContext) + ) + .expect("valid generic param")]), + gt_token: Default::default(), + where_clause: None + } + ), + Some(KnownOrUnknown::Unknown) + ); + + assert_eq!( + validate_with_generics( + quote!(T), + Generics { + lt_token: Default::default(), + params: Punctuated::from_iter([ + syn::parse2::(quote!(T)).expect("valid generic param") + ]), + gt_token: Default::default(), + where_clause: Some( + syn::parse2(quote!(where T: Validate)).expect("valid where clause") + ) + } + ), + Some(KnownOrUnknown::Unknown) + ); + assert_eq!( + validate_with_generics( + quote!([T]), + Generics { + lt_token: Default::default(), + params: Punctuated::from_iter([ + syn::parse2::(quote!(T)).expect("valid generic param") + ]), + gt_token: Default::default(), + where_clause: Some( + syn::parse2(quote!(where T: Validate)).expect("valid where clause") + ) + } + ), + Some(KnownOrUnknown::Unknown) + ); + assert_eq!( + validate_with_generics( + quote!(T), + Generics { + lt_token: Default::default(), + params: Punctuated::from_iter([ + syn::parse2::(quote!(T)).expect("valid generic param") + ]), + gt_token: Default::default(), + where_clause: Some( + syn::parse2(quote!(where T: ValidateWithContext)) + .expect("valid where clause") + ) + } + ), + Some(KnownOrUnknown::Unknown) + ); + assert_eq!( + validate_with_generics( + quote!(T), + Generics { + lt_token: Default::default(), + params: Punctuated::from_iter([ + syn::parse2::(quote!(T)).expect("valid generic param") + ]), + gt_token: Default::default(), + where_clause: Some( + syn::parse2(quote!(where T: ValidateWithContext)) + .expect("valid where clause") + ) + } + ), + Some(KnownOrUnknown::Unknown) + ); } #[test] fn should_not_validate_with_generics() { - assert!(!validate_with_generics( - quote!(T), - Generics { - lt_token: Default::default(), - params: Punctuated::from_iter([ - syn::parse2::(quote!(T: PartialEq)).expect("valid generic param") - ]), - gt_token: Default::default(), - where_clause: None - } - )); - - assert!(!validate_with_generics( - quote!(T), - Generics { - lt_token: Default::default(), - params: Punctuated::from_iter([ - syn::parse2::(quote!(T)).expect("valid generic param") - ]), - gt_token: Default::default(), - where_clause: Some( - syn::parse2(quote!(where T: PartialEq)).expect("valid where clause") - ) - } - )); + assert_eq!( + validate_with_generics( + quote!(T), + Generics { + lt_token: Default::default(), + params: Punctuated::from_iter([syn::parse2::( + quote!(T: Serialize) + ) + .expect("valid generic param")]), + gt_token: Default::default(), + where_clause: None + } + ), + None + ); + + assert_eq!( + validate_with_generics( + quote!(T), + Generics { + lt_token: Default::default(), + params: Punctuated::from_iter([ + syn::parse2::(quote!(T)).expect("valid generic param") + ]), + gt_token: Default::default(), + where_clause: Some( + syn::parse2(quote!(where T: Serialize)).expect("valid where clause") + ) + } + ), + None + ); } }