Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions aescbc/aescbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ func pkcs7Unpad(data []byte) ([]byte, error) {
if padding < 1 || padding > aes.BlockSize {
return nil, fmt.Errorf("invalid padding")
}
for _, b := range data[len(data)-padding:] {
if b != byte(padding) {
return nil, fmt.Errorf("invalid padding")
}
}
return data[:len(data)-padding], nil
}
Comment thread
patrislav marked this conversation as resolved.

Expand Down
6 changes: 6 additions & 0 deletions aesgcm/aesgcm.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@ import (

// Encrypt encrypts plaintext using key and random entropy. Key must be a valid AES-256 key with a length of 32 bytes.
// The result is a concatenation of nonce (using standard 12-byte nonce size) and the actual ciphertext.
//
// random must be a cryptographically secure random source (e.g. crypto/rand.Reader or an NSM session).
// AES-GCM is catastrophically broken under nonce reuse.
func Encrypt(random io.Reader, key []byte, plaintext []byte, additionalData []byte) ([]byte, error) {
if random == nil {
return nil, fmt.Errorf("random source must not be nil")
}
if len(key) != 32 {
return nil, fmt.Errorf("key must be 32 bytes for AES-256 but was %d", len(key))
}
Expand Down
7 changes: 3 additions & 4 deletions attestation/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ func Middleware(enc *enclave.Enclave, errorFn func(http.ResponseWriter, error),
return context.WithValue(r.Context(), contextKey, att), cancelFunc, nil
}

runPostMiddleware := func(w http.ResponseWriter, r *http.Request, body []byte, nonce []byte) (err error) {
runPostMiddleware := func(w http.ResponseWriter, r *http.Request, reqBody []byte, resBody []byte, nonce []byte) (err error) {
log := loggerFromContextFn(r.Context())
ctx, span := tracing.Trace(r.Context(), "attestation.Middleware")
defer func() {
span.RecordError(err)
span.End()
}()

userData, err := generateUserData(r, body)
userData, err := generateUserData(r, reqBody, resBody)
if err != nil {
return err
}
Expand Down Expand Up @@ -109,8 +109,7 @@ func Middleware(enc *enclave.Enclave, errorFn func(http.ResponseWriter, error),

next.ServeHTTP(ww, r.WithContext(ctx))

r.Body = io.NopCloser(bytes.NewBuffer(reqBody))
if err := runPostMiddleware(ww, r, body.Bytes(), nonce); err != nil {
if err := runPostMiddleware(ww, r, reqBody, body.Bytes(), nonce); err != nil {
errorFn(w, err)
return
}
Expand Down
16 changes: 2 additions & 14 deletions attestation/userdata.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package attestation

import (
"bytes"
"crypto/sha256"
"encoding/base64"
"fmt"
"io"
"net/http"
)

Expand All @@ -24,20 +22,10 @@ func (u *userData) String() string {
return fmt.Sprintf("%s/%d:%s", u.Prefix, u.Version, base64.StdEncoding.EncodeToString(u.Hash))
}

func generateUserData(r *http.Request, resBody []byte) ([]byte, error) {
func generateUserData(r *http.Request, reqBody []byte, resBody []byte) ([]byte, error) {
hasher := sha256.New()
hasher.Write([]byte(r.Method + " " + r.URL.Path + "\n"))

var reqBody []byte
var err error
if r.Body != nil {
reqBody, err = io.ReadAll(r.Body)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %w", err)
}
r.Body = io.NopCloser(bytes.NewBuffer(reqBody))
hasher.Write(reqBody)
}
hasher.Write(reqBody)
hasher.Write([]byte("\n"))
hasher.Write(resBody)
hash := hasher.Sum(nil)
Expand Down
13 changes: 5 additions & 8 deletions cms/ber.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,14 @@ func readObject(ber []byte, offset int) (asn1Object, int, error) {
for ber[offset] >= 0x80 {
tag = tag*128 + ber[offset] - 0x80
offset++
if offset > berLen {
if offset >= berLen {
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
}
}
// jvehent 20170227: this doesn't appear to be used anywhere...
//tag = tag*128 + ber[offset] - 0x80
offset++
if offset > berLen {
if offset >= berLen {
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
}
}
Expand All @@ -173,15 +173,15 @@ func readObject(ber []byte, offset int) (asn1Object, int, error) {
var length int
l := ber[offset]
offset++
if offset > berLen {
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
}
indefinite := false
if l > 0x80 {
numberOfBytes := (int)(l & 0x7F)
if numberOfBytes > 4 { // int is only guaranteed to be 32bit
return nil, 0, errors.New("ber2der: BER tag length too long")
}
if offset+numberOfBytes > berLen {
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
}
if numberOfBytes == 4 && (int)(ber[offset]) > 0x7F {
return nil, 0, errors.New("ber2der: BER tag length is negative")
}
Expand All @@ -193,9 +193,6 @@ func readObject(ber []byte, offset int) (asn1Object, int, error) {
for i := 0; i < numberOfBytes; i++ {
length = length*256 + (int)(ber[offset])
offset++
if offset > berLen {
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
}
}
} else if l == 0x80 {
indefinite = true
Expand Down
4 changes: 4 additions & 0 deletions cms/cms.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ func (ek *EncryptedKey) Decrypt(key *rsa.PrivateKey) ([]byte, error) {
return nil, errors.New("pkcs7: encryption algorithm parameters are malformed")
}

if len(ek.cipherText) == 0 || len(ek.cipherText)%block.BlockSize() != 0 {
return nil, fmt.Errorf("cms: ciphertext length %d is not a multiple of block size %d", len(ek.cipherText), block.BlockSize())
}

mode := cipher.NewCBCDecrypter(block, ek.iv)
plaintext := make([]byte, len(ek.cipherText))
mode.CryptBlocks(plaintext, ek.cipherText)
Expand Down
35 changes: 35 additions & 0 deletions cms/cms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,27 @@ cJEGAbCDYhyjvtjBLNy7YDQ1hdmCnqMxg/5AIwUMkvTTRg+qepfboA==
}
)

func TestParse_malformedBER(t *testing.T) {
tests := []struct {
name string
input []byte
}{
{"multi-byte tag truncated", []byte{0x1F, 0x80}},
{"multi-byte tag no length", []byte{0x1F, 0x01}},
{"tag only", []byte{0x30}},
{"long-form length truncated", []byte{0x30, 0x82}},
{"long-form length partial", []byte{0x30, 0x82, 0x01}},
{"length exceeds data", []byte{0x30, 0x10, 0x00}},
{"empty input", []byte{}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := cms.Parse(tt.input)
require.Error(t, err)
})
}
}

func TestDecodeCiphertextForRecipient(t *testing.T) {
block, _ := pem.Decode([]byte(testPrivateKey))
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
Expand All @@ -65,3 +86,17 @@ func TestDecodeCiphertextForRecipient(t *testing.T) {

require.Equal(t, plaintextKey, dataKey)
}

func TestDecryptEnvelopedKey_truncatedCiphertext(t *testing.T) {
block, _ := pem.Decode([]byte(testPrivateKey))
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
require.NoError(t, err)

ciphertext, err := base64.StdEncoding.DecodeString(testCiphertextString)
require.NoError(t, err)

// Truncate by one byte to misalign the inner ciphertext off a block boundary
truncated := ciphertext[:len(ciphertext)-1]
_, err = cms.DecryptEnvelopedKey(key, truncated)
require.Error(t, err)
}
10 changes: 8 additions & 2 deletions enclave/attestation.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@ func (a *Attestation) Document() []byte {

// Decrypt requests a decryption operation from KMS on ciphertext. If the key used to encrypt the
// original data is not one of allowedKeyIDs, Decrypt returns an error.
//
// allowedKeyIDs must not be empty — callers must explicitly specify which KMS keys are acceptable.
func (a *Attestation) Decrypt(ctx context.Context, ciphertext []byte, allowedKeyIDs []string) ([]byte, error) {
if len(allowedKeyIDs) == 0 {
return nil, fmt.Errorf("allowedKeyIDs must not be empty")
}

params := &kms.DecryptInput{
CiphertextBlob: ciphertext,
EncryptionAlgorithm: types.EncryptionAlgorithmSpecSymmetricDefault,
Expand Down Expand Up @@ -117,8 +123,8 @@ func (a *Attestation) GenerateDataKey(ctx context.Context, keyID string) (*DataK
}

func keyIsAllowed(key *string, allowedKeys []string) (string, bool) {
if key == nil || len(allowedKeys) == 0 {
return "", true
if key == nil {
return "", false
Comment thread
patrislav marked this conversation as resolved.
}

for _, v := range allowedKeys {
Expand Down
9 changes: 5 additions & 4 deletions enclave/attestation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,11 @@ func TestNitroAttestation_Decrypt(t *testing.T) {
doc, err := nitro.Parse(params.Recipient.AttestationDocument)
require.NoError(t, err)
assert.Equal(t, []byte("nonce"), doc.Nonce)
assert.NoError(t, doc.Validate(nitro.WithRootFingerprint("14e8bc5fabb52876f35f122289eaabfa08885837cc7f161149c6d242596258aa")))
assert.NoError(t, doc.Validate(nitro.WithRootFingerprint(doc.RootCertFingerprint())))
assert.NoError(t, doc.Verify())
assert.Equal(t, types.KeyEncryptionMechanismRsaesOaepSha256, params.Recipient.KeyEncryptionAlgorithm)
return &kms.DecryptOutput{CiphertextForRecipient: ciphertextForRecipient}, nil
keyID := "arn:aws:kms:us-east-1:000000000000:key/test-key-id"
return &kms.DecryptOutput{KeyId: &keyID, CiphertextForRecipient: ciphertextForRecipient}, nil
},
}

Expand All @@ -92,7 +93,7 @@ func TestNitroAttestation_Decrypt(t *testing.T) {
att, err := e.GetAttestation(context.Background(), []byte("nonce"), []byte("user-data"))
require.NoError(t, err)

plaintext, err := att.Decrypt(context.Background(), []byte("ciphertext"), nil)
plaintext, err := att.Decrypt(context.Background(), []byte("ciphertext"), []string{"arn:aws:kms:us-east-1:000000000000:key/test-key-id"})
require.NoError(t, err)
assert.Equal(t, expectedPlaintext, plaintext)
}()
Expand All @@ -116,7 +117,7 @@ func TestAttestation_GenerateDataKey(t *testing.T) {
doc, err := nitro.Parse(params.Recipient.AttestationDocument)
require.NoError(t, err)
assert.Equal(t, []byte("nonce"), doc.Nonce)
assert.NoError(t, doc.Validate(nitro.WithRootFingerprint("14e8bc5fabb52876f35f122289eaabfa08885837cc7f161149c6d242596258aa")))
assert.NoError(t, doc.Validate(nitro.WithRootFingerprint(doc.RootCertFingerprint())))
assert.NoError(t, doc.Verify())
assert.Equal(t, types.KeyEncryptionMechanismRsaesOaepSha256, params.Recipient.KeyEncryptionAlgorithm)
return &kms.GenerateDataKeyOutput{
Expand Down
103 changes: 42 additions & 61 deletions enclave/provider_dummy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/pem"
"fmt"
"io"
"math"
Expand All @@ -23,78 +22,60 @@ import (
"github.com/fxamacker/cbor/v2"
)

var dummyPrivKey = `-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAujDWnWEKVYoHUwieLegkzR2K+4z2Fg3uVEwmZ16iRJiYm5TO
ltLN6BSHaLCqreA1bYXXTFlIG10z2+h16fhkCNKzy4yKwjwUdXJlbBivypQers8h
Pwy1l4c+uID/VX5zXG4y7g7aNc0Ude+lzBvydh9vFz5PwupFzY6ok3czI95ODni7
hn/X/8TBGTyh0eYZu8ehfKy6W9AHbX7D+yL2qebSWWkJBEribptpCcaJi8QPUx9M
HWz8j1j83+M6rnG1FQpLl8VNOO6BXmzb5FNr+6lwEfvwHbht0Azhk0ArMQZ/r0lO
ObAvVDmE2AuudXyWWh5sRrXnXlVitDjTQybQAQIDAQABAoIBAQCYf9Poh0jdkvY4
zkAwvYkW73GcY3JT0gk4xj5WQC6MHKgyFgm3guXfhqD54GmLjK52DD+xaxciQo5t
OdMKVcYpa9qTh4NHX8oqAA6OIRIqzHLtHv3OFGzPtZhrqkx4C+AU/rV8QnH7ywNN
LYIQ0XsfwNNOqFzP+u49VPFCB0m9v7r7mJxeUXp8PDfdhquFT69hpKwNdpzuIDA7
kVOG4ATkkPTGp3AmJj9Vrit9ffi+xlbhrNIuBui9Fxo1v5G6VT2uBhXJU22zl1hS
uYWT4rCOwVQaV/TBDj4T8diDxYpnAXvpO8U+WdqLddhUNaYeDym/HPq2cFsN9VdY
9FYiVl4ZAoGBAOWVsrRAWgFTmx99nUwy6XhobSWgZDrCQiSK50VGzblBdVnmMvyW
Q3LmdqtVQUkZLETx7PZXYkvIzMRP4oWGcViBPaSZ/IqX/kF5WJeXWW7Zgl5HEXTk
GaN26xl7yFjQ5l0f++HAwSW485B2GXvMcdp+6n7OfG6Xo1cg8CgWck5TAoGBAM+c
/h03pASGVvUDNNfeDulyxcXR/PZZTt1YMTqeYLmkbkJcIJVa2uTdDmzcEbGDA0eq
ezMDA+omGB+WR7HRe9+vgmz7Ww4BZRhKjvnxRgHlTGYHBsHhYr21fgPteGv/aDi2
xhAGqyOj1jua8ooqpw8TviYXk6ZbxMNF7eV9KxXbAoGAasEjKaHKuFcyCICWhfoe
ifi02AwuzwvJSci1JYd43a3MbZMXHlCY6HK1t5GbG+xyo1SDRUD42hhy7s3enQwY
5HikO0fHIILwnW1ZfpPH6D2H22LcgSgXq+T+CQl/7ZyloaPfsee5aFsKFqBz1RcJ
0fm1/GTzg1FLiJYuVdWqLTUCgYAaOURHwH1xLN7S9+K22Y+coSimAg4nt8QkZT1i
oBqrmD9tFmHvO5imi92Elo+NknTZmokROnJGIyWs57iKl2FEMdERnvYzYK26UcCZ
hYZIOwRZZs3Ns4BbYg9Ww6oQSiSJ9VwzLgRz7f/ja4DzPsv3NZExEo1N2A2UdMLF
1/eXPQKBgQDSCJ1tWQYVLvjrzJBC5gute7kHf1AhMoIEqpsEvk51JXu7+xN8BMnb
zSwIPR3fSngqLJqGw+Tz5LT3iSsDNVj7EnaHoYvTrxsd2yFYtVmz2fHgnHXBjZmj
AzDn4G6VZ+F11K/sdfuo+1vfgxPendYDkjp0ZtgJc97iBq49Devv1A==
-----END RSA PRIVATE KEY-----`
var dummyCert = `-----BEGIN CERTIFICATE-----
MIIDITCCAgmgAwIBAgIBATANBgkqhkiG9w0BAQsFADAyMREwDwYDVQQKEwhTZXF1
ZW5jZTEdMBsGA1UEAxMUZHVtbXkubml0cm8tZW5jbGF2ZXMwHhcNMjUwNDI0MTM0
NjA5WhcNMzUwNDI0MTM0NjA5WjAyMREwDwYDVQQKEwhTZXF1ZW5jZTEdMBsGA1UE
AxMUZHVtbXkubml0cm8tZW5jbGF2ZXMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw
ggEKAoIBAQC6MNadYQpVigdTCJ4t6CTNHYr7jPYWDe5UTCZnXqJEmJiblM6W0s3o
FIdosKqt4DVthddMWUgbXTPb6HXp+GQI0rPLjIrCPBR1cmVsGK/KlB6uzyE/DLWX
hz64gP9VfnNcbjLuDto1zRR176XMG/J2H28XPk/C6kXNjqiTdzMj3k4OeLuGf9f/
xMEZPKHR5hm7x6F8rLpb0AdtfsP7Ivap5tJZaQkESuJum2kJxomLxA9TH0wdbPyP
WPzf4zqucbUVCkuXxU047oFebNvkU2v7qXAR+/AduG3QDOGTQCsxBn+vSU45sC9U
OYTYC651fJZaHmxGtedeVWK0ONNDJtABAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIB
hjAPBgNVHRMBAf8EBTADAQH/MB0GA1UdDgQWBBSw2hfihIyfiqyiuiuTp3OCt0Sl
8DANBgkqhkiG9w0BAQsFAAOCAQEAl55+EnYlS5/YTQQhZozA/XW7Y9Kt00w9k0Ix
9vXTVeZdzTNR/YKCAzG7ynNjNbdFkhJcqqwKycVOSID0Xz4dWvB6jVukIV6B3W2u
ta/P4SYg4VQ9YzPqF1n1sUzX3OwKOhEcSxQQjvs8ssRaWq9aqEHyxCxuc9BWoqvB
Am9iwrNpmUmlRbFwDOwtICZRbqAf799pOFo1i8WKQc/J5y1KwZCCg3GAEBv8CNQE
vMVH5ygi1fMeQPNg8oWDD+3gP1GmLGMP14kHT/aPyDAHHUMrq7nSgA8SXTC9fihO
sygULgtpiSjKgeg9cTvK9yhz7T0c2CxFgyhUnz4v6uZtQTJK2Q==
-----END CERTIFICATE-----`

func DummyProvider(random io.Reader) func() (Session, error) {
if random == nil {
random = rand.Reader
}
return func() (Session, error) {
block, _ := pem.Decode([]byte(dummyPrivKey))
if block == nil || block.Type != "RSA PRIVATE KEY" {
return nil, fmt.Errorf("invalid PEM block")

// Generate an ephemeral CA key pair and self-signed certificate once per provider instance.
// This avoids shipping static key material in the source code.
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return func() (Session, error) {
return nil, fmt.Errorf("failed to generate CA key: %v", err)
}
}

serialNumber, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64))
if err != nil {
return func() (Session, error) {
return nil, fmt.Errorf("failed to generate serial number: %v", err)
}
Comment thread
patrislav marked this conversation as resolved.
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse dummy private key: %v", err)
}

caTemplate := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Dummy"},
CommonName: "dummy.nitro-enclaves",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}

caCertDER, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, &caKey.PublicKey, caKey)
if err != nil {
return func() (Session, error) {
return nil, fmt.Errorf("failed to create CA certificate: %v", err)
}
}

certBlock, _ := pem.Decode([]byte(dummyCert))
caCert, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil {
caCert, err := x509.ParseCertificate(caCertDER)
if err != nil {
return func() (Session, error) {
return nil, fmt.Errorf("failed to parse CA certificate: %v", err)
}
}

return func() (Session, error) {
return &dummySession{
random: random,
privateKey: key,
privateKey: caKey,
caCert: caCert,
caCertDER: certBlock.Bytes,
caCertDER: caCertDER,
}, nil
}
}
Expand Down Expand Up @@ -231,7 +212,7 @@ func (d *dummySession) generateCertificate() ([]byte, *ecdsa.PrivateKey, error)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Sequence"},
Organization: []string{"Dummy"},
CommonName: "dummy.nitro-enclaves",
},
NotBefore: time.Now(),
Expand Down
Loading
Loading