diff options
author | 2024-08-11 13:15:50 -0700 | |
---|---|---|
committer | 2024-08-11 13:15:50 -0700 | |
commit | 6a3c21fb0b1c126849f2bbff494403bbe901448e (patch) | |
tree | 5d7805524357c2c8a9819c39d2051a4e3633a1d5 /backend/internal/keys/keys.go | |
parent | 29c6040a51616e9e4cf6c70ee16391b2a3b238c9 (diff) | |
parent | f34b92ded11b07f78575ac62c260a380c468e5ea (diff) | |
download | ibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.tar.gz ibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.tar.zst ibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.zip |
Merge remote-tracking branch 'backend/main'
Diffstat (limited to 'backend/internal/keys/keys.go')
-rw-r--r-- | backend/internal/keys/keys.go | 150 |
1 files changed, 150 insertions, 0 deletions
diff --git a/backend/internal/keys/keys.go b/backend/internal/keys/keys.go new file mode 100644 index 0000000..ac73173 --- /dev/null +++ b/backend/internal/keys/keys.go @@ -0,0 +1,150 @@ +package keys + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "io" +) + +var CSRNG = rand.Reader + +//go:generate mockgen -destination mock_keys_test.go -package keys_test -typed . KeyManagementService +type KeyManagementService interface { + io.Closer + + // Encrypt encrypts the given plaintext using the key with the given key name. + Encrypt(ctx context.Context, keyName string, plaintext []byte) ([]byte, error) + + // Decrypt decrypts the given ciphertext using the key with the given key name. + Decrypt(ctx context.Context, keyName string, ciphertext []byte) ([]byte, error) +} + +// Encrypt encrypts the given plaintext using a hybrid encryption scheme. +// +// It first generates a random AES 256-bit key and encrypts the plaintext with it. +// Then, it encrypts the AES key using the KMS. +// +// It returns the ciphertext, the encrypted AES key, and any errors that occurred. +func Encrypt( + ctx context.Context, + kms KeyManagementService, + keyName string, + plaintext []byte, +) (ciphertext []byte, encryptedKey []byte, err error) { + // Generate a random AES key + aesKey := make([]byte, 32) + if _, err = io.ReadFull(CSRNG, aesKey); err != nil { + return nil, nil, fmt.Errorf("unable to generate AES key: %w", err) + } + + // Encrypt the plaintext using the AES key + ciphertext, err = encrypt(aesKey, plaintext) + if err != nil { + return nil, nil, fmt.Errorf("unable to encrypt plaintext: %w", err) + } + + // Encrypt the AES key using the KMS + encryptedKey, err = kms.Encrypt(ctx, keyName, aesKey) + if err != nil { + return nil, nil, fmt.Errorf("unable to encrypt AES key: %w", err) + } + + return ciphertext, encryptedKey, nil +} + +// EncryptWithKey encrypts the given plaintext using a hybrid encryption scheme. +// +// This works similarly to Encrypt, but instead of generating a new AES key, it uses a given already encrypted AES key. +func EncryptWithKey( + ctx context.Context, + kms KeyManagementService, + keyName string, + encryptedKey []byte, + plaintext []byte, +) ([]byte, error) { + // Decrypt the AES key + aesKey, err := kms.Decrypt(ctx, keyName, encryptedKey) + if err != nil { + return nil, fmt.Errorf("unable to decrypt AES key: %w", err) + } + + // Encrypt the plaintext using the AES key + ciphertext, err := encrypt(aesKey, plaintext) + if err != nil { + return nil, fmt.Errorf("unable to encrypt plaintext: %w", err) + } + + return ciphertext, nil +} + +func encrypt(aesKey []byte, plaintext []byte) ([]byte, error) { + // Create an AES cipher + blockCipher, err := aes.NewCipher(aesKey) + if err != nil { + return nil, fmt.Errorf("unable to create AES cipher: %w", err) + } + + gcm, err := cipher.NewGCM(blockCipher) + if err != nil { + return nil, fmt.Errorf("unable to create GCM: %w", err) + } + + // Generate a random nonce + nonce := make([]byte, gcm.NonceSize()) + if _, err = io.ReadFull(CSRNG, nonce); err != nil { + return nil, fmt.Errorf("unable to generate nonce: %w", err) + } + + // Encrypt the plaintext + ciphertext := gcm.Seal(nonce, nonce, plaintext, nil) + return ciphertext, nil +} + +// Decrypt decrypts the given ciphertext using a hybrid encryption scheme. +// +// It first decrypts the AES key using the KMS. +// Then, it decrypts the ciphertext using the decrypted AES key. +// +// It returns the plaintext and any errors that occurred. +func Decrypt( + ctx context.Context, + kms KeyManagementService, + keyName string, + ciphertext []byte, + encryptedKey []byte, +) ([]byte, error) { + // Decrypt the AES key + aesKey, err := kms.Decrypt(ctx, keyName, encryptedKey) + if err != nil { + return nil, fmt.Errorf("unable to decrypt AES key: %w", err) + } + + // Create an AES cipher + blockCipher, err := aes.NewCipher(aesKey) + if err != nil { + return nil, fmt.Errorf("unable to create AES cipher: %w", err) + } + + gcm, err := cipher.NewGCM(blockCipher) + if err != nil { + return nil, fmt.Errorf("unable to create GCM: %w", err) + } + + // Extract the nonce from the ciphertext + nonceSize := gcm.NonceSize() + if len(ciphertext) < nonceSize { + return nil, fmt.Errorf("ciphertext is too short") + } + nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] + + // Decrypt the ciphertext + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, fmt.Errorf("unable to decrypt ciphertext: %w", err) + } + + return plaintext, nil +} |