Skip to content

Commit

Permalink
feat: Sensitive type (#1284)
Browse files Browse the repository at this point in the history
  • Loading branch information
adityathebe authored Jan 23, 2025
1 parent 1546617 commit da73584
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 33 deletions.
42 changes: 42 additions & 0 deletions secret/ciphertext.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package secret

import (
"encoding/base64"
"encoding/json"
"fmt"
"strings"
)

const ciphertextPrefix = "enc:"

type Ciphertext []byte

func (t Ciphertext) String() string {
return fmt.Sprintf("%s%s", ciphertextPrefix, base64.StdEncoding.EncodeToString(t))
}

func (t Ciphertext) MarshalJSON() ([]byte, error) {
return json.Marshal(t.String())
}

func (t Ciphertext) MarshalText() ([]byte, error) {
return []byte(t.String()), nil
}

func ParseCiphertext(s string) (Ciphertext, error) {
if !strings.HasPrefix(s, ciphertextPrefix) {
return nil, fmt.Errorf("invalid ciphertext prefix")
}

encoded := s[len(ciphertextPrefix):]
if encoded == "" {
return nil, fmt.Errorf("empty ciphertext")
}

data, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return nil, err
}

return Ciphertext(data), nil
}
104 changes: 104 additions & 0 deletions secret/keeper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package secret

import (
"fmt"
"sync"
"time"

"github.com/flanksource/duty/connection"
"github.com/flanksource/duty/context"
"github.com/flanksource/duty/models"
"github.com/patrickmn/go-cache"
"github.com/samber/lo"
"github.com/samber/oops"
"gocloud.dev/secrets"
)

const defaultKeeperTTL = time.Minute * 10

var (
keeperCache = cache.New(defaultKeeperTTL, defaultKeeperTTL*2)

// keeperLock locks access to the keeperCache
keeperLock sync.RWMutex
)

var (
// KMSConnection is the connection to the key management service
// that's used to encrypt and decrypt secrets.
KMSConnection string

allowedConnectionTypes = []string{
models.ConnectionTypeAWSKMS,
models.ConnectionTypeGCPKMS,
models.ConnectionTypeAzureKeyVault,
// Vault not supported yet
}
)

func init() {
keeperCache.OnEvicted(func(key string, keeper interface{}) {
if keeper != nil {
keeper.(*secrets.Keeper).Close()
}
})
}

// createOrGetKeeper creates a new Keeper from the KMSConnection if it doesn't
// exist in the cache, otherwise it returns the cached Keeper.
func createOrGetKeeper(ctx context.Context) (*secrets.Keeper, error) {
if KMSConnection == "" {
return nil, oops.Errorf("secret keeper connection is not set")
}

keeperLock.RLock()
cached, ok := keeperCache.Get("keeper")
keeperLock.RUnlock()
if ok {
return cached.(*secrets.Keeper), nil
}

keeperLock.Lock()
defer keeperLock.Unlock()

keeper, err := KeeperFromConnection(ctx, KMSConnection)
if err != nil {
return nil, err
}

ttl := ctx.Properties().Duration("secretkeeper.cache.ttl", defaultKeeperTTL)
keeperCache.Set("keeper", keeper, ttl)
return keeper, nil
}

func KeeperFromConnection(ctx context.Context, connectionString string) (*secrets.Keeper, error) {
conn, err := ctx.HydrateConnectionByURL(connectionString)
if err != nil {
return nil, fmt.Errorf("failed to hydrate connection: %w", err)
} else if conn == nil {
return nil, fmt.Errorf("connection not found: %s", connectionString)
}

if !lo.Contains(allowedConnectionTypes, conn.Type) {
return nil, fmt.Errorf("connection type %s cannot be used to create a SecretKeeper", conn.Type)
}

switch conn.Type {
case models.ConnectionTypeAWSKMS:
var kmsConn connection.AWSKMS
kmsConn.FromModel(*conn)
return kmsConn.SecretKeeper(ctx)

case models.ConnectionTypeAzureKeyVault:
var keyvaultConn connection.AzureKeyVault
keyvaultConn.FromModel(*conn)
return keyvaultConn.SecretKeeper(ctx)

case models.ConnectionTypeGCPKMS:
var kmsConn connection.GCPKMS
kmsConn.FromModel(*conn)
return kmsConn.SecretKeeper(ctx)
}

return nil, nil
}
51 changes: 18 additions & 33 deletions secret/secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,33 @@ package secret
import (
"fmt"

"github.com/flanksource/duty/connection"
"github.com/flanksource/duty/context"
"github.com/flanksource/duty/models"
"github.com/samber/lo"
"gocloud.dev/secrets"
)

