aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/keys/keys.go
diff options
context:
space:
mode:
Diffstat (limited to 'backend/internal/keys/keys.go')
-rw-r--r--backend/internal/keys/keys.go150
1 files changed, 150 insertions, 0 deletions
diff --git a/backend/internal/keys/keys.go b/backend/internal/keys/keys.go
new file mode 100644
index 0000000..ac73173
--- /dev/null
+++ b/backend/internal/keys/keys.go
@@ -0,0 +1,150 @@
+package keys
+
+import (
+ "context"
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/rand"
+ "fmt"
+ "io"
+)
+
+var CSRNG = rand.Reader
+
+//go:generate mockgen -destination mock_keys_test.go -package keys_test -typed . KeyManagementService
+type KeyManagementService interface {
+ io.Closer
+
+ // Encrypt encrypts the given plaintext using the key with the given key name.
+ Encrypt(ctx context.Context, keyName string, plaintext []byte) ([]byte, error)
+
+ // Decrypt decrypts the given ciphertext using the key with the given key name.
+ Decrypt(ctx context.Context, keyName string, ciphertext []byte) ([]byte, error)
+}
+
+// Encrypt encrypts the given plaintext using a hybrid encryption scheme.
+//
+// It first generates a random AES 256-bit key and encrypts the plaintext with it.
+// Then, it encrypts the AES key using the KMS.
+//
+// It returns the ciphertext, the encrypted AES key, and any errors that occurred.
+func Encrypt(
+ ctx context.Context,
+ kms KeyManagementService,
+ keyName string,
+ plaintext []byte,
+) (ciphertext []byte, encryptedKey []byte, err error) {
+ // Generate a random AES key
+ aesKey := make([]byte, 32)
+ if _, err = io.ReadFull(CSRNG, aesKey); err != nil {
+ return nil, nil, fmt.Errorf("unable to generate AES key: %w", err)
+ }
+
+ // Encrypt the plaintext using the AES key
+ ciphertext, err = encrypt(aesKey, plaintext)
+ if err != nil {
+ return nil, nil, fmt.Errorf("unable to encrypt plaintext: %w", err)
+ }
+
+ // Encrypt the AES key using the KMS
+ encryptedKey, err = kms.Encrypt(ctx, keyName, aesKey)
+ if err != nil {
+ return nil, nil, fmt.Errorf("unable to encrypt AES key: %w", err)
+ }
+
+ return ciphertext, encryptedKey, nil
+}
+
+// EncryptWithKey encrypts the given plaintext using a hybrid encryption scheme.
+//
+// This works similarly to Encrypt, but instead of generating a new AES key, it uses a given already encrypted AES key.
+func EncryptWithKey(
+ ctx context.Context,
+ kms KeyManagementService,
+ keyName string,
+ encryptedKey []byte,
+ plaintext []byte,
+) ([]byte, error) {
+ // Decrypt the AES key
+ aesKey, err := kms.Decrypt(ctx, keyName, encryptedKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decrypt AES key: %w", err)
+ }
+
+ // Encrypt the plaintext using the AES key
+ ciphertext, err := encrypt(aesKey, plaintext)
+ if err != nil {
+ return nil, fmt.Errorf("unable to encrypt plaintext: %w", err)
+ }
+
+ return ciphertext, nil
+}
+
+func encrypt(aesKey []byte, plaintext []byte) ([]byte, error) {
+ // Create an AES cipher
+ blockCipher, err := aes.NewCipher(aesKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create AES cipher: %w", err)
+ }
+
+ gcm, err := cipher.NewGCM(blockCipher)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create GCM: %w", err)
+ }
+
+ // Generate a random nonce
+ nonce := make([]byte, gcm.NonceSize())
+ if _, err = io.ReadFull(CSRNG, nonce); err != nil {
+ return nil, fmt.Errorf("unable to generate nonce: %w", err)
+ }
+
+ // Encrypt the plaintext
+ ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
+ return ciphertext, nil
+}
+
+// Decrypt decrypts the given ciphertext using a hybrid encryption scheme.
+//
+// It first decrypts the AES key using the KMS.
+// Then, it decrypts the ciphertext using the decrypted AES key.
+//
+// It returns the plaintext and any errors that occurred.
+func Decrypt(
+ ctx context.Context,
+ kms KeyManagementService,
+ keyName string,
+ ciphertext []byte,
+ encryptedKey []byte,
+) ([]byte, error) {
+ // Decrypt the AES key
+ aesKey, err := kms.Decrypt(ctx, keyName, encryptedKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decrypt AES key: %w", err)
+ }
+
+ // Create an AES cipher
+ blockCipher, err := aes.NewCipher(aesKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create AES cipher: %w", err)
+ }
+
+ gcm, err := cipher.NewGCM(blockCipher)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create GCM: %w", err)
+ }
+
+ // Extract the nonce from the ciphertext
+ nonceSize := gcm.NonceSize()
+ if len(ciphertext) < nonceSize {
+ return nil, fmt.Errorf("ciphertext is too short")
+ }
+ nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
+
+ // Decrypt the ciphertext
+ plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decrypt ciphertext: %w", err)
+ }
+
+ return plaintext, nil
+}