Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: introduced RWMutex to flag state to prevent concurrent r/w of map #370

Merged
merged 2 commits into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pkg/eval/fractional_evaluation_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package eval

import (
"sync"
"testing"

"github.com/open-feature/flagd/pkg/logger"
Expand All @@ -10,6 +11,7 @@ import (

func TestFractionalEvaluation(t *testing.T) {
flags := Flags{
mx: &sync.RWMutex{},
Flags: map[string]Flag{
"headerColor": {
State: "ENABLED",
Expand Down Expand Up @@ -113,6 +115,7 @@ func TestFractionalEvaluation(t *testing.T) {
},
"non even split": {
flags: Flags{
mx: &sync.RWMutex{},
Flags: map[string]Flag{
"headerColor": {
State: "ENABLED",
Expand Down Expand Up @@ -164,6 +167,7 @@ func TestFractionalEvaluation(t *testing.T) {
},
"fallback to default variant if no email provided": {
flags: Flags{
mx: &sync.RWMutex{},
Flags: map[string]Flag{
"headerColor": {
State: "ENABLED",
Expand Down Expand Up @@ -206,6 +210,7 @@ func TestFractionalEvaluation(t *testing.T) {
},
"fallback to default variant if invalid variant as result of fractional evaluation": {
flags: Flags{
mx: &sync.RWMutex{},
Flags: map[string]Flag{
"headerColor": {
State: "ENABLED",
Expand Down Expand Up @@ -240,6 +245,7 @@ func TestFractionalEvaluation(t *testing.T) {
},
"fallback to default variant if percentages don't sum to 100": {
flags: Flags{
mx: &sync.RWMutex{},
Flags: map[string]Flag{
"headerColor": {
State: "ENABLED",
Expand Down
18 changes: 16 additions & 2 deletions pkg/eval/json_evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"regexp"
"strconv"
"strings"
mxSync "sync"

"github.com/open-feature/flagd/pkg/sync"

Expand Down Expand Up @@ -47,6 +48,7 @@ func NewJSONEvaluator(logger *logger.Logger) *JSONEvaluator {
),
state: Flags{
Flags: map[string]Flag{},
mx: &mxSync.RWMutex{},
},
}
jsonlogic.AddOperator("fractionalEvaluation", ev.fractionalEvaluation)
Expand Down Expand Up @@ -110,6 +112,8 @@ func (je *JSONEvaluator) ResolveAllValues(reqID string, context *structpb.Struct
var variant string
var reason string
var err error
je.state.mx.RLock()
defer je.state.mx.RUnlock()
for flagKey, flag := range je.state.Flags {
defaultValue := flag.Variants[flag.DefaultVariant]
switch defaultValue.(type) {
Expand Down Expand Up @@ -161,6 +165,8 @@ func (je *JSONEvaluator) ResolveBooleanValue(reqID string, flagKey string, conte
reason string,
err error,
) {
je.state.mx.RLock()
defer je.state.mx.RUnlock()
je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating boolean flag: %s", flagKey))
return resolve[bool](reqID, flagKey, context, je.evaluateVariant, je.state.Flags[flagKey].Variants)
}
Expand All @@ -171,6 +177,8 @@ func (je *JSONEvaluator) ResolveStringValue(reqID string, flagKey string, contex
reason string,
err error,
) {
je.state.mx.RLock()
defer je.state.mx.RUnlock()
je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating string flag: %s", flagKey))
return resolve[string](reqID, flagKey, context, je.evaluateVariant, je.state.Flags[flagKey].Variants)
}
Expand All @@ -181,6 +189,8 @@ func (je *JSONEvaluator) ResolveFloatValue(reqID string, flagKey string, context
reason string,
err error,
) {
je.state.mx.RLock()
defer je.state.mx.RUnlock()
je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating float flag: %s", flagKey))
value, variant, reason, err = resolve[float64](
reqID, flagKey, context, je.evaluateVariant, je.state.Flags[flagKey].Variants)
Expand All @@ -193,6 +203,8 @@ func (je *JSONEvaluator) ResolveIntValue(reqID string, flagKey string, context *
reason string,
err error,
) {
je.state.mx.RLock()
defer je.state.mx.RUnlock()
je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating int flag: %s", flagKey))
var val float64
val, variant, reason, err = resolve[float64](
Expand All @@ -207,6 +219,8 @@ func (je *JSONEvaluator) ResolveObjectValue(reqID string, flagKey string, contex
reason string,
err error,
) {
je.state.mx.RLock()
defer je.state.mx.RUnlock()
je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating object flag: %s", flagKey))
return resolve[map[string]any](reqID, flagKey, context, je.evaluateVariant, je.state.Flags[flagKey].Variants)
}
Expand Down Expand Up @@ -256,7 +270,7 @@ func (je *JSONEvaluator) evaluateVariant(
variant = strings.ReplaceAll(strings.TrimSpace(result.String()), "\"", "")

// if this is a valid variant, return it
if _, ok := je.state.Flags[flagKey].Variants[variant]; ok {
if _, ok := flag.Variants[variant]; ok {
return variant, model.TargetingMatchReason, nil
}

Expand All @@ -266,7 +280,7 @@ func (je *JSONEvaluator) evaluateVariant(
reason = model.StaticReason
}

return je.state.Flags[flagKey].DefaultVariant, reason, nil
return flag.DefaultVariant, reason, nil
}

// configToFlags convert string configurations to flags and store them to pointer newFlags
Expand Down
34 changes: 29 additions & 5 deletions pkg/eval/json_evaluator_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"reflect"
"sync"

"github.com/open-feature/flagd/pkg/logger"
)
Expand All @@ -21,6 +22,7 @@ type Evaluators struct {
}

type Flags struct {
mx *sync.RWMutex
Flags map[string]Flag `json:"flags"`
}

Expand All @@ -29,7 +31,10 @@ func (f Flags) Add(logger *logger.Logger, source string, ff Flags) map[string]in
notifications := map[string]interface{}{}

for k, newFlag := range ff.Flags {
if storedFlag, ok := f.Flags[k]; ok && storedFlag.Source != source {
f.mx.RLock()
storedFlag, ok := f.Flags[k]
f.mx.RUnlock()
if ok && storedFlag.Source != source {
logger.Warn(fmt.Sprintf(
"flag with key %s from source %s already exist, overriding this with flag from source %s",
k,
Expand All @@ -45,7 +50,9 @@ func (f Flags) Add(logger *logger.Logger, source string, ff Flags) map[string]in

// Store the new version of the flag
newFlag.Source = source
f.mx.Lock()
f.Flags[k] = newFlag
f.mx.Unlock()
}

return notifications
Expand All @@ -56,14 +63,18 @@ func (f Flags) Update(logger *logger.Logger, source string, ff Flags) map[string
notifications := map[string]interface{}{}

for k, flag := range ff.Flags {
if storedFlag, ok := f.Flags[k]; !ok {
f.mx.RLock()
storedFlag, ok := f.Flags[k]
f.mx.RUnlock()
if !ok {
logger.Warn(
fmt.Sprintf("failed to update the flag, flag with key %s from source %s does not exisit.",
fmt.Sprintf("failed to update the flag, flag with key %s from source %s does not exist.",
k,
source))

continue
} else if storedFlag.Source != source {
}
if storedFlag.Source != source {
logger.Warn(fmt.Sprintf(
"flag with key %s from source %s already exist, overriding this with flag from source %s",
k,
Expand All @@ -78,7 +89,9 @@ func (f Flags) Update(logger *logger.Logger, source string, ff Flags) map[string
}

flag.Source = source
f.mx.Lock()
f.Flags[k] = flag
f.mx.Unlock()
}

return notifications
Expand All @@ -89,13 +102,18 @@ func (f Flags) Delete(logger *logger.Logger, source string, ff Flags) map[string
notifications := map[string]interface{}{}

for k := range ff.Flags {
if _, ok := f.Flags[k]; ok {
f.mx.RLock()
_, ok := f.Flags[k]
f.mx.RUnlock()
if ok {
notifications[k] = map[string]interface{}{
"type": string(NotificationDelete),
"source": source,
}

f.mx.Lock()
delete(f.Flags, k)
f.mx.Unlock()
} else {
logger.Warn(
fmt.Sprintf("failed to remove flag, flag with key %s from source %s does not exisit.",
Expand All @@ -111,6 +129,7 @@ func (f Flags) Delete(logger *logger.Logger, source string, ff Flags) map[string
func (f Flags) Merge(logger *logger.Logger, source string, ff Flags) map[string]interface{} {
notifications := map[string]interface{}{}

f.mx.Lock()
for k, v := range f.Flags {
if v.Source == source {
if _, ok := ff.Flags[k]; !ok {
Expand All @@ -124,11 +143,14 @@ func (f Flags) Merge(logger *logger.Logger, source string, ff Flags) map[string]
}
}
}
f.mx.Unlock()

for k, newFlag := range ff.Flags {
newFlag.Source = source

f.mx.RLock()
storedFlag, ok := f.Flags[k]
f.mx.RUnlock()
if !ok {
notifications[k] = map[string]interface{}{
"type": string(NotificationCreate),
Expand All @@ -151,8 +173,10 @@ func (f Flags) Merge(logger *logger.Logger, source string, ff Flags) map[string]
}
}

f.mx.Lock()
// Store the new version of the flag
f.Flags[k] = newFlag
f.mx.Unlock()
}

return notifications
Expand Down
Loading