├── .github └── workflows │ └── rust.yml ├── .gitignore ├── Cargo.toml ├── README.md ├── src ├── lib.rs ├── normalize.rs └── substitute.rs └── tests ├── large.rs ├── large2.rs ├── no_self_arg.rs └── test.rs /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Build 20 | run: cargo build --verbose 21 | - name: Run tests 22 | run: cargo test --verbose 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "min-specialization" 3 | version = "0.1.2" 4 | description = "Experimental implementation of specialization" 5 | license = "MIT" 6 | authors = ["Yasuo Ozu "] 7 | keywords = ["macros", "zerocost", "specialization"] 8 | categories = ["rust-patterns"] 9 | repository = "https://github.com/yasuo-ozu/min_specialization" 10 | edition = "2021" 11 | 12 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 13 | [lib] 14 | proc-macro = true 15 | 16 | [dependencies] 17 | proc-macro2 = "1.0" 18 | template-quote = "0.4" 19 | proc-macro-error = "1.0" 20 | derive-syn-parse = "0.1.5" 21 | 22 | [dependencies.syn] 23 | version = "2.0" 24 | features = [ "full", "derive", "printing", "extra-traits", "visit-mut", "visit"] 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # min-specialization [![Latest Version]][crates.io] [![Documentation]][docs.rs] [![GitHub Actions]][actions] 2 | 3 | [Latest Version]: https://img.shields.io/crates/v/min-specialization.svg 4 | [crates.io]: https://crates.io/crates/min-specialization 5 | [Documentation]: https://img.shields.io/docsrs/min-specialization 6 | [docs.rs]: https://docs.rs/min-specialization/latest/min-specialization/ 7 | [GitHub Actions]: https://github.com/yasuo-ozu/min_specialization/actions/workflows/rust.yml/badge.svg 8 | [actions]: https://github.com/yasuo-ozu/min_specialization/actions/workflows/rust.yml 9 | 10 | Rust's specialization feature allows you to provide a default implementation of a trait for generic types and then specialize it for specific types. This feature is currently unstable and only available on the nightly version of Rust. 11 | 12 | This crate emulates Rust's `#[feature(min_specialization)]` unstable feature on stable Rust. 13 | 14 | # Example 15 | 16 | ``` 17 | # use min_specialization::specialization; 18 | #[specialization] 19 | mod inner { 20 | #[allow(unused)] 21 | trait Trait { 22 | type Ty; 23 | fn number(_: U) -> Self::Ty; 24 | } 25 | 26 | impl Trait for T { 27 | type Ty = usize; 28 | default fn number(_: U) -> Self::Ty { 29 | 0 30 | } 31 | } 32 | 33 | impl Trait for () { 34 | fn number(_: U) -> Self::Ty { 35 | 1 36 | } 37 | } 38 | } 39 | ``` 40 | 41 | see `tests` for more. 42 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![doc = include_str!("../README.md")] 2 | 3 | mod normalize; 4 | mod substitute; 5 | 6 | use normalize::WherePredicateBinding; 7 | use proc_macro::TokenStream as TokenStream1; 8 | use proc_macro2::{Span, TokenStream}; 9 | use proc_macro_error::{abort, proc_macro_error}; 10 | use std::collections::{HashMap, HashSet}; 11 | use substitute::{Substitute, SubstituteEnvironment}; 12 | use syn::punctuated::Punctuated; 13 | use syn::spanned::Spanned; 14 | use syn::visit_mut::VisitMut; 15 | use syn::*; 16 | use template_quote::quote; 17 | 18 | fn replace_type_of_trait_item_fn(mut ty: TraitItemFn, from: &Type, to: &Type) -> TraitItemFn { 19 | use syn::visit_mut::VisitMut; 20 | struct Visitor<'a>(&'a Type, &'a Type); 21 | impl<'a> VisitMut for Visitor<'a> { 22 | fn visit_type_mut(&mut self, ty: &mut Type) { 23 | if &ty == &self.0 { 24 | *ty = self.1.clone(); 25 | } 26 | syn::visit_mut::visit_type_mut(self, ty) 27 | } 28 | } 29 | let mut visitor = Visitor(from, to); 30 | visitor.visit_trait_item_fn_mut(&mut ty); 31 | ty 32 | } 33 | 34 | fn check_defaultness(item_impl: &ItemImpl) -> Option { 35 | let mut ret = false; 36 | // does not support impl-level default keyword 37 | if item_impl.defaultness.is_some() { 38 | return None; 39 | } 40 | for item in item_impl.items.iter() { 41 | match item { 42 | ImplItem::Const(item_const) if item_const.defaultness.is_some() => { 43 | return None; 44 | } 45 | ImplItem::Fn(item_method) if item_method.defaultness.is_some() => { 46 | ret = true; 47 | } 48 | ImplItem::Type(item_type) if item_type.defaultness.is_some() => { 49 | return None; 50 | } 51 | _ => (), 52 | } 53 | } 54 | Some(ret) 55 | } 56 | 57 | fn normalize_params_and_predicates( 58 | impl_: &ItemImpl, 59 | ) -> (HashSet, HashSet) { 60 | let (mut gps, mut wps) = (HashSet::new(), HashSet::new()); 61 | for gp in impl_.generics.params.iter() { 62 | let (gp, nwps) = normalize::normalize_generic_param(gp.clone()); 63 | gps.insert(gp); 64 | wps.extend(nwps); 65 | } 66 | if let Some(wc) = &impl_.generics.where_clause { 67 | for p in wc.predicates.iter() { 68 | let nwps = normalize::normalize_where_predicate(p.clone()); 69 | wps.extend(nwps); 70 | } 71 | } 72 | (gps, wps) 73 | } 74 | 75 | fn get_param_ident(p: GenericParam) -> Option { 76 | match p { 77 | GenericParam::Type(tp) => Some(tp.ident), 78 | _ => None, 79 | } 80 | } 81 | 82 | fn get_type_ident(ty: Type) -> Option { 83 | match ty { 84 | Type::Path(tp) if tp.qself.is_none() => tp.path.get_ident().cloned(), 85 | _ => None, 86 | } 87 | } 88 | 89 | fn find_type_ident(ty: &Type, ident: &Ident) -> bool { 90 | use syn::visit::Visit; 91 | struct Visitor<'a>(&'a Ident, bool); 92 | impl<'ast, 'a> Visit<'ast> for Visitor<'a> { 93 | fn visit_type(&mut self, i: &'ast Type) { 94 | match i { 95 | Type::Path(tp) if tp.qself.is_none() && tp.path.get_ident() == Some(&self.0) => { 96 | self.1 = true; 97 | } 98 | _ => { 99 | syn::visit::visit_type(self, i); 100 | } 101 | } 102 | } 103 | } 104 | let mut vis = Visitor(ident, false); 105 | vis.visit_type(ty); 106 | vis.1 107 | } 108 | 109 | fn get_trivial_substitutions( 110 | special_params: &HashSet, 111 | substitution: &HashMap, 112 | ) -> Vec<(Ident, Ident)> { 113 | substitution 114 | .iter() 115 | .filter_map(|(d, s)| { 116 | get_type_ident(s.clone()) 117 | .and_then(|i| special_params.iter().find(|ii| &&i == ii).cloned()) 118 | .map(|s| (d.clone(), s)) 119 | }) 120 | .collect() 121 | } 122 | 123 | fn substitute_impl( 124 | default_impl: &ItemImpl, 125 | special_impl: &ItemImpl, 126 | ) -> Vec<(HashMap, usize)> { 127 | let (d_ps, d_ws) = normalize_params_and_predicates(default_impl); 128 | let (s_ps, s_ws) = normalize_params_and_predicates(special_impl); 129 | // Remove `Self` type 130 | let self_ident = Ident::new("Self", Span::call_site()); 131 | let d_ws = d_ws 132 | .into_iter() 133 | .map(|w| { 134 | w.replace_type_params( 135 | core::iter::once((self_ident.clone(), default_impl.self_ty.as_ref().clone())) 136 | .collect(), 137 | ) 138 | }) 139 | .collect::>(); 140 | let s_ws = s_ws 141 | .into_iter() 142 | .map(|w| { 143 | w.replace_type_params( 144 | core::iter::once((self_ident.clone(), special_impl.self_ty.as_ref().clone())) 145 | .collect(), 146 | ) 147 | }) 148 | .collect::>(); 149 | let s_ps: HashSet<_> = s_ps.into_iter().filter_map(get_param_ident).collect(); 150 | let env = SubstituteEnvironment { 151 | general_params: d_ps.into_iter().filter_map(get_param_ident).collect(), 152 | }; 153 | let s = env.substitute(&d_ws, &s_ws) 154 | * env.substitute( 155 | &default_impl.trait_.as_ref().unwrap().1, 156 | &special_impl.trait_.as_ref().unwrap().1, 157 | ) 158 | * env.substitute(&*default_impl.self_ty, &*special_impl.self_ty); 159 | // Filter substitutions, which has parameters in replacement 160 | s.0.into_iter() 161 | .filter(|m| { 162 | m.iter().all(|(_, ty)| { 163 | s_ps.iter().all(|i| { 164 | &get_type_ident(ty.clone()).as_ref() == &Some(i) || !find_type_ident(ty, &i) 165 | }) 166 | }) 167 | }) 168 | .map(|r| { 169 | ( 170 | r.clone(), 171 | r.len() - get_trivial_substitutions(&s_ps, &r).len(), 172 | ) 173 | }) 174 | .collect() 175 | } 176 | 177 | trait ReplaceTypeParams { 178 | fn replace_type_params(self, map: HashMap) -> Self; 179 | } 180 | 181 | const _: () = { 182 | fn filter_map_with_generics( 183 | map: &HashMap, 184 | generics: &Generics, 185 | ) -> HashMap { 186 | map.clone() 187 | .into_iter() 188 | .filter(|(k, _)| { 189 | generics 190 | .params 191 | .iter() 192 | .filter_map(|o| { 193 | if let GenericParam::Type(pt) = o { 194 | Some(&pt.ident) 195 | } else { 196 | None 197 | } 198 | }) 199 | .all(|id| k != id) 200 | }) 201 | .collect() 202 | } 203 | #[derive(Clone)] 204 | struct Visitor(HashMap); 205 | impl VisitMut for Visitor { 206 | fn visit_type_mut(&mut self, i: &mut Type) { 207 | if let Type::Path(tp) = i { 208 | if let Some(id) = tp.path.get_ident() { 209 | if let Some(replaced) = self.0.get(id) { 210 | *i = replaced.clone(); 211 | return; 212 | } 213 | } 214 | } 215 | syn::visit_mut::visit_type_mut(self, i) 216 | } 217 | fn visit_item_fn_mut(&mut self, i: &mut ItemFn) { 218 | let mut this = Visitor(filter_map_with_generics(&self.0, &i.sig.generics)); 219 | syn::visit_mut::visit_item_fn_mut(&mut this, i); 220 | } 221 | fn visit_item_impl_mut(&mut self, i: &mut ItemImpl) { 222 | let mut this = Visitor(filter_map_with_generics(&self.0, &i.generics)); 223 | syn::visit_mut::visit_item_impl_mut(&mut this, i); 224 | } 225 | fn visit_item_trait_mut(&mut self, i: &mut ItemTrait) { 226 | let mut this = Visitor(filter_map_with_generics(&self.0, &i.generics)); 227 | syn::visit_mut::visit_item_trait_mut(&mut this, i); 228 | } 229 | fn visit_item_struct_mut(&mut self, i: &mut ItemStruct) { 230 | let mut this = Visitor(filter_map_with_generics(&self.0, &i.generics)); 231 | syn::visit_mut::visit_item_struct_mut(&mut this, i); 232 | } 233 | fn visit_item_enum_mut(&mut self, i: &mut ItemEnum) { 234 | let mut this = Visitor(filter_map_with_generics(&self.0, &i.generics)); 235 | syn::visit_mut::visit_item_enum_mut(&mut this, i); 236 | } 237 | fn visit_item_type_mut(&mut self, i: &mut ItemType) { 238 | let mut this = Visitor(filter_map_with_generics(&self.0, &i.generics)); 239 | syn::visit_mut::visit_item_type_mut(&mut this, i); 240 | } 241 | fn visit_item_union_mut(&mut self, i: &mut ItemUnion) { 242 | let mut this = Visitor(filter_map_with_generics(&self.0, &i.generics)); 243 | syn::visit_mut::visit_item_union_mut(&mut this, i); 244 | } 245 | } 246 | 247 | impl ReplaceTypeParams for WherePredicateBinding { 248 | fn replace_type_params(self, map: HashMap) -> Self { 249 | match self { 250 | WherePredicateBinding::Lifetime(lt) => { 251 | WherePredicateBinding::Lifetime(lt.replace_type_params(map)) 252 | } 253 | WherePredicateBinding::Type(pt) => { 254 | WherePredicateBinding::Type(pt.replace_type_params(map)) 255 | } 256 | WherePredicateBinding::Eq { 257 | lhs_ty, 258 | eq_token, 259 | rhs_ty, 260 | } => WherePredicateBinding::Eq { 261 | lhs_ty: lhs_ty.replace_type_params(map.clone()), 262 | eq_token, 263 | rhs_ty: rhs_ty.replace_type_params(map), 264 | }, 265 | } 266 | } 267 | } 268 | impl ReplaceTypeParams for PredicateType { 269 | fn replace_type_params(mut self, map: HashMap) -> Self { 270 | let mut visitor = Visitor(map); 271 | visitor.visit_predicate_type_mut(&mut self); 272 | self 273 | } 274 | } 275 | impl ReplaceTypeParams for PredicateLifetime { 276 | fn replace_type_params(mut self, map: HashMap) -> Self { 277 | let mut visitor = Visitor(map); 278 | visitor.visit_predicate_lifetime_mut(&mut self); 279 | self 280 | } 281 | } 282 | impl ReplaceTypeParams for ImplItemFn { 283 | fn replace_type_params(mut self, map: HashMap) -> Self { 284 | let mut visitor = Visitor(map); 285 | visitor.visit_impl_item_fn_mut(&mut self); 286 | self 287 | } 288 | } 289 | impl ReplaceTypeParams for Type { 290 | fn replace_type_params(mut self, map: HashMap) -> Self { 291 | let mut visitor = Visitor(map); 292 | visitor.visit_type_mut(&mut self); 293 | self 294 | } 295 | } 296 | }; 297 | 298 | fn contains_generics_param(param: &GenericParam, ty: &Type) -> bool { 299 | use syn::visit::Visit; 300 | struct Visitor<'a>(&'a GenericParam, bool); 301 | impl<'ast, 'a> Visit<'ast> for Visitor<'a> { 302 | fn visit_lifetime(&mut self, i: &Lifetime) { 303 | if matches!(&self.0, GenericParam::Lifetime(l) if &l.lifetime == i) { 304 | self.1 = true; 305 | } 306 | } 307 | fn visit_type_path(&mut self, i: &TypePath) { 308 | if matches!( 309 | (&self.0, &i.qself, i.path.get_ident()), 310 | (GenericParam::Type(TypeParam {ident, ..}), &None, Some(id)) | 311 | (GenericParam::Const(ConstParam {ident, ..}), &None, Some(id)) 312 | if ident == id 313 | ) { 314 | self.1 = true; 315 | } else { 316 | syn::visit::visit_type_path(self, i) 317 | } 318 | } 319 | } 320 | let mut visitor = Visitor(param, false); 321 | visitor.visit_type(ty); 322 | visitor.1 323 | } 324 | 325 | fn specialize_item_fn_trait( 326 | impl_: &ItemImpl, 327 | ident: &Ident, 328 | fn_ident: &Ident, 329 | impl_item_fn: &ImplItemFn, 330 | needs_sized_bound: bool, 331 | self_ty: &Type, 332 | ) -> (TokenStream, Punctuated) { 333 | let trait_path = &impl_.trait_.as_ref().unwrap().1; 334 | let impl_generics: Punctuated<_, Token![,]> = impl_ 335 | .generics 336 | .params 337 | .iter() 338 | .filter(|p| { 339 | contains_generics_param( 340 | p, 341 | &Type::Path(TypePath { 342 | qself: None, 343 | path: trait_path.clone(), 344 | }), 345 | ) || contains_generics_param(p, self_ty) 346 | }) 347 | .cloned() 348 | .collect(); 349 | let ty_generics: Punctuated<_, Token![,]> = impl_ 350 | .generics 351 | .params 352 | .iter() 353 | .filter(|p| { 354 | contains_generics_param( 355 | p, 356 | &Type::Path(TypePath { 357 | qself: None, 358 | path: trait_path.clone(), 359 | }), 360 | ) 361 | }) 362 | .map(|p| { 363 | let mut p = p.clone(); 364 | match &mut p { 365 | GenericParam::Lifetime(p) => { 366 | p.attrs = Vec::new(); 367 | p.colon_token = None; 368 | p.bounds = Punctuated::new(); 369 | } 370 | GenericParam::Type(t) => { 371 | t.attrs = Vec::new(); 372 | t.colon_token = None; 373 | t.bounds = Punctuated::new(); 374 | t.eq_token = None; 375 | t.default = None; 376 | } 377 | GenericParam::Const(c) => { 378 | c.attrs = Vec::new(); 379 | c.eq_token = None; 380 | c.default = None; 381 | } 382 | } 383 | p 384 | }) 385 | .collect(); 386 | let mut item_fn = replace_type_of_trait_item_fn( 387 | TraitItemFn { 388 | attrs: vec![], 389 | sig: impl_item_fn.sig.clone(), 390 | default: None, 391 | semi_token: Some(Default::default()), 392 | }, 393 | &impl_.self_ty, 394 | &parse_quote!(Self), 395 | ); 396 | item_fn.sig.ident = fn_ident.clone(); 397 | let mut impl_item_fn = impl_item_fn.clone(); 398 | impl_item_fn.defaultness = None; 399 | impl_item_fn.sig.ident = fn_ident.clone(); 400 | let out = quote! { 401 | trait #ident<#ty_generics>: #trait_path 402 | #(if needs_sized_bound) { + ::core::marker::Sized } 403 | { 404 | #item_fn 405 | } 406 | impl<#impl_generics> #ident<#ty_generics> for #self_ty 407 | #{&impl_.generics.where_clause} 408 | { 409 | #impl_item_fn 410 | } 411 | }; 412 | (out, ty_generics) 413 | } 414 | 415 | fn set_argument_named(sig: &mut Signature) { 416 | for (n, arg) in sig.inputs.iter_mut().enumerate() { 417 | if let FnArg::Typed(PatType { pat, .. }) = arg { 418 | if let Pat::Wild(_) = &**pat { 419 | *pat = Box::new(Pat::Ident(PatIdent { 420 | attrs: Vec::new(), 421 | by_ref: None, 422 | mutability: None, 423 | ident: Ident::new(&format!("_min_specialization_v{}", n), pat.span()), 424 | subpat: None, 425 | })); 426 | } 427 | } 428 | } 429 | } 430 | 431 | fn specialize_item_fn( 432 | default_impl: &ItemImpl, 433 | mut ifn: ImplItemFn, 434 | specials: Vec<(HashMap, ItemImpl, ImplItemFn)>, 435 | needs_sized_bound: bool, 436 | ) -> ImplItemFn { 437 | let itrait_name = Ident::new("__MinSpecialization_InnerTrait", Span::call_site()); 438 | let ifn_name = Ident::new("__min_specialization__inner_fn", Span::call_site()); 439 | set_argument_named(&mut ifn.sig); 440 | let specials_out = specials 441 | .into_iter() 442 | .enumerate() 443 | .map(|(n, (m, simpl, mut sfn))| { 444 | let strait_name = Ident::new( 445 | &format!("__MinSpecialization_InnerTrait_{}", n), 446 | Span::call_site(), 447 | ); 448 | let sfn_name = Ident::new( 449 | &format!("__min_specialization__inner_fn_{}", n), 450 | Span::call_site(), 451 | ); 452 | sfn.sig.ident = sfn_name.clone(); 453 | let mut condition = quote! {true}; 454 | let mut replacement = HashMap::new(); 455 | for (lhs, rhs) in m.iter() { 456 | if let Some(rhs) = get_type_ident(rhs.clone()) { 457 | if simpl 458 | .generics 459 | .params 460 | .iter() 461 | .filter_map(|p| { 462 | if let GenericParam::Type(p) = p { 463 | Some(&p.ident) 464 | } else { 465 | None 466 | } 467 | }) 468 | .any(|p| p == &rhs) 469 | { 470 | let lhs = Type::Path(TypePath { 471 | qself: None, 472 | path: Path { 473 | leading_colon: None, 474 | segments: Some(PathSegment { 475 | ident: lhs.clone(), 476 | arguments: PathArguments::None, 477 | }) 478 | .into_iter() 479 | .collect(), 480 | }, 481 | }); 482 | replacement.insert(rhs, lhs); 483 | continue; 484 | } 485 | } 486 | condition.extend(quote! { 487 | && __min_specialization_id::<#lhs> as *const () 488 | == __min_specialization_id::<#rhs> as *const () 489 | }); 490 | } 491 | let sfn = sfn.replace_type_params(replacement.clone()); 492 | let replaced_self_ty = default_impl.self_ty.clone().replace_type_params(m.clone()); 493 | let (special_trait_impl, special_trait_params) = specialize_item_fn_trait( 494 | default_impl, 495 | &strait_name, 496 | &sfn_name, 497 | &sfn, 498 | needs_sized_bound, 499 | &replaced_self_ty, 500 | ); 501 | quote! { 502 | if #condition { 503 | #special_trait_impl 504 | __min_specialization_transmute( 505 | <#replaced_self_ty as #strait_name< 506 | #(for par in &special_trait_params), { 507 | #(if let GenericParam::Type(TypeParam{ident, ..}) = par) { 508 | #(if let Some(ident) = replacement.get(ident)) { 509 | #ident 510 | } 511 | #(else) { 512 | #ident 513 | } 514 | } #(else) { 515 | #par 516 | } 517 | } 518 | >>::#sfn_name( 519 | #(for arg in &ifn.sig.inputs), { 520 | #(if let FnArg::Receiver(_) = arg) { 521 | __min_specialization_transmute(self) 522 | } 523 | #(if let FnArg::Typed(pt) = arg) { 524 | __min_specialization_transmute(#{&pt.pat}) 525 | } 526 | } 527 | ) 528 | ) 529 | } else 530 | } 531 | }) 532 | .collect::>(); 533 | let (default_trait_impl, default_trait_params) = specialize_item_fn_trait( 534 | default_impl, 535 | &itrait_name, 536 | &ifn_name, 537 | &ifn, 538 | needs_sized_bound, 539 | &default_impl.self_ty, 540 | ); 541 | let inner = quote! { 542 | #(for attr in &ifn.attrs) {#attr} 543 | #{&ifn.vis} 544 | #{&ifn.sig} 545 | { 546 | fn __min_specialization_id(input: &T) -> ! { 547 | unsafe { 548 | let _ = ::core::mem::MaybeUninit::new( 549 | ::core::ptr::read_volatile(input as *const _) 550 | ); 551 | } 552 | ::core::panic!() 553 | } 554 | fn __min_specialization_transmute(input: T) -> U { 555 | ::core::assert_eq!( 556 | ::core::mem::size_of::(), 557 | ::core::mem::size_of::() 558 | ); 559 | ::core::assert_eq!( 560 | ::core::mem::align_of::(), 561 | ::core::mem::align_of::() 562 | ); 563 | let mut rhs = ::core::mem::MaybeUninit::new(input); 564 | let mut lhs = ::core::mem::MaybeUninit::::uninit(); 565 | unsafe { 566 | let rhs = ::core::mem::transmute::< 567 | _, &mut ::core::mem::MaybeUninit 568 | >(&mut rhs); 569 | ::core::ptr::swap(lhs.as_mut_ptr(), rhs.as_mut_ptr()); 570 | lhs.assume_init() 571 | } 572 | } 573 | #( #specials_out)* 574 | { 575 | #default_trait_impl 576 | <#{&default_impl.self_ty} as #itrait_name<#default_trait_params>>::#ifn_name( 577 | #(for arg in &ifn.sig.inputs),{ 578 | #(if let FnArg::Receiver(Receiver{self_token, ..}) = arg) { 579 | #self_token 580 | } 581 | #(if let FnArg::Typed(PatType{pat, ..}) = arg) { 582 | #pat 583 | } 584 | } 585 | ) 586 | } 587 | } 588 | }; 589 | parse2(inner).unwrap() 590 | } 591 | 592 | fn check_needs_sized_bound(impl_: &ItemImpl) -> bool { 593 | impl_ 594 | .items 595 | .iter() 596 | .filter_map(|item| { 597 | if let ImplItem::Fn(item) = item { 598 | Some(item) 599 | } else { 600 | None 601 | } 602 | }) 603 | .any(|item| { 604 | item.sig 605 | .inputs 606 | .iter() 607 | .filter_map(|item| { 608 | if let FnArg::Typed(PatType { ty, .. }) = item { 609 | Some(&*ty) 610 | } else { 611 | None 612 | } 613 | }) 614 | .chain(if let ReturnType::Type(_, ty) = &item.sig.output { 615 | Some(&*ty) 616 | } else { 617 | None 618 | }) 619 | .any(|ty| ty == &impl_.self_ty || ty == &parse_quote!(Self)) 620 | }) 621 | } 622 | 623 | fn specialize_impl( 624 | mut default_impl: ItemImpl, 625 | special_impls: Vec<(ItemImpl, HashMap)>, 626 | ) -> ItemImpl { 627 | if special_impls.len() == 0 { 628 | return default_impl; 629 | } 630 | let needs_sized_bound = check_needs_sized_bound(&default_impl); 631 | let mut fn_map = HashMap::new(); 632 | for (simpl, ssub) in special_impls.into_iter() { 633 | for item in simpl.items.iter() { 634 | match item { 635 | ImplItem::Fn(ifn) => { 636 | fn_map 637 | .entry(ifn.sig.ident.clone()) 638 | .or_insert(Vec::new()) 639 | .push((ssub.clone(), simpl.clone(), ifn.clone())); 640 | } 641 | o => abort!(o.span(), "This item cannot be specialized"), 642 | } 643 | } 644 | } 645 | let mut out = Vec::new(); 646 | for item in &default_impl.items { 647 | match item { 648 | ImplItem::Fn(ifn) => { 649 | let specials = fn_map.get(&ifn.sig.ident).cloned().unwrap_or(Vec::new()); 650 | out.push(ImplItem::Fn(specialize_item_fn( 651 | &default_impl, 652 | ifn.clone(), 653 | specials, 654 | needs_sized_bound, 655 | ))); 656 | } 657 | o => out.push(o.clone()), 658 | } 659 | } 660 | default_impl.items = out; 661 | default_impl 662 | } 663 | 664 | fn specialize_trait( 665 | default_impls: HashSet, 666 | special_impls: HashSet, 667 | ) -> (Vec, Vec) { 668 | let mut default_map: HashMap<_, _> = default_impls 669 | .iter() 670 | .cloned() 671 | .map(|d| (d, Vec::new())) 672 | .collect(); 673 | let mut orphan_impls = Vec::new(); 674 | for s in special_impls.into_iter() { 675 | if let Some((d, a, _)) = default_impls 676 | .iter() 677 | .map(|d| { 678 | substitute_impl(d, &s) 679 | .into_iter() 680 | .map(move |(sub, n)| (d, sub, n)) 681 | }) 682 | .flatten() 683 | .min_by_key(|(_, _, n)| *n) 684 | { 685 | default_map 686 | .entry(d.clone()) 687 | .or_insert_with(|| unreachable!()) 688 | .push((s, a)); 689 | } else { 690 | orphan_impls.push(s); 691 | } 692 | } 693 | ( 694 | default_map 695 | .into_iter() 696 | .map(|(d, s)| specialize_impl(d, s)) 697 | .collect(), 698 | orphan_impls, 699 | ) 700 | } 701 | 702 | fn specialization_mod(module: ItemMod) -> TokenStream { 703 | let (_, content) = if let Some(inner) = module.content { 704 | inner 705 | } else { 706 | abort!(module.span(), "Require mod content") 707 | }; 708 | let (mut defaults, mut specials): (HashSet<_>, HashSet<_>) = Default::default(); 709 | let mut generated_content = Vec::new(); 710 | for item in content.into_iter() { 711 | if let Item::Impl(item_impl) = &item { 712 | if item_impl.trait_.is_some() { 713 | if let Some(defaultness) = check_defaultness(&item_impl) { 714 | if defaultness { 715 | defaults.insert(item_impl.clone()); 716 | } else { 717 | specials.insert(item_impl.clone()); 718 | } 719 | continue; 720 | } 721 | } 722 | } 723 | generated_content.push(item); 724 | } 725 | let (impls, orphans) = specialize_trait(defaults, specials); 726 | generated_content.extend(impls.into_iter().map(Item::Impl)); 727 | generated_content.extend(orphans.into_iter().map(Item::Impl)); 728 | 729 | quote! { 730 | #(for attr in &module.attrs) { #attr } 731 | #{&module.vis} 732 | #{&module.mod_token} 733 | #{&module.ident} 734 | { 735 | #(#generated_content)* 736 | } 737 | } 738 | } 739 | 740 | #[proc_macro_error] 741 | #[proc_macro_attribute] 742 | pub fn specialization(_attr: TokenStream1, input: TokenStream1) -> TokenStream1 { 743 | let module = parse_macro_input!(input); 744 | specialization_mod(module).into() 745 | } 746 | -------------------------------------------------------------------------------- /src/normalize.rs: -------------------------------------------------------------------------------- 1 | use syn::punctuated::Punctuated; 2 | use syn::*; 3 | 4 | #[derive(Debug, Clone, PartialEq, Eq, Hash)] 5 | pub enum WherePredicateBinding { 6 | Lifetime(PredicateLifetime), 7 | Type(PredicateType), 8 | Eq { 9 | lhs_ty: Type, 10 | eq_token: Token![=], 11 | rhs_ty: Type, 12 | }, 13 | } 14 | 15 | pub fn normalize_where_predicate(pred: WherePredicate) -> Vec { 16 | let mut ret = Vec::new(); 17 | match pred { 18 | WherePredicate::Type(pt) => { 19 | for mut bound in pt.bounds.into_iter() { 20 | let additives = if let TypeParamBound::Trait(tb) = &mut bound { 21 | remove_path_predicates(&mut tb.path) 22 | } else { 23 | Vec::new() 24 | }; 25 | for (bound_trait, bindings, constraints) in additives.into_iter() { 26 | for AssocType { 27 | ident, 28 | ty, 29 | eq_token, 30 | .. 31 | } in bindings.into_iter() 32 | { 33 | let mut path = bound_trait.clone(); 34 | path.segments.push(ident.into()); 35 | let lhs_ty = Type::Path(TypePath { 36 | qself: Some(QSelf { 37 | lt_token: Default::default(), 38 | ty: Box::new(pt.bounded_ty.clone()), 39 | position: bound_trait.segments.len() - 1, 40 | as_token: Default::default(), 41 | gt_token: Default::default(), 42 | }), 43 | path, 44 | }); 45 | ret.push(WherePredicateBinding::Eq { 46 | lhs_ty, 47 | eq_token, 48 | rhs_ty: ty, 49 | }); 50 | } 51 | for Constraint { 52 | ident, 53 | bounds, 54 | colon_token, 55 | .. 56 | } in constraints.into_iter() 57 | { 58 | for bound in bounds.into_iter() { 59 | let mut path = bound_trait.clone(); 60 | path.segments.push(ident.clone().into()); 61 | let bounded_ty = Type::Path(TypePath { 62 | qself: Some(QSelf { 63 | lt_token: Default::default(), 64 | ty: Box::new(pt.bounded_ty.clone()), 65 | position: bound_trait.segments.len() - 1, 66 | as_token: Default::default(), 67 | gt_token: Default::default(), 68 | }), 69 | path, 70 | }); 71 | ret.push(WherePredicateBinding::Type(PredicateType { 72 | lifetimes: Default::default(), 73 | bounded_ty, 74 | colon_token: colon_token.clone(), 75 | bounds: Some(bound).into_iter().collect(), 76 | })); 77 | } 78 | } 79 | } 80 | ret.push(WherePredicateBinding::Type(PredicateType { 81 | lifetimes: pt.lifetimes.clone(), 82 | bounded_ty: pt.bounded_ty.clone(), 83 | colon_token: pt.colon_token.clone(), 84 | bounds: Some(bound.clone()).into_iter().collect(), 85 | })); 86 | } 87 | } 88 | WherePredicate::Lifetime(lt) => { 89 | for bound in lt.bounds.iter() { 90 | ret.push(WherePredicateBinding::Lifetime(PredicateLifetime { 91 | lifetime: lt.lifetime.clone(), 92 | colon_token: lt.colon_token.clone(), 93 | bounds: Some(bound.clone()).into_iter().collect(), 94 | })); 95 | } 96 | } // WherePredicate::Eq(pe) => { 97 | // ret.push(WherePredicate::Eq(PredicateEq { 98 | // lhs_ty: pe.lhs_ty, 99 | // eq_token: pe.eq_token, 100 | // rhs_ty: pe.rhs_ty, 101 | // })); 102 | // } 103 | _ => panic!(), 104 | } 105 | ret 106 | } 107 | 108 | pub fn normalize_generic_param(param: GenericParam) -> (GenericParam, Vec) { 109 | match param { 110 | GenericParam::Type(mut pt) => { 111 | let mut preds = Vec::new(); 112 | if let Some(colon_token) = pt.colon_token { 113 | for mut bound in pt.bounds.into_iter() { 114 | if let TypeParamBound::Trait(tb) = &mut bound { 115 | let ret = remove_path_predicates(&mut tb.path); 116 | assert!(ret.is_empty()); 117 | } 118 | preds.push(WherePredicateBinding::Type(PredicateType { 119 | lifetimes: None, 120 | bounded_ty: Type::Path(TypePath { 121 | path: pt.ident.clone().into(), 122 | qself: None, 123 | }), 124 | colon_token: colon_token.clone(), 125 | bounds: Some(bound).into_iter().collect(), 126 | })); 127 | } 128 | } 129 | pt.colon_token = None; 130 | pt.bounds = Punctuated::new(); 131 | (GenericParam::Type(pt), preds) 132 | } 133 | GenericParam::Lifetime(mut pl) => { 134 | let mut preds = Vec::new(); 135 | if let Some(colon_token) = pl.colon_token { 136 | for bound in pl.bounds.into_iter() { 137 | preds.push(WherePredicateBinding::Lifetime(PredicateLifetime { 138 | lifetime: pl.lifetime.clone(), 139 | colon_token: colon_token.clone(), 140 | bounds: Some(bound).into_iter().collect(), 141 | })); 142 | } 143 | } 144 | pl.colon_token = None; 145 | pl.bounds = Punctuated::new(); 146 | (GenericParam::Lifetime(pl), preds) 147 | } 148 | o => (o, Vec::new()), 149 | } 150 | } 151 | 152 | fn remove_path_predicates(path: &mut Path) -> Vec<(Path, Vec, Vec)> { 153 | trait Take: Sized { 154 | fn take_owned(&mut self, closure: impl FnOnce(Self) -> Self) -> &mut Self; 155 | } 156 | 157 | impl Take for T { 158 | fn take_owned(&mut self, closure: impl FnOnce(Self) -> Self) -> &mut Self { 159 | use core::ptr; 160 | use std::panic; 161 | 162 | unsafe { 163 | let oldval = ptr::read(self); 164 | let newval = panic::catch_unwind(panic::AssertUnwindSafe(|| closure(oldval))) 165 | .unwrap_or_else(|_| ::std::process::abort()); 166 | ptr::write(self, newval); 167 | } 168 | self 169 | } 170 | } 171 | struct PathVisitor; 172 | use syn::visit_mut::VisitMut; 173 | 174 | impl VisitMut for PathVisitor { 175 | fn visit_path_mut(&mut self, i: &mut Path) { 176 | let ret = remove_path_predicates(i); 177 | assert!(ret.is_empty()); 178 | syn::visit_mut::visit_path_mut(self, i); 179 | } 180 | } 181 | let mut ret = Vec::new(); 182 | let mut current_path = Path { 183 | leading_colon: path.leading_colon.clone(), 184 | segments: Punctuated::new(), 185 | }; 186 | for seg in path.segments.iter_mut() { 187 | let mut bindings = Vec::new(); 188 | let mut constraints = Vec::new(); 189 | match &mut seg.arguments { 190 | PathArguments::AngleBracketed(args) => { 191 | args.args.take_owned(|args| { 192 | let mut new_args = Punctuated::new(); 193 | for mut arg in args.into_iter() { 194 | PathVisitor.visit_generic_argument_mut(&mut arg); 195 | match arg { 196 | GenericArgument::AssocType(binding) => bindings.push(binding), 197 | GenericArgument::Constraint(constraint) => constraints.push(constraint), 198 | o => new_args.push(o), 199 | } 200 | } 201 | new_args 202 | }); 203 | } 204 | o => PathVisitor.visit_path_arguments_mut(o), 205 | } 206 | current_path.segments.push(seg.clone()); 207 | if bindings.len() > 0 || constraints.len() > 0 { 208 | ret.push((current_path.clone(), bindings, constraints)); 209 | } 210 | } 211 | ret 212 | } 213 | -------------------------------------------------------------------------------- /src/substitute.rs: -------------------------------------------------------------------------------- 1 | use crate::normalize::WherePredicateBinding; 2 | use std::collections::{HashMap, HashSet}; 3 | use syn::*; 4 | 5 | // For now, the right hand side should not contais any unbounded type parameter. 6 | #[derive(Clone, Debug, PartialEq, Eq)] 7 | pub struct Substitution(pub Vec>); 8 | 9 | impl Substitution { 10 | pub fn empty() -> Self { 11 | Self(Vec::new()) 12 | } 13 | 14 | pub fn some() -> Self { 15 | Self(vec![HashMap::new()]) 16 | } 17 | 18 | pub fn new(param: Ident, ty: Type) -> Self { 19 | Self(vec![Some((param, ty)).into_iter().collect()]) 20 | } 21 | 22 | fn merge( 23 | mut lhs: HashMap, 24 | rhs: HashMap, 25 | ) -> Option> { 26 | for (param, ty) in rhs.into_iter() { 27 | if let Some(l_ty) = lhs.get(¶m) { 28 | if l_ty != &ty { 29 | return None; 30 | } 31 | } else { 32 | lhs.insert(param, ty); 33 | } 34 | } 35 | Some(lhs) 36 | } 37 | } 38 | 39 | impl core::ops::Add for Substitution { 40 | type Output = Self; 41 | 42 | fn add(mut self, rhs: Self) -> Self::Output { 43 | for record in rhs.0.into_iter() { 44 | if !self.0.contains(&record) { 45 | self.0.push(record); 46 | } 47 | } 48 | self 49 | } 50 | } 51 | 52 | impl core::ops::AddAssign for Substitution { 53 | fn add_assign(&mut self, rhs: Self) { 54 | core::mem::swap(self, &mut (self.clone() + rhs)) 55 | } 56 | } 57 | 58 | impl core::ops::Mul for Substitution { 59 | type Output = Self; 60 | 61 | fn mul(self, rhs: Self) -> Self::Output { 62 | let mut ret = Vec::new(); 63 | for l in self.0.into_iter() { 64 | for r in rhs.0.iter() { 65 | if let Some(item) = Self::merge(l.clone(), r.clone()) { 66 | ret.push(item); 67 | } 68 | } 69 | } 70 | Self(ret) 71 | } 72 | } 73 | 74 | impl core::ops::MulAssign for Substitution { 75 | fn mul_assign(&mut self, rhs: Self) { 76 | core::mem::swap(self, &mut (self.clone() * rhs)) 77 | } 78 | } 79 | 80 | impl core::iter::Product for Substitution { 81 | fn product>(iter: I) -> Self { 82 | let mut ret = Substitution::some(); 83 | for item in iter { 84 | ret *= item; 85 | } 86 | ret 87 | } 88 | } 89 | 90 | pub trait Substitute { 91 | fn substitute(&self, general: &T, special: &T) -> Substitution; 92 | } 93 | 94 | pub struct SubstituteEnvironment { 95 | pub general_params: HashSet, 96 | } 97 | 98 | impl Substitute for SubstituteEnvironment { 99 | fn substitute(&self, general: &GenericArgument, special: &GenericArgument) -> Substitution { 100 | match (general, special) { 101 | (GenericArgument::Type(g_ty), GenericArgument::Type(s_ty)) => { 102 | self.substitute(g_ty, s_ty) 103 | } 104 | (GenericArgument::AssocType(g_bind), GenericArgument::AssocType(s_bind)) => { 105 | if &g_bind.ident != &s_bind.ident { 106 | return Substitution::empty(); 107 | } 108 | self.substitute(&g_bind.ty, &s_bind.ty) 109 | } 110 | (GenericArgument::Constraint(g_ct), GenericArgument::Constraint(s_ct)) => { 111 | if &g_ct.ident != &s_ct.ident { 112 | return Substitution::empty(); 113 | } 114 | self.substitute( 115 | g_ct.bounds.iter().cloned().collect::>().as_slice(), 116 | s_ct.bounds.iter().cloned().collect::>().as_slice(), 117 | ) 118 | } 119 | (g, s) => { 120 | if g == s { 121 | Substitution::some() 122 | } else { 123 | Substitution::empty() 124 | } 125 | } 126 | } 127 | } 128 | } 129 | 130 | impl Substitute for SubstituteEnvironment { 131 | fn substitute(&self, general: &Path, special: &Path) -> Substitution { 132 | if let Some(g_ident) = general.get_ident() { 133 | if self.general_params.contains(g_ident) { 134 | return Substitution::new( 135 | g_ident.clone(), 136 | Type::Path(TypePath { 137 | qself: None, 138 | path: special.clone(), 139 | }), 140 | ); 141 | } 142 | } 143 | if &general.leading_colon != &special.leading_colon 144 | || general.segments.len() != special.segments.len() 145 | { 146 | return Substitution::empty(); 147 | } 148 | 149 | let mut subst = Substitution::some(); 150 | for (i, (g_seg, s_seg)) in general.segments.iter().zip(&special.segments).enumerate() { 151 | if i == 0 152 | && &g_seg.arguments == &PathArguments::None 153 | && self.general_params.contains(&g_seg.ident) 154 | { 155 | subst *= Substitution::new( 156 | g_seg.ident.clone(), 157 | Type::Path(TypePath { 158 | qself: None, 159 | path: Path { 160 | leading_colon: None, 161 | segments: Some(s_seg.clone()).into_iter().collect(), 162 | }, 163 | }), 164 | ); 165 | } else { 166 | if &g_seg.ident != &s_seg.ident { 167 | return Substitution::empty(); 168 | } 169 | match (&g_seg.arguments, &s_seg.arguments) { 170 | (PathArguments::None, PathArguments::None) => (), 171 | ( 172 | PathArguments::AngleBracketed(g_args), 173 | PathArguments::AngleBracketed(s_args), 174 | ) => { 175 | // TODO: consider order 176 | if g_args.args.len() != s_args.args.len() { 177 | return Substitution::empty(); 178 | } 179 | for (g_arg, s_arg) in g_args.args.iter().zip(&s_args.args) { 180 | subst *= self.substitute(g_arg, s_arg); 181 | } 182 | } 183 | ( 184 | PathArguments::Parenthesized(g_args), 185 | PathArguments::Parenthesized(s_args), 186 | ) => { 187 | let mut g_tys: Vec<_> = g_args.inputs.iter().cloned().collect(); 188 | let mut s_tys: Vec<_> = s_args.inputs.iter().cloned().collect(); 189 | match (&g_args.output, &s_args.output) { 190 | (ReturnType::Default, ReturnType::Default) => (), 191 | (ReturnType::Type(_, g_ty), ReturnType::Type(_, s_ty)) => { 192 | g_tys.push(g_ty.as_ref().clone()); 193 | s_tys.push(s_ty.as_ref().clone()); 194 | } 195 | _ => { 196 | return Substitution::empty(); 197 | } 198 | } 199 | 200 | subst *= self.substitute(g_tys.as_slice(), s_tys.as_slice()); 201 | } 202 | _ => { 203 | return Substitution::empty(); 204 | } 205 | } 206 | } 207 | } 208 | subst 209 | } 210 | } 211 | 212 | impl Substitute<[T]> for SubstituteEnvironment 213 | where 214 | Self: Substitute, 215 | { 216 | fn substitute(&self, general: &[T], special: &[T]) -> Substitution { 217 | if general.len() != special.len() { 218 | return Substitution::empty(); 219 | } 220 | general 221 | .iter() 222 | .zip(special) 223 | .map(|(g, s)| self.substitute(g, s)) 224 | .product() 225 | } 226 | } 227 | 228 | impl Substitute for SubstituteEnvironment { 229 | fn substitute(&self, general: &TraitBound, special: &TraitBound) -> Substitution { 230 | // TODO: consider lifetime order (assignment) 231 | if &general.paren_token != &special.paren_token 232 | || &general.modifier != &special.modifier 233 | || &general.lifetimes != &special.lifetimes 234 | { 235 | return Substitution::empty(); 236 | } 237 | self.substitute(&general.path, &special.path) 238 | } 239 | } 240 | 241 | fn substitute_by_set( 242 | env: &SubstituteEnvironment, 243 | general: &HashSet, 244 | special: &HashSet, 245 | ) -> Substitution 246 | where 247 | T: PartialEq + Eq + core::hash::Hash, 248 | SubstituteEnvironment: Substitute, 249 | { 250 | let mut subst = Substitution::some(); 251 | for g in general.iter() { 252 | let mut next_subst = Substitution::empty(); 253 | for s in special.iter() { 254 | next_subst += subst.clone() * env.substitute(g, s); 255 | } 256 | subst = next_subst; 257 | } 258 | subst 259 | } 260 | 261 | impl Substitute> for SubstituteEnvironment 262 | where 263 | Self: Substitute, 264 | { 265 | fn substitute(&self, general: &HashSet, special: &HashSet) -> Substitution { 266 | substitute_by_set(self, general, special) 267 | } 268 | } 269 | 270 | impl Substitute for SubstituteEnvironment { 271 | fn substitute(&self, general: &Lifetime, special: &Lifetime) -> Substitution { 272 | if general == special { 273 | Substitution::some() 274 | } else { 275 | Substitution::empty() 276 | } 277 | } 278 | } 279 | 280 | impl Substitute for SubstituteEnvironment { 281 | fn substitute(&self, general: &TypeParamBound, special: &TypeParamBound) -> Substitution { 282 | match (general, special) { 283 | (TypeParamBound::Trait(g_tb), TypeParamBound::Trait(s_tb)) => { 284 | self.substitute(g_tb, s_tb) 285 | } 286 | (TypeParamBound::Lifetime(g_l), TypeParamBound::Lifetime(s_l)) => { 287 | self.substitute(g_l, s_l) 288 | } 289 | _ => Substitution::empty(), 290 | } 291 | } 292 | } 293 | 294 | impl Substitute for SubstituteEnvironment { 295 | fn substitute(&self, general: &QSelf, special: &QSelf) -> Substitution { 296 | if (general.position, &general.as_token) != (special.position, &special.as_token) { 297 | return Substitution::empty(); 298 | } 299 | self.substitute(general.ty.as_ref(), special.ty.as_ref()) 300 | } 301 | } 302 | 303 | impl Substitute> for SubstituteEnvironment 304 | where 305 | Self: Substitute, 306 | { 307 | fn substitute(&self, general: &Option, special: &Option) -> Substitution { 308 | match (general, special) { 309 | (None, None) => Substitution::some(), 310 | (Some(g), Some(s)) => self.substitute(g, s), 311 | _ => Substitution::empty(), 312 | } 313 | } 314 | } 315 | 316 | impl Substitute for SubstituteEnvironment { 317 | fn substitute(&self, general: &Type, special: &Type) -> Substitution { 318 | match (general, special) { 319 | ( 320 | Type::Array(TypeArray { 321 | elem: g_elem, 322 | len: g_len, 323 | .. 324 | }), 325 | Type::Array(TypeArray { 326 | elem: s_elem, 327 | len: s_len, 328 | .. 329 | }), 330 | ) => { 331 | if g_len != s_len { 332 | return Substitution::empty(); 333 | } 334 | self.substitute(g_elem.as_ref(), s_elem.as_ref()) 335 | } 336 | (Type::BareFn(g_fn), Type::BareFn(s_fn)) => { 337 | if (&g_fn.lifetimes, &g_fn.unsafety, &g_fn.abi, &g_fn.variadic) 338 | != (&s_fn.lifetimes, &s_fn.unsafety, &s_fn.abi, &s_fn.variadic) 339 | || g_fn.inputs.len() != s_fn.inputs.len() 340 | { 341 | return Substitution::empty(); 342 | } 343 | let mut subst = g_fn 344 | .inputs 345 | .iter() 346 | .zip(&s_fn.inputs) 347 | .map(|(g_arg, s_arg)| { 348 | if &g_arg.attrs != &s_arg.attrs { 349 | Substitution::empty() 350 | } else { 351 | self.substitute(&g_arg.ty, &s_arg.ty) 352 | } 353 | }) 354 | .product(); 355 | 356 | match (&g_fn.output, &s_fn.output) { 357 | (ReturnType::Default, ReturnType::Default) => (), 358 | (ReturnType::Type(_, g_ty), ReturnType::Type(_, s_ty)) => { 359 | subst *= self.substitute(g_ty.as_ref(), s_ty.as_ref()); 360 | } 361 | _ => { 362 | return Substitution::empty(); 363 | } 364 | } 365 | subst 366 | } 367 | (Type::Group(g_gr), Type::Group(s_gr)) => { 368 | self.substitute(g_gr.elem.as_ref(), s_gr.elem.as_ref()) 369 | } 370 | (Type::ImplTrait(g_it), Type::ImplTrait(s_it)) => self.substitute( 371 | g_it.bounds.iter().cloned().collect::>().as_slice(), 372 | s_it.bounds.iter().cloned().collect::>().as_slice(), 373 | ), 374 | (Type::Paren(g_p), Type::Paren(s_p)) => { 375 | self.substitute(g_p.elem.as_ref(), s_p.elem.as_ref()) 376 | } 377 | (Type::Path(g_p), s) 378 | if g_p 379 | .path 380 | .get_ident() 381 | .map(|id| self.general_params.contains(id)) 382 | == Some(true) => 383 | { 384 | Substitution::new(g_p.path.get_ident().unwrap().clone(), s.clone()) 385 | } 386 | (Type::Path(g_path), Type::Path(s_path)) => { 387 | self.substitute(&g_path.qself, &s_path.qself) 388 | * self.substitute(&g_path.path, &s_path.path) 389 | } 390 | (Type::Ptr(g_ptr), Type::Ptr(s_ptr)) => { 391 | if &g_ptr.mutability != &s_ptr.mutability { 392 | Substitution::empty() 393 | } else { 394 | self.substitute(g_ptr.elem.as_ref(), s_ptr.elem.as_ref()) 395 | } 396 | } 397 | (Type::Reference(g_ref), Type::Reference(s_ref)) => { 398 | if (&g_ref.lifetime, &g_ref.mutability) != (&s_ref.lifetime, &s_ref.mutability) { 399 | Substitution::empty() 400 | } else { 401 | self.substitute(g_ref.elem.as_ref(), s_ref.elem.as_ref()) 402 | } 403 | } 404 | (Type::Slice(g_slice), Type::Slice(s_slice)) => { 405 | self.substitute(g_slice.elem.as_ref(), s_slice.elem.as_ref()) 406 | } 407 | (Type::TraitObject(g_to), Type::TraitObject(s_to)) => { 408 | // TODO: consider freedom of the order 409 | self.substitute( 410 | g_to.bounds.iter().cloned().collect::>().as_slice(), 411 | s_to.bounds.iter().cloned().collect::>().as_slice(), 412 | ) 413 | } 414 | (Type::Tuple(g_tup), Type::Tuple(s_tup)) => self.substitute( 415 | g_tup.elems.iter().cloned().collect::>().as_slice(), 416 | s_tup.elems.iter().cloned().collect::>().as_slice(), 417 | ), 418 | (g, s) => { 419 | if &g == &s { 420 | Substitution::some() 421 | } else { 422 | Substitution::empty() 423 | } 424 | } 425 | } 426 | } 427 | } 428 | 429 | impl Substitute for SubstituteEnvironment { 430 | fn substitute( 431 | &self, 432 | general: &WherePredicateBinding, 433 | special: &WherePredicateBinding, 434 | ) -> Substitution { 435 | match (general, special) { 436 | ( 437 | WherePredicateBinding::Type(PredicateType { 438 | lifetimes: g_lifetimes, 439 | bounded_ty: g_bounded_ty, 440 | bounds: g_bounds, 441 | .. 442 | }), 443 | WherePredicateBinding::Type(PredicateType { 444 | lifetimes: s_lifetimes, 445 | bounded_ty: s_bounded_ty, 446 | bounds: s_bounds, 447 | .. 448 | }), 449 | ) => { 450 | if &g_lifetimes != &s_lifetimes { 451 | return Substitution::empty(); 452 | } 453 | self.substitute(g_bounded_ty, s_bounded_ty) 454 | * self.substitute( 455 | &g_bounds.into_iter().cloned().collect::>(), 456 | &s_bounds.into_iter().cloned().collect::>(), 457 | ) 458 | } 459 | ( 460 | WherePredicateBinding::Eq { 461 | lhs_ty: g_lhs_ty, 462 | rhs_ty: g_rhs_ty, 463 | .. 464 | }, 465 | WherePredicateBinding::Eq { 466 | lhs_ty: s_lhs_ty, 467 | rhs_ty: s_rhs_ty, 468 | .. 469 | }, 470 | ) => self.substitute(g_lhs_ty, s_lhs_ty) * self.substitute(g_rhs_ty, s_rhs_ty), 471 | (g, s) => { 472 | if g == s { 473 | Substitution::some() 474 | } else { 475 | Substitution::empty() 476 | } 477 | } 478 | } 479 | } 480 | } 481 | 482 | #[test] 483 | fn test_unit() { 484 | use proc_macro2::Span; 485 | 486 | // Ident "A*" are unbounded type params in general side 487 | let env = SubstituteEnvironment { 488 | general_params: vec![Ident::new("A0", Span::call_site())] 489 | .into_iter() 490 | .collect(), 491 | }; 492 | 493 | macro_rules! unittest { 494 | ($env:ident [$typ:ty] { $($t0:tt)* } { $($t1:tt)* } None) => { 495 | assert_eq!( 496 | <_ as Substitute<$typ>>::substitute( 497 | &$env, 498 | &parse_quote!($($t0)*), 499 | &parse_quote!($($t1)*), 500 | ), 501 | Substitution::empty() 502 | ); 503 | }; 504 | ($env:ident [$typ:ty] { $($t0:tt)* } { $($t1:tt)* } Some [ $($name:expr => { $($t2:tt)* }),* $(,)? ]) => { 505 | assert_eq!( 506 | <_ as Substitute<$typ>>::substitute( 507 | &$env, 508 | &parse_quote!($($t0)*), 509 | &parse_quote!($($t1)*), 510 | ), 511 | Substitution( 512 | vec![ 513 | vec![$((Ident::new($name, Span::call_site()), parse_quote!($($t2)*)))*] 514 | .into_iter() 515 | .collect() 516 | ] 517 | ) 518 | ); 519 | }; 520 | } 521 | 522 | unittest! { env [Type] { [A0; 32] } { [A0; 32] } Some ["A0" => {A0}] } 523 | unittest! { env [Type] { [B0; 32] } { [B0; 32] } Some [] } 524 | unittest! { env [Type] { [A0; 32] } { [B0; 32] } Some ["A0" => {B0}] } 525 | unittest! { env [Type] { [A0; 32] } { [B0; 48] } None } 526 | unittest! { env [Type] { 527 | *mut ([A0; 1], (Option, &mut [(A0, B0)])) 528 | } { 529 | *mut ([Vec; 1], (Option>, &mut [(Vec, B0)])) 530 | } Some [ 531 | "A0" => { Vec } 532 | ]} 533 | unittest! { env [Type] { 534 | *mut ([A0; 1], (Option, &mut [(A0, B0)])) 535 | } { 536 | *mut ([Vec; 1], (Option>, &mut [(BTreeSet, B0)])) 537 | } None } 538 | unittest! { env [Type] { 539 | (,)>>::Item) 540 | } { 541 | (,)>>::Item) 542 | } Some [ 543 | "A0" => { B0 } 544 | ] } 545 | unittest! { env [Type] { 546 | (,)>>::Item) 547 | } { 548 | (,)>>::Item) 549 | } None } 550 | } 551 | -------------------------------------------------------------------------------- /tests/large.rs: -------------------------------------------------------------------------------- 1 | use min_specialization::specialization; 2 | 3 | #[specialization] 4 | mod test_mod { 5 | pub trait DataSize { 6 | fn size(&self) -> usize; 7 | } 8 | 9 | impl DataSize for T { 10 | default fn size(&self) -> usize { 11 | std::mem::size_of::() 12 | } 13 | } 14 | 15 | impl DataSize for &str { 16 | fn size(&self) -> usize { 17 | self.len() 18 | } 19 | } 20 | 21 | impl DataSize for i32 { 22 | fn size(&self) -> usize { 23 | 4 24 | } 25 | } 26 | 27 | // 特殊化された実装:f64 型の場合 28 | impl DataSize for f64 { 29 | fn size(&self) -> usize { 30 | 8 // f64 は 8 バイト固定 31 | } 32 | } 33 | } 34 | 35 | use test_mod::DataSize; 36 | 37 | fn main() { 38 | let integer = 42; 39 | let floating_point = 3.14; 40 | let text = "Hello, Rust!"; 41 | 42 | println!("Size of i32: {}", integer.size()); 43 | println!("Size of f64: {}", floating_point.size()); 44 | println!("Size of &str: {}", text.size()); 45 | println!("Size of f64 default: {}", (5.0f64).size()); 46 | } 47 | -------------------------------------------------------------------------------- /tests/large2.rs: -------------------------------------------------------------------------------- 1 | use min_specialization::specialization; 2 | #[specialization] 3 | mod test_mod { 4 | pub trait Serialize { 5 | fn serialize(&self) -> String; 6 | } 7 | 8 | impl Serialize for T { 9 | default fn serialize(&self) -> String { 10 | format!("Generic serialization: {:?}", self) 11 | } 12 | } 13 | 14 | impl Serialize for i32 15 | where 16 | Self: core::fmt::Debug, 17 | { 18 | fn serialize(&self) -> String { 19 | format!("Integer serialization: {}", self) 20 | } 21 | } 22 | 23 | impl Serialize for &str 24 | where 25 | Self: core::fmt::Debug, 26 | { 27 | fn serialize(&self) -> String { 28 | format!("String serialization: '{}'", self) 29 | } 30 | } 31 | } 32 | 33 | use test_mod::Serialize; 34 | 35 | fn main() { 36 | let x = 42; 37 | let y = "Hello, world!"; 38 | let z = (); 39 | assert!(x.serialize().starts_with("Integer")); 40 | assert!(y.serialize().starts_with("String")); 41 | assert!(z.serialize().starts_with("Generic")); 42 | } 43 | -------------------------------------------------------------------------------- /tests/no_self_arg.rs: -------------------------------------------------------------------------------- 1 | #[min_specialization::specialization] 2 | mod test1 { 3 | #[allow(unused)] 4 | trait MyTrait { 5 | fn f(a: Self) -> Self; 6 | } 7 | 8 | impl MyTrait for T { 9 | default fn f(a: T) -> T { 10 | a 11 | } 12 | } 13 | 14 | impl MyTrait for () { 15 | fn f(_: ()) -> () { 16 | () 17 | } 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /tests/test.rs: -------------------------------------------------------------------------------- 1 | use min_specialization::specialization; 2 | 3 | #[specialization] 4 | mod test_mod { 5 | #[allow(unused)] 6 | trait Trait { 7 | type Ty; 8 | fn number(_: U) -> Self::Ty; 9 | } 10 | 11 | impl Trait for T { 12 | type Ty = usize; 13 | default fn number(_: U) -> Self::Ty { 14 | 0 15 | } 16 | } 17 | 18 | impl Trait for () { 19 | fn number(_: U) -> Self::Ty { 20 | 1 21 | } 22 | } 23 | 24 | #[allow(unused)] 25 | struct S(T); 26 | impl core::ops::AddAssign for S { 27 | default fn add_assign(&mut self, _rhs: Self) {} 28 | } 29 | impl core::ops::AddAssign for S { 30 | fn add_assign(&mut self, _rhs: Self) {} 31 | } 32 | } 33 | --------------------------------------------------------------------------------