From 4b655b96c29f86a4cef7ea4b43081c5842931692 Mon Sep 17 00:00:00 2001 From: DS Date: Fri, 12 Jun 2026 04:11:35 -0700 Subject: [PATCH 1/3] Refactor rustler attr parsing into its own module This is done in preparation to add field attributes to rustler, which will reuse some existing logic. --- rustler_codegen/src/attrs.rs | 93 ++++++++++++++++++++++++++++++++++ rustler_codegen/src/context.rs | 63 +++-------------------- rustler_codegen/src/lib.rs | 9 +--- 3 files changed, 101 insertions(+), 64 deletions(-) create mode 100644 rustler_codegen/src/attrs.rs diff --git a/rustler_codegen/src/attrs.rs b/rustler_codegen/src/attrs.rs new file mode 100644 index 00000000..9adca154 --- /dev/null +++ b/rustler_codegen/src/attrs.rs @@ -0,0 +1,93 @@ +use std::fmt::Display; + +use syn::{Ident, Lit, Meta}; + +pub(crate) trait TryFromRustlerNestedAttr: Sized { + fn collect_attrs_for_ident(ident: &Ident, meta: &Meta) -> Option>; + fn try_from_rustler_nested_attr(ident: &Ident) -> Option; + fn parse_failure_message() -> impl Display; + + fn parse_rustler(meta: &Meta) -> Vec { + if let Meta::List(ref list) = meta { + let mut attrs: Vec = vec![]; + let _ = list.parse_nested_meta(|nested_meta| { + let parsed_attr = nested_meta + .path + .get_ident() + .and_then(T::try_from_rustler_nested_attr); + + match parsed_attr { + None => Err(nested_meta.error(T::parse_failure_message())), + Some(attr) => { + attrs.push(attr); + Ok(()) + } + } + }); + + return attrs; + } + + panic!("Expected nested attributes inside the rustler attribute"); + } +} + +#[derive(Debug)] +pub(crate) enum RustlerAttr { + Encode, + Decode, + Module(String), + Tag(String), +} + +impl RustlerAttr { + fn try_parse_tag(meta: &Meta) -> Option> { + if let Meta::NameValue(ref name_value) = meta { + let expr = &name_value.value; + + if let syn::Expr::Lit(lit_expr) = expr { + if let Lit::Str(ref tag) = lit_expr.lit { + return Some(vec![RustlerAttr::Tag(tag.value())]); + } + } + } + panic!("Cannot parse tag") + } + + fn try_parse_module(meta: &Meta) -> Option> { + if let Meta::NameValue(name_value) = meta { + let expr = &name_value.value; + + if let syn::Expr::Lit(lit_expr) = expr { + if let Lit::Str(ref module) = lit_expr.lit { + let ident = format!("Elixir.{}", module.value()); + return Some(vec![RustlerAttr::Module(ident)]); + } + } + } + panic!("Cannot parse module") + } +} + +impl TryFromRustlerNestedAttr for RustlerAttr { + fn parse_failure_message() -> impl Display { + "Expected encode, decode and/or optional in rustler attribute" + } + + fn collect_attrs_for_ident(ident: &Ident, meta: &Meta) -> Option> { + match ident.to_string().as_ref() { + "rustler" => Some(Self::parse_rustler(meta)), + "tag" => Self::try_parse_tag(meta), + "module" => Self::try_parse_module(meta), + _ => None, + } + } + + fn try_from_rustler_nested_attr(ident: &Ident) -> Option { + match ident.to_string().as_ref() { + "encode" => Some(Self::Encode), + "decode" => Some(Self::Decode), + _ => None, + } + } +} diff --git a/rustler_codegen/src/context.rs b/rustler_codegen/src/context.rs index 843e11ab..14432014 100644 --- a/rustler_codegen/src/context.rs +++ b/rustler_codegen/src/context.rs @@ -1,9 +1,9 @@ use heck::ToSnakeCase; use proc_macro2::{Span, TokenStream}; use quote::quote; -use syn::{Data, Field, Fields, Ident, Lifetime, Lit, Meta, TypeParam, Variant}; +use syn::{Data, Field, Fields, Ident, Lifetime, TypeParam, Variant}; -use super::RustlerAttr; +use crate::attrs::{RustlerAttr, TryFromRustlerNestedAttr}; /// /// A parsing context struct. @@ -168,67 +168,18 @@ impl<'a> Context<'a> { } fn get_rustler_attrs(attr: &syn::Attribute) -> Vec { + Self::parse_attr::(attr) + } + + fn parse_attr(attr: &syn::Attribute) -> Vec { attr.path() .segments .iter() .filter_map(|segment| { let meta = &attr.meta; - match segment.ident.to_string().as_ref() { - "rustler" => Some(Context::parse_rustler(meta)), - "tag" => Context::try_parse_tag(meta), - "module" => Context::try_parse_module(meta), - _ => None, - } + T::collect_attrs_for_ident(&segment.ident, meta) }) .flatten() .collect() } - - fn parse_rustler(meta: &Meta) -> Vec { - if let Meta::List(ref list) = meta { - let mut attrs: Vec = vec![]; - let _ = list.parse_nested_meta(|nested_meta| { - if nested_meta.path.is_ident("encode") { - attrs.push(RustlerAttr::Encode); - Ok(()) - } else if nested_meta.path.is_ident("decode") { - attrs.push(RustlerAttr::Decode); - Ok(()) - } else { - Err(nested_meta.error("Expected encode and/or decode in rustler attribute")) - } - }); - - return attrs; - } - - panic!("Expected encode and/or decode in rustler attribute"); - } - - fn try_parse_tag(meta: &Meta) -> Option> { - if let Meta::NameValue(ref name_value) = meta { - let expr = &name_value.value; - - if let syn::Expr::Lit(lit_expr) = expr { - if let Lit::Str(ref tag) = lit_expr.lit { - return Some(vec![RustlerAttr::Tag(tag.value())]); - } - } - } - panic!("Cannot parse tag") - } - - fn try_parse_module(meta: &Meta) -> Option> { - if let Meta::NameValue(name_value) = meta { - let expr = &name_value.value; - - if let syn::Expr::Lit(lit_expr) = expr { - if let Lit::Str(ref module) = lit_expr.lit { - let ident = format!("Elixir.{}", module.value()); - return Some(vec![RustlerAttr::Module(ident)]); - } - } - } - panic!("Cannot parse module") - } } diff --git a/rustler_codegen/src/lib.rs b/rustler_codegen/src/lib.rs index 9792a3f6..2b4a8bf9 100644 --- a/rustler_codegen/src/lib.rs +++ b/rustler_codegen/src/lib.rs @@ -2,6 +2,7 @@ use proc_macro::TokenStream; +pub(crate) mod attrs; mod context; mod encode_decode_templates; mod ex_struct; @@ -15,14 +16,6 @@ mod tuple; mod unit_enum; mod untagged_enum; -#[derive(Debug)] -enum RustlerAttr { - Encode, - Decode, - Module(String), - Tag(String), -} - /// Initialise the Native Implemented Function (NIF) environment /// and register NIF functions in an Elixir module. /// From af89764a88232af9b3b2a51e380bcede9bc33c08 Mon Sep 17 00:00:00 2001 From: DS Date: Fri, 12 Jun 2026 04:11:35 -0700 Subject: [PATCH 2/3] Introduce rustler field attributes and optional_decode Since field attributes might change the behaviour of code generated for a specific field, we introduce an extra struct to hold information about the field and which attributes it has to make it easier for future code to work with them. --- rustler_codegen/src/attrs.rs | 29 +++++++++++- rustler_codegen/src/context.rs | 75 +++++++++++++++++++++++++++++--- rustler_codegen/src/ex_struct.rs | 19 ++++---- rustler_codegen/src/lib.rs | 2 +- rustler_codegen/src/map.rs | 18 ++++---- rustler_codegen/src/record.rs | 17 ++++---- rustler_codegen/src/tuple.rs | 16 ++++--- 7 files changed, 136 insertions(+), 40 deletions(-) diff --git a/rustler_codegen/src/attrs.rs b/rustler_codegen/src/attrs.rs index 9adca154..504051e6 100644 --- a/rustler_codegen/src/attrs.rs +++ b/rustler_codegen/src/attrs.rs @@ -36,6 +36,7 @@ pub(crate) trait TryFromRustlerNestedAttr: Sized { pub(crate) enum RustlerAttr { Encode, Decode, + OptionalDecode, Module(String), Tag(String), } @@ -71,7 +72,7 @@ impl RustlerAttr { impl TryFromRustlerNestedAttr for RustlerAttr { fn parse_failure_message() -> impl Display { - "Expected encode, decode and/or optional in rustler attribute" + "Expected encode, decode and/or optional_decode in rustler attribute" } fn collect_attrs_for_ident(ident: &Ident, meta: &Meta) -> Option> { @@ -87,6 +88,32 @@ impl TryFromRustlerNestedAttr for RustlerAttr { match ident.to_string().as_ref() { "encode" => Some(Self::Encode), "decode" => Some(Self::Decode), + "optional_decode" => Some(Self::OptionalDecode), + _ => None, + } + } +} + +#[derive(Debug)] +pub(crate) enum RustlerFieldAttr { + OptionalDecode, +} + +impl TryFromRustlerNestedAttr for RustlerFieldAttr { + fn parse_failure_message() -> impl Display { + "Expected optional_decode in rustler field attribute" + } + + fn collect_attrs_for_ident(ident: &Ident, meta: &Meta) -> Option> { + match ident.to_string().as_ref() { + "rustler" => Some(Self::parse_rustler(meta)), + _ => None, + } + } + + fn try_from_rustler_nested_attr(ident: &Ident) -> Option { + match ident.to_string().as_ref() { + "optional_decode" => Some(Self::OptionalDecode), _ => None, } } diff --git a/rustler_codegen/src/context.rs b/rustler_codegen/src/context.rs index 14432014..f0b6e54c 100644 --- a/rustler_codegen/src/context.rs +++ b/rustler_codegen/src/context.rs @@ -1,9 +1,30 @@ use heck::ToSnakeCase; use proc_macro2::{Span, TokenStream}; use quote::quote; -use syn::{Data, Field, Fields, Ident, Lifetime, TypeParam, Variant}; +use syn::{Data, Field, Fields, Ident, Lifetime, Type, TypeParam, Variant}; -use crate::attrs::{RustlerAttr, TryFromRustlerNestedAttr}; +use crate::attrs::{RustlerAttr, RustlerFieldAttr, TryFromRustlerNestedAttr}; + +/// +/// A helper struct to make it easier to access field attributes. +/// +/// `StructField` holds a reference to the field itself as well as any parsed attributes declared on that field. +/// +pub(crate) struct StructField<'a> { + pub field: &'a Field, + pub attrs: Vec, + + /// `true` if we identified that this field is of type `Option`. + pub is_option_type: bool, +} + +impl<'a> StructField<'a> { + pub fn optional_decode(&self) -> bool { + self.attrs + .iter() + .any(|attr| matches!(attr, RustlerFieldAttr::OptionalDecode)) + } +} /// /// A parsing context struct. @@ -17,7 +38,7 @@ pub(crate) struct Context<'a> { pub lifetimes: Vec, pub type_parameters: Vec, pub variants: Option>, - pub struct_fields: Option>, + pub struct_fields: Option>>, pub is_tuple_struct: bool, } @@ -43,7 +64,13 @@ impl<'a> Context<'a> { }; let struct_fields = match ast.data { - Data::Struct(ref data_struct) => Some(data_struct.fields.iter().collect()), + Data::Struct(ref data_struct) => Some( + data_struct + .fields + .iter() + .map(Self::struct_field_with_parsed_attrs) + .collect(), + ), _ => None, }; @@ -104,14 +131,20 @@ impl<'a> Context<'a> { .any(|attr| matches!(attr, RustlerAttr::Decode)) } + pub fn optional_decode(&self) -> bool { + self.attrs + .iter() + .any(|attr| matches!(attr, RustlerAttr::OptionalDecode)) + } + pub fn field_atoms(&self) -> Option> { self.struct_fields.as_ref().map(|struct_fields| { struct_fields .iter() .map(|field| { - let atom_fun = Self::field_to_atom_fun(field); + let atom_fun = Self::field_to_atom_fun(field.field); - let ident = field.ident.as_ref().unwrap(); + let ident = field.field.ident.as_ref().unwrap(); let ident_str = ident.to_string(); let ident_str = Self::remove_raw(&ident_str); @@ -135,6 +168,32 @@ impl<'a> Context<'a> { Ident::new(&format!("atom_{ident_str}"), Span::call_site()) } + fn struct_field_with_parsed_attrs(field: &'a Field) -> StructField<'a> { + let attrs: Vec<_> = field + .attrs + .iter() + .flat_map(Self::get_rustler_field_attrs) + .collect(); + + StructField { + field, + attrs, + is_option_type: Self::is_option_type(&field.ty), + } + } + + fn is_option_type(t: &Type) -> bool { + match t { + Type::Path(type_path) => { + // Has a chance of returning false negatives (in case Option was aliased to another name and the field uses the other name as the type) and false positives (if some other module has an `Option` type - we only check that the name is `Option`). + let type_name = type_path.path.segments.last().unwrap(); + type_name.ident == "Option" + } + Type::Paren(p) => Self::is_option_type(&p.elem), + _ => false, + } + } + pub fn escape_ident_with_index(ident_str: &str, index: usize, infix: &str) -> Ident { Ident::new( &format!( @@ -171,6 +230,10 @@ impl<'a> Context<'a> { Self::parse_attr::(attr) } + fn get_rustler_field_attrs(attr: &syn::Attribute) -> Vec { + Self::parse_attr::(attr) + } + fn parse_attr(attr: &syn::Attribute) -> Vec { attr.path() .segments diff --git a/rustler_codegen/src/ex_struct.rs b/rustler_codegen/src/ex_struct.rs index 71cff834..2a8e36de 100644 --- a/rustler_codegen/src/ex_struct.rs +++ b/rustler_codegen/src/ex_struct.rs @@ -1,10 +1,11 @@ use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned}; -use syn::{self, spanned::Spanned, Field, Ident}; +use syn::{self, spanned::Spanned, Ident}; use super::context::Context; -use super::RustlerAttr; +use crate::attrs::RustlerAttr; +use crate::context::StructField; pub fn transcoder_decorator(ast: &syn::DeriveInput, add_exception: bool) -> TokenStream { let ctx = Context::from_ast(ast); @@ -65,13 +66,13 @@ pub fn transcoder_decorator(ast: &syn::DeriveInput, add_exception: bool) -> Toke gen } -fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> TokenStream { +fn gen_decoder(ctx: &Context, fields: &[StructField], atoms_module_name: &Ident) -> TokenStream { let struct_name = ctx.ident; let struct_name_str = struct_name.to_string(); let idents: Vec<_> = fields .iter() - .map(|field| field.ident.as_ref().unwrap()) + .map(|field| field.field.ident.as_ref().unwrap()) .collect(); let (assignments, field_defs): (Vec, Vec) = fields @@ -79,10 +80,10 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T .zip(idents.iter()) .enumerate() .map(|(index, (field, ident))| { - let atom_fun = Context::field_to_atom_fun(field); + let atom_fun = Context::field_to_atom_fun(field.field); let variable = Context::escape_ident_with_index(&ident.to_string(), index, "struct"); - let assignment = quote_spanned! { field.span() => + let assignment = quote_spanned! { field.field.span() => let #variable = try_decode_field(term, #atom_fun())?; }; @@ -131,7 +132,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T fn gen_encoder( ctx: &Context, - fields: &[&Field], + fields: &[StructField], atoms_module_name: &Ident, add_exception: bool, ) -> TokenStream { @@ -144,8 +145,8 @@ fn gen_encoder( let (mut data_keys, mut data_values): (Vec<_>, Vec<_>) = fields .iter() .map(|field| { - let field_ident = field.ident.as_ref().unwrap(); - let atom_fun = Context::field_to_atom_fun(field); + let field_ident = field.field.ident.as_ref().unwrap(); + let atom_fun = Context::field_to_atom_fun(field.field); ( quote! { ::rustler::Encoder::encode(&#atom_fun(), env) }, quote! { ::rustler::Encoder::encode(&self.#field_ident, env) }, diff --git a/rustler_codegen/src/lib.rs b/rustler_codegen/src/lib.rs index 2b4a8bf9..58d7b389 100644 --- a/rustler_codegen/src/lib.rs +++ b/rustler_codegen/src/lib.rs @@ -185,7 +185,7 @@ pub fn nif_exception(input: TokenStream) -> TokenStream { /// ``` /// /// And vice versa, decoding this map would result in `value`. -#[proc_macro_derive(NifMap, attributes(rustler))] +#[proc_macro_derive(NifMap, attributes(rustler, optional_decode))] pub fn nif_map(input: TokenStream) -> TokenStream { let ast = syn::parse(input).unwrap(); map::transcoder_decorator(&ast).into() diff --git a/rustler_codegen/src/map.rs b/rustler_codegen/src/map.rs index 51c4ef13..159363b0 100644 --- a/rustler_codegen/src/map.rs +++ b/rustler_codegen/src/map.rs @@ -1,7 +1,9 @@ use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned}; -use syn::{self, spanned::Spanned, Field, Ident}; +use syn::{self, spanned::Spanned, Ident}; + +use crate::context::StructField; use super::context::Context; @@ -50,12 +52,12 @@ pub fn transcoder_decorator(ast: &syn::DeriveInput) -> TokenStream { gen } -fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> TokenStream { +fn gen_decoder(ctx: &Context, fields: &[StructField], atoms_module_name: &Ident) -> TokenStream { let struct_name = ctx.ident; let idents: Vec<_> = fields .iter() - .map(|field| field.ident.as_ref().unwrap()) + .map(|field| field.field.ident.as_ref().unwrap()) .collect(); let (assignments, field_defs): (Vec, Vec) = fields @@ -63,10 +65,10 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T .zip(idents.iter()) .enumerate() .map(|(index, (field, ident))| { - let atom_fun = Context::field_to_atom_fun(field); + let atom_fun = Context::field_to_atom_fun(field.field); let variable = Context::escape_ident_with_index(&ident.to_string(), index, "map"); - let assignment = quote_spanned! { field.span() => + let assignment = quote_spanned! { field.field.span() => let #variable = try_decode_field(term, #atom_fun())?; }; @@ -106,12 +108,12 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T ) } -fn gen_encoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> TokenStream { +fn gen_encoder(ctx: &Context, fields: &[StructField], atoms_module_name: &Ident) -> TokenStream { let (keys, values): (Vec<_>, Vec<_>) = fields .iter() .map(|field| { - let field_ident = field.ident.as_ref().unwrap(); - let atom_fun = Context::field_to_atom_fun(field); + let field_ident = field.field.ident.as_ref().unwrap(); + let atom_fun = Context::field_to_atom_fun(field.field); ( quote! { ::rustler::Encoder::encode(&#atom_fun(), env) }, quote! { ::rustler::Encoder::encode(&self.#field_ident, env) }, diff --git a/rustler_codegen/src/record.rs b/rustler_codegen/src/record.rs index 6d418e87..852fd603 100644 --- a/rustler_codegen/src/record.rs +++ b/rustler_codegen/src/record.rs @@ -1,10 +1,11 @@ use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned}; -use syn::{self, spanned::Spanned, Field, Ident, Index}; +use syn::{self, spanned::Spanned, Ident, Index}; use super::context::Context; -use super::RustlerAttr; +use crate::attrs::RustlerAttr; +use crate::context::StructField; pub fn transcoder_decorator(ast: &syn::DeriveInput) -> TokenStream { let ctx = Context::from_ast(ast); @@ -49,7 +50,7 @@ pub fn transcoder_decorator(ast: &syn::DeriveInput) -> TokenStream { gen } -fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> TokenStream { +fn gen_decoder(ctx: &Context, fields: &[StructField], atoms_module_name: &Ident) -> TokenStream { let struct_name = ctx.ident; // Make a decoder for each of the fields in the struct. @@ -57,7 +58,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T .iter() .enumerate() .map(|(index, field)| { - let ident = field.ident.as_ref(); + let ident = field.field.ident.as_ref(); let pos_in_struct = if let Some(ident) = ident { ident.to_string() } else { @@ -67,7 +68,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T let variable = Context::escape_ident(&pos_in_struct, "record"); - let assignment = quote_spanned! { field.span() => + let assignment = quote_spanned! { field.field.span() => let #variable = try_decode_index(&terms, #pos_in_struct, #actual_index)?; }; @@ -135,19 +136,19 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T ) } -fn gen_encoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> TokenStream { +fn gen_encoder(ctx: &Context, fields: &[StructField], atoms_module_name: &Ident) -> TokenStream { // Make a field encoder expression for each of the items in the struct. let field_encoders: Vec = fields .iter() .enumerate() .map(|(index, field)| { let literal_index = Index::from(index); - let field_source = match field.ident.as_ref() { + let field_source = match field.field.ident.as_ref() { None => quote! { self.#literal_index }, Some(ident) => quote! { self.#ident }, }; - quote_spanned! { field.span() => ::rustler::Encoder::encode(&#field_source, env) } + quote_spanned! { field.field.span() => ::rustler::Encoder::encode(&#field_source, env) } }) .collect(); diff --git a/rustler_codegen/src/tuple.rs b/rustler_codegen/src/tuple.rs index 900f5608..c0c50ff8 100644 --- a/rustler_codegen/src/tuple.rs +++ b/rustler_codegen/src/tuple.rs @@ -1,7 +1,9 @@ use proc_macro2::TokenStream; use quote::{quote, quote_spanned}; -use syn::{self, spanned::Spanned, Field, Index}; +use syn::{self, spanned::Spanned, Index}; + +use crate::context::StructField; use super::context::Context; @@ -33,7 +35,7 @@ pub fn transcoder_decorator(ast: &syn::DeriveInput) -> TokenStream { gen } -fn gen_decoder(ctx: &Context, fields: &[&Field]) -> TokenStream { +fn gen_decoder(ctx: &Context, fields: &[StructField]) -> TokenStream { let struct_name = ctx.ident; let struct_name_str = struct_name.to_string(); @@ -42,7 +44,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field]) -> TokenStream { .iter() .enumerate() .map(|(index, field)| { - let ident = field.ident.as_ref(); + let ident = field.field.ident.as_ref(); let pos_in_struct = if let Some(ident) = ident { ident.to_string() } else { @@ -51,7 +53,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field]) -> TokenStream { let variable = Context::escape_ident(&pos_in_struct, "struct"); - let assignment = quote_spanned! { field.span() => + let assignment = quote_spanned! { field.field.span() => let #variable = try_decode_index(&terms, #pos_in_struct, #index)?; }; @@ -105,19 +107,19 @@ fn gen_decoder(ctx: &Context, fields: &[&Field]) -> TokenStream { ) } -fn gen_encoder(ctx: &Context, fields: &[&Field]) -> TokenStream { +fn gen_encoder(ctx: &Context, fields: &[StructField]) -> TokenStream { // Make a field encoder expression for each of the items in the struct. let field_encoders: Vec = fields .iter() .enumerate() .map(|(index, field)| { let literal_index = Index::from(index); - let field_source = match field.ident.as_ref() { + let field_source = match field.field.ident.as_ref() { None => quote! { self.#literal_index }, Some(ident) => quote! { self.#ident }, }; - quote_spanned! { field.span() => ::rustler::Encoder::encode(&#field_source, env) } + quote_spanned! { field.field.span() => ::rustler::Encoder::encode(&#field_source, env) } }) .collect(); From b9653000e10a9bd529726148c73549aa79700413 Mon Sep 17 00:00:00 2001 From: DS Date: Fri, 12 Jun 2026 07:28:17 -0700 Subject: [PATCH 3/3] Implement optional_decode for NifMap This allows Elixir maps with missing keys to decode as None variants for fields that are Options, as long as the attribute is applied to them. --- rustler_codegen/src/lib.rs | 70 +++++++++++++++++++ rustler_codegen/src/map.rs | 68 +++++++++++------- rustler_tests/lib/rustler_test.ex | 2 + .../native/rustler_compile_tests/src/lib.rs | 15 ++++ .../native/rustler_test/src/test_codegen.rs | 25 +++++++ rustler_tests/test/codegen_test.exs | 31 ++++++++ 6 files changed, 187 insertions(+), 24 deletions(-) diff --git a/rustler_codegen/src/lib.rs b/rustler_codegen/src/lib.rs index 58d7b389..c04f8a1a 100644 --- a/rustler_codegen/src/lib.rs +++ b/rustler_codegen/src/lib.rs @@ -185,6 +185,76 @@ pub fn nif_exception(input: TokenStream) -> TokenStream { /// ``` /// /// And vice versa, decoding this map would result in `value`. +/// +/// By default, any fields with type `Option` need to be present in the map, otherwise decoding will fail: +/// +/// ```ignore +/// #[derive(NifMap)] +/// struct MapWithOption { +/// x: i32, +/// y: Option, +/// z: Option, +/// } +/// ``` +/// +/// ```elixir +/// # Both of these will successfully decode into MapWithOption: +/// %{x: 1, y: 2, z: 3} +/// %{x: 1, y: nil, z: 3} +/// # These will not: +/// %{x: 1} +/// %{x: 1, y: 2} +/// %{x: 1, y: nil} +/// %{x: 1, z: 3} +/// %{x: 1, z: nil} +/// ``` +/// +/// If you wish to treat missing keys in an Elixir map as `None` in the struct, use the `#[rustler(optional_decode)]` attribute: +/// +/// ```ignore +/// #[derive(NifMap)] +/// #[rustler(optional_decode)] // Allows any fields with type Option to be missing in the map. +/// struct MapWithOption { +/// x: i32, +/// y: Option, +/// z: Option, +/// } +/// ``` +/// +/// ```elixir +/// # All of these will successfully decode into MapWithOption: +/// %{x: 1, y: 2, z: 3} +/// %{x: 1, y: nil, z: 3} +/// %{x: 1} +/// %{x: 1, y: 2} +/// %{x: 1, y: nil} +/// %{x: 1, z: 3} +/// %{x: 1, z: nil} +/// ``` +/// +/// `#[rustler(optional_decode)]` can also be applied only to specific fields: +/// +/// ```ignore +/// #[derive(NifMap)] +/// struct MapWithOption { +/// x: i32, +/// #[rustler(optional_decode)] +/// y: Option, +/// z: Option, +/// } +/// ``` +/// +/// ```elixir +/// # These will successfully decode into MapWithOption: +/// %{x: 1, y: 2, z: 3} +/// %{x: 1, y: nil, z: 3} +/// %{x: 1, z: 3} +/// %{x: 1, z: nil} +/// # These will not: +/// %{x: 1} +/// %{x: 1, y: 2} +/// %{x: 1, y: nil} +/// ``` #[proc_macro_derive(NifMap, attributes(rustler, optional_decode))] pub fn nif_map(input: TokenStream) -> TokenStream { let ast = syn::parse(input).unwrap(); diff --git a/rustler_codegen/src/map.rs b/rustler_codegen/src/map.rs index 159363b0..73bf4308 100644 --- a/rustler_codegen/src/map.rs +++ b/rustler_codegen/src/map.rs @@ -55,21 +55,25 @@ pub fn transcoder_decorator(ast: &syn::DeriveInput) -> TokenStream { fn gen_decoder(ctx: &Context, fields: &[StructField], atoms_module_name: &Ident) -> TokenStream { let struct_name = ctx.ident; - let idents: Vec<_> = fields - .iter() - .map(|field| field.field.ident.as_ref().unwrap()) - .collect(); - + let has_global_optional = ctx.optional_decode(); let (assignments, field_defs): (Vec, Vec) = fields .iter() - .zip(idents.iter()) .enumerate() - .map(|(index, (field, ident))| { + .map(|(index, field)| { + let ident = field.field.ident.as_ref().unwrap(); + let is_optional = has_global_optional || field.optional_decode(); + let atom_fun = Context::field_to_atom_fun(field.field); let variable = Context::escape_ident_with_index(&ident.to_string(), index, "map"); - let assignment = quote_spanned! { field.field.span() => - let #variable = try_decode_field(term, #atom_fun())?; + let assignment = if field.is_option_type { + quote_spanned! { field.field.span() => + let #variable = try_decode_field_optional(term, #atom_fun(), #is_optional)?; + } + } else { + quote_spanned! { field.field.span() => + let #variable = try_decode_field(term, #atom_fun())?; + } }; let field_def = quote! { @@ -84,22 +88,38 @@ fn gen_decoder(ctx: &Context, fields: &[StructField], atoms_module_name: &Ident) quote! { use #atoms_module_name::*; + fn try_decode_field_optional<'a, T>( + term: ::rustler::Term<'a>, + field: ::rustler::Atom, + is_optional: bool, + ) -> ::rustler::NifResult> + where + T: ::rustler::Decoder<'a>, + { + let field_is_missing = term.map_get(&field).is_err(); + if is_optional && field_is_missing { + Ok(None) + } else { + try_decode_field(term, field) + } + } + fn try_decode_field<'a, T>( - term: rustler::Term<'a>, - field: rustler::Atom, - ) -> ::rustler::NifResult - where - T: rustler::Decoder<'a>, - { - use rustler::Encoder; - match ::rustler::Decoder::decode(term.map_get(&field)?) { - Err(_) => Err(::rustler::Error::RaiseTerm(Box::new(format!( - "Could not decode field :{:?} on %{{}}", - field - )))), - Ok(value) => Ok(value), - } - }; + term: ::rustler::Term<'a>, + field: ::rustler::Atom, + ) -> ::rustler::NifResult + where + T: ::rustler::Decoder<'a>, + { + use rustler::Encoder; + match ::rustler::Decoder::decode(term.map_get(&field)?) { + Err(_) => Err(::rustler::Error::RaiseTerm(Box::new(format!( + "Could not decode field :{:?} on %{{}}", + field + )))), + Ok(value) => Ok(value), + } + } #(#assignments);* diff --git a/rustler_tests/lib/rustler_test.ex b/rustler_tests/lib/rustler_test.ex index 82192335..8b4c67da 100644 --- a/rustler_tests/lib/rustler_test.ex +++ b/rustler_tests/lib/rustler_test.ex @@ -115,6 +115,8 @@ defmodule RustlerTest do def tuple_echo(_), do: err() def record_echo(_), do: err() def map_echo(_), do: err() + def map_with_optional_echo(_), do: err() + def map_with_optional_field_echo(_), do: err() def exception_echo(_), do: err() def struct_echo(_), do: err() def unit_enum_echo(_), do: err() diff --git a/rustler_tests/native/rustler_compile_tests/src/lib.rs b/rustler_tests/native/rustler_compile_tests/src/lib.rs index 1ced5650..e8861da9 100644 --- a/rustler_tests/native/rustler_compile_tests/src/lib.rs +++ b/rustler_tests/native/rustler_compile_tests/src/lib.rs @@ -38,4 +38,19 @@ pub mod lifetimes { pub i: Binary<'a>, pub j: Binary<'b>, } + + #[derive(NifMap)] + #[rustler(optional_decode)] + pub struct GenericMapWithOptional<'a, 'b> { + pub i: Binary<'a>, + pub j: Binary<'b>, + } + + #[derive(NifMap)] + pub struct GenericMapWithOptionalField<'a, 'b> { + #[rustler(optional_decode)] + pub i: Option>, + pub j: Binary<'b>, + pub k: Option>, + } } diff --git a/rustler_tests/native/rustler_test/src/test_codegen.rs b/rustler_tests/native/rustler_test/src/test_codegen.rs index aa72678e..c2d91625 100644 --- a/rustler_tests/native/rustler_test/src/test_codegen.rs +++ b/rustler_tests/native/rustler_test/src/test_codegen.rs @@ -54,6 +54,31 @@ pub fn map_echo(map: AddMap) -> AddMap { map } +#[derive(NifMap)] +#[rustler(optional_decode)] +pub struct MapWithOptional { + x: i32, + y: Option, +} + +#[rustler::nif] +pub fn map_with_optional_echo(map: MapWithOptional) -> MapWithOptional { + map +} + +#[derive(NifMap)] +pub struct MapWithOptionalField { + x: i32, + #[rustler(optional_decode)] + y: Option, + z: Option, +} + +#[rustler::nif] +pub fn map_with_optional_field_echo(map: MapWithOptionalField) -> MapWithOptionalField { + map +} + #[derive(Debug, NifStruct)] #[must_use] // Added to test Issue #152 #[module = "AddStruct"] diff --git a/rustler_tests/test/codegen_test.exs b/rustler_tests/test/codegen_test.exs index 069f8471..71feb4cb 100644 --- a/rustler_tests/test/codegen_test.exs +++ b/rustler_tests/test/codegen_test.exs @@ -45,6 +45,29 @@ defmodule RustlerTest.CodegenTest do assert value == RustlerTest.map_echo(value) end + test "transcoder 2" do + value = %{x: 1, y: 2} + assert value == RustlerTest.map_with_optional_echo(value) + + value = %{x: 1} + expected_value = %{x: 1, y: nil} + assert expected_value == RustlerTest.map_with_optional_echo(value) + end + + test "transcoder 3" do + value = %{x: 1, y: 2, z: nil} + assert value == RustlerTest.map_with_optional_field_echo(value) + value = %{x: 1, y: 2, z: 3} + assert value == RustlerTest.map_with_optional_field_echo(value) + + value = %{x: 1, z: nil} + expected_value = %{x: 1, y: nil, z: nil} + assert expected_value == RustlerTest.map_with_optional_field_echo(value) + value = %{x: 1, z: 3} + expected_value = %{x: 1, y: nil, z: 3} + assert expected_value == RustlerTest.map_with_optional_field_echo(value) + end + test "with invalid map" do value = %{lhs: "invalid", rhs: 2, loc: {57, 15}} @@ -52,6 +75,14 @@ defmodule RustlerTest.CodegenTest do assert value == RustlerTest.map_echo(value) end end + + test "with invalid map with optional field" do + value = %{x: 1} + + assert_raise ArgumentError, fn -> + assert value == RustlerTest.map_with_optional_field_echo(value) + end + end end describe "struct" do