diff options
Diffstat (limited to 'backend/internal/keys/gcp.go')
-rw-r--r-- | backend/internal/keys/gcp.go | 131 |
1 files changed, 131 insertions, 0 deletions
diff --git a/backend/internal/keys/gcp.go b/backend/internal/keys/gcp.go new file mode 100644 index 0000000..9d10fc5 --- /dev/null +++ b/backend/internal/keys/gcp.go @@ -0,0 +1,131 @@ +package keys + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "hash/crc32" + "sync" + + kms "cloud.google.com/go/kms/apiv1" + "cloud.google.com/go/kms/apiv1/kmspb" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +type GoogleKMS struct { + client *kms.KeyManagementClient + + mx sync.RWMutex + keyCache map[string]*rsa.PublicKey +} + +func NewGoogleKMS(ctx context.Context) (*GoogleKMS, error) { + client, err := kms.NewKeyManagementClient(ctx) + if err != nil { + return nil, err + } + + return &GoogleKMS{ + client: client, + keyCache: make(map[string]*rsa.PublicKey), + }, nil +} + +func (g *GoogleKMS) checkCache(keyName string) *rsa.PublicKey { + g.mx.RLock() + defer g.mx.RUnlock() + + return g.keyCache[keyName] +} + +func (g *GoogleKMS) getPublicKey(ctx context.Context, keyName string) (*rsa.PublicKey, error) { + if key := g.checkCache(keyName); key != nil { + return key, nil + } + + response, err := g.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{Name: keyName}) + if err != nil { + return nil, err + } + + block, _ := pem.Decode([]byte(response.Pem)) + if block == nil || block.Type != "PUBLIC KEY" { + return nil, errors.New("failed to decode PEM public key") + } + publicKey, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse public key: %w", err) + } + rsaKey, ok := publicKey.(*rsa.PublicKey) + if !ok { + return nil, errors.New("public key is not an RSA key") + } + + g.mx.Lock() + defer g.mx.Unlock() + g.keyCache[keyName] = rsaKey + + return rsaKey, nil +} + +func (g *GoogleKMS) Encrypt(ctx context.Context, keyName string, plaintext []byte) ([]byte, error) { + publicKey, err := g.getPublicKey(ctx, keyName) + if err != nil { + return nil, err + } + + cipherText, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, publicKey, plaintext, nil) + if err != nil { + return nil, fmt.Errorf("failed to encrypt plaintext: %w", err) + } + + return cipherText, nil +} + +func (g *GoogleKMS) Decrypt(ctx context.Context, keyName string, ciphertext []byte) ([]byte, error) { + req := &kmspb.AsymmetricDecryptRequest{ + Name: keyName, + Ciphertext: ciphertext, + CiphertextCrc32C: wrapperspb.Int64(int64(calcCRC32(ciphertext))), + } + + result, err := g.client.AsymmetricDecrypt(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to decrypt ciphertext: %w", err) + } + + if !result.VerifiedCiphertextCrc32C { + return nil, errors.New("AsymmetricDecrypt: request corrupted in-transit") + } + if int64(calcCRC32(result.Plaintext)) != result.PlaintextCrc32C.Value { + return nil, fmt.Errorf("AsymmetricDecrypt: response corrupted in-transit") + } + + return result.Plaintext, nil +} + +func (g *GoogleKMS) Close() error { + return g.client.Close() +} + +func calcCRC32(data []byte) uint32 { + t := crc32.MakeTable(crc32.Castagnoli) + return crc32.Checksum(data, t) +} + +type GCPKeyName struct { + Project string + Location string + KeyRing string + CryptoKey string + CryptoKeyVersion string +} + +func (k GCPKeyName) String() string { + return fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s/cryptoKeyVersions/%s", k.Project, k.Location, k.KeyRing, k.CryptoKey, k.CryptoKeyVersion) +} |