ink_macro/sol/
codec.rs

1// Copyright (C) Use Ink (UK) Ltd.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use proc_macro2::TokenStream as TokenStream2;
16use quote::quote;
17use syn::{
18    spanned::Spanned,
19    Expr,
20    Field,
21    Fields,
22    GenericParam,
23    Lit,
24};
25use synstructure::VariantInfo;
26
27use super::utils;
28
29/// Derives the `ink::SolDecode` trait for the given `struct` or `enum`.
30pub fn sol_decode_derive(s: synstructure::Structure) -> TokenStream2 {
31    match s.ast().data {
32        syn::Data::Struct(_) => {
33            sol_decode_derive_struct(s).unwrap_or_else(|err| err.to_compile_error())
34        }
35        syn::Data::Enum(_) => {
36            sol_decode_derive_enum(s).unwrap_or_else(|err| err.to_compile_error())
37        }
38        _ => {
39            syn::Error::new(
40                s.ast().span(),
41                "can only derive `SolDecode` for Rust `struct` and `enum` items",
42            )
43            .to_compile_error()
44        }
45    }
46}
47
48/// Derives the `ink::SolEncode` trait for the given `struct` or `enum`.
49pub fn sol_encode_derive(s: synstructure::Structure) -> TokenStream2 {
50    match s.ast().data {
51        syn::Data::Struct(_) => {
52            sol_encode_derive_struct(s).unwrap_or_else(|err| err.to_compile_error())
53        }
54        syn::Data::Enum(_) => {
55            sol_encode_derive_enum(s).unwrap_or_else(|err| err.to_compile_error())
56        }
57        _ => {
58            syn::Error::new(
59                s.ast().span(),
60                "can only derive `SolEncode` for Rust `struct` and `enum` items",
61            )
62            .to_compile_error()
63        }
64    }
65}
66
67/// Derives the `ink::SolDecode` trait for the given `struct`.
68fn sol_decode_derive_struct(s: synstructure::Structure) -> syn::Result<TokenStream2> {
69    let Some(variant) = s.variants().first() else {
70        return Err(syn::Error::new(
71            s.ast().span(),
72            "can only derive `SolDecode` for Rust `struct` items",
73        ));
74    };
75
76    let fields = variant.ast().fields;
77    let sol_tys = fields.iter().map(|field| {
78        let ty = &field.ty;
79        quote! {
80            <#ty as ::ink::SolDecode>::SolType
81        }
82    });
83    fn from_sol_type(value: TokenStream2, field: &Field) -> TokenStream2 {
84        let ty = &field.ty;
85        quote! {
86            <#ty as ::ink::SolDecode>::from_sol_type(#value)?
87        }
88    }
89    let self_body = utils::body_from_fields(fields, Some(from_sol_type));
90
91    Ok(s.bound_impl(
92        quote!(::ink::SolDecode),
93        quote! {
94            type SolType = ( #( #sol_tys, )* );
95
96            fn from_sol_type(value: Self::SolType) -> ::core::result::Result<Self, ::ink::sol::Error> {
97                Ok(Self #self_body)
98            }
99        },
100    ))
101}
102
103/// Derives the `ink::SolEncode` trait for the given `struct`.
104fn sol_encode_derive_struct(mut s: synstructure::Structure) -> syn::Result<TokenStream2> {
105    let Some(variant) = s.variants().first() else {
106        return Err(syn::Error::new(
107            s.ast().span(),
108            "can only derive `SolEncode` for Rust `struct` items",
109        ));
110    };
111
112    let fields = variant.ast().fields;
113    let sol_tys = fields.iter().map(|field| {
114        let ty = &field.ty;
115        quote!( <#ty as ::ink::SolEncode<'a>>::SolType )
116    });
117    fn to_sol_type(value: TokenStream2, field: &Field) -> TokenStream2 {
118        let ty = &field.ty;
119        quote! {
120            <#ty as ::ink::SolEncode<'_>>::to_sol_type(#value)
121        }
122    }
123    let sol_ty_tuple = utils::tuple_elems_from_fields(fields, Some(to_sol_type));
124
125    let lifetime: GenericParam = syn::parse_quote!('a);
126    s.add_impl_generic(lifetime);
127    Ok(s.bound_impl(
128        quote!(::ink::SolEncode<'a>),
129        quote! {
130            type SolType = ( #( #sol_tys, )* );
131
132            fn to_sol_type(&'a self) -> Self::SolType {
133                #sol_ty_tuple
134            }
135        },
136    ))
137}
138
139/// Derives the `ink::SolDecode` trait for the given `enum`.
140fn sol_decode_derive_enum(s: synstructure::Structure) -> syn::Result<TokenStream2> {
141    utils::ensure_non_empty_enum(&s, "SolDecode")?;
142    ensure_empty_variants(&s, "SolDecode")?;
143    ensure_u8_max_variants(&s, "SolDecode")?;
144    ensure_consistent_variant_int_repr(&s, "SolDecode")?;
145    ensure_valid_discriminant_values(&s, "SolDecode")?;
146
147    let variants_match = s.variants().iter().enumerate().map(|(idx, variant)| {
148        let variant_ident = variant.ast().ident;
149        let int_repr = variant_int_repr(variant, idx as u8);
150        let field_delimiters = variant_field_delimiters(variant);
151        quote! {
152            #int_repr => {
153                ::core::result::Result::Ok(Self:: #variant_ident #field_delimiters)
154            }
155        }
156    });
157
158    Ok(s.bound_impl(
159        quote!(::ink::SolDecode),
160        quote! {
161            type SolType = ::core::primitive::u8;
162
163            fn from_sol_type(value: Self::SolType) -> ::core::result::Result<Self, ::ink::sol::Error> {
164                match value {
165                    #( #variants_match )*
166                    _ => ::core::result::Result::Err(::ink::sol::Error)
167                }
168            }
169        },
170    ))
171}
172
173/// Derives the `ink::SolEncode` trait for the given `enum`.
174fn sol_encode_derive_enum(mut s: synstructure::Structure) -> syn::Result<TokenStream2> {
175    utils::ensure_non_empty_enum(&s, "SolEncode")?;
176    ensure_empty_variants(&s, "SolEncode")?;
177    ensure_u8_max_variants(&s, "SolEncode")?;
178    ensure_consistent_variant_int_repr(&s, "SolEncode")?;
179    ensure_valid_discriminant_values(&s, "SolEncode")?;
180
181    let lifetime: GenericParam = syn::parse_quote!('a);
182    s.add_impl_generic(lifetime);
183
184    let variants_match = s.variants().iter().enumerate().map(|(idx, variant)| {
185        let variant_ident = variant.ast().ident;
186        let int_repr = variant_int_repr(variant, idx as u8);
187        let field_delimiters = variant_field_delimiters(variant);
188        quote! {
189            Self:: #variant_ident #field_delimiters => #int_repr,
190        }
191    });
192
193    Ok(s.bound_impl(
194        quote!(::ink::SolEncode<'a>),
195        quote! {
196            type SolType = ::core::primitive::u8;
197
198            fn to_sol_type(&'a self) -> Self::SolType {
199                match self {
200                    #( #variants_match )*
201                }
202            }
203        },
204    ))
205}
206
207/// Ensures that the given item has only unit-only or field-less variant.
208fn ensure_empty_variants(
209    s: &synstructure::Structure,
210    trait_name: &str,
211) -> syn::Result<()> {
212    let has_non_empty_variants = s
213        .variants()
214        .iter()
215        .any(|variant| !variant.ast().fields.is_empty());
216    if has_non_empty_variants {
217        Err(syn::Error::new(
218            s.ast().span(),
219            format!(
220                "can only derive `{trait_name}` for Rust `enum` items with \
221                only unit-only or field-less variants"
222            ),
223        ))
224    } else {
225        Ok(())
226    }
227}
228
229/// Ensures that the given item has at most `u8::MAX` variants.
230///
231/// # Note
232///
233/// Rust doesn't have an explicit limit on the number of allowed enum variants, however,
234/// the practical limit can be understood to `isize::MAX` as a `rustc` implementation
235/// detail.
236///
237/// References:
238///
239/// - <https://doc.rust-lang.org/reference/items/enumerations.html#r-items.enum.discriminant.repr-rust>
240/// - <https://github.com/rust-lang/rust/blob/f63685ddf3d3c92a61158cd55d44bde17c2b024f/compiler/rustc_ast/src/ast.rs#L3270>
241fn ensure_u8_max_variants(
242    s: &synstructure::Structure,
243    trait_name: &str,
244) -> syn::Result<()> {
245    if s.variants().len() > u8::MAX as usize {
246        Err(syn::Error::new(
247            s.ast().span(),
248            format!(
249                "can only derive `{trait_name}` for Rust `enum` items \
250                with at most `u8::MAX` variants"
251            ),
252        ))
253    } else {
254        Ok(())
255    }
256}
257
258/// Ensures that the given item will yield a consistent integer representation for all its
259/// variants.
260///
261/// # Note
262///
263/// This check only succeeds if one of the following conditions is met:
264/// - No variant has an explicitly set discriminant
265/// - All variants have an explicitly set discriminant
266fn ensure_consistent_variant_int_repr(
267    s: &synstructure::Structure,
268    trait_name: &str,
269) -> syn::Result<()> {
270    let n_variants_with_discriminants = s
271        .variants()
272        .iter()
273        .filter(|variant| variant.ast().discriminant.is_some())
274        .count();
275    if n_variants_with_discriminants > 0
276        && n_variants_with_discriminants != s.variants().len()
277    {
278        Err(syn::Error::new(
279            s.ast().span(),
280            format!(
281                "can only derive `{trait_name}` for Rust `enum` items that \
282                either have no variants with explicitly specified discriminants, \
283                or have explicitly specified discriminants for all variants"
284            ),
285        ))
286    } else {
287        Ok(())
288    }
289}
290
291/// Ensures that the given item has only valid discriminant values, if any are explicitly
292/// specified.
293///
294/// # Note
295///
296/// Enums are encoded as `u8` (i.e. `uint8` in Solidity ABI encoding), so the maximum
297/// allowed discriminant value is `u8::MAX` (i.e. `255`)
298///
299/// Rust does NOT have an explicit limit on the number of allowed enum variants, however,
300/// the practical limit can be understood to `isize::MAX` as a `rustc` implementation
301/// detail.
302///
303/// For unit-only enums, `rustc` **currently** enforces that any explicitly specified enum
304/// discriminant is not larger than `u8::MAX` if (and only if) the number of variants is
305/// also less than `u8::MAX`.
306///
307/// References:
308///
309/// - <https://doc.rust-lang.org/reference/items/enumerations.html#r-items.enum.discriminant.repr-rust>
310/// - <https://github.com/rust-lang/rust/blob/f63685ddf3d3c92a61158cd55d44bde17c2b024f/compiler/rustc_ast/src/ast.rs#L3270>
311fn ensure_valid_discriminant_values(
312    s: &synstructure::Structure,
313    trait_name: &str,
314) -> syn::Result<()> {
315    let offending_span = s.variants().iter().find_map(|variant| {
316        variant.ast().discriminant.as_ref().and_then(|(_, expr)| {
317            match expr {
318                Expr::Lit(expr) => {
319                    match &expr.lit {
320                        Lit::Int(value) => {
321                            let discr = value.base10_parse::<usize>().ok()?;
322                            (discr > u8::MAX as usize).then_some(expr.span())
323                        }
324                        _ => None,
325                    }
326                }
327                _ => None,
328            }
329        })
330    });
331    if let Some(span) = offending_span {
332        Err(syn::Error::new(
333            span,
334            format!(
335                "can only derive `{trait_name}` for Rust `enum` items \
336                with discriminant values (if explicitly specified) \
337                not larger than `u8::MAX`"
338            ),
339        ))
340    } else {
341        Ok(())
342    }
343}
344
345/// Returns an integer representation of an enum variant.
346fn variant_int_repr(variant: &VariantInfo, idx: u8) -> TokenStream2 {
347    variant
348        .ast()
349        .discriminant
350        .as_ref()
351        .map(|(_, expr)| quote!( #expr ))
352        .unwrap_or_else(|| quote! ( #idx ))
353}
354
355/// Returns the field delimiters for given a variant.
356fn variant_field_delimiters(variant: &VariantInfo) -> TokenStream2 {
357    match variant.ast().fields {
358        Fields::Named(_) => quote!({}),
359        Fields::Unnamed(_) => quote! { () },
360        Fields::Unit => quote!(),
361    }
362}