aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/keys
diff options
context:
space:
mode:
authorGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-05 18:55:10 -0700
committerGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-05 18:55:19 -0700
commitb96fcd1a54a46a95f98467b49a051564bc21c23c (patch)
tree93caeeb05f8d6310e241095608ea2428c749b18c /backend/internal/keys
downloadibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.tar.gz
ibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.tar.zst
ibd-trader-b96fcd1a54a46a95f98467b49a051564bc21c23c.zip
Initial Commit
Diffstat (limited to 'backend/internal/keys')
-rw-r--r--backend/internal/keys/gcp.go131
-rw-r--r--backend/internal/keys/keys.go150
-rw-r--r--backend/internal/keys/keys_test.go64
-rw-r--r--backend/internal/keys/mock_keys_test.go156
4 files changed, 501 insertions, 0 deletions
diff --git a/backend/internal/keys/gcp.go b/backend/internal/keys/gcp.go
new file mode 100644
index 0000000..9d10fc5
--- /dev/null
+++ b/backend/internal/keys/gcp.go
@@ -0,0 +1,131 @@
+package keys
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/sha256"
+ "crypto/x509"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "hash/crc32"
+ "sync"
+
+ kms "cloud.google.com/go/kms/apiv1"
+ "cloud.google.com/go/kms/apiv1/kmspb"
+ "google.golang.org/protobuf/types/known/wrapperspb"
+)
+
+type GoogleKMS struct {
+ client *kms.KeyManagementClient
+
+ mx sync.RWMutex
+ keyCache map[string]*rsa.PublicKey
+}
+
+func NewGoogleKMS(ctx context.Context) (*GoogleKMS, error) {
+ client, err := kms.NewKeyManagementClient(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return &GoogleKMS{
+ client: client,
+ keyCache: make(map[string]*rsa.PublicKey),
+ }, nil
+}
+
+func (g *GoogleKMS) checkCache(keyName string) *rsa.PublicKey {
+ g.mx.RLock()
+ defer g.mx.RUnlock()
+
+ return g.keyCache[keyName]
+}
+
+func (g *GoogleKMS) getPublicKey(ctx context.Context, keyName string) (*rsa.PublicKey, error) {
+ if key := g.checkCache(keyName); key != nil {
+ return key, nil
+ }
+
+ response, err := g.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{Name: keyName})
+ if err != nil {
+ return nil, err
+ }
+
+ block, _ := pem.Decode([]byte(response.Pem))
+ if block == nil || block.Type != "PUBLIC KEY" {
+ return nil, errors.New("failed to decode PEM public key")
+ }
+ publicKey, err := x509.ParsePKIXPublicKey(block.Bytes)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse public key: %w", err)
+ }
+ rsaKey, ok := publicKey.(*rsa.PublicKey)
+ if !ok {
+ return nil, errors.New("public key is not an RSA key")
+ }
+
+ g.mx.Lock()
+ defer g.mx.Unlock()
+ g.keyCache[keyName] = rsaKey
+
+ return rsaKey, nil
+}
+
+func (g *GoogleKMS) Encrypt(ctx context.Context, keyName string, plaintext []byte) ([]byte, error) {
+ publicKey, err := g.getPublicKey(ctx, keyName)
+ if err != nil {
+ return nil, err
+ }
+
+ cipherText, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, publicKey, plaintext, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to encrypt plaintext: %w", err)
+ }
+
+ return cipherText, nil
+}
+
+func (g *GoogleKMS) Decrypt(ctx context.Context, keyName string, ciphertext []byte) ([]byte, error) {
+ req := &kmspb.AsymmetricDecryptRequest{
+ Name: keyName,
+ Ciphertext: ciphertext,
+ CiphertextCrc32C: wrapperspb.Int64(int64(calcCRC32(ciphertext))),
+ }
+
+ result, err := g.client.AsymmetricDecrypt(ctx, req)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decrypt ciphertext: %w", err)
+ }
+
+ if !result.VerifiedCiphertextCrc32C {
+ return nil, errors.New("AsymmetricDecrypt: request corrupted in-transit")
+ }
+ if int64(calcCRC32(result.Plaintext)) != result.PlaintextCrc32C.Value {
+ return nil, fmt.Errorf("AsymmetricDecrypt: response corrupted in-transit")
+ }
+
+ return result.Plaintext, nil
+}
+
+func (g *GoogleKMS) Close() error {
+ return g.client.Close()
+}
+
+func calcCRC32(data []byte) uint32 {
+ t := crc32.MakeTable(crc32.Castagnoli)
+ return crc32.Checksum(data, t)
+}
+
+type GCPKeyName struct {
+ Project string
+ Location string
+ KeyRing string
+ CryptoKey string
+ CryptoKeyVersion string
+}
+
+func (k GCPKeyName) String() string {
+ return fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s/cryptoKeyVersions/%s", k.Project, k.Location, k.KeyRing, k.CryptoKey, k.CryptoKeyVersion)
+}
diff --git a/backend/internal/keys/keys.go b/backend/internal/keys/keys.go
new file mode 100644
index 0000000..ac73173
--- /dev/null
+++ b/backend/internal/keys/keys.go
@@ -0,0 +1,150 @@
+package keys
+
+import (
+ "context"
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/rand"
+ "fmt"
+ "io"
+)
+
+var CSRNG = rand.Reader
+
+//go:generate mockgen -destination mock_keys_test.go -package keys_test -typed . KeyManagementService
+type KeyManagementService interface {
+ io.Closer
+
+ // Encrypt encrypts the given plaintext using the key with the given key name.
+ Encrypt(ctx context.Context, keyName string, plaintext []byte) ([]byte, error)
+
+ // Decrypt decrypts the given ciphertext using the key with the given key name.
+ Decrypt(ctx context.Context, keyName string, ciphertext []byte) ([]byte, error)
+}
+
+// Encrypt encrypts the given plaintext using a hybrid encryption scheme.
+//
+// It first generates a random AES 256-bit key and encrypts the plaintext with it.
+// Then, it encrypts the AES key using the KMS.
+//
+// It returns the ciphertext, the encrypted AES key, and any errors that occurred.
+func Encrypt(
+ ctx context.Context,
+ kms KeyManagementService,
+ keyName string,
+ plaintext []byte,
+) (ciphertext []byte, encryptedKey []byte, err error) {
+ // Generate a random AES key
+ aesKey := make([]byte, 32)
+ if _, err = io.ReadFull(CSRNG, aesKey); err != nil {
+ return nil, nil, fmt.Errorf("unable to generate AES key: %w", err)
+ }
+
+ // Encrypt the plaintext using the AES key
+ ciphertext, err = encrypt(aesKey, plaintext)
+ if err != nil {
+ return nil, nil, fmt.Errorf("unable to encrypt plaintext: %w", err)
+ }
+
+ // Encrypt the AES key using the KMS
+ encryptedKey, err = kms.Encrypt(ctx, keyName, aesKey)
+ if err != nil {
+ return nil, nil, fmt.Errorf("unable to encrypt AES key: %w", err)
+ }
+
+ return ciphertext, encryptedKey, nil
+}
+
+// EncryptWithKey encrypts the given plaintext using a hybrid encryption scheme.
+//
+// This works similarly to Encrypt, but instead of generating a new AES key, it uses a given already encrypted AES key.
+func EncryptWithKey(
+ ctx context.Context,
+ kms KeyManagementService,
+ keyName string,
+ encryptedKey []byte,
+ plaintext []byte,
+) ([]byte, error) {
+ // Decrypt the AES key
+ aesKey, err := kms.Decrypt(ctx, keyName, encryptedKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decrypt AES key: %w", err)
+ }
+
+ // Encrypt the plaintext using the AES key
+ ciphertext, err := encrypt(aesKey, plaintext)
+ if err != nil {
+ return nil, fmt.Errorf("unable to encrypt plaintext: %w", err)
+ }
+
+ return ciphertext, nil
+}
+
+func encrypt(aesKey []byte, plaintext []byte) ([]byte, error) {
+ // Create an AES cipher
+ blockCipher, err := aes.NewCipher(aesKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create AES cipher: %w", err)
+ }
+
+ gcm, err := cipher.NewGCM(blockCipher)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create GCM: %w", err)
+ }
+
+ // Generate a random nonce
+ nonce := make([]byte, gcm.NonceSize())
+ if _, err = io.ReadFull(CSRNG, nonce); err != nil {
+ return nil, fmt.Errorf("unable to generate nonce: %w", err)
+ }
+
+ // Encrypt the plaintext
+ ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
+ return ciphertext, nil
+}
+
+// Decrypt decrypts the given ciphertext using a hybrid encryption scheme.
+//
+// It first decrypts the AES key using the KMS.
+// Then, it decrypts the ciphertext using the decrypted AES key.
+//
+// It returns the plaintext and any errors that occurred.
+func Decrypt(
+ ctx context.Context,
+ kms KeyManagementService,
+ keyName string,
+ ciphertext []byte,
+ encryptedKey []byte,
+) ([]byte, error) {
+ // Decrypt the AES key
+ aesKey, err := kms.Decrypt(ctx, keyName, encryptedKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decrypt AES key: %w", err)
+ }
+
+ // Create an AES cipher
+ blockCipher, err := aes.NewCipher(aesKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create AES cipher: %w", err)
+ }
+
+ gcm, err := cipher.NewGCM(blockCipher)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create GCM: %w", err)
+ }
+
+ // Extract the nonce from the ciphertext
+ nonceSize := gcm.NonceSize()
+ if len(ciphertext) < nonceSize {
+ return nil, fmt.Errorf("ciphertext is too short")
+ }
+ nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
+
+ // Decrypt the ciphertext
+ plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
+ if err != nil {
+ return nil, fmt.Errorf("unable to decrypt ciphertext: %w", err)
+ }
+
+ return plaintext, nil
+}
diff --git a/backend/internal/keys/keys_test.go b/backend/internal/keys/keys_test.go
new file mode 100644
index 0000000..14bbcc2
--- /dev/null
+++ b/backend/internal/keys/keys_test.go
@@ -0,0 +1,64 @@
+package keys_test
+
+import (
+ "bytes"
+ "context"
+ "encoding/hex"
+ "testing"
+
+ "ibd-trader/internal/keys"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.uber.org/mock/gomock"
+)
+
+func TestEncrypt(t *testing.T) {
+ ctrl := gomock.NewController(t)
+
+ // Replace RNG with a deterministic RNG
+ aesKey := []byte("0123456789abcdef0123456789abcdef")
+ nonce := []byte("0123456789ab")
+ keys.CSRNG = bytes.NewReader(append(aesKey, nonce...))
+
+ // Create a mock KMS
+ kms := NewMockKeyManagementService(ctrl)
+ keyName := "keyName"
+
+ ctx := context.Background()
+ plaintext := []byte("plaintext")
+
+ kms.EXPECT().
+ Encrypt(ctx, keyName, aesKey).
+ Return([]byte("encryptedKey"), nil)
+
+ ciphertext, encryptedKey, err := keys.Encrypt(ctx, kms, keyName, plaintext)
+ require.NoError(t, err)
+
+ encrypted, err := hex.DecodeString("e9c586532dbefd63812293e1c4baf71edb7042a294c49c2020")
+ require.NoError(t, err)
+ assert.Equal(t, append(nonce, encrypted...), ciphertext)
+ assert.Equal(t, []byte("encryptedKey"), encryptedKey)
+}
+
+func TestDecrypt(t *testing.T) {
+ ctrl := gomock.NewController(t)
+
+ kms := NewMockKeyManagementService(ctrl)
+ keyName := "keyName"
+
+ ctx := context.Background()
+ encryptedKey := []byte("encryptedKey")
+ ciphertext, err := hex.DecodeString("e9c586532dbefd63812293e1c4baf71edb7042a294c49c2020")
+ require.NoError(t, err)
+ ciphertext = append([]byte("0123456789ab"), ciphertext...)
+
+ aesKey := []byte("0123456789abcdef0123456789abcdef")
+ kms.EXPECT().
+ Decrypt(ctx, keyName, encryptedKey).
+ Return(aesKey, nil)
+
+ plaintext, err := keys.Decrypt(ctx, kms, keyName, ciphertext, encryptedKey)
+ require.NoError(t, err)
+ assert.Equal(t, []byte("plaintext"), plaintext)
+}
diff --git a/backend/internal/keys/mock_keys_test.go b/backend/internal/keys/mock_keys_test.go
new file mode 100644
index 0000000..5a435a0
--- /dev/null
+++ b/backend/internal/keys/mock_keys_test.go
@@ -0,0 +1,156 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: ibd-trader/internal/keys (interfaces: KeyManagementService)
+//
+// Generated by this command:
+//
+// mockgen -destination mock_keys_test.go -package keys_test -typed . KeyManagementService
+//
+
+// Package keys_test is a generated GoMock package.
+package keys_test
+
+import (
+ context "context"
+ reflect "reflect"
+
+ gomock "go.uber.org/mock/gomock"
+)
+
+// MockKeyManagementService is a mock of KeyManagementService interface.
+type MockKeyManagementService struct {
+ ctrl *gomock.Controller
+ recorder *MockKeyManagementServiceMockRecorder
+}
+
+// MockKeyManagementServiceMockRecorder is the mock recorder for MockKeyManagementService.
+type MockKeyManagementServiceMockRecorder struct {
+ mock *MockKeyManagementService
+}
+
+// NewMockKeyManagementService creates a new mock instance.
+func NewMockKeyManagementService(ctrl *gomock.Controller) *MockKeyManagementService {
+ mock := &MockKeyManagementService{ctrl: ctrl}
+ mock.recorder = &MockKeyManagementServiceMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockKeyManagementService) EXPECT() *MockKeyManagementServiceMockRecorder {
+ return m.recorder
+}
+
+// Close mocks base method.
+func (m *MockKeyManagementService) Close() error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Close")
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Close indicates an expected call of Close.
+func (mr *MockKeyManagementServiceMockRecorder) Close() *MockKeyManagementServiceCloseCall {
+ mr.mock.ctrl.T.Helper()
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockKeyManagementService)(nil).Close))
+ return &MockKeyManagementServiceCloseCall{Call: call}
+}
+
+// MockKeyManagementServiceCloseCall wrap *gomock.Call
+type MockKeyManagementServiceCloseCall struct {
+ *gomock.Call
+}
+
+// Return rewrite *gomock.Call.Return
+func (c *MockKeyManagementServiceCloseCall) Return(arg0 error) *MockKeyManagementServiceCloseCall {
+ c.Call = c.Call.Return(arg0)
+ return c
+}
+
+// Do rewrite *gomock.Call.Do
+func (c *MockKeyManagementServiceCloseCall) Do(f func() error) *MockKeyManagementServiceCloseCall {
+ c.Call = c.Call.Do(f)
+ return c
+}
+
+// DoAndReturn rewrite *gomock.Call.DoAndReturn
+func (c *MockKeyManagementServiceCloseCall) DoAndReturn(f func() error) *MockKeyManagementServiceCloseCall {
+ c.Call = c.Call.DoAndReturn(f)
+ return c
+}
+
+// Decrypt mocks base method.
+func (m *MockKeyManagementService) Decrypt(arg0 context.Context, arg1 string, arg2 []byte) ([]byte, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Decrypt", arg0, arg1, arg2)
+ ret0, _ := ret[0].([]byte)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// Decrypt indicates an expected call of Decrypt.
+func (mr *MockKeyManagementServiceMockRecorder) Decrypt(arg0, arg1, arg2 any) *MockKeyManagementServiceDecryptCall {
+ mr.mock.ctrl.T.Helper()
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decrypt", reflect.TypeOf((*MockKeyManagementService)(nil).Decrypt), arg0, arg1, arg2)
+ return &MockKeyManagementServiceDecryptCall{Call: call}
+}
+
+// MockKeyManagementServiceDecryptCall wrap *gomock.Call
+type MockKeyManagementServiceDecryptCall struct {
+ *gomock.Call
+}
+
+// Return rewrite *gomock.Call.Return
+func (c *MockKeyManagementServiceDecryptCall) Return(arg0 []byte, arg1 error) *MockKeyManagementServiceDecryptCall {
+ c.Call = c.Call.Return(arg0, arg1)
+ return c
+}
+
+// Do rewrite *gomock.Call.Do
+func (c *MockKeyManagementServiceDecryptCall) Do(f func(context.Context, string, []byte) ([]byte, error)) *MockKeyManagementServiceDecryptCall {
+ c.Call = c.Call.Do(f)
+ return c
+}
+
+// DoAndReturn rewrite *gomock.Call.DoAndReturn
+func (c *MockKeyManagementServiceDecryptCall) DoAndReturn(f func(context.Context, string, []byte) ([]byte, error)) *MockKeyManagementServiceDecryptCall {
+ c.Call = c.Call.DoAndReturn(f)
+ return c
+}
+
+// Encrypt mocks base method.
+func (m *MockKeyManagementService) Encrypt(arg0 context.Context, arg1 string, arg2 []byte) ([]byte, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Encrypt", arg0, arg1, arg2)
+ ret0, _ := ret[0].([]byte)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// Encrypt indicates an expected call of Encrypt.
+func (mr *MockKeyManagementServiceMockRecorder) Encrypt(arg0, arg1, arg2 any) *MockKeyManagementServiceEncryptCall {
+ mr.mock.ctrl.T.Helper()
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encrypt", reflect.TypeOf((*MockKeyManagementService)(nil).Encrypt), arg0, arg1, arg2)
+ return &MockKeyManagementServiceEncryptCall{Call: call}
+}
+
+// MockKeyManagementServiceEncryptCall wrap *gomock.Call
+type MockKeyManagementServiceEncryptCall struct {
+ *gomock.Call
+}
+
+// Return rewrite *gomock.Call.Return
+func (c *MockKeyManagementServiceEncryptCall) Return(arg0 []byte, arg1 error) *MockKeyManagementServiceEncryptCall {
+ c.Call = c.Call.Return(arg0, arg1)
+ return c
+}
+
+// Do rewrite *gomock.Call.Do
+func (c *MockKeyManagementServiceEncryptCall) Do(f func(context.Context, string, []byte) ([]byte, error)) *MockKeyManagementServiceEncryptCall {
+ c.Call = c.Call.Do(f)
+ return c
+}
+
+// DoAndReturn rewrite *gomock.Call.DoAndReturn
+func (c *MockKeyManagementServiceEncryptCall) DoAndReturn(f func(context.Context, string, []byte) ([]byte, error)) *MockKeyManagementServiceEncryptCall {
+ c.Call = c.Call.DoAndReturn(f)
+ return c
+}