Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 59 additions & 12 deletions rustler_codegen/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use heck::ToSnakeCase;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{Data, Field, Fields, Ident, Lifetime, Lit, Meta, TypeParam, Variant};
use syn::{Attribute, Data, Field, Fields, Ident, Lifetime, Lit, Meta, TypeParam, Variant};

use super::RustlerAttr;

Expand Down Expand Up @@ -110,29 +110,49 @@ impl<'a> Context<'a> {
.iter()
.map(|field| {
let atom_fun = Self::field_to_atom_fun(field);

let ident = field.ident.as_ref().unwrap();
let ident_str = ident.to_string();
let ident_str = Self::remove_raw(&ident_str);
let atom_name = Self::field_atom_name(field);

quote! {
#atom_fun = #ident_str,
#atom_fun = #atom_name,
}
})
.collect()
})
}

pub fn field_to_atom_fun(field: &Field) -> Ident {
let ident = field.ident.as_ref().unwrap();
Self::ident_to_atom_fun(ident)
Self::atom_fun(&Self::field_atom_name(field))
}

pub fn field_atom_name(field: &Field) -> String {
Self::rename_attr(&field.attrs).unwrap_or_else(|| {
let ident = field.ident.as_ref().unwrap();
Self::ident_to_atom_name(ident)
})
}

pub fn variant_to_atom_fun(variant: &Variant) -> Ident {
Self::atom_fun(&Self::variant_atom_name(variant))
}

