Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Secret store locking issue #12988

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions config/secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,23 @@ func (s *Secret) Get() ([]byte, error) {
return newsecret, protect(newsecret)
}

// Set overwrites the secret's value with a new one. Please note, the secret
// is not linked again, so only references to secret-stores can be used, e.g. by
// adding more clear-text or reordering secrets.
func (s *Secret) Set(value []byte) error {
// Link the new value can be resolved
secret, res, replaceErrs := resolve(value, s.resolvers)
if len(replaceErrs) > 0 {
return fmt.Errorf("linking new secrets failed: %s", strings.Join(replaceErrs, ";"))
}

// Set the new secret
s.enclave = memguard.NewEnclave(secret)
s.resolvers = res

return nil
}

// GetUnlinked return the parts of the secret that is not yet linked to a resolver
func (s *Secret) GetUnlinked() []string {
return s.unlinked
Expand All @@ -170,9 +187,6 @@ func (s *Secret) GetUnlinked() []string {
// Link used the given resolver map to link the secret parts to their
// secret-store resolvers.
func (s *Secret) Link(resolvers map[string]telegraf.ResolveFunc) error {
// Setup the resolver map
s.resolvers = make(map[string]telegraf.ResolveFunc)

// Decrypt the secret so we can return it
if s.enclave == nil {
return nil
Expand All @@ -184,9 +198,30 @@ func (s *Secret) Link(resolvers map[string]telegraf.ResolveFunc) error {
defer lockbuf.Destroy()
secret := lockbuf.Bytes()

// Iterate through the parts and try to resolve them. For static parts
// we directly replace them, while for dynamic ones we store the resolver.
newsecret, res, replaceErrs := resolve(secret, resolvers)
if len(replaceErrs) > 0 {
return fmt.Errorf("linking secrets failed: %s", strings.Join(replaceErrs, ";"))
}
s.resolvers = res

// Store the secret if it has changed
if string(secret) != string(newsecret) {
s.enclave = memguard.NewEnclave(newsecret)
}

// All linked now
s.unlinked = nil

return nil
}

func resolve(secret []byte, resolvers map[string]telegraf.ResolveFunc) ([]byte, map[string]telegraf.ResolveFunc, []string) {
// Iterate through the parts and try to resolve them. For static parts
// we directly replace them, while for dynamic ones we store the resolver.
replaceErrs := make([]string, 0)
remaining := make(map[string]telegraf.ResolveFunc)
newsecret := secretPattern.ReplaceAllFunc(secret, func(match []byte) []byte {
resolver, found := resolvers[string(match)]
if !found {
Expand All @@ -205,22 +240,10 @@ func (s *Secret) Link(resolvers map[string]telegraf.ResolveFunc) error {
}

// Keep the resolver for dynamic secrets
s.resolvers[string(match)] = resolver
remaining[string(match)] = resolver
return match
})
if len(replaceErrs) > 0 {
return fmt.Errorf("linking secrets failed: %s", strings.Join(replaceErrs, ";"))
}

// Store the secret if it has changed
if string(secret) != string(newsecret) {
s.enclave = memguard.NewEnclave(newsecret)
}

// All linked now
s.unlinked = nil

return nil
return newsecret, remaining, replaceErrs
}

func splitLink(s string) (storeid string, key string) {
Expand Down
101 changes: 101 additions & 0 deletions config/secret_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,8 @@ func TestSecretStoreInvalidReference(t *testing.T) {
}

func TestSecretStoreStaticChanging(t *testing.T) {
defer func() { unlinkedSecrets = make([]*Secret, 0) }()

cfg := []byte(
`
[[inputs.mockup]]
Expand Down Expand Up @@ -520,6 +522,8 @@ func TestSecretStoreStaticChanging(t *testing.T) {
}

func TestSecretStoreDynamic(t *testing.T) {
defer func() { unlinkedSecrets = make([]*Secret, 0) }()

cfg := []byte(
`
[[inputs.mockup]]
Expand Down Expand Up @@ -554,6 +558,8 @@ func TestSecretStoreDynamic(t *testing.T) {
}

func TestSecretStoreDeclarationMissingID(t *testing.T) {
defer func() { unlinkedSecrets = make([]*Secret, 0) }()

cfg := []byte(`[[secretstores.mockup]]`)

c := NewConfig()
Expand All @@ -562,6 +568,8 @@ func TestSecretStoreDeclarationMissingID(t *testing.T) {
}

func TestSecretStoreDeclarationInvalidID(t *testing.T) {
defer func() { unlinkedSecrets = make([]*Secret, 0) }()

invalidIDs := []string{"foo.bar", "dummy-123", "test!", "wohoo+"}
tmpl := `
[[secretstores.mockup]]
Expand All @@ -578,6 +586,8 @@ func TestSecretStoreDeclarationInvalidID(t *testing.T) {
}

func TestSecretStoreDeclarationValidID(t *testing.T) {
defer func() { unlinkedSecrets = make([]*Secret, 0) }()

validIDs := []string{"foobar", "dummy123", "test_id", "W0Hoo_lala123"}
tmpl := `
[[secretstores.mockup]]
Expand All @@ -593,6 +603,97 @@ func TestSecretStoreDeclarationValidID(t *testing.T) {
}
}

func TestSecretSet(t *testing.T) {
defer func() { unlinkedSecrets = make([]*Secret, 0) }()

cfg := []byte(`
[[inputs.mockup]]
secret = "a secret"
`)
c := NewConfig()
require.NoError(t, c.LoadConfigData(cfg))
require.Len(t, c.Inputs, 1)
require.NoError(t, c.LinkSecrets())

plugin := c.Inputs[0].Input.(*MockupSecretPlugin)

secret, err := plugin.Secret.Get()
require.NoError(t, err)
defer ReleaseSecret(secret)
require.EqualValues(t, "a secret", string(secret))

require.NoError(t, plugin.Secret.Set([]byte("another secret")))
newsecret, err := plugin.Secret.Get()
require.NoError(t, err)
defer ReleaseSecret(newsecret)
require.EqualValues(t, "another secret", string(newsecret))
}

func TestSecretSetResolve(t *testing.T) {
defer func() { unlinkedSecrets = make([]*Secret, 0) }()

cfg := []byte(`
[[inputs.mockup]]
secret = "@{mock:secret}"
`)
c := NewConfig()
require.NoError(t, c.LoadConfigData(cfg))
require.Len(t, c.Inputs, 1)

// Create a mockup secretstore
store := &MockupSecretStore{
Secrets: map[string][]byte{"secret": []byte("Ood Bnar")},
Dynamic: true,
}
require.NoError(t, store.Init())
c.SecretStores["mock"] = store
require.NoError(t, c.LinkSecrets())

plugin := c.Inputs[0].Input.(*MockupSecretPlugin)

secret, err := plugin.Secret.Get()
require.NoError(t, err)
defer ReleaseSecret(secret)
require.EqualValues(t, "Ood Bnar", string(secret))

require.NoError(t, plugin.Secret.Set([]byte("@{mock:secret} is cool")))
newsecret, err := plugin.Secret.Get()
require.NoError(t, err)
defer ReleaseSecret(newsecret)
require.EqualValues(t, "Ood Bnar is cool", string(newsecret))
}

func TestSecretSetResolveInvalid(t *testing.T) {
defer func() { unlinkedSecrets = make([]*Secret, 0) }()

cfg := []byte(`
[[inputs.mockup]]
secret = "@{mock:secret}"
`)
c := NewConfig()
require.NoError(t, c.LoadConfigData(cfg))
require.Len(t, c.Inputs, 1)

// Create a mockup secretstore
store := &MockupSecretStore{
Secrets: map[string][]byte{"secret": []byte("Ood Bnar")},
Dynamic: true,
}
require.NoError(t, store.Init())
c.SecretStores["mock"] = store
require.NoError(t, c.LinkSecrets())

plugin := c.Inputs[0].Input.(*MockupSecretPlugin)

secret, err := plugin.Secret.Get()
require.NoError(t, err)
defer ReleaseSecret(secret)
require.EqualValues(t, "Ood Bnar", string(secret))

err = plugin.Secret.Set([]byte("@{mock:another_secret}"))
require.ErrorContains(t, err, `linking new secrets failed: unlinked part "@{mock:another_secret}"`)
}

/*** Mockup (input) plugin for testing to avoid cyclic dependencies ***/
type MockupSecretPlugin struct {
Secret Secret `toml:"secret"`
Expand Down
13 changes: 6 additions & 7 deletions config/secret_with_mlock.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
package config

import (
"syscall"

"github.com/awnumar/memguard"
)

func protect(secret []byte) error {
return syscall.Mlock(secret)
func protect(_ []byte) error {
//return syscall.Mlock(secret)
return nil
}

func ReleaseSecret(secret []byte) {
memguard.WipeBytes(secret)
if err := syscall.Munlock(secret); err != nil {
panic(err)
}
// if err := syscall.Munlock(secret); err != nil {
// panic(err)
// }
}
65 changes: 34 additions & 31 deletions plugins/inputs/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,31 @@ import (
var sampleConfig string

type Mysql struct {
Servers []config.Secret `toml:"servers"`
PerfEventsStatementsDigestTextLimit int64 `toml:"perf_events_statements_digest_text_limit"`
PerfEventsStatementsLimit int64 `toml:"perf_events_statements_limit"`
PerfEventsStatementsTimeLimit int64 `toml:"perf_events_statements_time_limit"`
TableSchemaDatabases []string `toml:"table_schema_databases"`
GatherProcessList bool `toml:"gather_process_list"`
GatherUserStatistics bool `toml:"gather_user_statistics"`
GatherInfoSchemaAutoInc bool `toml:"gather_info_schema_auto_inc"`
GatherInnoDBMetrics bool `toml:"gather_innodb_metrics"`
GatherSlaveStatus bool `toml:"gather_slave_status"`
GatherAllSlaveChannels bool `toml:"gather_all_slave_channels"`
MariadbDialect bool `toml:"mariadb_dialect"`
GatherBinaryLogs bool `toml:"gather_binary_logs"`
GatherTableIOWaits bool `toml:"gather_table_io_waits"`
GatherTableLockWaits bool `toml:"gather_table_lock_waits"`
GatherIndexIOWaits bool `toml:"gather_index_io_waits"`
GatherEventWaits bool `toml:"gather_event_waits"`
GatherTableSchema bool `toml:"gather_table_schema"`
GatherFileEventsStats bool `toml:"gather_file_events_stats"`
GatherPerfEventsStatements bool `toml:"gather_perf_events_statements"`
GatherGlobalVars bool `toml:"gather_global_variables"`
GatherPerfSummaryPerAccountPerEvent bool `toml:"gather_perf_sum_per_acc_per_event"`
PerfSummaryEvents []string `toml:"perf_summary_events"`
IntervalSlow config.Duration `toml:"interval_slow"`
MetricVersion int `toml:"metric_version"`
Servers []*config.Secret `toml:"servers"`
PerfEventsStatementsDigestTextLimit int64 `toml:"perf_events_statements_digest_text_limit"`
PerfEventsStatementsLimit int64 `toml:"perf_events_statements_limit"`
PerfEventsStatementsTimeLimit int64 `toml:"perf_events_statements_time_limit"`
TableSchemaDatabases []string `toml:"table_schema_databases"`
GatherProcessList bool `toml:"gather_process_list"`
GatherUserStatistics bool `toml:"gather_user_statistics"`
GatherInfoSchemaAutoInc bool `toml:"gather_info_schema_auto_inc"`
GatherInnoDBMetrics bool `toml:"gather_innodb_metrics"`
GatherSlaveStatus bool `toml:"gather_slave_status"`
GatherAllSlaveChannels bool `toml:"gather_all_slave_channels"`
MariadbDialect bool `toml:"mariadb_dialect"`
GatherBinaryLogs bool `toml:"gather_binary_logs"`
GatherTableIOWaits bool `toml:"gather_table_io_waits"`
GatherTableLockWaits bool `toml:"gather_table_lock_waits"`
GatherIndexIOWaits bool `toml:"gather_index_io_waits"`
GatherEventWaits bool `toml:"gather_event_waits"`
GatherTableSchema bool `toml:"gather_table_schema"`
GatherFileEventsStats bool `toml:"gather_file_events_stats"`
GatherPerfEventsStatements bool `toml:"gather_perf_events_statements"`
GatherGlobalVars bool `toml:"gather_global_variables"`
GatherPerfSummaryPerAccountPerEvent bool `toml:"gather_perf_sum_per_acc_per_event"`
PerfSummaryEvents []string `toml:"perf_summary_events"`
IntervalSlow config.Duration `toml:"interval_slow"`
MetricVersion int `toml:"metric_version"`

Log telegraf.Logger `toml:"-"`
tls.ClientConfig
Expand Down Expand Up @@ -80,7 +80,8 @@ func (m *Mysql) Init() error {

// Default to localhost if nothing specified.
if len(m.Servers) == 0 {
m.Servers = append(m.Servers, config.NewSecret([]byte(localhost)))
s := config.NewSecret([]byte(localhost))
m.Servers = append(m.Servers, &s)
}

// Register the TLS configuration. Due to the registry being a global
Expand All @@ -106,7 +107,7 @@ func (m *Mysql) Init() error {
for i, server := range m.Servers {
s, err := server.Get()
if err != nil {
return fmt.Errorf("getting server %d failed", i)
return fmt.Errorf("getting server %d failed: %w", i, err)
}
dsn := string(s)
config.ReleaseSecret(s)
Expand All @@ -125,8 +126,10 @@ func (m *Mysql) Init() error {
conf.TLSConfig = tlsid
}

server.Destroy()
m.Servers[i] = config.NewSecret([]byte(conf.FormatDSN()))
if err := server.Set([]byte(conf.FormatDSN())); err != nil {
return fmt.Errorf("replacing server %q failed: %w", dsn, err)
}
m.Servers[i] = server
}

return nil
Expand All @@ -138,7 +141,7 @@ func (m *Mysql) Gather(acc telegraf.Accumulator) error {
// Loop through each server and collect metrics
for _, server := range m.Servers {
wg.Add(1)
go func(s config.Secret) {
go func(s *config.Secret) {
defer wg.Done()
acc.AddError(m.gatherServer(s, acc))
}(server)
Expand Down Expand Up @@ -411,7 +414,7 @@ const (
`
)

func (m *Mysql) gatherServer(server config.Secret, acc telegraf.Accumulator) error {
func (m *Mysql) gatherServer(server *config.Secret, acc telegraf.Accumulator) error {
s, err := server.Get()
if err != nil {
return err
Expand Down
Loading