diff options
author | 2024-08-05 18:55:10 -0700 | |
---|---|---|
committer | 2024-08-05 18:55:19 -0700 | |
commit | b96fcd1a54a46a95f98467b49a051564bc21c23c (patch) | |
tree | 93caeeb05f8d6310e241095608ea2428c749b18c /backend/internal/keys | |
download | ibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.tar.gz ibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.tar.zst ibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.zip |
Initial Commit
Diffstat (limited to 'backend/internal/keys')
-rw-r--r-- | backend/internal/keys/gcp.go | 131 | ||||
-rw-r--r-- | backend/internal/keys/keys.go | 150 | ||||
-rw-r--r-- | backend/internal/keys/keys_test.go | 64 | ||||
-rw-r--r-- | backend/internal/keys/mock_keys_test.go | 156 |
4 files changed, 501 insertions, 0 deletions
diff --git a/backend/internal/keys/gcp.go b/backend/internal/keys/gcp.go new file mode 100644 index 0000000..9d10fc5 --- /dev/null +++ b/backend/internal/keys/gcp.go @@ -0,0 +1,131 @@ +package keys + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "hash/crc32" + "sync" + + kms "cloud.google.com/go/kms/apiv1" + "cloud.google.com/go/kms/apiv1/kmspb" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +type GoogleKMS struct { + client *kms.KeyManagementClient + + mx sync.RWMutex + keyCache map[string]*rsa.PublicKey +} + +func NewGoogleKMS(ctx context.Context) (*GoogleKMS, error) { + client, err := kms.NewKeyManagementClient(ctx) + if err != nil { + return nil, err + } + + return &GoogleKMS{ + client: client, + keyCache: make(map[string]*rsa.PublicKey), + }, nil +} + +func (g *GoogleKMS) checkCache(keyName string) *rsa.PublicKey { + g.mx.RLock() + defer g.mx.RUnlock() + + return g.keyCache[keyName] +} + +func (g *GoogleKMS) getPublicKey(ctx context.Context, keyName string) (*rsa.PublicKey, error) { + if key := g.checkCache(keyName); key != nil { + return key, nil + } + + response, err := g.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{Name: keyName}) + if err != nil { + return nil, err + } + + block, _ := pem.Decode([]byte(response.Pem)) + if block == nil || block.Type != "PUBLIC KEY" { + return nil, errors.New("failed to decode PEM public key") + } + publicKey, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse public key: %w", err) + } + rsaKey, ok := publicKey.(*rsa.PublicKey) + if !ok { + return nil, errors.New("public key is not an RSA key") + } + + g.mx.Lock() + defer g.mx.Unlock() + g.keyCache[keyName] = rsaKey + + return rsaKey, nil +} + +func (g *GoogleKMS) Encrypt(ctx context.Context, keyName string, plaintext []byte) ([]byte, error) { + publicKey, err := g.getPublicKey(ctx, keyName) + if err != nil { + return nil, err + } + + cipherText, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, publicKey, plaintext, nil) + if err != nil { + return nil, fmt.Errorf("failed to encrypt plaintext: %w", err) + } + + return cipherText, nil +} + +func (g *GoogleKMS) Decrypt(ctx context.Context, keyName string, ciphertext []byte) ([]byte, error) { + req := &kmspb.AsymmetricDecryptRequest{ + Name: keyName, + Ciphertext: ciphertext, + CiphertextCrc32C: wrapperspb.Int64(int64(calcCRC32(ciphertext))), + } + + result, err := g.client.AsymmetricDecrypt(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to decrypt ciphertext: %w", err) + } + + if !result.VerifiedCiphertextCrc32C { + return nil, errors.New("AsymmetricDecrypt: request corrupted in-transit") + } + if int64(calcCRC32(result.Plaintext)) != result.PlaintextCrc32C.Value { + return nil, fmt.Errorf("AsymmetricDecrypt: response corrupted in-transit") + } + + return result.Plaintext, nil +} + +func (g *GoogleKMS) Close() error { + return g.client.Close() +} + +func calcCRC32(data []byte) uint32 { + t := crc32.MakeTable(crc32.Castagnoli) + return crc32.Checksum(data, t) +} + +type GCPKeyName struct { + Project string + Location string + KeyRing string + CryptoKey string + CryptoKeyVersion string +} + +func (k GCPKeyName) String() string { + return fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s/cryptoKeyVersions/%s", k.Project, k.Location, k.KeyRing, k.CryptoKey, k.CryptoKeyVersion) +} diff --git a/backend/internal/keys/keys.go b/backend/internal/keys/keys.go new file mode 100644 index 0000000..ac73173 --- /dev/null +++ b/backend/internal/keys/keys.go @@ -0,0 +1,150 @@ +package keys + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "io" +) + +var CSRNG = rand.Reader + +//go:generate mockgen -destination mock_keys_test.go -package keys_test -typed . KeyManagementService +type KeyManagementService interface { + io.Closer + + // Encrypt encrypts the given plaintext using the key with the given key name. + Encrypt(ctx context.Context, keyName string, plaintext []byte) ([]byte, error) + + // Decrypt decrypts the given ciphertext using the key with the given key name. + Decrypt(ctx context.Context, keyName string, ciphertext []byte) ([]byte, error) +} + +// Encrypt encrypts the given plaintext using a hybrid encryption scheme. +// +// It first generates a random AES 256-bit key and encrypts the plaintext with it. +// Then, it encrypts the AES key using the KMS. +// +// It returns the ciphertext, the encrypted AES key, and any errors that occurred. +func Encrypt( + ctx context.Context, + kms KeyManagementService, + keyName string, + plaintext []byte, +) (ciphertext []byte, encryptedKey []byte, err error) { + // Generate a random AES key + aesKey := make([]byte, 32) + if _, err = io.ReadFull(CSRNG, aesKey); err != nil { + return nil, nil, fmt.Errorf("unable to generate AES key: %w", err) + } + + // Encrypt the plaintext using the AES key + ciphertext, err = encrypt(aesKey, plaintext) + if err != nil { + return nil, nil, fmt.Errorf("unable to encrypt plaintext: %w", err) + } + + // Encrypt the AES key using the KMS + encryptedKey, err = kms.Encrypt(ctx, keyName, aesKey) + if err != nil { + return nil, nil, fmt.Errorf("unable to encrypt AES key: %w", err) + } + + return ciphertext, encryptedKey, nil +} + +// EncryptWithKey encrypts the given plaintext using a hybrid encryption scheme. +// +// This works similarly to Encrypt, but instead of generating a new AES key, it uses a given already encrypted AES key. +func EncryptWithKey( + ctx context.Context, + kms KeyManagementService, + keyName string, + encryptedKey []byte, + plaintext []byte, +) ([]byte, error) { + // Decrypt the AES key + aesKey, err := kms.Decrypt(ctx, keyName, encryptedKey) + if err != nil { + return nil, fmt.Errorf("unable to decrypt AES key: %w", err) + } + + // Encrypt the plaintext using the AES key + ciphertext, err := encrypt(aesKey, plaintext) + if err != nil { + return nil, fmt.Errorf("unable to encrypt plaintext: %w", err) + } + + return ciphertext, nil +} + +func encrypt(aesKey []byte, plaintext []byte) ([]byte, error) { + // Create an AES cipher + blockCipher, err := aes.NewCipher(aesKey) + if err != nil { + return nil, fmt.Errorf("unable to create AES cipher: %w", err) + } + + gcm, err := cipher.NewGCM(blockCipher) + if err != nil { + return nil, fmt.Errorf("unable to create GCM: %w", err) + } + + // Generate a random nonce + nonce := make([]byte, gcm.NonceSize()) + if _, err = io.ReadFull(CSRNG, nonce); err != nil { + return nil, fmt.Errorf("unable to generate nonce: %w", err) + } + + // Encrypt the plaintext + ciphertext := gcm.Seal(nonce, nonce, plaintext, nil) + return ciphertext, nil +} + +// Decrypt decrypts the given ciphertext using a hybrid encryption scheme. +// +// It first decrypts the AES key using the KMS. +// Then, it decrypts the ciphertext using the decrypted AES key. +// +// It returns the plaintext and any errors that occurred. +func Decrypt( + ctx context.Context, + kms KeyManagementService, + keyName string, + ciphertext []byte, + encryptedKey []byte, +) ([]byte, error) { + // Decrypt the AES key + aesKey, err := kms.Decrypt(ctx, keyName, encryptedKey) + if err != nil { + return nil, fmt.Errorf("unable to decrypt AES key: %w", err) + } + + // Create an AES cipher + blockCipher, err := aes.NewCipher(aesKey) + if err != nil { + return nil, fmt.Errorf("unable to create AES cipher: %w", err) + } + + gcm, err := cipher.NewGCM(blockCipher) + if err != nil { + return nil, fmt.Errorf("unable to create GCM: %w", err) + } + + // Extract the nonce from the ciphertext + nonceSize := gcm.NonceSize() + if len(ciphertext) < nonceSize { + return nil, fmt.Errorf("ciphertext is too short") + } + nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] + + // Decrypt the ciphertext + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, fmt.Errorf("unable to decrypt ciphertext: %w", err) + } + + return plaintext, nil +} diff --git a/backend/internal/keys/keys_test.go b/backend/internal/keys/keys_test.go new file mode 100644 index 0000000..14bbcc2 --- /dev/null +++ b/backend/internal/keys/keys_test.go @@ -0,0 +1,64 @@ +package keys_test + +import ( + "bytes" + "context" + "encoding/hex" + "testing" + + "ibd-trader/internal/keys" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func TestEncrypt(t *testing.T) { + ctrl := gomock.NewController(t) + + // Replace RNG with a deterministic RNG + aesKey := []byte("0123456789abcdef0123456789abcdef") + nonce := []byte("0123456789ab") + keys.CSRNG = bytes.NewReader(append(aesKey, nonce...)) + + // Create a mock KMS + kms := NewMockKeyManagementService(ctrl) + keyName := "keyName" + + ctx := context.Background() + plaintext := []byte("plaintext") + + kms.EXPECT(). + Encrypt(ctx, keyName, aesKey). + Return([]byte("encryptedKey"), nil) + + ciphertext, encryptedKey, err := keys.Encrypt(ctx, kms, keyName, plaintext) + require.NoError(t, err) + + encrypted, err := hex.DecodeString("e9c586532dbefd63812293e1c4baf71edb7042a294c49c2020") + require.NoError(t, err) + assert.Equal(t, append(nonce, encrypted...), ciphertext) + assert.Equal(t, []byte("encryptedKey"), encryptedKey) +} + +func TestDecrypt(t *testing.T) { + ctrl := gomock.NewController(t) + + kms := NewMockKeyManagementService(ctrl) + keyName := "keyName" + + ctx := context.Background() + encryptedKey := []byte("encryptedKey") + ciphertext, err := hex.DecodeString("e9c586532dbefd63812293e1c4baf71edb7042a294c49c2020") + require.NoError(t, err) + ciphertext = append([]byte("0123456789ab"), ciphertext...) + + aesKey := []byte("0123456789abcdef0123456789abcdef") + kms.EXPECT(). + Decrypt(ctx, keyName, encryptedKey). + Return(aesKey, nil) + + plaintext, err := keys.Decrypt(ctx, kms, keyName, ciphertext, encryptedKey) + require.NoError(t, err) + assert.Equal(t, []byte("plaintext"), plaintext) +} diff --git a/backend/internal/keys/mock_keys_test.go b/backend/internal/keys/mock_keys_test.go new file mode 100644 index 0000000..5a435a0 --- /dev/null +++ b/backend/internal/keys/mock_keys_test.go @@ -0,0 +1,156 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ibd-trader/internal/keys (interfaces: KeyManagementService) +// +// Generated by this command: +// +// mockgen -destination mock_keys_test.go -package keys_test -typed . KeyManagementService +// + +// Package keys_test is a generated GoMock package. +package keys_test + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockKeyManagementService is a mock of KeyManagementService interface. +type MockKeyManagementService struct { + ctrl *gomock.Controller + recorder *MockKeyManagementServiceMockRecorder +} + +// MockKeyManagementServiceMockRecorder is the mock recorder for MockKeyManagementService. +type MockKeyManagementServiceMockRecorder struct { + mock *MockKeyManagementService +} + +// NewMockKeyManagementService creates a new mock instance. +func NewMockKeyManagementService(ctrl *gomock.Controller) *MockKeyManagementService { + mock := &MockKeyManagementService{ctrl: ctrl} + mock.recorder = &MockKeyManagementServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockKeyManagementService) EXPECT() *MockKeyManagementServiceMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockKeyManagementService) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockKeyManagementServiceMockRecorder) Close() *MockKeyManagementServiceCloseCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockKeyManagementService)(nil).Close)) + return &MockKeyManagementServiceCloseCall{Call: call} +} + +// MockKeyManagementServiceCloseCall wrap *gomock.Call +type MockKeyManagementServiceCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockKeyManagementServiceCloseCall) Return(arg0 error) *MockKeyManagementServiceCloseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockKeyManagementServiceCloseCall) Do(f func() error) *MockKeyManagementServiceCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockKeyManagementServiceCloseCall) DoAndReturn(f func() error) *MockKeyManagementServiceCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Decrypt mocks base method. +func (m *MockKeyManagementService) Decrypt(arg0 context.Context, arg1 string, arg2 []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Decrypt", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Decrypt indicates an expected call of Decrypt. +func (mr *MockKeyManagementServiceMockRecorder) Decrypt(arg0, arg1, arg2 any) *MockKeyManagementServiceDecryptCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decrypt", reflect.TypeOf((*MockKeyManagementService)(nil).Decrypt), arg0, arg1, arg2) + return &MockKeyManagementServiceDecryptCall{Call: call} +} + +// MockKeyManagementServiceDecryptCall wrap *gomock.Call +type MockKeyManagementServiceDecryptCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockKeyManagementServiceDecryptCall) Return(arg0 []byte, arg1 error) *MockKeyManagementServiceDecryptCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockKeyManagementServiceDecryptCall) Do(f func(context.Context, string, []byte) ([]byte, error)) *MockKeyManagementServiceDecryptCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockKeyManagementServiceDecryptCall) DoAndReturn(f func(context.Context, string, []byte) ([]byte, error)) *MockKeyManagementServiceDecryptCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Encrypt mocks base method. +func (m *MockKeyManagementService) Encrypt(arg0 context.Context, arg1 string, arg2 []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Encrypt", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Encrypt indicates an expected call of Encrypt. +func (mr *MockKeyManagementServiceMockRecorder) Encrypt(arg0, arg1, arg2 any) *MockKeyManagementServiceEncryptCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encrypt", reflect.TypeOf((*MockKeyManagementService)(nil).Encrypt), arg0, arg1, arg2) + return &MockKeyManagementServiceEncryptCall{Call: call} +} + +// MockKeyManagementServiceEncryptCall wrap *gomock.Call +type MockKeyManagementServiceEncryptCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockKeyManagementServiceEncryptCall) Return(arg0 []byte, arg1 error) *MockKeyManagementServiceEncryptCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockKeyManagementServiceEncryptCall) Do(f func(context.Context, string, []byte) ([]byte, error)) *MockKeyManagementServiceEncryptCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockKeyManagementServiceEncryptCall) DoAndReturn(f func(context.Context, string, []byte) ([]byte, error)) *MockKeyManagementServiceEncryptCall { + c.Call = c.Call.DoAndReturn(f) + return c +} |