pub fn ident_to_atom_fun(ident: &Ident) -> Ident {
let ident_str = ident.to_string().to_snake_case();
let ident_str = Self::remove_raw(&ident_str);
pub fn variant_atom_name(variant: &Variant) -> String {
Self::rename_attr(&variant.attrs)
.unwrap_or_else(|| Self::ident_to_atom_name(&variant.ident))
}

pub fn ident_to_atom_name(ident: &Ident) -> String {
let ident_str = ident.to_string();
Self::remove_raw(&ident_str).to_snake_case()
}

fn atom_fun(atom_name: &str) -> Ident {
let suffix = atom_name
.as_bytes()
.iter()
.map(|byte| format!("{byte:02x}"))
.collect::<String>();

Ident::new(&format!("atom_{ident_str}"), Span::call_site())
Ident::new(&format!("atom_{}", suffix), Span::call_site())
}

pub fn escape_ident_with_index(ident_str: &str, index: usize, infix: &str) -> Ident {
Expand Down Expand Up @@ -161,6 +181,33 @@ impl<'a> Context<'a> {
.expect("split has always at least one element")
}

fn rename_attr(attrs: &[Attribute]) -> Option<String> {
attrs.iter().find_map(|attr| {
if !attr.path().is_ident("rustler") {
return None;
}

let Meta::List(list) = &attr.meta else {
return None;
};

let mut rename = None;
list.parse_nested_meta(|nested_meta| {
if nested_meta.path.is_ident("rename") {
let value = nested_meta.value()?;
let lit: syn::LitStr = value.parse()?;
rename = Some(lit.value());
Ok(())
} else {
Err(nested_meta.error("Expected rename in rustler attribute"))
}
})
.unwrap_or_else(|err| panic!("{}", err));

rename
})
}

fn encode_decode_attr_set(attrs: &[RustlerAttr]) -> bool {
attrs
.iter()
Expand Down
28 changes: 12 additions & 16 deletions rustler_codegen/src/tagged_enum.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use proc_macro2::{Span, TokenStream};
use quote::{quote, quote_spanned};

use heck::ToSnakeCase;
use std::collections::HashMap;
use syn::{self, spanned::Spanned, Field, Fields, FieldsNamed, FieldsUnnamed, Ident, Variant};

Expand All @@ -18,28 +17,25 @@ pub fn transcoder_decorator(ast: &syn::DeriveInput) -> TokenStream {
let atoms = variants
.iter()
.flat_map(|variant| {
let mut ret = if let Fields::Named(fields) = &variant.fields {
let fields = if let Fields::Named(fields) = &variant.fields {
fields
.named
.iter()
.map(|field| {
field
.ident
.as_ref()
.expect("Named fields must have an ident.")
(
Context::field_atom_name(field),
Context::field_to_atom_fun(field),
)
})
.collect::<Vec<_>>()
} else {
vec![]
};

ret.push(&variant.ident);
ret
})
.map(|atom_ident| {
let atom_str = atom_ident.to_string().to_snake_case();
let atom_fn = Context::ident_to_atom_fun(atom_ident);
(atom_str, atom_fn)
fields.into_iter().chain(std::iter::once((
Context::variant_atom_name(variant),
Context::variant_to_atom_fun(variant),
)))
})
.collect::<HashMap<_, _>>()
.into_iter()
Expand Down Expand Up @@ -89,7 +85,7 @@ fn gen_decoder(ctx: &Context, variants: &[&Variant], atoms_module_name: &Ident)
.iter()
.filter_map(|variant| {
let variant_ident = &variant.ident;
let atom_fn = Context::ident_to_atom_fun(variant_ident);
let atom_fn = Context::variant_to_atom_fun(variant);
match &variant.fields {
Fields::Unit => Some(gen_unit_decoder(enum_name, variant_ident, atom_fn)),
_ => None,
Expand All @@ -100,7 +96,7 @@ fn gen_decoder(ctx: &Context, variants: &[&Variant], atoms_module_name: &Ident)
.iter()
.filter_map(|variant| {
let variant_ident = &variant.ident;
let atom_fn = Context::ident_to_atom_fun(variant_ident);
let atom_fn = Context::variant_to_atom_fun(variant);
match &variant.fields {
Fields::Unnamed(fields) => Some(gen_unnamed_decoder(
enum_name,
Expand Down Expand Up @@ -161,7 +157,7 @@ fn gen_encoder(ctx: &Context, variants: &[&Variant], atoms_module_name: &Ident)
.iter()
.map(|variant| {
let variant_ident = &variant.ident;
let atom_fn = Context::ident_to_atom_fun(variant_ident);
let atom_fn = Context::variant_to_atom_fun(variant);

match &variant.fields {
Fields::Unit => gen_unit_encoder(enum_name, variant_ident, atom_fn),
Expand Down
9 changes: 9 additions & 0 deletions rustler_codegen/tests/ui/derive-rename-invalid.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use rustler_codegen::NifMap;

#[derive(NifMap)]
struct InvalidRename {
#[rustler(rename)]
value: i32,
}

fn main() {}
7 changes: 7 additions & 0 deletions rustler_codegen/tests/ui/derive-rename-invalid.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
error: proc-macro derive panicked
--> tests/ui/derive-rename-invalid.rs:3:10
|
3 | #[derive(NifMap)]
| ^^^^^^
|
= help: message: expected `=`
3 changes: 3 additions & 0 deletions rustler_tests/lib/rustler_test.ex
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,16 @@ defmodule RustlerTest do
def tuple_echo(_), do: err()
def record_echo(_), do: err()
def map_echo(_), do: err()
def renamed_map_echo(_), do: err()
def unicode_renamed_map_echo(_), do: err()
def exception_echo(_), do: err()
def struct_echo(_), do: err()
def unit_enum_echo(_), do: err()
def tagged_enum_1_echo(_), do: err()
def tagged_enum_2_echo(_), do: err()
def tagged_enum_3_echo(_), do: err()
def tagged_enum_4_echo(_), do: err()
def renamed_tagged_enum_echo(_), do: err()
def untagged_enum_echo(_), do: err()
def untagged_enum_with_truthy(_), do: err()
def untagged_enum_for_issue_370(_), do: err()
Expand Down
44 changes: 44 additions & 0 deletions rustler_tests/native/rustler_test/src/test_codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,35 @@ pub fn map_echo(map: AddMap) -> AddMap {
map
}

#[derive(NifMap)]
pub struct RenamedMap {
#[rustler(rename = "type")]
type_: rustler::Atom,
start: u32,
#[rustler(rename = "end")]
end_: u32,
#[rustler(rename = "async")]
async_: bool,
}

#[rustler::nif]
pub fn renamed_map_echo(map: RenamedMap) -> RenamedMap {
map
}

#[derive(NifMap)]
pub struct UnicodeRenamedMap {
#[rustler(rename = "name_ä")]
name_a_umlaut: u32,
#[rustler(rename = "name_ö")]
name_o_umlaut: u32,
}

#[rustler::nif]
pub fn unicode_renamed_map_echo(map: UnicodeRenamedMap) -> UnicodeRenamedMap {
map
}

#[derive(Debug, NifStruct)]
#[must_use] // Added to test Issue #152
#[module = "AddStruct"]
Expand Down Expand Up @@ -156,6 +185,21 @@ pub fn tagged_enum_4_echo(tagged_enum: TaggedEnum4) -> TaggedEnum4 {
tagged_enum
}

#[derive(NifTaggedEnum)]
pub enum RenamedTaggedEnum {
#[rustler(rename = "renamed")]
Named {
#[rustler(rename = "end")]
end_: i32,
y: i32,
},
}

#[rustler::nif]
pub fn renamed_tagged_enum_echo(tagged_enum: RenamedTaggedEnum) -> RenamedTaggedEnum {
tagged_enum
}

#[derive(NifUntaggedEnum)]
pub enum UntaggedEnum {
Foo(u32),
Expand Down
15 changes: 15 additions & 0 deletions rustler_tests/test/codegen_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ defmodule RustlerTest.CodegenTest do
assert value == RustlerTest.map_echo(value)
end

test "renamed fields" do
value = %{type: :import, start: 0, end: 12, async: true}
assert value == RustlerTest.renamed_map_echo(value)
end

test "renamed fields with unicode atoms" do
value = %{"name_ä": 1, "name_ö": 2}
assert value == RustlerTest.unicode_renamed_map_echo(value)
end

test "with invalid map" do
value = %{lhs: "invalid", rhs: 2, loc: {57, 15}}

Expand Down Expand Up @@ -330,6 +340,11 @@ defmodule RustlerTest.CodegenTest do
end)
end

test "renamed tagged enum variants and fields" do
value = {:renamed, %{end: 1, y: 2}}
assert value == RustlerTest.renamed_tagged_enum_echo(value)
end

test "untagged enum transcoder" do
assert 123 == RustlerTest.untagged_enum_echo(123)
assert "Hello" == RustlerTest.untagged_enum_echo("Hello")
Expand Down
Loading