ink_macro/sol/
error.rs

1// Copyright (C) Use Ink (UK) Ltd.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use proc_macro2::TokenStream as TokenStream2;
16use quote::{
17    format_ident,
18    quote,
19};
20use syn::{
21    spanned::Spanned,
22    Fields,
23};
24
25use super::utils;
26
27/// Derives the `ink::sol::SolErrorDecode` trait for the given `struct` or `enum`.
28pub fn sol_error_decode_derive(s: synstructure::Structure) -> TokenStream2 {
29    match s.ast().data {
30        syn::Data::Struct(_) => {
31            sol_error_decode_derive_struct(s).unwrap_or_else(|err| err.to_compile_error())
32        }
33        syn::Data::Enum(_) => {
34            sol_error_decode_derive_enum(s).unwrap_or_else(|err| err.to_compile_error())
35        }
36        _ => {
37            syn::Error::new(
38                s.ast().span(),
39                "can only derive `SolErrorDecode` for Rust `struct` and `enum` items",
40            )
41            .to_compile_error()
42        }
43    }
44}
45
46/// Derives the `ink::sol::SolErrorEncode` trait for the given `struct` or `enum`.
47pub fn sol_error_encode_derive(s: synstructure::Structure) -> TokenStream2 {
48    match s.ast().data {
49        syn::Data::Struct(_) => {
50            sol_error_encode_derive_struct(s).unwrap_or_else(|err| err.to_compile_error())
51        }
52        syn::Data::Enum(_) => {
53            sol_error_encode_derive_enum(s).unwrap_or_else(|err| err.to_compile_error())
54        }
55        _ => {
56            syn::Error::new(
57                s.ast().span(),
58                "can only derive `SolErrorEncode` for Rust `struct` and `enum` items",
59            )
60            .to_compile_error()
61        }
62    }
63}
64
65/// Derives the `ink::sol::SolErrorDecode` trait for the given `struct`.
66fn sol_error_decode_derive_struct(
67    s: synstructure::Structure,
68) -> syn::Result<TokenStream2> {
69    ensure_no_generics(&s, "SolErrorDecode")?;
70
71    let Some(variant) = s.variants().first() else {
72        return Err(syn::Error::new(
73            s.ast().span(),
74            "can only derive `SolErrorDecode` for Rust `struct` items",
75        ));
76    };
77
78    let name = &s.ast().ident.to_string();
79    let fields = variant.ast().fields;
80    let params_tys = fields.iter().map(|field| &field.ty);
81    let params_tuple_ty = quote! {
82        ( #( #params_tys, )* )
83    };
84    let self_body = utils::body_from_fields(fields, None);
85
86    Ok(s.bound_impl(
87        quote!(::ink::sol::SolErrorDecode),
88        quote! {
89            fn decode(data: &[::core::primitive::u8]) -> ::core::result::Result<Self, ::ink::sol::Error>
90            where
91                Self: Sized,
92            {
93                const SELECTOR: [::core::primitive::u8; 4] = ::ink::sol_error_selector!(#name, #params_tuple_ty);
94                if data[..4] == SELECTOR {
95                    <#params_tuple_ty as ::ink::sol::SolParamsDecode>::decode(
96                        &data[4..],
97                    )
98                    .map(|value| {
99                        Self #self_body
100                    })
101                } else {
102                    Err(::ink::sol::Error)
103                }
104            }
105        },
106    ))
107}
108
109/// Derives the `ink::sol::SolErrorEncode` trait for the given `struct`.
110fn sol_error_encode_derive_struct(
111    s: synstructure::Structure,
112) -> syn::Result<TokenStream2> {
113    ensure_no_generics(&s, "SolErrorEncode")?;
114
115    let Some(variant) = s.variants().first() else {
116        return Err(syn::Error::new(
117            s.ast().span(),
118            "can only derive `SolErrorEncode` for Rust `struct` items",
119        ));
120    };
121
122    let name = &s.ast().ident.to_string();
123    let fields = variant.ast().fields;
124    let selector_params_tys = fields.iter().map(|field| &field.ty);
125    let encode_params_tys = fields.iter().map(|field| {
126        let ty = &field.ty;
127        quote!( &#ty )
128    });
129    let params_elems = utils::tuple_elems_from_fields(fields, None);
130
131    Ok(s.bound_impl(
132        quote!(::ink::sol::SolErrorEncode),
133        quote! {
134            fn encode(&self) -> ::ink::prelude::vec::Vec<::core::primitive::u8> {
135                let mut results = ::ink::prelude::vec::Vec::from(
136                    ::ink::sol_error_selector!(
137                        #name,
138                        ( #( #selector_params_tys, )* )
139                    )
140                );
141                results.extend(
142                    <( #( #encode_params_tys, )* ) as ::ink::sol::SolParamsEncode>::encode(
143                        &#params_elems,
144                    ),
145                );
146                results
147            }
148        },
149    ))
150}
151
152/// Derives the `ink::sol::SolErrorDecode` trait for the given `enum`.
153fn sol_error_decode_derive_enum(s: synstructure::Structure) -> syn::Result<TokenStream2> {
154    ensure_no_generics(&s, "SolErrorDecode")?;
155    utils::ensure_non_empty_enum(&s, "SolErrorDecode")?;
156
157    let variant_selector_ident = |idx: usize| format_ident!("VARIANT_{}", idx);
158    let variant_selectors = s.variants().iter().enumerate().map(|(idx, variant)| {
159        let selector_ident = variant_selector_ident(idx);
160        let variant_name = variant.ast().ident.to_string();
161        let fields = variant.ast().fields;
162        let param_tys = fields.iter().map(|field| &field.ty);
163        quote! {
164            const #selector_ident: [::core::primitive::u8; 4] = ::ink::sol_error_selector!(
165                #variant_name, ( #( #param_tys, )* )
166            );
167        }
168    });
169    let variants_match = s.variants().iter().enumerate().map(|(idx, variant)| {
170        let variant_ident = variant.ast().ident;
171        let selector_ident = variant_selector_ident(idx);
172        let fields = variant.ast().fields;
173        let param_tys = fields.iter().map(|field| &field.ty);
174        let variant_body = utils::body_from_fields(fields, None);
175        quote! {
176            #selector_ident => {
177                <( #( #param_tys, )* ) as ::ink::sol::SolParamsDecode>::decode(
178                    &data[4..],
179                )
180                .map(|value| {
181                    Self:: #variant_ident #variant_body
182                })
183            }
184        }
185    });
186
187    Ok(s.bound_impl(
188        quote!(::ink::sol::SolErrorDecode),
189        quote! {
190            fn decode(data: &[::core::primitive::u8]) -> ::core::result::Result<Self, ::ink::sol::Error>
191            where
192                Self: Sized,
193            {
194                let selector: [::core::primitive::u8; 4] = data[..4].try_into().map_err(|_| ::ink::sol::Error)?;
195
196                #( #variant_selectors )*
197
198                match selector {
199                    #( #variants_match )*
200                    _ => Err(::ink::sol::Error),
201                }
202            }
203        },
204    ))
205}
206
207/// Derives the `ink::sol::SolErrorEncode` trait for the given `enum`.
208fn sol_error_encode_derive_enum(s: synstructure::Structure) -> syn::Result<TokenStream2> {
209    ensure_no_generics(&s, "SolErrorEncode")?;
210    utils::ensure_non_empty_enum(&s, "SolErrorEncode")?;
211
212    let variants_match = s.variants().iter().map(|variant| {
213        let variant_ident = variant.ast().ident;
214        let variant_name = variant_ident.to_string();
215        let fields = variant.ast().fields;
216        let selector_params_tys = fields.iter().map(|field| &field.ty);
217        let encode_params_tys = fields.iter().map(|field| {
218            let ty = &field.ty;
219            quote!( &#ty )
220        });
221        let bindings = || {
222            variant.bindings().iter().map(|info| {
223                // var is either a field name, or generated "binding_*" name for tuple
224                // elements.
225                let var_name = info
226                    .ast()
227                    .ident
228                    .as_ref()
229                    .map(|ident| quote!(#ident))
230                    .unwrap_or_else(|| {
231                        let binding = &info.binding;
232                        quote!(#binding)
233                    });
234                var_name
235            })
236        };
237        let (variant_bindings, params_elems) = match fields {
238            // Handles named fields.
239            Fields::Named(_) => {
240                let variant_fields = bindings();
241                let params_elems = quote! {
242                    #( #variant_fields, )*
243                };
244                (
245                    quote!(
246                        {
247                            #params_elems
248                        }
249                    ),
250                    params_elems,
251                )
252            }
253            // Handles tuple elements.
254            Fields::Unnamed(_) => {
255                let variant_elems = bindings();
256                let params_elems = quote! {
257                    #( #variant_elems, )*
258                };
259                (
260                    quote! {
261                        ( #params_elems )
262                    },
263                    params_elems,
264                )
265            }
266            // Handles unit variants.
267            Fields::Unit => (quote!(), quote!()),
268        };
269
270        quote! {
271            Self:: #variant_ident #variant_bindings => {
272                let mut results = ::ink::prelude::vec::Vec::from(
273                    ::ink::sol_error_selector!(
274                        #variant_name,
275                        ( #( #selector_params_tys, )* )
276                    )
277                );
278                results.extend(
279                    <( #( #encode_params_tys, )* ) as ::ink::sol::SolParamsEncode>::encode(
280                        &( #params_elems ),
281                    ),
282                );
283                results
284            }
285        }
286    });
287
288    Ok(s.bound_impl(
289        quote!(::ink::sol::SolErrorEncode),
290        quote! {
291            fn encode(&self) -> ::ink::prelude::vec::Vec<::core::primitive::u8> {
292                match self {
293                    #( #variants_match )*
294                }
295            }
296        },
297    ))
298}
299
300/// Ensures that the given item has no generics.
301fn ensure_no_generics(s: &synstructure::Structure, trait_name: &str) -> syn::Result<()> {
302    if s.ast().generics.params.is_empty() {
303        Ok(())
304    } else {
305        Err(syn::Error::new(
306            s.ast().generics.params.span(),
307            format!(
308                "can only derive `{trait_name}` for Rust `struct` or `enum` \
309                items without generics"
310            ),
311        ))
312    }
313}