File size: 3,621 Bytes
2409829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
use crate::helpers::call_site_ident;
use proc_macro2::{Ident, Span, TokenStream};
use syn::spanned::Spanned;
use syn::{Attribute, Data, DeriveInput, Field, Fields, ItemEnum, MetaList};

pub fn derive_discriminant_impl(input_item: TokenStream) -> syn::Result<TokenStream> {
	let input = syn::parse2::<DeriveInput>(input_item).unwrap();

	let mut data = match input.data {
		Data::Enum(data) => data,
		_ => return Err(syn::Error::new(Span::call_site(), "Tried to derive a discriminant for non-enum")),
	};

	let mut is_sub_discriminant = vec![];
	let mut attr_errs = vec![];

	for var in &mut data.variants {
		if var.attrs.iter().any(|a| a.path().is_ident("sub_discriminant")) {
			match var.fields.len() {
				1 => {
					let Field { ty, .. } = var.fields.iter_mut().next().unwrap();
					*ty = syn::parse_quote! {
						<#ty as ToDiscriminant>::Discriminant
					};
					is_sub_discriminant.push(true);
				}
				n => unimplemented!("#[sub_discriminant] on variants with {n} fields is not supported (for now)"),
			}
		} else {
			var.fields = Fields::Unit;
			is_sub_discriminant.push(false);
		}
		let mut retain = vec![];
		for (i, a) in var.attrs.iter_mut().enumerate() {
			if a.path().is_ident("discriminant_attr") {
				match a.meta.require_list() {
					Ok(MetaList { tokens, .. }) => {
						let attr: Attribute = syn::parse_quote! {
							#[#tokens]
						};
						*a = attr;
						retain.push(i);
					}
					Err(e) => {
						attr_errs.push(syn::Error::new(a.span(), e));
					}
				}
			}
		}
		var.attrs = var.attrs.iter().enumerate().filter(|(i, _)| retain.contains(i)).map(|(_, x)| x.clone()).collect();
	}

	let attrs = input
		.attrs
		.iter()
		.cloned()
		.filter_map(|a| {
			let a_span = a.span();
			a.path()
				.is_ident("discriminant_attr")
				.then(|| match a.meta.require_list() {
					Ok(MetaList { tokens, .. }) => {
						let attr: Attribute = syn::parse_quote! {
							#[#tokens]
						};
						Some(attr)
					}
					Err(e) => {
						attr_errs.push(syn::Error::new(a_span, e));
						None
					}
				})
				.and_then(|opt| opt)
		})
		.collect::<Vec<Attribute>>();

	if !attr_errs.is_empty() {
		return Err(attr_errs
			.into_iter()
			.reduce(|mut l, r| {
				l.combine(r);
				l
			})
			.unwrap());
	}

	let discriminant = ItemEnum {
		attrs,
		vis: input.vis,
		enum_token: data.enum_token,
		ident: call_site_ident(format!("{}Discriminant", input.ident)),
		generics: input.generics,
		brace_token: data.brace_token,
		variants: data.variants,
	};

	let input_type = &input.ident;
	let discriminant_type = &discriminant.ident;
	let variant = &discriminant.variants.iter().map(|var| &var.ident).collect::<Vec<&Ident>>();

	let (pattern, value) = is_sub_discriminant
		.into_iter()
		.map(|b| {
			(
				if b {
					quote::quote! { (x) }
				} else {
					quote::quote! { { .. } }
				},
				b.then(|| quote::quote! { (x.to_discriminant()) }).unwrap_or_default(),
			)
		})
		.unzip::<_, _, Vec<_>, Vec<_>>();
	#[cfg(feature = "serde-discriminant")]
	let serde = quote::quote! {
		#[derive(serde::Serialize, serde::Deserialize)]
	};

	#[cfg(not(feature = "serde-discriminant"))]
	let serde = quote::quote! {};

	let res = quote::quote! {
		#serde
		#discriminant

		impl ToDiscriminant for #input_type {
			type Discriminant = #discriminant_type;

			fn to_discriminant(&self) -> #discriminant_type {
				match self {
					#(
						#input_type::#variant #pattern => #discriminant_type::#variant #value
					),*
				}
			}
		}

		impl From<&#input_type> for #discriminant_type {
			fn from(x: &#input_type) -> #discriminant_type {
				x.to_discriminant()
			}
		}
	};

	Ok(res)
}