var allowedConnectionTypes = []string{
models.ConnectionTypeAWSKMS,
models.ConnectionTypeGCPKMS,
models.ConnectionTypeAzureKeyVault,
// Vault not supported yet
}
func Encrypt(ctx context.Context, sensitive Sensitive) (Ciphertext, error) {
keeper, err := createOrGetKeeper(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get secret keeper from connection (%s): %w", KMSConnection, err)
}

func KeeperFromConnection(ctx context.Context, connectionString string) (*secrets.Keeper, error) {
conn, err := ctx.HydrateConnectionByURL(connectionString)
ciphertext, err := keeper.Encrypt(ctx, []byte(sensitive.PlainText()))
if err != nil {
return nil, fmt.Errorf("failed to hydrate connection: %w", err)
} else if conn == nil {
return nil, fmt.Errorf("connection not found: %s", connectionString)
return nil, fmt.Errorf("failed to encrypt secret: %w", err)
}

if !lo.Contains(allowedConnectionTypes, conn.Type) {
return nil, fmt.Errorf("connection type %s cannot be used to create a SecretKeeper", conn.Type)
return ciphertext, nil
}

func Decrypt(ctx context.Context, ciphertext Ciphertext) (Sensitive, error) {
keeper, err := createOrGetKeeper(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get secret keeper from connection (%s): %w", KMSConnection, err)
}

switch conn.Type {
case models.ConnectionTypeAWSKMS:
var kmsConn connection.AWSKMS
kmsConn.FromModel(*conn)
return kmsConn.SecretKeeper(ctx)

case models.ConnectionTypeAzureKeyVault:
var keyvaultConn connection.AzureKeyVault
keyvaultConn.FromModel(*conn)
return keyvaultConn.SecretKeeper(ctx)

case models.ConnectionTypeGCPKMS:
var kmsConn connection.GCPKMS
kmsConn.FromModel(*conn)
return kmsConn.SecretKeeper(ctx)
decrypted, err := keeper.Decrypt(ctx, ciphertext)
if err != nil {
return nil, fmt.Errorf("failed to decrypt secret: %w", err)
}

return nil, nil
return Sensitive(decrypted), nil
}
31 changes: 31 additions & 0 deletions secret/sensitive.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package secret

import "encoding/json"

const sensitivePlaceholder = "[REDACTED]"

type Sensitive []byte

func (t Sensitive) String() string {
return sensitivePlaceholder
}

func (t Sensitive) PlainText() string {
return string(t)
}

func (t Sensitive) MarshalJSON() ([]byte, error) {
return json.Marshal(sensitivePlaceholder)
}

func (t Sensitive) MarshalText() ([]byte, error) {
return []byte(sensitivePlaceholder), nil
}

func (t *Sensitive) Clear() {
*t = make([]byte, len(*t))
for i := range *t {
(*t)[i] = 0
}
*t = nil
}
69 changes: 69 additions & 0 deletions secret/sensitive_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package secret

import (
"bytes"
"encoding/json"
"fmt"
"log/slog"
"testing"
)

func TestSecretString(t *testing.T) {
const secretKey = "my_secret_string"

t.Run("simple .String()", func(t *testing.T) {
secret := Sensitive(secretKey)
if secret.String() != sensitivePlaceholder {
t.Errorf("Expected secret.String() to return %s", sensitivePlaceholder)
}
})

t.Run("formatted", func(t *testing.T) {
secret := Sensitive(secretKey)
if fmt.Sprintf("%s.", secret) != sensitivePlaceholder+"." { // added a period to avoid LSP warning
t.Errorf("Expected secret.String() to return %s", sensitivePlaceholder)
}
})

t.Run("JSON", func(t *testing.T) {
type myJSON struct {
Secret Sensitive
}

m := myJSON{
Secret: Sensitive(secretKey),
}
marshalled, err := json.Marshal(m)
if err != nil {
t.Errorf("Failed to marshal JSON: %s", err)
}

if string(marshalled) != fmt.Sprintf(`{"Secret":"%s"}`, sensitivePlaceholder) {
t.Errorf("Expected marshalled JSON to contain redacted")
}
})

t.Run("Clear", func(t *testing.T) {
secret := Sensitive(secretKey)
secret.Clear()
if len(secret) != 0 {
t.Errorf("Expected secret to be cleared")
}
})

t.Run("PlainText", func(t *testing.T) {
secret := Sensitive(secretKey)
if secret.PlainText() != secretKey {
t.Errorf("Expected secret to match plain text")
}
})

t.Run("Logger", func(t *testing.T) {
var buffer bytes.Buffer
myLogger := slog.New(slog.NewTextHandler(&buffer, nil))
myLogger.Info("secret: %s", slog.Any("secret", Sensitive(secretKey)))
if bytes.Contains(buffer.Bytes(), []byte(secretKey)) || !bytes.Contains(buffer.Bytes(), []byte(sensitivePlaceholder)) {
t.Errorf("Expected log to contain redacted")
}
})
}

0 comments on commit da73584

Please sign in to comment.