From 8afab4f5e2cd64665d0fd1bbc3f07d5aac534917 Mon Sep 17 00:00:00 2001 From: arckoor <33837362+arckoor@users.noreply.github.com> Date: Thu, 28 May 2026 17:52:26 +0200 Subject: [PATCH] Unify encoding and decoding key APIs --- src/crypto/aws_lc/ecdsa.rs | 8 +++++-- src/crypto/aws_lc/eddsa.rs | 6 +++-- src/crypto/aws_lc/hmac.rs | 16 ++++++++++---- src/crypto/aws_lc/rsa.rs | 2 +- src/crypto/rust_crypto/ecdsa.rs | 4 ++-- src/crypto/rust_crypto/eddsa.rs | 4 ++-- src/crypto/rust_crypto/hmac.rs | 16 ++++++++++---- src/crypto/rust_crypto/rsa.rs | 6 ++--- src/decoding.rs | 19 +++++----------- src/encoding.rs | 13 +++-------- src/jwk.rs | 39 ++++++++++----------------------- tests/hmac.rs | 12 ++++++++++ 12 files changed, 74 insertions(+), 71 deletions(-) diff --git a/src/crypto/aws_lc/ecdsa.rs b/src/crypto/aws_lc/ecdsa.rs index 5a5b364e..daf967c9 100644 --- a/src/crypto/aws_lc/ecdsa.rs +++ b/src/crypto/aws_lc/ecdsa.rs @@ -23,7 +23,7 @@ macro_rules! define_ecdsa_signer { } Ok(Self( - EcdsaKeyPair::from_pkcs8($signing_alg, encoding_key.inner()) + EcdsaKeyPair::from_pkcs8($signing_alg, encoding_key.as_bytes()) .map_err(|_| ErrorKind::InvalidEcdsaKey)?, )) } @@ -62,7 +62,11 @@ macro_rules! define_ecdsa_verifier { impl Verifier> for $name { fn verify(&self, msg: &[u8], signature: &Vec) -> std::result::Result<(), Error> { $verification_alg - .verify_sig(self.0.as_bytes(), msg, signature) + .verify_sig( + self.0.try_get_as_bytes().map_err(Error::from_source)?, + msg, + signature, + ) .map_err(Error::from_source)?; Ok(()) } diff --git a/src/crypto/aws_lc/eddsa.rs b/src/crypto/aws_lc/eddsa.rs index 085bf7c3..eaade202 100644 --- a/src/crypto/aws_lc/eddsa.rs +++ b/src/crypto/aws_lc/eddsa.rs @@ -16,7 +16,7 @@ impl EdDSASigner { } Ok(Self( - Ed25519KeyPair::from_pkcs8(encoding_key.inner()) + Ed25519KeyPair::from_pkcs8(encoding_key.as_bytes()) .map_err(|_| ErrorKind::InvalidEddsaKey)?, )) } @@ -48,7 +48,9 @@ impl EdDSAVerifier { impl Verifier> for EdDSAVerifier { fn verify(&self, msg: &[u8], signature: &Vec) -> std::result::Result<(), Error> { - ED25519.verify_sig(self.0.as_bytes(), msg, signature).map_err(Error::from_source)?; + ED25519 + .verify_sig(self.0.try_get_as_bytes().map_err(Error::from_source)?, msg, signature) + .map_err(Error::from_source)?; Ok(()) } } diff --git a/src/crypto/aws_lc/hmac.rs b/src/crypto/aws_lc/hmac.rs index 2997e491..cb968348 100644 --- a/src/crypto/aws_lc/hmac.rs +++ b/src/crypto/aws_lc/hmac.rs @@ -5,8 +5,8 @@ use aws_lc_rs::hmac; use signature::{Signer, Verifier}; use crate::crypto::{JwtSigner, JwtVerifier}; -use crate::errors::Result; -use crate::{Algorithm, DecodingKey, EncodingKey}; +use crate::errors::{ErrorKind, Result, new_error}; +use crate::{Algorithm, AlgorithmFamily, DecodingKey, EncodingKey}; macro_rules! define_hmac_signer { ($name:ident, $alg:expr, $hmac_alg:expr) => { @@ -14,7 +14,11 @@ macro_rules! define_hmac_signer { impl $name { pub(crate) fn new(encoding_key: &EncodingKey) -> Result { - Ok(Self(hmac::Key::new($hmac_alg, encoding_key.try_get_hmac_secret()?))) + if encoding_key.family() != AlgorithmFamily::Hmac { + return Err(new_error(ErrorKind::InvalidKeyFormat)); + } + + Ok(Self(hmac::Key::new($hmac_alg, encoding_key.as_bytes()))) } } @@ -38,7 +42,11 @@ macro_rules! define_hmac_verifier { impl $name { pub(crate) fn new(decoding_key: &DecodingKey) -> Result { - Ok(Self(hmac::Key::new($hmac_alg, decoding_key.try_get_hmac_secret()?))) + if decoding_key.family() != AlgorithmFamily::Hmac { + return Err(new_error(ErrorKind::InvalidKeyFormat)); + } + + Ok(Self(hmac::Key::new($hmac_alg, decoding_key.try_get_as_bytes()?))) } } diff --git a/src/crypto/aws_lc/rsa.rs b/src/crypto/aws_lc/rsa.rs index d72b0680..ce047b42 100644 --- a/src/crypto/aws_lc/rsa.rs +++ b/src/crypto/aws_lc/rsa.rs @@ -16,7 +16,7 @@ fn try_sign_rsa( encoding_key: &EncodingKey, msg: &[u8], ) -> std::result::Result, signature::Error> { - let key_pair = crypto_sig::RsaKeyPair::from_der(encoding_key.inner()) + let key_pair = crypto_sig::RsaKeyPair::from_der(encoding_key.as_bytes()) .map_err(signature::Error::from_source)?; let mut signature = vec![0; key_pair.public_modulus_len()]; diff --git a/src/crypto/rust_crypto/ecdsa.rs b/src/crypto/rust_crypto/ecdsa.rs index 9aad882e..cfcf105f 100644 --- a/src/crypto/rust_crypto/ecdsa.rs +++ b/src/crypto/rust_crypto/ecdsa.rs @@ -25,7 +25,7 @@ macro_rules! define_ecdsa_signer { } Ok(Self( - <$signing_key>::from_pkcs8_der(encoding_key.inner()) + <$signing_key>::from_pkcs8_der(encoding_key.as_bytes()) .map_err(|_| ErrorKind::InvalidEcdsaKey)?, )) } @@ -57,7 +57,7 @@ macro_rules! define_ecdsa_verifier { } Ok(Self( - <$verifying_key>::from_sec1_bytes(decoding_key.as_bytes()) + <$verifying_key>::from_sec1_bytes(decoding_key.try_get_as_bytes()?) .map_err(|_| ErrorKind::InvalidEcdsaKey)?, )) } diff --git a/src/crypto/rust_crypto/eddsa.rs b/src/crypto/rust_crypto/eddsa.rs index 9b77a9f8..5b509e4e 100644 --- a/src/crypto/rust_crypto/eddsa.rs +++ b/src/crypto/rust_crypto/eddsa.rs @@ -17,7 +17,7 @@ impl EdDSASigner { } Ok(Self( - SigningKey::from_pkcs8_der(encoding_key.inner()) + SigningKey::from_pkcs8_der(encoding_key.as_bytes()) .map_err(|_| ErrorKind::InvalidEddsaKey)?, )) } @@ -45,7 +45,7 @@ impl EdDSAVerifier { Ok(Self( VerifyingKey::from_bytes( - <&[u8; 32]>::try_from(&decoding_key.as_bytes()[..32]) + <&[u8; 32]>::try_from(&decoding_key.try_get_as_bytes()?[..32]) .map_err(|_| ErrorKind::InvalidEddsaKey)?, ) .map_err(|_| ErrorKind::InvalidEddsaKey)?, diff --git a/src/crypto/rust_crypto/hmac.rs b/src/crypto/rust_crypto/hmac.rs index fdcd2e12..ead21d97 100644 --- a/src/crypto/rust_crypto/hmac.rs +++ b/src/crypto/rust_crypto/hmac.rs @@ -6,8 +6,8 @@ use sha2::{Sha256, Sha384, Sha512}; use signature::{Signer, Verifier}; use crate::crypto::{JwtSigner, JwtVerifier}; -use crate::errors::{ErrorKind, Result}; -use crate::{Algorithm, DecodingKey, EncodingKey}; +use crate::errors::{ErrorKind, Result, new_error}; +use crate::{Algorithm, AlgorithmFamily, DecodingKey, EncodingKey}; type HmacSha256 = Hmac; type HmacSha384 = Hmac; @@ -20,7 +20,11 @@ macro_rules! define_hmac_signer { impl $name { pub(crate) fn new(encoding_key: &EncodingKey) -> Result { - let inner = <$hmac_type>::new_from_slice(encoding_key.try_get_hmac_secret()?) + if encoding_key.family() != AlgorithmFamily::Hmac { + return Err(new_error(ErrorKind::InvalidKeyFormat)); + } + + let inner = <$hmac_type>::new_from_slice(encoding_key.as_bytes()) .map_err(|_| ErrorKind::InvalidKeyFormat)?; Ok(Self(inner)) @@ -52,7 +56,11 @@ macro_rules! define_hmac_verifier { impl $name { pub(crate) fn new(decoding_key: &DecodingKey) -> Result { - let inner = <$hmac_type>::new_from_slice(decoding_key.try_get_hmac_secret()?) + if decoding_key.family() != AlgorithmFamily::Hmac { + return Err(new_error(ErrorKind::InvalidKeyFormat)); + } + + let inner = <$hmac_type>::new_from_slice(decoding_key.try_get_as_bytes()?) .map_err(|_| ErrorKind::InvalidKeyFormat)?; Ok(Self(inner)) diff --git a/src/crypto/rust_crypto/rsa.rs b/src/crypto/rust_crypto/rsa.rs index ba0af0fe..45e895a1 100644 --- a/src/crypto/rust_crypto/rsa.rs +++ b/src/crypto/rust_crypto/rsa.rs @@ -28,14 +28,12 @@ where H: Digest + AssociatedOid + FixedOutputReset, { let mut rng = rand::thread_rng(); + let private_key = rsa::RsaPrivateKey::from_pkcs1_der(encoding_key.as_bytes()) + .map_err(signature::Error::from_source)?; if pss { - let private_key = rsa::RsaPrivateKey::from_pkcs1_der(encoding_key.inner()) - .map_err(signature::Error::from_source)?; let signing_key = BlindedSigningKey::::new(private_key); Ok(signing_key.sign_with_rng(&mut rng, msg).to_vec()) } else { - let private_key = rsa::RsaPrivateKey::from_pkcs1_der(encoding_key.inner()) - .map_err(signature::Error::from_source)?; let signing_key = SigningKey::::new(private_key); Ok(signing_key.sign_with_rng(&mut rng, msg).to_vec()) } diff --git a/src/decoding.rs b/src/decoding.rs index b54a87e4..c0815a6d 100644 --- a/src/decoding.rs +++ b/src/decoding.rs @@ -229,20 +229,13 @@ impl DecodingKey { } } - /// Get the value of the key. - pub fn as_bytes(&self) -> &[u8] { + /// Try to get the key in raw byte format. + /// + /// To be used for defining your own `CryptoProvider`. + pub fn try_get_as_bytes(&self) -> Result<&[u8]> { match &self.kind { - DecodingKeyKind::SecretOrDer(b) => b, - DecodingKeyKind::RsaModulusExponent { .. } => unreachable!(), - } - } - - /// Try to get the HMAC secret from a key. - pub fn try_get_hmac_secret(&self) -> Result<&[u8]> { - if self.family == AlgorithmFamily::Hmac { - Ok(self.as_bytes()) - } else { - Err(new_error(ErrorKind::InvalidKeyFormat)) + DecodingKeyKind::SecretOrDer(b) => Ok(b), + DecodingKeyKind::RsaModulusExponent { .. } => Err(ErrorKind::InvalidKeyFormat.into()), } } } diff --git a/src/encoding.rs b/src/encoding.rs index 2c6bd671..746be053 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -110,18 +110,11 @@ impl EncodingKey { } /// Get the value of the key. - pub fn inner(&self) -> &[u8] { + /// + /// To be used for defining your own `CryptoProvider`. + pub fn as_bytes(&self) -> &[u8] { &self.content } - - /// Try to get the HMAC secret from a key. - pub fn try_get_hmac_secret(&self) -> Result<&[u8]> { - if self.family == AlgorithmFamily::Hmac { - Ok(self.inner()) - } else { - Err(new_error(ErrorKind::InvalidKeyFormat)) - } - } } impl Debug for EncodingKey { diff --git a/src/jwk.rs b/src/jwk.rs index 1648dcff..07bcf1ee 100644 --- a/src/jwk.rs +++ b/src/jwk.rs @@ -485,13 +485,13 @@ impl Jwk { algorithm: match key.family() { AlgorithmFamily::Hmac => AlgorithmParameters::OctetKey(OctetKeyParameters { key_type: OctetKeyType::Octet, - value: b64_encode(key.inner()), + value: b64_encode(key.as_bytes()), }), AlgorithmFamily::Rsa => { let (n, e) = (CryptoProvider::get_default() .key_utils .rsa_pub_components_from_private_key)( - key.inner() + key.as_bytes() )?; AlgorithmParameters::RSA(RSAKeyParameters { key_type: RSAKeyType::RSA, @@ -503,7 +503,7 @@ impl Jwk { let (curve, x, y) = (CryptoProvider::get_default() .key_utils .ec_pub_components_from_private_key)( - key.inner(), alg + key.as_bytes(), alg )?; AlgorithmParameters::EllipticCurve(EllipticCurveKeyParameters { key_type: EllipticCurveKeyType::EC, @@ -515,7 +515,7 @@ impl Jwk { AlgorithmFamily::Ed => { // Get the curve type based off the encoding key length // Note: here we will receive a DER key which contains a 16 byte ANS.1 header - let curve_type: EllipticCurve = match key.inner().len() { + let curve_type: EllipticCurve = match key.as_bytes().len() { // 16 byte header + 32 byte Ed25519 key 48 => Ok(EllipticCurve::Ed25519), _ => Err(Error::from(ErrorKind::InvalidEddsaKey)), @@ -525,7 +525,7 @@ impl Jwk { let public_key_bytes = (CryptoProvider::get_default() .key_utils .ed_pub_components_from_private_key)( - key.inner(), &curve_type + key.as_bytes(), &curve_type )?; AlgorithmParameters::OctetKeyPair(OctetKeyPairParameters { @@ -547,14 +547,9 @@ impl Jwk { common: CommonParameters { key_algorithm: alg.map(|a| a.into()), ..Default::default() }, algorithm: match key.family() { crate::algorithms::AlgorithmFamily::Hmac => { - let secret = match &key.kind() { - DecodingKeyKind::SecretOrDer(secret) => secret, - _ => return Err(ErrorKind::InvalidKeyFormat.into()), - }; - AlgorithmParameters::OctetKey(OctetKeyParameters { key_type: OctetKeyType::Octet, - value: b64_encode(secret), + value: b64_encode(key.try_get_as_bytes()?), }) } crate::algorithms::AlgorithmFamily::Rsa => { @@ -575,13 +570,7 @@ impl Jwk { AlgorithmParameters::RSA(RSAKeyParameters { key_type: RSAKeyType::RSA, n, e }) } crate::algorithms::AlgorithmFamily::Ec => { - let (curve, x, y) = match &key.kind() { - DecodingKeyKind::SecretOrDer(pub_bytes) => { - ec_pub_components_from_public_key(pub_bytes)? - } - _ => return Err(ErrorKind::InvalidKeyFormat.into()), - }; - + let (curve, x, y) = ec_pub_components_from_public_key(key.try_get_as_bytes()?)?; AlgorithmParameters::EllipticCurve(EllipticCurveKeyParameters { key_type: EllipticCurveKeyType::EC, curve, @@ -590,15 +579,11 @@ impl Jwk { }) } crate::algorithms::AlgorithmFamily::Ed => { - let (curve_type, x) = match &key.kind() { - DecodingKeyKind::SecretOrDer(pub_bytes) => { - match pub_bytes.len() { - // ED25519: https://datatracker.ietf.org/doc/html/rfc8032#section-5.1.5 - 32 => (EllipticCurve::Ed25519, pub_bytes), - _ => return Err(ErrorKind::InvalidEddsaKey.into()), - } - } - _ => return Err(ErrorKind::InvalidKeyFormat.into()), + let pub_bytes = key.try_get_as_bytes()?; + let (curve_type, x) = match pub_bytes.len() { + // ED25519: https://datatracker.ietf.org/doc/html/rfc8032#section-5.1.5 + 32 => (EllipticCurve::Ed25519, pub_bytes), + _ => return Err(ErrorKind::InvalidEddsaKey.into()), }; AlgorithmParameters::OctetKeyPair(OctetKeyPairParameters { diff --git a/tests/hmac.rs b/tests/hmac.rs index e309670a..794157c6 100644 --- a/tests/hmac.rs +++ b/tests/hmac.rs @@ -231,6 +231,18 @@ fn decode_token_wrong_algorithm() { assert_eq!(claims.unwrap_err().into_kind(), ErrorKind::InvalidAlgorithm); } +#[test] +#[wasm_bindgen_test] +fn decode_token_wrong_key_family() { + let token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjI1MzI1MjQ4OTF9.9r56oF7ZliOBlOAyiOFperTGxBtPykRQiWNFxhDCW98"; + let claims = decode::( + token, + &DecodingKey::from_rsa_der(b"secret"), + &Validation::new(Algorithm::HS256), + ); + assert_eq!(claims.unwrap_err().into_kind(), ErrorKind::InvalidKeyFormat); +} + #[test] #[wasm_bindgen_test] fn encode_wrong_alg_family() {