diff --git a/rustler_codegen/src/attrs.rs b/rustler_codegen/src/attrs.rs new file mode 100644 index 000000000..504051e66 --- /dev/null +++ b/rustler_codegen/src/attrs.rs @@ -0,0 +1,120 @@ +use std::fmt::Display; + +use syn::{Ident, Lit, Meta}; + +pub(crate) trait TryFromRustlerNestedAttr: Sized { + fn collect_attrs_for_ident(ident: &Ident, meta: &Meta) -> Option>; + fn try_from_rustler_nested_attr(ident: &Ident) -> Option; + fn parse_failure_message() -> impl Display; + + fn parse_rustler(meta: &Meta) -> Vec { + if let Meta::List(ref list) = meta { + let mut attrs: Vec = vec![]; + let _ = list.parse_nested_meta(|nested_meta| { + let parsed_attr = nested_meta + .path + .get_ident() + .and_then(T::try_from_rustler_nested_attr); + + match parsed_attr { + None => Err(nested_meta.error(T::parse_failure_message())), + Some(attr) => { + attrs.push(attr); + Ok(()) + } + } + }); + + return attrs; + } + + panic!("Expected nested attributes inside the rustler attribute"); + } +} + +#[derive(Debug)] +pub(crate) enum RustlerAttr { + Encode, + Decode, + OptionalDecode, + Module(String), + Tag(String), +} + +impl RustlerAttr { + fn try_parse_tag(meta: &Meta) -> Option> { + if let Meta::NameValue(ref name_value) = meta { + let expr = &name_value.value; + + if let syn::Expr::Lit(lit_expr) = expr { + if let Lit::Str(ref tag) = lit_expr.lit { + return Some(vec![RustlerAttr::Tag(tag.value())]); + } + } + } + panic!("Cannot parse tag") + } + + fn try_parse_module(meta: &Meta) -> Option> { + if let Meta::NameValue(name_value) = meta { + let expr = &name_value.value; + + if let syn::Expr::Lit(lit_expr) = expr { + if let Lit::Str(ref module) = lit_expr.lit { + let ident = format!("Elixir.{}", module.value()); + return Some(vec![RustlerAttr::Module(ident)]); + } + } + } + panic!("Cannot parse module") + } +} + +impl TryFromRustlerNestedAttr for RustlerAttr { + fn parse_failure_message() -> impl Display { + "Expected encode, decode and/or optional_decode in rustler attribute" + } + + fn collect_attrs_for_ident(ident: &Ident, meta: &Meta) -> Option> { + match ident.to_string().as_ref() { + "rustler" => Some(Self::parse_rustler(meta)), + "tag" => Self::try_parse_tag(meta), + "module" => Self::try_parse_module(meta), + _ => None, + } + } + + fn try_from_rustler_nested_attr(ident: &Ident) -> Option { + match ident.to_string().as_ref() { + "encode" => Some(Self::Encode), + "decode" => Some(Self::Decode), + "optional_decode" => Some(Self::OptionalDecode), + _ => None, + } + } +} + +#[derive(Debug)] +pub(crate) enum RustlerFieldAttr { + OptionalDecode, +} + +impl TryFromRustlerNestedAttr for RustlerFieldAttr { + fn parse_failure_message() -> impl Display { + "Expected optional_decode in rustler field attribute" + } + + fn collect_attrs_for_ident(ident: &Ident, meta: &Meta) -> Option> { + match ident.to_string().as_ref() { + "rustler" => Some(Self::parse_rustler(meta)), + _ => None, + } + } + + fn try_from_rustler_nested_attr(ident: &Ident) -> Option { + match ident.to_string().as_ref() { + "optional_decode" => Some(Self::OptionalDecode), + _ => None, + } + } +} diff --git a/rustler_codegen/src/context.rs b/rustler_codegen/src/context.rs index 843e11ab1..f0b6e54c5 100644 --- a/rustler_codegen/src/context.rs +++ b/rustler_codegen/src/context.rs @@ -1,9 +1,30 @@ use heck::ToSnakeCase; use proc_macro2::{Span, TokenStream}; use quote::quote; -use syn::{Data, Field, Fields, Ident, Lifetime, Lit, Meta, TypeParam, Variant}; +use syn::{Data, Field, Fields, Ident, Lifetime, Type, TypeParam, Variant}; -use super::RustlerAttr; +use crate::attrs::{RustlerAttr, RustlerFieldAttr, TryFromRustlerNestedAttr}; + +/// +/// A helper struct to make it easier to access field attributes. +/// +/// `StructField` holds a reference to the field itself as well as any parsed attributes declared on that field. +/// +pub(crate) struct StructField<'a> { + pub field: &'a Field, + pub attrs: Vec, + + /// `true` if we identified that this field is of type `Option`. + pub is_option_type: bool, +} + +impl<'a> StructField<'a> { + pub fn optional_decode(&self) -> bool { + self.attrs + .iter() + .any(|attr| matches!(attr, RustlerFieldAttr::OptionalDecode)) + } +} /// /// A parsing context struct. @@ -17,7 +38,7 @@ pub(crate) struct Context<'a> { pub lifetimes: Vec, pub type_parameters: Vec, pub variants: Option>, - pub struct_fields: Option>, + pub struct_fields: Option>>, pub is_tuple_struct: bool, } @@ -43,7 +64,13 @@ impl<'a> Context<'a> { }; let struct_fields = match ast.data { - Data::Struct(ref data_struct) => Some(data_struct.fields.iter().collect()), + Data::Struct(ref data_struct) => Some( + data_struct + .fields + .iter() + .map(Self::struct_field_with_parsed_attrs) + .collect(), + ), _ => None, }; @@ -104,14 +131,20 @@ impl<'a> Context<'a> { .any(|attr| matches!(attr, RustlerAttr::Decode)) } + pub fn optional_decode(&self) -> bool { + self.attrs + .iter() + .any(|attr| matches!(attr, RustlerAttr::OptionalDecode)) + } + pub fn field_atoms(&self) -> Option> { self.struct_fields.as_ref().map(|struct_fields| { struct_fields .iter() .map(|field| { - let atom_fun = Self::field_to_atom_fun(field); + let atom_fun = Self::field_to_atom_fun(field.field); - let ident = field.ident.as_ref().unwrap(); + let ident = field.field.ident.as_ref().unwrap(); let ident_str = ident.to_string(); let ident_str = Self::remove_raw(&ident_str); @@ -135,6 +168,32 @@ impl<'a> Context<'a> { Ident::new(&format!("atom_{ident_str}"), Span::call_site()) } + fn struct_field_with_parsed_attrs(field: &'a Field) -> StructField<'a> { + let attrs: Vec<_> = field + .attrs + .iter() + .flat_map(Self::get_rustler_field_attrs) + .collect(); + + StructField { + field, + attrs, + is_option_type: Self::is_option_type(&field.ty), + } + } + + fn is_option_type(t: &Type) -> bool { + match t { + Type::Path(type_path) => { + // Has a chance of returning false negatives (in case Option was aliased to another name and the field uses the other name as the type) and false positives (if some other module has an `Option` type - we only check that the name is `Option`). + let type_name = type_path.path.segments.last().unwrap(); + type_name.ident == "Option" + } + Type::Paren(p) => Self::is_option_type(&p.elem), + _ => false, + } + } + pub fn escape_ident_with_index(ident_str: &str, index: usize, infix: &str) -> Ident { Ident::new( &format!( @@ -168,67 +227,22 @@ impl<'a> Context<'a> { } fn get_rustler_attrs(attr: &syn::Attribute) -> Vec { + Self::parse_attr::(attr) + } + + fn get_rustler_field_attrs(attr: &syn::Attribute) -> Vec { + Self::parse_attr::(attr) + } + + fn parse_attr(attr: &syn::Attribute) -> Vec { attr.path() .segments .iter() .filter_map(|segment| { let meta = &attr.meta; - match segment.ident.to_string().as_ref() { - "rustler" => Some(Context::parse_rustler(meta)), - "tag" => Context::try_parse_tag(meta), - "module" => Context::try_parse_module(meta), - _ => None, - } + T::collect_attrs_for_ident(&segment.ident, meta) }) .flatten() .collect() } - - fn parse_rustler(meta: &Meta) -> Vec { - if let Meta::List(ref list) = meta { - let mut attrs: Vec = vec![]; - let _ = list.parse_nested_meta(|nested_meta| { - if nested_meta.path.is_ident("encode") { - attrs.push(RustlerAttr::Encode); - Ok(()) - } else if nested_meta.path.is_ident("decode") { - attrs.push(RustlerAttr::Decode); - Ok(()) - } else { - Err(nested_meta.error("Expected encode and/or decode in rustler attribute")) - } - }); - - return attrs; - } - - panic!("Expected encode and/or decode in rustler attribute"); - } - - fn try_parse_tag(meta: &Meta) -> Option> { - if let Meta::NameValue(ref name_value) = meta { - let expr = &name_value.value; - - if let syn::Expr::Lit(lit_expr) = expr { - if let Lit::Str(ref tag) = lit_expr.lit { - return Some(vec![RustlerAttr::Tag(tag.value())]); - } - } - } - panic!("Cannot parse tag") - } - - fn try_parse_module(meta: &Meta) -> Option> { - if let Meta::NameValue(name_value) = meta { - let expr = &name_value.value; - - if let syn::Expr::Lit(lit_expr) = expr { - if let Lit::Str(ref module) = lit_expr.lit { - let ident = format!("Elixir.{}", module.value()); - return Some(vec![RustlerAttr::Module(ident)]); - } - } - } - panic!("Cannot parse module") - } } diff --git a/rustler_codegen/src/ex_struct.rs b/rustler_codegen/src/ex_struct.rs index 71cff834b..2a8e36de7 100644 --- a/rustler_codegen/src/ex_struct.rs +++ b/rustler_codegen/src/ex_struct.rs @@ -1,10 +1,11 @@ use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned}; -use syn::{self, spanned::Spanned, Field, Ident}; +use syn::{self, spanned::Spanned, Ident}; use super::context::Context; -use super::RustlerAttr; +use crate::attrs::RustlerAttr; +use crate::context::StructField; pub fn transcoder_decorator(ast: &syn::DeriveInput, add_exception: bool) -> TokenStream { let ctx = Context::from_ast(ast); @@ -65,13 +66,13 @@ pub fn transcoder_decorator(ast: &syn::DeriveInput, add_exception: bool) -> Toke gen } -fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> TokenStream { +fn gen_decoder(ctx: &Context, fields: &[StructField], atoms_module_name: &Ident) -> TokenStream { let struct_name = ctx.ident; let struct_name_str = struct_name.to_string(); let idents: Vec<_> = fields .iter() - .map(|field| field.ident.as_ref().unwrap()) + .map(|field| field.field.ident.as_ref().unwrap()) .collect(); let (assignments, field_defs): (Vec, Vec) = fields @@ -79,10 +80,10 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T .zip(idents.iter()) .enumerate() .map(|(index, (field, ident))| { - let atom_fun = Context::field_to_atom_fun(field); + let atom_fun = Context::field_to_atom_fun(field.field); let variable = Context::escape_ident_with_index(&ident.to_string(), index, "struct"); - let assignment = quote_spanned! { field.span() => + let assignment = quote_spanned! { field.field.span() => let #variable = try_decode_field(term, #atom_fun())?; }; @@ -131,7 +132,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T fn gen_encoder( ctx: &Context, - fields: &[&Field], + fields: &[StructField], atoms_module_name: &Ident, add_exception: bool, ) -> TokenStream { @@ -144,8 +145,8 @@ fn gen_encoder( let (mut data_keys, mut data_values): (Vec<_>, Vec<_>) = fields .iter() .map(|field| { - let field_ident = field.ident.as_ref().unwrap(); - let atom_fun = Context::field_to_atom_fun(field); + let field_ident = field.field.ident.as_ref().unwrap(); + let atom_fun = Context::field_to_atom_fun(field.field); ( quote! { ::rustler::Encoder::encode(&#atom_fun(), env) }, quote! { ::rustler::Encoder::encode(&self.#field_ident, env) }, diff --git a/rustler_codegen/src/lib.rs b/rustler_codegen/src/lib.rs index 9792a3f68..c04f8a1a6 100644 --- a/rustler_codegen/src/lib.rs +++ b/rustler_codegen/src/lib.rs @@ -2,6 +2,7 @@ use proc_macro::TokenStream; +pub(crate) mod attrs; mod context; mod encode_decode_templates; mod ex_struct; @@ -15,14 +16,6 @@ mod tuple; mod unit_enum; mod untagged_enum; -#[derive(Debug)] -enum RustlerAttr { - Encode, - Decode, - Module(String), - Tag(String), -} - /// Initialise the Native Implemented Function (NIF) environment /// and register NIF functions in an Elixir module. /// @@ -192,7 +185,77 @@ pub fn nif_exception(input: TokenStream) -> TokenStream { /// ``` /// /// And vice versa, decoding this map would result in `value`. -#[proc_macro_derive(NifMap, attributes(rustler))] +/// +/// By default, any fields with type `Option` need to be present in the map, otherwise decoding will fail: +/// +/// ```ignore +/// #[derive(NifMap)] +/// struct MapWithOption { +/// x: i32, +/// y: Option, +/// z: Option, +/// } +/// ``` +/// +/// ```elixir +/// # Both of these will successfully decode into MapWithOption: +/// %{x: 1, y: 2, z: 3} +/// %{x: 1, y: nil, z: 3} +/// # These will not: +/// %{x: 1} +/// %{x: 1, y: 2} +/// %{x: 1, y: nil} +/// %{x: 1, z: 3} +/// %{x: 1, z: nil} +/// ``` +/// +/// If you wish to treat missing keys in an Elixir map as `None` in the struct, use the `#[rustler(optional_decode)]` attribute: +/// +/// ```ignore +/// #[derive(NifMap)] +/// #[rustler(optional_decode)] // Allows any fields with type Option to be missing in the map. +/// struct MapWithOption { +/// x: i32, +/// y: Option, +/// z: Option, +/// } +/// ``` +/// +/// ```elixir +/// # All of these will successfully decode into MapWithOption: +/// %{x: 1, y: 2, z: 3} +/// %{x: 1, y: nil, z: 3} +/// %{x: 1} +/// %{x: 1, y: 2} +/// %{x: 1, y: nil} +/// %{x: 1, z: 3} +/// %{x: 1, z: nil} +/// ``` +/// +/// `#[rustler(optional_decode)]` can also be applied only to specific fields: +/// +/// ```ignore +/// #[derive(NifMap)] +/// struct MapWithOption { +/// x: i32, +/// #[rustler(optional_decode)] +/// y: Option, +/// z: Option, +/// } +/// ``` +/// +/// ```elixir +/// # These will successfully decode into MapWithOption: +/// %{x: 1, y: 2, z: 3} +/// %{x: 1, y: nil, z: 3} +/// %{x: 1, z: 3} +/// %{x: 1, z: nil} +/// # These will not: +/// %{x: 1} +/// %{x: 1, y: 2} +/// %{x: 1, y: nil} +/// ``` +#[proc_macro_derive(NifMap, attributes(rustler, optional_decode))] pub fn nif_map(input: TokenStream) -> TokenStream { let ast = syn::parse(input).unwrap(); map::transcoder_decorator(&ast).into() diff --git a/rustler_codegen/src/map.rs b/rustler_codegen/src/map.rs index 51c4ef13e..73bf43083 100644 --- a/rustler_codegen/src/map.rs +++ b/rustler_codegen/src/map.rs @@ -1,7 +1,9 @@ use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned}; -use syn::{self, spanned::Spanned, Field, Ident}; +use syn::{self, spanned::Spanned, Ident}; + +use crate::context::StructField; use super::context::Context; @@ -50,24 +52,28 @@ pub fn transcoder_decorator(ast: &syn::DeriveInput) -> TokenStream { gen } -fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> TokenStream { +fn gen_decoder(ctx: &Context, fields: &[StructField], atoms_module_name: &Ident) -> TokenStream { let struct_name = ctx.ident; - let idents: Vec<_> = fields - .iter() - .map(|field| field.ident.as_ref().unwrap()) - .collect(); - + let has_global_optional = ctx.optional_decode(); let (assignments, field_defs): (Vec, Vec) = fields .iter() - .zip(idents.iter()) .enumerate() - .map(|(index, (field, ident))| { - let atom_fun = Context::field_to_atom_fun(field); + .map(|(index, field)| { + let ident = field.field.ident.as_ref().unwrap(); + let is_optional = has_global_optional || field.optional_decode(); + + let atom_fun = Context::field_to_atom_fun(field.field); let variable = Context::escape_ident_with_index(&ident.to_string(), index, "map"); - let assignment = quote_spanned! { field.span() => - let #variable = try_decode_field(term, #atom_fun())?; + let assignment = if field.is_option_type { + quote_spanned! { field.field.span() => + let #variable = try_decode_field_optional(term, #atom_fun(), #is_optional)?; + } + } else { + quote_spanned! { field.field.span() => + let #variable = try_decode_field(term, #atom_fun())?; + } }; let field_def = quote! { @@ -82,22 +88,38 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T quote! { use #atoms_module_name::*; + fn try_decode_field_optional<'a, T>( + term: ::rustler::Term<'a>, + field: ::rustler::Atom, + is_optional: bool, + ) -> ::rustler::NifResult> + where + T: ::rustler::Decoder<'a>, + { + let field_is_missing = term.map_get(&field).is_err(); + if is_optional && field_is_missing { + Ok(None) + } else { + try_decode_field(term, field) + } + } + fn try_decode_field<'a, T>( - term: rustler::Term<'a>, - field: rustler::Atom, - ) -> ::rustler::NifResult - where - T: rustler::Decoder<'a>, - { - use rustler::Encoder; - match ::rustler::Decoder::decode(term.map_get(&field)?) { - Err(_) => Err(::rustler::Error::RaiseTerm(Box::new(format!( - "Could not decode field :{:?} on %{{}}", - field - )))), - Ok(value) => Ok(value), - } - }; + term: ::rustler::Term<'a>, + field: ::rustler::Atom, + ) -> ::rustler::NifResult + where + T: ::rustler::Decoder<'a>, + { + use rustler::Encoder; + match ::rustler::Decoder::decode(term.map_get(&field)?) { + Err(_) => Err(::rustler::Error::RaiseTerm(Box::new(format!( + "Could not decode field :{:?} on %{{}}", + field + )))), + Ok(value) => Ok(value), + } + } #(#assignments);* @@ -106,12 +128,12 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T ) } -fn gen_encoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> TokenStream { +fn gen_encoder(ctx: &Context, fields: &[StructField], atoms_module_name: &Ident) -> TokenStream { let (keys, values): (Vec<_>, Vec<_>) = fields .iter() .map(|field| { - let field_ident = field.ident.as_ref().unwrap(); - let atom_fun = Context::field_to_atom_fun(field); + let field_ident = field.field.ident.as_ref().unwrap(); + let atom_fun = Context::field_to_atom_fun(field.field); ( quote! { ::rustler::Encoder::encode(&#atom_fun(), env) }, quote! { ::rustler::Encoder::encode(&self.#field_ident, env) }, diff --git a/rustler_codegen/src/record.rs b/rustler_codegen/src/record.rs index 6d418e879..852fd6034 100644 --- a/rustler_codegen/src/record.rs +++ b/rustler_codegen/src/record.rs @@ -1,10 +1,11 @@ use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned}; -use syn::{self, spanned::Spanned, Field, Ident, Index}; +use syn::{self, spanned::Spanned, Ident, Index}; use super::context::Context; -use super::RustlerAttr; +use crate::attrs::RustlerAttr; +use crate::context::StructField; pub fn transcoder_decorator(ast: &syn::DeriveInput) -> TokenStream { let ctx = Context::from_ast(ast); @@ -49,7 +50,7 @@ pub fn transcoder_decorator(ast: &syn::DeriveInput) -> TokenStream { gen } -fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> TokenStream { +fn gen_decoder(ctx: &Context, fields: &[StructField], atoms_module_name: &Ident) -> TokenStream { let struct_name = ctx.ident; // Make a decoder for each of the fields in the struct. @@ -57,7 +58,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T .iter() .enumerate() .map(|(index, field)| { - let ident = field.ident.as_ref(); + let ident = field.field.ident.as_ref(); let pos_in_struct = if let Some(ident) = ident { ident.to_string() } else { @@ -67,7 +68,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T let variable = Context::escape_ident(&pos_in_struct, "record"); - let assignment = quote_spanned! { field.span() => + let assignment = quote_spanned! { field.field.span() => let #variable = try_decode_index(&terms, #pos_in_struct, #actual_index)?; }; @@ -135,19 +136,19 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T ) } -fn gen_encoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> TokenStream { +fn gen_encoder(ctx: &Context, fields: &[StructField], atoms_module_name: &Ident) -> TokenStream { // Make a field encoder expression for each of the items in the struct. let field_encoders: Vec = fields .iter() .enumerate() .map(|(index, field)| { let literal_index = Index::from(index); - let field_source = match field.ident.as_ref() { + let field_source = match field.field.ident.as_ref() { None => quote! { self.#literal_index }, Some(ident) => quote! { self.#ident }, }; - quote_spanned! { field.span() => ::rustler::Encoder::encode(&#field_source, env) } + quote_spanned! { field.field.span() => ::rustler::Encoder::encode(&#field_source, env) } }) .collect(); diff --git a/rustler_codegen/src/tuple.rs b/rustler_codegen/src/tuple.rs index 900f5608b..c0c50ff8c 100644 --- a/rustler_codegen/src/tuple.rs +++ b/rustler_codegen/src/tuple.rs @@ -1,7 +1,9 @@ use proc_macro2::TokenStream; use quote::{quote, quote_spanned}; -use syn::{self, spanned::Spanned, Field, Index}; +use syn::{self, spanned::Spanned, Index}; + +use crate::context::StructField; use super::context::Context; @@ -33,7 +35,7 @@ pub fn transcoder_decorator(ast: &syn::DeriveInput) -> TokenStream { gen } -fn gen_decoder(ctx: &Context, fields: &[&Field]) -> TokenStream { +fn gen_decoder(ctx: &Context, fields: &[StructField]) -> TokenStream { let struct_name = ctx.ident; let struct_name_str = struct_name.to_string(); @@ -42,7 +44,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field]) -> TokenStream { .iter() .enumerate() .map(|(index, field)| { - let ident = field.ident.as_ref(); + let ident = field.field.ident.as_ref(); let pos_in_struct = if let Some(ident) = ident { ident.to_string() } else { @@ -51,7 +53,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field]) -> TokenStream { let variable = Context::escape_ident(&pos_in_struct, "struct"); - let assignment = quote_spanned! { field.span() => + let assignment = quote_spanned! { field.field.span() => let #variable = try_decode_index(&terms, #pos_in_struct, #index)?; }; @@ -105,19 +107,19 @@ fn gen_decoder(ctx: &Context, fields: &[&Field]) -> TokenStream { ) } -fn gen_encoder(ctx: &Context, fields: &[&Field]) -> TokenStream { +fn gen_encoder(ctx: &Context, fields: &[StructField]) -> TokenStream { // Make a field encoder expression for each of the items in the struct. let field_encoders: Vec = fields .iter() .enumerate() .map(|(index, field)| { let literal_index = Index::from(index); - let field_source = match field.ident.as_ref() { + let field_source = match field.field.ident.as_ref() { None => quote! { self.#literal_index }, Some(ident) => quote! { self.#ident }, }; - quote_spanned! { field.span() => ::rustler::Encoder::encode(&#field_source, env) } + quote_spanned! { field.field.span() => ::rustler::Encoder::encode(&#field_source, env) } }) .collect(); diff --git a/rustler_tests/lib/rustler_test.ex b/rustler_tests/lib/rustler_test.ex index 82192335b..8b4c67dab 100644 --- a/rustler_tests/lib/rustler_test.ex +++ b/rustler_tests/lib/rustler_test.ex @@ -115,6 +115,8 @@ defmodule RustlerTest do def tuple_echo(_), do: err() def record_echo(_), do: err() def map_echo(_), do: err() + def map_with_optional_echo(_), do: err() + def map_with_optional_field_echo(_), do: err() def exception_echo(_), do: err() def struct_echo(_), do: err() def unit_enum_echo(_), do: err() diff --git a/rustler_tests/native/rustler_compile_tests/src/lib.rs b/rustler_tests/native/rustler_compile_tests/src/lib.rs index 1ced56501..e8861da96 100644 --- a/rustler_tests/native/rustler_compile_tests/src/lib.rs +++ b/rustler_tests/native/rustler_compile_tests/src/lib.rs @@ -38,4 +38,19 @@ pub mod lifetimes { pub i: Binary<'a>, pub j: Binary<'b>, } + + #[derive(NifMap)] + #[rustler(optional_decode)] + pub struct GenericMapWithOptional<'a, 'b> { + pub i: Binary<'a>, + pub j: Binary<'b>, + } + + #[derive(NifMap)] + pub struct GenericMapWithOptionalField<'a, 'b> { + #[rustler(optional_decode)] + pub i: Option>, + pub j: Binary<'b>, + pub k: Option>, + } } diff --git a/rustler_tests/native/rustler_test/src/test_codegen.rs b/rustler_tests/native/rustler_test/src/test_codegen.rs index aa72678ec..c2d916258 100644 --- a/rustler_tests/native/rustler_test/src/test_codegen.rs +++ b/rustler_tests/native/rustler_test/src/test_codegen.rs @@ -54,6 +54,31 @@ pub fn map_echo(map: AddMap) -> AddMap { map } +#[derive(NifMap)] +#[rustler(optional_decode)] +pub struct MapWithOptional { + x: i32, + y: Option, +} + +#[rustler::nif] +pub fn map_with_optional_echo(map: MapWithOptional) -> MapWithOptional { + map +} + +#[derive(NifMap)] +pub struct MapWithOptionalField { + x: i32, + #[rustler(optional_decode)] + y: Option, + z: Option, +} + +#[rustler::nif] +pub fn map_with_optional_field_echo(map: MapWithOptionalField) -> MapWithOptionalField { + map +} + #[derive(Debug, NifStruct)] #[must_use] // Added to test Issue #152 #[module = "AddStruct"] diff --git a/rustler_tests/test/codegen_test.exs b/rustler_tests/test/codegen_test.exs index 069f8471e..71feb4cbe 100644 --- a/rustler_tests/test/codegen_test.exs +++ b/rustler_tests/test/codegen_test.exs @@ -45,6 +45,29 @@ defmodule RustlerTest.CodegenTest do assert value == RustlerTest.map_echo(value) end + test "transcoder 2" do + value = %{x: 1, y: 2} + assert value == RustlerTest.map_with_optional_echo(value) + + value = %{x: 1} + expected_value = %{x: 1, y: nil} + assert expected_value == RustlerTest.map_with_optional_echo(value) + end + + test "transcoder 3" do + value = %{x: 1, y: 2, z: nil} + assert value == RustlerTest.map_with_optional_field_echo(value) + value = %{x: 1, y: 2, z: 3} + assert value == RustlerTest.map_with_optional_field_echo(value) + + value = %{x: 1, z: nil} + expected_value = %{x: 1, y: nil, z: nil} + assert expected_value == RustlerTest.map_with_optional_field_echo(value) + value = %{x: 1, z: 3} + expected_value = %{x: 1, y: nil, z: 3} + assert expected_value == RustlerTest.map_with_optional_field_echo(value) + end + test "with invalid map" do value = %{lhs: "invalid", rhs: 2, loc: {57, 15}} @@ -52,6 +75,14 @@ defmodule RustlerTest.CodegenTest do assert value == RustlerTest.map_echo(value) end end + + test "with invalid map with optional field" do + value = %{x: 1} + + assert_raise ArgumentError, fn -> + assert value == RustlerTest.map_with_optional_field_echo(value) + end + end end describe "struct" do