package keys import ( "context" "crypto/rand" "crypto/rsa" "crypto/sha256" "crypto/x509" "encoding/pem" "errors" "fmt" "hash/crc32" "sync" kms "cloud.google.com/go/kms/apiv1" "cloud.google.com/go/kms/apiv1/kmspb" "google.golang.org/protobuf/types/known/wrapperspb" ) type GoogleKMS struct { client *kms.KeyManagementClient mx sync.RWMutex keyCache map[string]*rsa.PublicKey } func NewGoogleKMS(ctx context.Context) (*GoogleKMS, error) { client, err := kms.NewKeyManagementClient(ctx) if err != nil { return nil, err } return &GoogleKMS{ client: client, keyCache: make(map[string]*rsa.PublicKey), }, nil } func (g *GoogleKMS) checkCache(keyName string) *rsa.PublicKey { g.mx.RLock() defer g.mx.RUnlock() return g.keyCache[keyName] } func (g *GoogleKMS) getPublicKey(ctx context.Context, keyName string) (*rsa.PublicKey, error) { if key := g.checkCache(keyName); key != nil { return key, nil } response, err := g.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{Name: keyName}) if err != nil { return nil, err } block, _ := pem.Decode([]byte(response.Pem)) if block == nil || block.Type != "PUBLIC KEY" { return nil, errors.New("failed to decode PEM public key") } publicKey, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { return nil, fmt.Errorf("failed to parse public key: %w", err) } rsaKey, ok := publicKey.(*rsa.PublicKey) if !ok { return nil, errors.New("public key is not an RSA key") } g.mx.Lock() defer g.mx.Unlock() g.keyCache[keyName] = rsaKey return rsaKey, nil } func (g *GoogleKMS) Encrypt(ctx context.Context, keyName string, plaintext []byte) ([]byte, error) { publicKey, err := g.getPublicKey(ctx, keyName) if err != nil { return nil, err } cipherText, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, publicKey, plaintext, nil) if err != nil { return nil, fmt.Errorf("failed to encrypt plaintext: %w", err) } return cipherText, nil } func (g *GoogleKMS) Decrypt(ctx context.Context, keyName string, ciphertext []byte) ([]byte, error) { req := &kmspb.AsymmetricDecryptRequest{ Name: keyName, Ciphertext: ciphertext, CiphertextCrc32C: wrapperspb.Int64(int64(calcCRC32(ciphertext))), } result, err := g.client.AsymmetricDecrypt(ctx, req) if err != nil { return nil, fmt.Errorf("failed to decrypt ciphertext: %w", err) } if !result.VerifiedCiphertextCrc32C { return nil, errors.New("AsymmetricDecrypt: request corrupted in-transit") } if int64(calcCRC32(result.Plaintext)) != result.PlaintextCrc32C.Value { return nil, fmt.Errorf("AsymmetricDecrypt: response corrupted in-transit") } return result.Plaintext, nil } func (g *GoogleKMS) Close() error { return g.client.Close() } func calcCRC32(data []byte) uint32 { t := crc32.MakeTable(crc32.Castagnoli) return crc32.Checksum(data, t) } type GCPKeyName struct { Project string Location string KeyRing string CryptoKey string CryptoKeyVersion string } func (k GCPKeyName) String() string { return fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s/cryptoKeyVersions/%s", k.Project, k.Location, k.KeyRing, k.CryptoKey, k.CryptoKeyVersion) }