1use proc_macro2::TokenStream as TokenStream2;
16use quote::quote;
17use syn::{
18 spanned::Spanned,
19 Expr,
20 Field,
21 Fields,
22 GenericParam,
23 Lit,
24};
25use synstructure::VariantInfo;
26
27use super::utils;
28
29pub fn sol_decode_derive(s: synstructure::Structure) -> TokenStream2 {
31 match s.ast().data {
32 syn::Data::Struct(_) => {
33 sol_decode_derive_struct(s).unwrap_or_else(|err| err.to_compile_error())
34 }
35 syn::Data::Enum(_) => {
36 sol_decode_derive_enum(s).unwrap_or_else(|err| err.to_compile_error())
37 }
38 _ => {
39 syn::Error::new(
40 s.ast().span(),
41 "can only derive `SolDecode` for Rust `struct` and `enum` items",
42 )
43 .to_compile_error()
44 }
45 }
46}
47
48pub fn sol_encode_derive(s: synstructure::Structure) -> TokenStream2 {
50 match s.ast().data {
51 syn::Data::Struct(_) => {
52 sol_encode_derive_struct(s).unwrap_or_else(|err| err.to_compile_error())
53 }
54 syn::Data::Enum(_) => {
55 sol_encode_derive_enum(s).unwrap_or_else(|err| err.to_compile_error())
56 }
57 _ => {
58 syn::Error::new(
59 s.ast().span(),
60 "can only derive `SolEncode` for Rust `struct` and `enum` items",
61 )
62 .to_compile_error()
63 }
64 }
65}
66
67fn sol_decode_derive_struct(s: synstructure::Structure) -> syn::Result<TokenStream2> {
69 let Some(variant) = s.variants().first() else {
70 return Err(syn::Error::new(
71 s.ast().span(),
72 "can only derive `SolDecode` for Rust `struct` items",
73 ));
74 };
75
76 let fields = variant.ast().fields;
77 let sol_tys = fields.iter().map(|field| {
78 let ty = &field.ty;
79 quote! {
80 <#ty as ::ink::SolDecode>::SolType
81 }
82 });
83 fn from_sol_type(value: TokenStream2, field: &Field) -> TokenStream2 {
84 let ty = &field.ty;
85 quote! {
86 <#ty as ::ink::SolDecode>::from_sol_type(#value)?
87 }
88 }
89 let self_body = utils::body_from_fields(fields, Some(from_sol_type));
90
91 Ok(s.bound_impl(
92 quote!(::ink::SolDecode),
93 quote! {
94 type SolType = ( #( #sol_tys, )* );
95
96 fn from_sol_type(value: Self::SolType) -> ::core::result::Result<Self, ::ink::sol::Error> {
97 Ok(Self #self_body)
98 }
99 },
100 ))
101}
102
103fn sol_encode_derive_struct(mut s: synstructure::Structure) -> syn::Result<TokenStream2> {
105 let Some(variant) = s.variants().first() else {
106 return Err(syn::Error::new(
107 s.ast().span(),
108 "can only derive `SolEncode` for Rust `struct` items",
109 ));
110 };
111
112 let fields = variant.ast().fields;
113 let sol_tys = fields.iter().map(|field| {
114 let ty = &field.ty;
115 quote!( <#ty as ::ink::SolEncode<'a>>::SolType )
116 });
117 fn to_sol_type(value: TokenStream2, field: &Field) -> TokenStream2 {
118 let ty = &field.ty;
119 quote! {
120 <#ty as ::ink::SolEncode<'_>>::to_sol_type(#value)
121 }
122 }
123 let sol_ty_tuple = utils::tuple_elems_from_fields(fields, Some(to_sol_type));
124
125 let lifetime: GenericParam = syn::parse_quote!('a);
126 s.add_impl_generic(lifetime);
127 Ok(s.bound_impl(
128 quote!(::ink::SolEncode<'a>),
129 quote! {
130 type SolType = ( #( #sol_tys, )* );
131
132 fn to_sol_type(&'a self) -> Self::SolType {
133 #sol_ty_tuple
134 }
135 },
136 ))
137}
138
139fn sol_decode_derive_enum(s: synstructure::Structure) -> syn::Result<TokenStream2> {
141 utils::ensure_non_empty_enum(&s, "SolDecode")?;
142 ensure_empty_variants(&s, "SolDecode")?;
143 ensure_u8_max_variants(&s, "SolDecode")?;
144 ensure_consistent_variant_int_repr(&s, "SolDecode")?;
145 ensure_valid_discriminant_values(&s, "SolDecode")?;
146
147 let variants_match = s.variants().iter().enumerate().map(|(idx, variant)| {
148 let variant_ident = variant.ast().ident;
149 let int_repr = variant_int_repr(variant, idx as u8);
150 let field_delimiters = variant_field_delimiters(variant);
151 quote! {
152 #int_repr => {
153 ::core::result::Result::Ok(Self:: #variant_ident #field_delimiters)
154 }
155 }
156 });
157
158 Ok(s.bound_impl(
159 quote!(::ink::SolDecode),
160 quote! {
161 type SolType = ::core::primitive::u8;
162
163 fn from_sol_type(value: Self::SolType) -> ::core::result::Result<Self, ::ink::sol::Error> {
164 match value {
165 #( #variants_match )*
166 _ => ::core::result::Result::Err(::ink::sol::Error)
167 }
168 }
169 },
170 ))
171}
172
173fn sol_encode_derive_enum(mut s: synstructure::Structure) -> syn::Result<TokenStream2> {
175 utils::ensure_non_empty_enum(&s, "SolEncode")?;
176 ensure_empty_variants(&s, "SolEncode")?;
177 ensure_u8_max_variants(&s, "SolEncode")?;
178 ensure_consistent_variant_int_repr(&s, "SolEncode")?;
179 ensure_valid_discriminant_values(&s, "SolEncode")?;
180
181 let lifetime: GenericParam = syn::parse_quote!('a);
182 s.add_impl_generic(lifetime);
183
184 let variants_match = s.variants().iter().enumerate().map(|(idx, variant)| {
185 let variant_ident = variant.ast().ident;
186 let int_repr = variant_int_repr(variant, idx as u8);
187 let field_delimiters = variant_field_delimiters(variant);
188 quote! {
189 Self:: #variant_ident #field_delimiters => #int_repr,
190 }
191 });
192
193 Ok(s.bound_impl(
194 quote!(::ink::SolEncode<'a>),
195 quote! {
196 type SolType = ::core::primitive::u8;
197
198 fn to_sol_type(&'a self) -> Self::SolType {
199 match self {
200 #( #variants_match )*
201 }
202 }
203 },
204 ))
205}
206
207fn ensure_empty_variants(
209 s: &synstructure::Structure,
210 trait_name: &str,
211) -> syn::Result<()> {
212 let has_non_empty_variants = s
213 .variants()
214 .iter()
215 .any(|variant| !variant.ast().fields.is_empty());
216 if has_non_empty_variants {
217 Err(syn::Error::new(
218 s.ast().span(),
219 format!(
220 "can only derive `{trait_name}` for Rust `enum` items with \
221 only unit-only or field-less variants"
222 ),
223 ))
224 } else {
225 Ok(())
226 }
227}
228
229fn ensure_u8_max_variants(
242 s: &synstructure::Structure,
243 trait_name: &str,
244) -> syn::Result<()> {
245 if s.variants().len() > u8::MAX as usize {
246 Err(syn::Error::new(
247 s.ast().span(),
248 format!(
249 "can only derive `{trait_name}` for Rust `enum` items \
250 with at most `u8::MAX` variants"
251 ),
252 ))
253 } else {
254 Ok(())
255 }
256}
257
258fn ensure_consistent_variant_int_repr(
267 s: &synstructure::Structure,
268 trait_name: &str,
269) -> syn::Result<()> {
270 let n_variants_with_discriminants = s
271 .variants()
272 .iter()
273 .filter(|variant| variant.ast().discriminant.is_some())
274 .count();
275 if n_variants_with_discriminants > 0
276 && n_variants_with_discriminants != s.variants().len()
277 {
278 Err(syn::Error::new(
279 s.ast().span(),
280 format!(
281 "can only derive `{trait_name}` for Rust `enum` items that \
282 either have no variants with explicitly specified discriminants, \
283 or have explicitly specified discriminants for all variants"
284 ),
285 ))
286 } else {
287 Ok(())
288 }
289}
290
291fn ensure_valid_discriminant_values(
312 s: &synstructure::Structure,
313 trait_name: &str,
314) -> syn::Result<()> {
315 let offending_span = s.variants().iter().find_map(|variant| {
316 variant.ast().discriminant.as_ref().and_then(|(_, expr)| {
317 match expr {
318 Expr::Lit(expr) => {
319 match &expr.lit {
320 Lit::Int(value) => {
321 let discr = value.base10_parse::<usize>().ok()?;
322 (discr > u8::MAX as usize).then_some(expr.span())
323 }
324 _ => None,
325 }
326 }
327 _ => None,
328 }
329 })
330 });
331 if let Some(span) = offending_span {
332 Err(syn::Error::new(
333 span,
334 format!(
335 "can only derive `{trait_name}` for Rust `enum` items \
336 with discriminant values (if explicitly specified) \
337 not larger than `u8::MAX`"
338 ),
339 ))
340 } else {
341 Ok(())
342 }
343}
344
345fn variant_int_repr(variant: &VariantInfo, idx: u8) -> TokenStream2 {
347 variant
348 .ast()
349 .discriminant
350 .as_ref()
351 .map(|(_, expr)| quote!( #expr ))
352 .unwrap_or_else(|| quote! ( #idx ))
353}
354
355fn variant_field_delimiters(variant: &VariantInfo) -> TokenStream2 {
357 match variant.ast().fields {
358 Fields::Named(_) => quote!({}),
359 Fields::Unnamed(_) => quote! { () },
360 Fields::Unit => quote!(),
361 }
362}