aboutsummaryrefslogtreecommitdiff
path: root/backend/internal/keys/gcp.go
diff options
context:
space:
mode:
authorGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-11 13:15:50 -0700
committerGravatar Anshul Gupta <ansg191@anshulg.com> 2024-08-11 13:15:50 -0700
commit6a3c21fb0b1c126849f2bbff494403bbe901448e (patch)
tree5d7805524357c2c8a9819c39d2051a4e3633a1d5 /backend/internal/keys/gcp.go
parent29c6040a51616e9e4cf6c70ee16391b2a3b238c9 (diff)
parentf34b92ded11b07f78575ac62c260a380c468e5ea (diff)
downloadibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.tar.gz
ibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.tar.zst
ibd-trader-6a3c21fb0b1c126849f2bbff494403bbe901448e.zip
Merge remote-tracking branch 'backend/main'
Diffstat (limited to 'backend/internal/keys/gcp.go')
-rw-r--r--backend/internal/keys/gcp.go131
1 files changed, 131 insertions, 0 deletions
diff --git a/backend/internal/keys/gcp.go b/backend/internal/keys/gcp.go
new file mode 100644
index 0000000..9d10fc5
--- /dev/null
+++ b/backend/internal/keys/gcp.go
@@ -0,0 +1,131 @@
+package keys
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/sha256"
+ "crypto/x509"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "hash/crc32"
+ "sync"
+
+ kms "cloud.google.com/go/kms/apiv1"
+ "cloud.google.com/go/kms/apiv1/kmspb"
+ "google.golang.org/protobuf/types/known/wrapperspb"
+)
+
+type GoogleKMS struct {
+ client *kms.KeyManagementClient
+
+ mx sync.RWMutex
+ keyCache map[string]*rsa.PublicKey
+}
+
+func NewGoogleKMS(ctx context.Context) (*GoogleKMS, error) {
+ client, err := kms.NewKeyManagementClient(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return &GoogleKMS{
+ client: client,
+ keyCache: make(map[string]*rsa.PublicKey),
+ }, nil
+}
+
+func (g *GoogleKMS) checkCache(keyName string) *rsa.PublicKey {
+ g.mx.RLock()
+ defer g.mx.RUnlock()
+
+ return g.keyCache[keyName]
+}
+
+func (g *GoogleKMS) getPublicKey(ctx context.Context, keyName string) (*rsa.PublicKey, error) {
+ if key := g.checkCache(keyName); key != nil {
+ return key, nil
+ }
+
+ response, err := g.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{Name: keyName})
+ if err != nil {
+ return nil, err
+ }
+
+ block, _ := pem.Decode([]byte(response.Pem))
+ if block == nil || block.Type != "PUBLIC KEY" {
+ return nil, errors.New("failed to decode PEM public key")
+ }
+ publicKey, err := x509.ParsePKIXPublicKey(block.Bytes)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse public key: %w", err)
+ }
+ rsaKey, ok := publicKey.(*rsa.PublicKey)
+ if !ok {
+ return nil, errors.New("public key is not an RSA key")
+ }
+
+ g.mx.Lock()
+ defer g.mx.Unlock()
+ g.keyCache[keyName] = rsaKey
+
+ return rsaKey, nil
+}
+
+func (g *GoogleKMS) Encrypt(ctx context.Context, keyName string, plaintext []byte) ([]byte, error) {
+ publicKey, err := g.getPublicKey(ctx, keyName)
+ if err != nil {
+ return nil, err
+ }
+
+ cipherText, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, publicKey, plaintext, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to encrypt plaintext: %w", err)
+ }
+
+ return cipherText, nil
+}
+
+func (g *GoogleKMS) Decrypt(ctx context.Context, keyName string, ciphertext []byte) ([]byte, error) {
+ req := &kmspb.AsymmetricDecryptRequest{
+ Name: keyName,
+ Ciphertext: ciphertext,
+ CiphertextCrc32C: wrapperspb.Int64(int64(calcCRC32(ciphertext))),
+ }
+
+ result, err := g.client.AsymmetricDecrypt(ctx, req)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decrypt ciphertext: %w", err)
+ }
+
+ if !result.VerifiedCiphertextCrc32C {
+ return nil, errors.New("AsymmetricDecrypt: request corrupted in-transit")
+ }
+ if int64(calcCRC32(result.Plaintext)) != result.PlaintextCrc32C.Value {
+ return nil, fmt.Errorf("AsymmetricDecrypt: response corrupted in-transit")
+ }
+
+ return result.Plaintext, nil
+}
+
+func (g *GoogleKMS) Close() error {
+ return g.client.Close()
+}
+
+func calcCRC32(data []byte) uint32 {
+ t := crc32.MakeTable(crc32.Castagnoli)
+ return crc32.Checksum(data, t)
+}
+
+type GCPKeyName struct {
+ Project string
+ Location string
+ KeyRing string
+ CryptoKey string
+ CryptoKeyVersion string
+}
+
+func (k GCPKeyName) String() string {
+ return fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s/cryptoKeyVersions/%s", k.Project, k.Location, k.KeyRing, k.CryptoKey, k.CryptoKeyVersion)
+}