| 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
 | package keys
import (
	"context"
	"crypto/aes"
	"crypto/cipher"
	"crypto/rand"
	"fmt"
	"io"
)
var CSRNG = rand.Reader
//go:generate mockgen -destination mock_keys_test.go -package keys_test -typed . KeyManagementService
type KeyManagementService interface {
	io.Closer
	// Encrypt encrypts the given plaintext using the key with the given key name.
	Encrypt(ctx context.Context, keyName string, plaintext []byte) ([]byte, error)
	// Decrypt decrypts the given ciphertext using the key with the given key name.
	Decrypt(ctx context.Context, keyName string, ciphertext []byte) ([]byte, error)
}
// Encrypt encrypts the given plaintext using a hybrid encryption scheme.
//
// It first generates a random AES 256-bit key and encrypts the plaintext with it.
// Then, it encrypts the AES key using the KMS.
//
// It returns the ciphertext, the encrypted AES key, and any errors that occurred.
func Encrypt(
	ctx context.Context,
	kms KeyManagementService,
	keyName string,
	plaintext []byte,
) (ciphertext []byte, encryptedKey []byte, err error) {
	// Generate a random AES key
	aesKey := make([]byte, 32)
	if _, err = io.ReadFull(CSRNG, aesKey); err != nil {
		return nil, nil, fmt.Errorf("unable to generate AES key: %w", err)
	}
	// Encrypt the plaintext using the AES key
	ciphertext, err = encrypt(aesKey, plaintext)
	if err != nil {
		return nil, nil, fmt.Errorf("unable to encrypt plaintext: %w", err)
	}
	// Encrypt the AES key using the KMS
	encryptedKey, err = kms.Encrypt(ctx, keyName, aesKey)
	if err != nil {
		return nil, nil, fmt.Errorf("unable to encrypt AES key: %w", err)
	}
	return ciphertext, encryptedKey, nil
}
// EncryptWithKey encrypts the given plaintext using a hybrid encryption scheme.
//
// This works similarly to Encrypt, but instead of generating a new AES key, it uses a given already encrypted AES key.
func EncryptWithKey(
	ctx context.Context,
	kms KeyManagementService,
	keyName string,
	encryptedKey []byte,
	plaintext []byte,
) ([]byte, error) {
	// Decrypt the AES key
	aesKey, err := kms.Decrypt(ctx, keyName, encryptedKey)
	if err != nil {
		return nil, fmt.Errorf("unable to decrypt AES key: %w", err)
	}
	// Encrypt the plaintext using the AES key
	ciphertext, err := encrypt(aesKey, plaintext)
	if err != nil {
		return nil, fmt.Errorf("unable to encrypt plaintext: %w", err)
	}
	return ciphertext, nil
}
func encrypt(aesKey []byte, plaintext []byte) ([]byte, error) {
	// Create an AES cipher
	blockCipher, err := aes.NewCipher(aesKey)
	if err != nil {
		return nil, fmt.Errorf("unable to create AES cipher: %w", err)
	}
	gcm, err := cipher.NewGCM(blockCipher)
	if err != nil {
		return nil, fmt.Errorf("unable to create GCM: %w", err)
	}
	// Generate a random nonce
	nonce := make([]byte, gcm.NonceSize())
	if _, err = io.ReadFull(CSRNG, nonce); err != nil {
		return nil, fmt.Errorf("unable to generate nonce: %w", err)
	}
	// Encrypt the plaintext
	ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
	return ciphertext, nil
}
// Decrypt decrypts the given ciphertext using a hybrid encryption scheme.
//
// It first decrypts the AES key using the KMS.
// Then, it decrypts the ciphertext using the decrypted AES key.
//
// It returns the plaintext and any errors that occurred.
func Decrypt(
	ctx context.Context,
	kms KeyManagementService,
	keyName string,
	ciphertext []byte,
	encryptedKey []byte,
) ([]byte, error) {
	// Decrypt the AES key
	aesKey, err := kms.Decrypt(ctx, keyName, encryptedKey)
	if err != nil {
		return nil, fmt.Errorf("unable to decrypt AES key: %w", err)
	}
	// Create an AES cipher
	blockCipher, err := aes.NewCipher(aesKey)
	if err != nil {
		return nil, fmt.Errorf("unable to create AES cipher: %w", err)
	}
	gcm, err := cipher.NewGCM(blockCipher)
	if err != nil {
		return nil, fmt.Errorf("unable to create GCM: %w", err)
	}
	// Extract the nonce from the ciphertext
	nonceSize := gcm.NonceSize()
	if len(ciphertext) < nonceSize {
		return nil, fmt.Errorf("ciphertext is too short")
	}
	nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
	// Decrypt the ciphertext
	plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
	if err != nil {
		return nil, fmt.Errorf("unable to decrypt ciphertext: %w", err)
	}
	return plaintext, nil
}
 |