diff --git a/pkg/resourceinterpreter/configurableinterpreter/configurable.go b/pkg/resourceinterpreter/configurableinterpreter/configurable.go index c8a3eea711ce..99340651a00d 100644 --- a/pkg/resourceinterpreter/configurableinterpreter/configurable.go +++ b/pkg/resourceinterpreter/configurableinterpreter/configurable.go @@ -17,6 +17,7 @@ import ( type ConfigurableInterpreter struct { // configManager caches all ResourceInterpreterCustomizations. configManager configmanager.ConfigManager + luaVM *luavm.VM } // NewConfigurableInterpreter builds a new interpreter by registering the @@ -24,6 +25,8 @@ type ConfigurableInterpreter struct { func NewConfigurableInterpreter(informer genericmanager.SingleClusterInformerManager) *ConfigurableInterpreter { return &ConfigurableInterpreter{ configManager: configmanager.NewInterpreterConfigManager(informer), + // TODO: set an appropriate pool size. + luaVM: luavm.New(false, 10), } } @@ -40,8 +43,7 @@ func (c *ConfigurableInterpreter) GetReplicas(object *unstructured.Unstructured) if !enabled { return } - vm := luavm.VM{UseOpenLibs: false} - replicas, requires, err = vm.GetReplicas(object, luaScript) + replicas, requires, err = c.luaVM.GetReplicas(object, luaScript) return } @@ -52,8 +54,7 @@ func (c *ConfigurableInterpreter) ReviseReplica(object *unstructured.Unstructure if !enabled { return } - vm := luavm.VM{UseOpenLibs: false} - revised, err = vm.ReviseReplica(object, replica, luaScript) + revised, err = c.luaVM.ReviseReplica(object, replica, luaScript) return } @@ -64,8 +65,7 @@ func (c *ConfigurableInterpreter) Retain(desired *unstructured.Unstructured, obs if !enabled { return } - vm := luavm.VM{UseOpenLibs: false} - retained, err = vm.Retain(desired, observed, luaScript) + retained, err = c.luaVM.Retain(desired, observed, luaScript) return } @@ -76,8 +76,7 @@ func (c *ConfigurableInterpreter) AggregateStatus(object *unstructured.Unstructu if !enabled { return } - vm := luavm.VM{UseOpenLibs: false} - status, err = vm.AggregateStatus(object, aggregatedStatusItems, luaScript) + status, err = c.luaVM.AggregateStatus(object, aggregatedStatusItems, luaScript) return } @@ -88,8 +87,7 @@ func (c *ConfigurableInterpreter) GetDependencies(object *unstructured.Unstructu if !enabled { return } - vm := luavm.VM{UseOpenLibs: false} - dependencies, err = vm.GetDependencies(object, luaScript) + dependencies, err = c.luaVM.GetDependencies(object, luaScript) return } @@ -100,8 +98,7 @@ func (c *ConfigurableInterpreter) ReflectStatus(object *unstructured.Unstructure if !enabled { return } - vm := luavm.VM{UseOpenLibs: false} - status, err = vm.ReflectStatus(object, luaScript) + status, err = c.luaVM.ReflectStatus(object, luaScript) return } @@ -112,8 +109,7 @@ func (c *ConfigurableInterpreter) InterpretHealth(object *unstructured.Unstructu if !enabled { return } - vm := luavm.VM{UseOpenLibs: false} - health, err = vm.InterpretHealth(object, luaScript) + health, err = c.luaVM.InterpretHealth(object, luaScript) return } diff --git a/pkg/resourceinterpreter/configurableinterpreter/luavm/lua.go b/pkg/resourceinterpreter/configurableinterpreter/luavm/lua.go index 2c3178d4c370..5b90db2a053d 100644 --- a/pkg/resourceinterpreter/configurableinterpreter/luavm/lua.go +++ b/pkg/resourceinterpreter/configurableinterpreter/luavm/lua.go @@ -15,6 +15,7 @@ import ( configv1alpha1 "github.com/karmada-io/karmada/pkg/apis/config/v1alpha1" workv1alpha2 "github.com/karmada-io/karmada/pkg/apis/work/v1alpha2" + "github.com/karmada-io/karmada/pkg/util/fixedpool" "github.com/karmada-io/karmada/pkg/util/lifted" ) @@ -22,46 +23,104 @@ import ( type VM struct { // UseOpenLibs flag to enable open libraries. Libraries are disabled by default while running, but enabled during testing to allow the use of print statements. UseOpenLibs bool + Pool *fixedpool.FixedPool } -// GetReplicas returns the desired replicas of the object as well as the requirements of each replica by lua script. -func (vm VM) GetReplicas(obj *unstructured.Unstructured, script string) (replica int32, requires *workv1alpha2.ReplicaRequirements, err error) { +// New creates a manager for lua VM +func New(useOpenLibs bool, poolSize int) *VM { + vm := &VM{ + UseOpenLibs: useOpenLibs, + } + vm.Pool = fixedpool.New( + func() (any, error) { return vm.NewLuaState() }, + func(a any) { a.(*lua.LState).Close() }, + poolSize) + return vm +} + +// NewLuaState creates a new lua state. +func (vm *VM) NewLuaState() (*lua.LState, error) { l := lua.NewState(lua.Options{ SkipOpenLibs: !vm.UseOpenLibs, }) - defer l.Close() // Opens table library to allow access to functions to manipulate tables - err = vm.setLib(l) + err := vm.setLib(l) if err != nil { - return 0, nil, err + return nil, err } // preload our 'safe' version of the OS library. Allows the 'local os = require("os")' to work l.PreloadModule(lua.OsLibName, lifted.SafeOsLoader) + return l, err +} + +// RunScript got a lua vm from pool, and execute script with given arguments. +func (vm *VM) RunScript(script string, fnName string, nRets int, args ...interface{}) ([]lua.LValue, error) { + a, err := vm.Pool.Get() + if err != nil { + return nil, err + } + defer vm.Pool.Put(a) - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + l := a.(*lua.LState) + l.Pop(l.GetTop()) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() l.SetContext(ctx) err = l.DoString(script) - f := l.GetGlobal("GetReplicas") + if err != nil { + return nil, err + } + vArgs := make([]lua.LValue, len(args)) + for i, arg := range args { + vArgs[i], err = decodeValue(l, arg) + if err != nil { + return nil, err + } + } + + f := l.GetGlobal(fnName) if f.Type() == lua.LTNil { - return 0, nil, fmt.Errorf("can't get function GetReplicas, please check the lua script") + return nil, fmt.Errorf("not found function %v", fnName) + } + if f.Type() != lua.LTFunction { + return nil, fmt.Errorf("%s is not a function: %s", fnName, f.Type()) } - args := make([]lua.LValue, 1) - args[0], err = decodeValue(l, obj.Object) + err = l.CallByParam(lua.P{ + Fn: f, + NRet: nRets, + Protect: true, + }, vArgs...) if err != nil { - return + return nil, err } - err = l.CallByParam(lua.P{Fn: f, NRet: 2, Protect: true}, args...) + + // get rets from stack: [ret1, ret2, ret3 ...] + rets := make([]lua.LValue, nRets) + for i := range rets { + rets[i] = l.Get(i + 1) + } + // pop all the values in stack + l.Pop(l.GetTop()) + return rets, nil +} + +// GetReplicas returns the desired replicas of the object as well as the requirements of each replica by lua script. +func (vm *VM) GetReplicas(obj *unstructured.Unstructured, script string) (replica int32, requires *workv1alpha2.ReplicaRequirements, err error) { + results, err := vm.RunScript(script, "GetReplicas", 2, obj) if err != nil { return 0, nil, err } - replicaRequirementResult := l.Get(l.GetTop()) - l.Pop(1) + replica, err = ConvertLuaResultToInt(results[0]) + if err != nil { + return 0, nil, err + } + replicaRequirementResult := results[1] requires = &workv1alpha2.ReplicaRequirements{} if replicaRequirementResult.Type() == lua.LTTable { err = ConvertLuaResultInto(replicaRequirementResult, requires) @@ -75,56 +134,17 @@ func (vm VM) GetReplicas(obj *unstructured.Unstructured, script string) (replica return 0, nil, fmt.Errorf("expect the returned requires type is table but got %s", replicaRequirementResult.Type()) } - luaReplica := l.Get(l.GetTop()) - replica, err = ConvertLuaResultToInt(luaReplica) - if err != nil { - return 0, nil, err - } return } // ReviseReplica revises the replica of the given object by lua. -func (vm VM) ReviseReplica(object *unstructured.Unstructured, replica int64, script string) (*unstructured.Unstructured, error) { - l := lua.NewState(lua.Options{ - SkipOpenLibs: !vm.UseOpenLibs, - }) - defer l.Close() - // Opens table library to allow access to functions to manipulate tables - err := vm.setLib(l) +func (vm *VM) ReviseReplica(object *unstructured.Unstructured, replica int64, script string) (*unstructured.Unstructured, error) { + results, err := vm.RunScript(script, "ReviseReplica", 1, object, replica) if err != nil { return nil, err } - // preload our 'safe' version of the OS library. Allows the 'local os = require("os")' to work - l.PreloadModule(lua.OsLibName, lifted.SafeOsLoader) - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - l.SetContext(ctx) - - err = l.DoString(script) - if err != nil { - return nil, err - } - reviseReplicaLuaFunc := l.GetGlobal("ReviseReplica") - if reviseReplicaLuaFunc.Type() == lua.LTNil { - return nil, fmt.Errorf("can't get function ReviseReplica, please check the lua script") - } - - args := make([]lua.LValue, 2) - args[0], err = decodeValue(l, object.Object) - if err != nil { - return nil, err - } - args[1], err = decodeValue(l, replica) - if err != nil { - return nil, err - } - err = l.CallByParam(lua.P{Fn: reviseReplicaLuaFunc, NRet: 1, Protect: true}, args...) - if err != nil { - return nil, err - } - - luaResult := l.Get(l.GetTop()) + luaResult := results[0] reviseReplicaResult := &unstructured.Unstructured{} if luaResult.Type() == lua.LTTable { err := ConvertLuaResultInto(luaResult, reviseReplicaResult) @@ -137,7 +157,7 @@ func (vm VM) ReviseReplica(object *unstructured.Unstructured, replica int64, scr return nil, fmt.Errorf("expect the returned requires type is table but got %s", luaResult.Type()) } -func (vm VM) setLib(l *lua.LState) error { +func (vm *VM) setLib(l *lua.LState) error { for _, pair := range []struct { n string f lua.LGFunction @@ -160,47 +180,13 @@ func (vm VM) setLib(l *lua.LState) error { } // Retain returns the objects that based on the "desired" object but with values retained from the "observed" object by lua. -func (vm VM) Retain(desired *unstructured.Unstructured, observed *unstructured.Unstructured, script string) (retained *unstructured.Unstructured, err error) { - l := lua.NewState(lua.Options{ - SkipOpenLibs: !vm.UseOpenLibs, - }) - defer l.Close() - // Opens table library to allow access to functions to manipulate tables - err = vm.setLib(l) - if err != nil { - return nil, err - } - // preload our 'safe' version of the OS library. Allows the 'local os = require("os")' to work - l.PreloadModule(lua.OsLibName, lifted.SafeOsLoader) - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - l.SetContext(ctx) - - err = l.DoString(script) - if err != nil { - return nil, err - } - retainLuaFunc := l.GetGlobal("Retain") - if retainLuaFunc.Type() == lua.LTNil { - return nil, fmt.Errorf("can't get function Retatin, please check the lua script") - } - - args := make([]lua.LValue, 2) - args[0], err = decodeValue(l, desired.Object) - if err != nil { - return - } - args[1], err = decodeValue(l, observed.Object) - if err != nil { - return - } - err = l.CallByParam(lua.P{Fn: retainLuaFunc, NRet: 1, Protect: true}, args...) +func (vm *VM) Retain(desired *unstructured.Unstructured, observed *unstructured.Unstructured, script string) (retained *unstructured.Unstructured, err error) { + results, err := vm.RunScript(script, "Retain", 1, desired, observed) if err != nil { return nil, err } - luaResult := l.Get(l.GetTop()) + luaResult := results[0] retainResult := &unstructured.Unstructured{} if luaResult.Type() == lua.LTTable { err := ConvertLuaResultInto(luaResult, retainResult) @@ -213,47 +199,13 @@ func (vm VM) Retain(desired *unstructured.Unstructured, observed *unstructured.U } // AggregateStatus returns the objects that based on the 'object' but with status aggregated by lua. -func (vm VM) AggregateStatus(object *unstructured.Unstructured, items []workv1alpha2.AggregatedStatusItem, script string) (*unstructured.Unstructured, error) { - l := lua.NewState(lua.Options{ - SkipOpenLibs: !vm.UseOpenLibs, - }) - defer l.Close() - // Opens table library to allow access to functions to manipulate tables - err := vm.setLib(l) - if err != nil { - return nil, err - } - // preload our 'safe' version of the OS library. Allows the 'local os = require("os")' to work - l.PreloadModule(lua.OsLibName, lifted.SafeOsLoader) - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - l.SetContext(ctx) - - err = l.DoString(script) - if err != nil { - return nil, err - } - - f := l.GetGlobal("AggregateStatus") - if f.Type() == lua.LTNil { - return nil, fmt.Errorf("can't get function AggregateStatus, please check the lua script") - } - args := make([]lua.LValue, 2) - args[0], err = decodeValue(l, object.Object) - if err != nil { - return nil, err - } - args[1], err = decodeValue(l, items) - if err != nil { - return nil, err - } - err = l.CallByParam(lua.P{Fn: f, NRet: 1, Protect: true}, args...) +func (vm *VM) AggregateStatus(object *unstructured.Unstructured, items []workv1alpha2.AggregatedStatusItem, script string) (*unstructured.Unstructured, error) { + results, err := vm.RunScript(script, "AggregateStatus", 1, object, items) if err != nil { return nil, err } - luaResult := l.Get(l.GetTop()) + luaResult := results[0] aggregateStatus := &unstructured.Unstructured{} if luaResult.Type() == lua.LTTable { err := ConvertLuaResultInto(luaResult, aggregateStatus) @@ -266,45 +218,14 @@ func (vm VM) AggregateStatus(object *unstructured.Unstructured, items []workv1al } // InterpretHealth returns the health state of the object by lua. -func (vm VM) InterpretHealth(object *unstructured.Unstructured, script string) (bool, error) { - l := lua.NewState(lua.Options{ - SkipOpenLibs: !vm.UseOpenLibs, - }) - defer l.Close() - // Opens table library to allow access to functions to manipulate tables - err := vm.setLib(l) - if err != nil { - return false, err - } - // preload our 'safe' version of the OS library. Allows the 'local os = require("os")' to work - l.PreloadModule(lua.OsLibName, lifted.SafeOsLoader) - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - l.SetContext(ctx) - - err = l.DoString(script) - if err != nil { - return false, err - } - f := l.GetGlobal("InterpretHealth") - if f.Type() == lua.LTNil { - return false, fmt.Errorf("can't get function InterpretHealth, please check the lua script") - } - - args := make([]lua.LValue, 1) - args[0], err = decodeValue(l, object.Object) - if err != nil { - return false, err - } - err = l.CallByParam(lua.P{Fn: f, NRet: 1, Protect: true}, args...) +func (vm *VM) InterpretHealth(object *unstructured.Unstructured, script string) (bool, error) { + results, err := vm.RunScript(script, "InterpretHealth", 1, object) if err != nil { return false, err } var health bool - luaResult := l.Get(l.GetTop()) - health, err = ConvertLuaResultToBool(luaResult) + health, err = ConvertLuaResultToBool(results[0]) if err != nil { return false, err } @@ -312,43 +233,13 @@ func (vm VM) InterpretHealth(object *unstructured.Unstructured, script string) ( } // ReflectStatus returns the status of the object by lua. -func (vm VM) ReflectStatus(object *unstructured.Unstructured, script string) (status *runtime.RawExtension, err error) { - l := lua.NewState(lua.Options{ - SkipOpenLibs: !vm.UseOpenLibs, - }) - defer l.Close() - // Opens table library to allow access to functions to manipulate tables - err = vm.setLib(l) - if err != nil { - return nil, err - } - // preload our 'safe' version of the OS library. Allows the 'local os = require("os")' to work - l.PreloadModule(lua.OsLibName, lifted.SafeOsLoader) - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - l.SetContext(ctx) - - err = l.DoString(script) +func (vm *VM) ReflectStatus(object *unstructured.Unstructured, script string) (status *runtime.RawExtension, err error) { + results, err := vm.RunScript(script, "ReflectStatus", 1, object) if err != nil { return nil, err } - f := l.GetGlobal("ReflectStatus") - if f.Type() == lua.LTNil { - return nil, fmt.Errorf("can't get function ReflectStatus, please check the lua script") - } - args := make([]lua.LValue, 1) - args[0], err = decodeValue(l, object.Object) - if err != nil { - return - } - err = l.CallByParam(lua.P{Fn: f, NRet: 1, Protect: true}, args...) - if err != nil { - return nil, err - } - luaStatusResult := l.Get(l.GetTop()) - l.Pop(1) + luaStatusResult := results[0] if luaStatusResult.Type() != lua.LTTable { return nil, fmt.Errorf("expect the returned replica type is table but got %s", luaStatusResult.Type()) } @@ -359,43 +250,13 @@ func (vm VM) ReflectStatus(object *unstructured.Unstructured, script string) (st } // GetDependencies returns the dependent resources of the given object by lua. -func (vm VM) GetDependencies(object *unstructured.Unstructured, script string) (dependencies []configv1alpha1.DependentObjectReference, err error) { - l := lua.NewState(lua.Options{ - SkipOpenLibs: !vm.UseOpenLibs, - }) - defer l.Close() - // Opens table library to allow access to functions to manipulate tables - err = vm.setLib(l) - if err != nil { - return nil, err - } - // preload our 'safe' version of the OS library. Allows the 'local os = require("os")' to work - l.PreloadModule(lua.OsLibName, lifted.SafeOsLoader) - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - l.SetContext(ctx) - - err = l.DoString(script) - if err != nil { - return nil, err - } - f := l.GetGlobal("GetDependencies") - if f.Type() == lua.LTNil { - return nil, fmt.Errorf("can't get function GetDependencies, please check the lua script") - } - - args := make([]lua.LValue, 1) - args[0], err = decodeValue(l, object.Object) - if err != nil { - return - } - err = l.CallByParam(lua.P{Fn: f, NRet: 1, Protect: true}, args...) +func (vm *VM) GetDependencies(object *unstructured.Unstructured, script string) (dependencies []configv1alpha1.DependentObjectReference, err error) { + results, err := vm.RunScript(script, "GetDependencies", 1, object) if err != nil { return nil, err } - luaResult := l.Get(l.GetTop()) + luaResult := results[0] if luaResult.Type() != lua.LTTable { return nil, fmt.Errorf("expect the returned requires type is table but got %s", luaResult.Type()) diff --git a/pkg/resourceinterpreter/configurableinterpreter/luavm/lua_test.go b/pkg/resourceinterpreter/configurableinterpreter/luavm/lua_test.go index df03749135c1..5059acb6ae1d 100644 --- a/pkg/resourceinterpreter/configurableinterpreter/luavm/lua_test.go +++ b/pkg/resourceinterpreter/configurableinterpreter/luavm/lua_test.go @@ -20,7 +20,7 @@ import ( func TestGetReplicas(t *testing.T) { var replicas int32 = 1 - vm := VM{UseOpenLibs: false} + vm := New(false, 1) tests := []struct { name string deploy *appsv1.Deployment @@ -150,7 +150,7 @@ func TestReviseDeploymentReplica(t *testing.T) { end`, }, } - vm := VM{UseOpenLibs: false} + vm := New(false, 1) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -217,7 +217,7 @@ func TestAggregateDeploymentStatus(t *testing.T) { end`, }, } - vm := VM{UseOpenLibs: false} + vm := New(false, 1) for _, tt := range tests { actualObj, _ := vm.AggregateStatus(tt.curObj, tt.aggregatedStatusItems, tt.luaScript) @@ -265,7 +265,7 @@ func TestHealthDeploymentStatus(t *testing.T) { end `, }, } - vm := VM{UseOpenLibs: false} + vm := New(false, 1) for _, tt := range tests { flag, err := vm.InterpretHealth(tt.curObj, tt.luaScript) @@ -349,7 +349,7 @@ func TestRetainDeployment(t *testing.T) { end`, }, } - vm := VM{UseOpenLibs: false} + vm := New(false, 1) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -397,7 +397,7 @@ func TestStatusReflection(t *testing.T) { }, } - vm := VM{UseOpenLibs: false} + vm := New(false, 1) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := vm.ReflectStatus(tt.args.object, tt.luaScript) @@ -470,7 +470,7 @@ func TestGetDeployPodDependencies(t *testing.T) { }, } - vm := VM{UseOpenLibs: false} + vm := New(false, 1) for _, tt := range tests { res, err := vm.GetDependencies(tt.curObj, tt.luaScript) diff --git a/pkg/util/fixedpool/fixedpool.go b/pkg/util/fixedpool/fixedpool.go new file mode 100644 index 000000000000..cf2d96635ff0 --- /dev/null +++ b/pkg/util/fixedpool/fixedpool.go @@ -0,0 +1,72 @@ +package fixedpool + +import ( + "sync" +) + +// A FixedPool like sync.Pool. But it's limited capacity. +// When pool is full, Put will abandon and call destroyFunc to destroy the object. +type FixedPool struct { + lock sync.Mutex + pool []any + capacity int + + newFunc func() (any, error) + destroyFunc func(any) +} + +// New return a FixedPool +func New(newFunc func() (any, error), destroyFunc func(any), capacity int) *FixedPool { + return &FixedPool{ + pool: make([]any, 0, capacity), + capacity: capacity, + newFunc: newFunc, + destroyFunc: destroyFunc, + } +} + +// Get selects an arbitrary item from the pool, removes it from the +// pool, and returns it to the caller. +// Get may choose to ignore the pool and treat it as empty. +// Callers should not assume any relation between values passed to Put and +// the values returned by Get. +// +// If pool is empty, Get returns the result of calling newFunc. +func (p *FixedPool) Get() (any, error) { + o, ok := p.pop() + if ok { + return o, nil + } + + return p.newFunc() +} + +// Put adds x to the pool. If pool is full, x will be abandoned, +// and it's destroy function will be called. +func (p *FixedPool) Put(x any) { + if p.push(x) { + return + } + p.destroyFunc(x) +} + +func (p *FixedPool) pop() (any, bool) { + p.lock.Lock() + defer p.lock.Unlock() + if s := len(p.pool); s > 0 { + o := p.pool[s-1] + p.pool = p.pool[:s-1] + return o, true + } + return nil, false +} + +func (p *FixedPool) push(o any) bool { + p.lock.Lock() + defer p.lock.Unlock() + if s := len(p.pool); s < p.capacity { + p.pool = append(p.pool, o) + return true + } + return false +} diff --git a/pkg/util/fixedpool/fixedpool_test.go b/pkg/util/fixedpool/fixedpool_test.go new file mode 100644 index 000000000000..d3c4d9d813b0 --- /dev/null +++ b/pkg/util/fixedpool/fixedpool_test.go @@ -0,0 +1,141 @@ +package fixedpool + +import ( + "testing" +) + +func TestFixedPool_Get(t *testing.T) { + type fields struct { + pool []any + capacity int + } + type want struct { + len int + } + tests := []struct { + name string + fields fields + want want + }{ + { + name: "poll is empty", + fields: fields{ + pool: []any{}, + capacity: 3, + }, + want: want{ + len: 0, + }, + }, + { + name: "poll is not empty", + fields: fields{ + pool: []any{1}, + capacity: 3, + }, + want: want{ + len: 0, + }, + }, + { + name: "poll is full", + fields: fields{ + pool: []any{1, 2, 3}, + capacity: 3, + }, + want: want{ + len: 2, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &FixedPool{ + pool: tt.fields.pool, + capacity: tt.fields.capacity, + newFunc: func() (any, error) { return &struct{}{}, nil }, + destroyFunc: func(a any) {}, + } + g, err := p.Get() + if err != nil { + t.Errorf("Get() returns error: %v", err) + return + } + if g == nil { + t.Errorf("Get() returns nil") + return + } + if got := len(p.pool); got != tt.want.len { + t.Errorf("Get() got = %v, want %v", got, tt.want.len) + } + }) + } +} + +func TestFixedPool_Put(t *testing.T) { + type fields struct { + pool []any + capacity int + } + type want struct { + len int + destroyed bool + } + tests := []struct { + name string + fields fields + want want + }{ + { + name: "pool is empty", + fields: fields{ + pool: nil, + capacity: 3, + }, + want: want{ + len: 1, + destroyed: false, + }, + }, + { + name: "pool is not empty", + fields: fields{ + pool: []any{1}, + capacity: 3, + }, + want: want{ + len: 2, + destroyed: false, + }, + }, + { + name: "pool is not full", + fields: fields{ + pool: []any{1, 2, 3}, + capacity: 3, + }, + want: want{ + len: 3, + destroyed: true, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + destroyed := false + p := &FixedPool{ + pool: tt.fields.pool, + capacity: tt.fields.capacity, + newFunc: func() (any, error) { return &struct{}{}, nil }, + destroyFunc: func(a any) { destroyed = true }, + } + p.Put(&struct{}{}) + if got := len(p.pool); got != tt.want.len { + t.Errorf("pool len got %v, want %v", got, tt.want) + } + if destroyed != tt.want.destroyed { + t.Errorf("destroyed got %v, want %v", destroyed, tt.want.destroyed) + } + }) + } +}