package keys import ( "context" "crypto/aes" "crypto/cipher" "crypto/rand" "fmt" "io" ) var CSRNG = rand.Reader //go:generate mockgen -destination mock_keys_test.go -package keys_test -typed . KeyManagementService type KeyManagementService interface { io.Closer // Encrypt encrypts the given plaintext using the key with the given key name. Encrypt(ctx context.Context, keyName string, plaintext []byte) ([]byte, error) // Decrypt decrypts the given ciphertext using the key with the given key name. Decrypt(ctx context.Context, keyName string, ciphertext []byte) ([]byte, error) } // Encrypt encrypts the given plaintext using a hybrid encryption scheme. // // It first generates a random AES 256-bit key and encrypts the plaintext with it. // Then, it encrypts the AES key using the KMS. // // It returns the ciphertext, the encrypted AES key, and any errors that occurred. func Encrypt( ctx context.Context, kms KeyManagementService, keyName string, plaintext []byte, ) (ciphertext []byte, encryptedKey []byte, err error) { // Generate a random AES key aesKey := make([]byte, 32) if _, err = io.ReadFull(CSRNG, aesKey); err != nil { return nil, nil, fmt.Errorf("unable to generate AES key: %w", err) } // Encrypt the plaintext using the AES key ciphertext, err = encrypt(aesKey, plaintext) if err != nil { return nil, nil, fmt.Errorf("unable to encrypt plaintext: %w", err) } // Encrypt the AES key using the KMS encryptedKey, err = kms.Encrypt(ctx, keyName, aesKey) if err != nil { return nil, nil, fmt.Errorf("unable to encrypt AES key: %w", err) } return ciphertext, encryptedKey, nil } // EncryptWithKey encrypts the given plaintext using a hybrid encryption scheme. // // This works similarly to Encrypt, but instead of generating a new AES key, it uses a given already encrypted AES key. func EncryptWithKey( ctx context.Context, kms KeyManagementService, keyName string, encryptedKey []byte, plaintext []byte, ) ([]byte, error) { // Decrypt the AES key aesKey, err := kms.Decrypt(ctx, keyName, encryptedKey) if err != nil { return nil, fmt.Errorf("unable to decrypt AES key: %w", err) } // Encrypt the plaintext using the AES key ciphertext, err := encrypt(aesKey, plaintext) if err != nil { return nil, fmt.Errorf("unable to encrypt plaintext: %w", err) } return ciphertext, nil } func encrypt(aesKey []byte, plaintext []byte) ([]byte, error) { // Create an AES cipher blockCipher, err := aes.NewCipher(aesKey) if err != nil { return nil, fmt.Errorf("unable to create AES cipher: %w", err) } gcm, err := cipher.NewGCM(blockCipher) if err != nil { return nil, fmt.Errorf("unable to create GCM: %w", err) } // Generate a random nonce nonce := make([]byte, gcm.NonceSize()) if _, err = io.ReadFull(CSRNG, nonce); err != nil { return nil, fmt.Errorf("unable to generate nonce: %w", err) } // Encrypt the plaintext ciphertext := gcm.Seal(nonce, nonce, plaintext, nil) return ciphertext, nil } // Decrypt decrypts the given ciphertext using a hybrid encryption scheme. // // It first decrypts the AES key using the KMS. // Then, it decrypts the ciphertext using the decrypted AES key. // // It returns the plaintext and any errors that occurred. func Decrypt( ctx context.Context, kms KeyManagementService, keyName string, ciphertext []byte, encryptedKey []byte, ) ([]byte, error) { // Decrypt the AES key aesKey, err := kms.Decrypt(ctx, keyName, encryptedKey) if err != nil { return nil, fmt.Errorf("unable to decrypt AES key: %w", err) } // Create an AES cipher blockCipher, err := aes.NewCipher(aesKey) if err != nil { return nil, fmt.Errorf("unable to create AES cipher: %w", err) } gcm, err := cipher.NewGCM(blockCipher) if err != nil { return nil, fmt.Errorf("unable to create GCM: %w", err) } // Extract the nonce from the ciphertext nonceSize := gcm.NonceSize() if len(ciphertext) < nonceSize { return nil, fmt.Errorf("ciphertext is too short") } nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] // Decrypt the ciphertext plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) if err != nil { return nil, fmt.Errorf("unable to decrypt ciphertext: %w", err) } return plaintext, nil }