ink_macro/sol/
error.rs

1// Copyright (C) Use Ink (UK) Ltd.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use ink_ir::IsDocAttribute;
16use proc_macro2::{
17    Ident,
18    TokenStream as TokenStream2,
19};
20use quote::{
21    format_ident,
22    quote,
23};
24use syn::{
25    Attribute,
26    Field,
27    Fields,
28    spanned::Spanned,
29};
30
31use super::utils;
32
33/// Derives the `ink::sol::SolErrorDecode` trait for the given `struct` or `enum`.
34pub fn sol_error_decode_derive(s: synstructure::Structure) -> TokenStream2 {
35    match s.ast().data {
36        syn::Data::Struct(_) => {
37            sol_error_decode_derive_struct(s).unwrap_or_else(|err| err.to_compile_error())
38        }
39        syn::Data::Enum(_) => {
40            sol_error_decode_derive_enum(s).unwrap_or_else(|err| err.to_compile_error())
41        }
42        _ => {
43            syn::Error::new(
44                s.ast().span(),
45                "can only derive `SolErrorDecode` for Rust `struct` and `enum` items",
46            )
47            .to_compile_error()
48        }
49    }
50}
51
52/// Derives the `ink::sol::SolErrorEncode` trait for the given `struct` or `enum`.
53pub fn sol_error_encode_derive(s: synstructure::Structure) -> TokenStream2 {
54    match s.ast().data {
55        syn::Data::Struct(_) => {
56            sol_error_encode_derive_struct(s).unwrap_or_else(|err| err.to_compile_error())
57        }
58        syn::Data::Enum(_) => {
59            sol_error_encode_derive_enum(s).unwrap_or_else(|err| err.to_compile_error())
60        }
61        _ => {
62            syn::Error::new(
63                s.ast().span(),
64                "can only derive `SolErrorEncode` for Rust `struct` and `enum` items",
65            )
66            .to_compile_error()
67        }
68    }
69}
70
71/// Derives the `ink::metadata::sol::SolErrorMetadata` trait for the given `struct` or
72/// `enum`.
73pub fn sol_error_metadata_derive(s: synstructure::Structure) -> TokenStream2 {
74    match s.ast().data {
75        syn::Data::Struct(_) => {
76            sol_error_metadata_derive_struct(s)
77                .unwrap_or_else(|err| err.to_compile_error())
78        }
79        syn::Data::Enum(_) => {
80            sol_error_metadata_derive_enum(s).unwrap_or_else(|err| err.to_compile_error())
81        }
82        _ => {
83            syn::Error::new(
84                s.ast().span(),
85                "can only derive `SolErrorEncode` for Rust `struct` and `enum` items",
86            )
87            .to_compile_error()
88        }
89    }
90}
91
92/// Derives the `ink::sol::SolErrorDecode` trait for the given `struct`.
93fn sol_error_decode_derive_struct(
94    s: synstructure::Structure,
95) -> syn::Result<TokenStream2> {
96    ensure_no_generics(&s, "SolErrorDecode")?;
97
98    let Some(variant) = s.variants().first() else {
99        return Err(syn::Error::new(
100            s.ast().span(),
101            "can only derive `SolErrorDecode` for Rust `struct` items",
102        ));
103    };
104
105    let name = &s.ast().ident.to_string();
106    let fields = variant.ast().fields;
107    let params_tys = fields.iter().map(|field| &field.ty);
108    let params_tuple_ty = quote! {
109        ( #( #params_tys, )* )
110    };
111    let self_body = utils::body_from_fields(fields, None);
112
113    Ok(s.bound_impl(
114        quote!(::ink::sol::SolErrorDecode),
115        quote! {
116            fn decode(data: &[::core::primitive::u8]) -> ::core::result::Result<Self, ::ink::sol::Error>
117            where
118                Self: Sized,
119            {
120                const SELECTOR: [::core::primitive::u8; 4] = ::ink::sol_error_selector!(#name, #params_tuple_ty);
121                if data[..4] == SELECTOR {
122                    <#params_tuple_ty as ::ink::sol::SolParamsDecode>::decode(
123                        &data[4..],
124                    )
125                    .map(|value| {
126                        Self #self_body
127                    })
128                } else {
129                    Err(::ink::sol::Error)
130                }
131            }
132        },
133    ))
134}
135
136/// Derives the `ink::sol::SolErrorEncode` trait for the given `struct`.
137fn sol_error_encode_derive_struct(
138    s: synstructure::Structure,
139) -> syn::Result<TokenStream2> {
140    ensure_no_generics(&s, "SolErrorEncode")?;
141
142    let Some(variant) = s.variants().first() else {
143        return Err(syn::Error::new(
144            s.ast().span(),
145            "can only derive `SolErrorEncode` for Rust `struct` items",
146        ));
147    };
148
149    let name = &s.ast().ident.to_string();
150    let fields = variant.ast().fields;
151    let selector_params_tys = fields.iter().map(|field| &field.ty);
152    let encode_params_tys = fields.iter().map(|field| {
153        let ty = &field.ty;
154        quote!( &#ty )
155    });
156    let params_elems = utils::tuple_elems_from_fields(fields, None);
157
158    Ok(s.bound_impl(
159        quote!(::ink::sol::SolErrorEncode),
160        quote! {
161            fn encode(&self) -> ::ink::prelude::vec::Vec<::core::primitive::u8> {
162                let mut results = ::ink::prelude::vec::Vec::from(
163                    ::ink::sol_error_selector!(
164                        #name,
165                        ( #( #selector_params_tys, )* )
166                    )
167                );
168                results.extend(
169                    <( #( #encode_params_tys, )* ) as ::ink::sol::SolParamsEncode>::encode(
170                        &#params_elems,
171                    ),
172                );
173                results
174            }
175        },
176    ))
177}
178
179/// Derives the `ink::metadata::sol::SolErrorMetadata` trait for the given `struct`.
180fn sol_error_metadata_derive_struct(
181    s: synstructure::Structure,
182) -> syn::Result<TokenStream2> {
183    ensure_no_generics(&s, "SolErrorMetadata")?;
184
185    let Some(variant) = s.variants().first() else {
186        return Err(syn::Error::new(
187            s.ast().span(),
188            "can only derive `SolErrorMetadata` for Rust `struct` items",
189        ));
190    };
191
192    let ident = &s.ast().ident;
193    let name = ident.to_string();
194    let params = variant.ast().fields.iter().map(param_metadata_from_field);
195    let docs = extract_docs(s.ast().attrs.as_slice());
196    let metadata_linker = register_metadata(ident);
197
198    Ok(s.bound_impl(
199        quote!(::ink::metadata::sol::SolErrorMetadata),
200        quote! {
201            fn error_specs() -> ::ink::prelude::vec::Vec<::ink::metadata::sol::ErrorMetadata> {
202                #metadata_linker
203
204                vec![
205                    ::ink::metadata::sol::ErrorMetadata {
206                        name: #name.into(),
207                        params: vec![ #( #params ),* ],
208                        docs: #docs.into(),
209                    }
210                ]
211            }
212        },
213    ))
214}
215
216/// Derives the `ink::sol::SolErrorDecode` trait for the given `enum`.
217fn sol_error_decode_derive_enum(s: synstructure::Structure) -> syn::Result<TokenStream2> {
218    ensure_no_generics(&s, "SolErrorDecode")?;
219    utils::ensure_non_empty_enum(&s, "SolErrorDecode")?;
220
221    let variant_selector_ident = |idx: usize| format_ident!("VARIANT_{}", idx);
222    let variant_selectors = s.variants().iter().enumerate().map(|(idx, variant)| {
223        let selector_ident = variant_selector_ident(idx);
224        let variant_name = variant.ast().ident.to_string();
225        let fields = variant.ast().fields;
226        let param_tys = fields.iter().map(|field| &field.ty);
227        quote! {
228            const #selector_ident: [::core::primitive::u8; 4] = ::ink::sol_error_selector!(
229                #variant_name, ( #( #param_tys, )* )
230            );
231        }
232    });
233    let variants_match = s.variants().iter().enumerate().map(|(idx, variant)| {
234        let variant_ident = variant.ast().ident;
235        let selector_ident = variant_selector_ident(idx);
236        let fields = variant.ast().fields;
237        let param_tys = fields.iter().map(|field| &field.ty);
238        let variant_body = utils::body_from_fields(fields, None);
239        quote! {
240            #selector_ident => {
241                <( #( #param_tys, )* ) as ::ink::sol::SolParamsDecode>::decode(
242                    &data[4..],
243                )
244                .map(|value| {
245                    Self:: #variant_ident #variant_body
246                })
247            }
248        }
249    });
250
251    Ok(s.bound_impl(
252        quote!(::ink::sol::SolErrorDecode),
253        quote! {
254            fn decode(data: &[::core::primitive::u8]) -> ::core::result::Result<Self, ::ink::sol::Error>
255            where
256                Self: Sized,
257            {
258                let selector: [::core::primitive::u8; 4] = data[..4].try_into().map_err(|_| ::ink::sol::Error)?;
259
260                #( #variant_selectors )*
261
262                match selector {
263                    #( #variants_match )*
264                    _ => Err(::ink::sol::Error),
265                }
266            }
267        },
268    ))
269}
270
271/// Derives the `ink::sol::SolErrorEncode` trait for the given `enum`.
272fn sol_error_encode_derive_enum(s: synstructure::Structure) -> syn::Result<TokenStream2> {
273    ensure_no_generics(&s, "SolErrorEncode")?;
274    utils::ensure_non_empty_enum(&s, "SolErrorEncode")?;
275
276    let variants_match = s.variants().iter().map(|variant| {
277        let variant_ident = variant.ast().ident;
278        let variant_name = variant_ident.to_string();
279        let fields = variant.ast().fields;
280        let selector_params_tys = fields.iter().map(|field| &field.ty);
281        let encode_params_tys = fields.iter().map(|field| {
282            let ty = &field.ty;
283            quote!( &#ty )
284        });
285        let bindings = || {
286            variant.bindings().iter().map(|info| {
287                // var is either a field name, or generated "binding_*" name for tuple
288                // elements.
289                info
290                    .ast()
291                    .ident
292                    .as_ref()
293                    .map(|ident| quote!(#ident))
294                    .unwrap_or_else(|| {
295                        let binding = &info.binding;
296                        quote!(#binding)
297                    })
298            })
299        };
300        let (variant_bindings, params_elems) = match fields {
301            // Handles named fields.
302            Fields::Named(_) => {
303                let variant_fields = bindings();
304                let params_elems = quote! {
305                    #( #variant_fields, )*
306                };
307                (
308                    quote!(
309                        {
310                            #params_elems
311                        }
312                    ),
313                    params_elems,
314                )
315            }
316            // Handles tuple elements.
317            Fields::Unnamed(_) => {
318                let variant_elems = bindings();
319                let params_elems = quote! {
320                    #( #variant_elems, )*
321                };
322                (
323                    quote! {
324                        ( #params_elems )
325                    },
326                    params_elems,
327                )
328            }
329            // Handles unit variants.
330            Fields::Unit => (quote!(), quote!()),
331        };
332
333        quote! {
334            Self:: #variant_ident #variant_bindings => {
335                let mut results = ::ink::prelude::vec::Vec::from(
336                    ::ink::sol_error_selector!(
337                        #variant_name,
338                        ( #( #selector_params_tys, )* )
339                    )
340                );
341                results.extend(
342                    <( #( #encode_params_tys, )* ) as ::ink::sol::SolParamsEncode>::encode(
343                        &( #params_elems ),
344                    ),
345                );
346                results
347            }
348        }
349    });
350
351    Ok(s.bound_impl(
352        quote!(::ink::sol::SolErrorEncode),
353        quote! {
354            fn encode(&self) -> ::ink::prelude::vec::Vec<::core::primitive::u8> {
355                match self {
356                    #( #variants_match )*
357                }
358            }
359        },
360    ))
361}
362
363/// Derives the `ink::metadata::sol::SolErrorMetadata` trait for the given `enum`.
364fn sol_error_metadata_derive_enum(
365    s: synstructure::Structure,
366) -> syn::Result<TokenStream2> {
367    ensure_no_generics(&s, "SolErrorMetadata")?;
368    utils::ensure_non_empty_enum(&s, "SolErrorMetadata")?;
369
370    let error_variants = s.variants().iter().map(|variant| {
371        let variant_ident = variant.ast().ident;
372        let variant_name = variant_ident.to_string();
373        let params = variant.ast().fields.iter().map(param_metadata_from_field);
374        let docs = extract_docs(variant.ast().attrs);
375
376        quote! {
377            ::ink::metadata::sol::ErrorMetadata {
378                name: #variant_name.into(),
379                params: vec![ #( #params ),* ],
380                docs: #docs.into(),
381            }
382        }
383    });
384    let metadata_linker = register_metadata(&s.ast().ident);
385
386    Ok(s.bound_impl(
387        quote!(::ink::metadata::sol::SolErrorMetadata),
388        quote! {
389            fn error_specs() -> ::ink::prelude::vec::Vec<::ink::metadata::sol::ErrorMetadata> {
390                #metadata_linker
391
392                vec![ #( #error_variants ),* ]
393            }
394        },
395    ))
396}
397
398/// Ensures that the given item has no generics.
399fn ensure_no_generics(s: &synstructure::Structure, trait_name: &str) -> syn::Result<()> {
400    if s.ast().generics.params.is_empty() {
401        Ok(())
402    } else {
403        Err(syn::Error::new(
404            s.ast().generics.params.span(),
405            format!(
406                "can only derive `{trait_name}` for Rust `struct` or `enum` \
407                items without generics"
408            ),
409        ))
410    }
411}
412
413/// Register an error metadata function in the distributed slice for combining all
414/// errors referenced in the contract binary.
415fn register_metadata(ident: &Ident) -> TokenStream2 {
416    quote! {
417       #[::ink::linkme::distributed_slice(::ink::CONTRACT_ERRORS_SOL)]
418       #[linkme(crate = ::ink::linkme)]
419       static ERROR_METADATA: fn() -> ::ink::prelude::vec::Vec<::ink::metadata::sol::ErrorMetadata> =
420           <#ident as ::ink::metadata::sol::SolErrorMetadata>::error_specs;
421    }
422}
423
424/// Returns the error parameter from the given field.
425fn param_metadata_from_field(field: &Field) -> TokenStream2 {
426    let ty = &field.ty;
427    let name = field
428        .ident
429        .as_ref()
430        .map(ToString::to_string)
431        .unwrap_or_default();
432    let docs = extract_docs(field.attrs.as_slice());
433    let sol_ty = quote! {
434        <#ty as ::ink::SolEncode>::SOL_NAME
435    };
436    quote! {
437        ::ink::metadata::sol::ErrorParamMetadata {
438            name: #name.into(),
439            ty: #sol_ty.into(),
440            docs: #docs.into(),
441        }
442    }
443}
444
445/// Returns the rustdoc string from the given item attributes.
446fn extract_docs(attrs: &[Attribute]) -> String {
447    attrs
448        .iter()
449        .filter_map(|attr| attr.extract_docs())
450        .collect::<Vec<_>>()
451        .join("\n")
452}