Skip to content

Commit

Permalink
refactor: improve rpc message (#3115)
Browse files Browse the repository at this point in the history
* refactor: unify capabilities constraints message structure

This commit unifies the capabilities constraints message structure by combining
the minVersion constraint and CapabilityConstraints into a single Constraints
message field. The `NewCapabilitiesWithConstraints` method has been removed
for consistency with how the minVersion constraint is handled, allowing the
constraints fields to be private.

* fix(ai): prevent nil error crash

Ensure backward compatibility by handling cases where the Constraints
field is missing in the capabilities request message. This prevents crashes
when a Gateway with older software calls the updated orchestrator.
  • Loading branch information
rickstaa authored Aug 9, 2024
1 parent c336dd1 commit 350556d
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 206 deletions.
15 changes: 8 additions & 7 deletions cmd/livepeer/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,7 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
}

var aiCaps []core.Capability
capabilityConstraints := make(map[core.Capability]*core.PerCapabilityConstraints)
capabilityConstraints := make(core.PerCapabilityConstraints)

if *cfg.AIWorker {
gpus := []string{}
Expand Down Expand Up @@ -1159,7 +1159,7 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
_, ok := capabilityConstraints[core.Capability_TextToImage]
if !ok {
aiCaps = append(aiCaps, core.Capability_TextToImage)
capabilityConstraints[core.Capability_TextToImage] = &core.PerCapabilityConstraints{
capabilityConstraints[core.Capability_TextToImage] = &core.CapabilityConstraints{
Models: make(map[string]*core.ModelConstraint),
}
}
Expand All @@ -1173,7 +1173,7 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
_, ok := capabilityConstraints[core.Capability_ImageToImage]
if !ok {
aiCaps = append(aiCaps, core.Capability_ImageToImage)
capabilityConstraints[core.Capability_ImageToImage] = &core.PerCapabilityConstraints{
capabilityConstraints[core.Capability_ImageToImage] = &core.CapabilityConstraints{
Models: make(map[string]*core.ModelConstraint),
}
}
Expand All @@ -1187,7 +1187,7 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
_, ok := capabilityConstraints[core.Capability_ImageToVideo]
if !ok {
aiCaps = append(aiCaps, core.Capability_ImageToVideo)
capabilityConstraints[core.Capability_ImageToVideo] = &core.PerCapabilityConstraints{
capabilityConstraints[core.Capability_ImageToVideo] = &core.CapabilityConstraints{
Models: make(map[string]*core.ModelConstraint),
}
}
Expand All @@ -1201,7 +1201,7 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
_, ok := capabilityConstraints[core.Capability_Upscale]
if !ok {
aiCaps = append(aiCaps, core.Capability_Upscale)
capabilityConstraints[core.Capability_Upscale] = &core.PerCapabilityConstraints{
capabilityConstraints[core.Capability_Upscale] = &core.CapabilityConstraints{
Models: make(map[string]*core.ModelConstraint),
}
}
Expand All @@ -1215,7 +1215,7 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
_, ok := capabilityConstraints[core.Capability_AudioToText]
if !ok {
aiCaps = append(aiCaps, core.Capability_AudioToText)
capabilityConstraints[core.Capability_AudioToText] = &core.PerCapabilityConstraints{
capabilityConstraints[core.Capability_AudioToText] = &core.CapabilityConstraints{
Models: make(map[string]*core.ModelConstraint),
}
}
Expand Down Expand Up @@ -1405,7 +1405,8 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
*cfg.CliAddr = defaultAddr(*cfg.CliAddr, "127.0.0.1", TranscoderCliPort)
}

n.Capabilities = core.NewCapabilitiesWithConstraints(append(transcoderCaps, aiCaps...), core.MandatoryOCapabilities(), core.Constraints{}, capabilityConstraints)
n.Capabilities = core.NewCapabilities(append(transcoderCaps, aiCaps...), nil)
n.Capabilities.SetPerCapabilityConstraints(capabilityConstraints)
if cfg.OrchMinLivepeerVersion != nil {
n.Capabilities.SetMinVersionConstraint(*cfg.OrchMinLivepeerVersion)
}
Expand Down
71 changes: 38 additions & 33 deletions core/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,21 @@ type ModelConstraint struct {
type Capability int
type CapabilityString []uint64
type Constraints struct {
minVersion string
minVersion string
perCapability PerCapabilityConstraints
}
type PerCapabilityConstraints struct {
type CapabilityConstraints struct {
// Models contains a *ModelConstraint for each supported model ID
Models ModelConstraints
}
type CapabilityConstraints map[Capability]*PerCapabilityConstraints
type PerCapabilityConstraints map[Capability]*CapabilityConstraints
type Capabilities struct {
bitstring CapabilityString
mandatories CapabilityString
version string
constraints Constraints
capabilityConstraints CapabilityConstraints
capacities map[Capability]int
mutex sync.Mutex
bitstring CapabilityString
mandatories CapabilityString
version string
constraints Constraints
capacities map[Capability]int
mutex sync.Mutex
}
type CapabilityTest struct {
inVideoData []byte
Expand Down Expand Up @@ -239,7 +239,7 @@ func (c1 CapabilityString) CompatibleWith(c2 CapabilityString) bool {
return true
}

func (c1 CapabilityConstraints) CompatibleWith(c2 CapabilityConstraints) bool {
func (c1 PerCapabilityConstraints) CompatibleWith(c2 PerCapabilityConstraints) bool {
for c1Cap, c1Constraints := range c1 {
c2Constraints, ok := c2[c1Cap]
if !ok {
Expand All @@ -255,7 +255,7 @@ func (c1 CapabilityConstraints) CompatibleWith(c2 CapabilityConstraints) bool {
return true
}

func (c1 *PerCapabilityConstraints) CompatibleWith(c2 *PerCapabilityConstraints) bool {
func (c1 *CapabilityConstraints) CompatibleWith(c2 *CapabilityConstraints) bool {
return c1.Models.CompatibleWith(c2.Models)
}

Expand Down Expand Up @@ -453,8 +453,8 @@ func (bcast *Capabilities) CompatibleWith(orch *net.Capabilities) bool {
return false
}

orchCapabilityConstraints := CapabilitiesFromNetCapabilities(orch).capabilityConstraints
if !bcast.capabilityConstraints.CompatibleWith(orchCapabilityConstraints) {
orchCapabilityConstraints := CapabilitiesFromNetCapabilities(orch).constraints.perCapability
if !bcast.constraints.perCapability.CompatibleWith(orchCapabilityConstraints) {
return false
}

Expand All @@ -467,19 +467,19 @@ func (c *Capabilities) ToNetCapabilities() *net.Capabilities {
}
c.mutex.Lock()
defer c.mutex.Unlock()
netCaps := &net.Capabilities{Bitstring: c.bitstring, Mandatories: c.mandatories, Version: c.version, Capacities: make(map[uint32]uint32), Constraints: &net.Capabilities_Constraints{MinVersion: c.constraints.minVersion}, CapabilityConstraints: make(map[uint32]*net.Capabilities_CapabilityConstraints)}
netCaps := &net.Capabilities{Bitstring: c.bitstring, Mandatories: c.mandatories, Version: c.version, Capacities: make(map[uint32]uint32), Constraints: &net.Capabilities_Constraints{MinVersion: c.constraints.minVersion, PerCapability: make(map[uint32]*net.Capabilities_CapabilityConstraints)}}
for capability, capacity := range c.capacities {
netCaps.Capacities[uint32(capability)] = uint32(capacity)
}
for capability, constraints := range c.capabilityConstraints {
for capability, constraints := range c.constraints.perCapability {
models := make(map[string]*net.Capabilities_CapabilityConstraints_ModelConstraint)
for modelID, modelConstraint := range constraints.Models {
models[modelID] = &net.Capabilities_CapabilityConstraints_ModelConstraint{
Warm: modelConstraint.Warm,
}
}

netCaps.CapabilityConstraints[uint32(capability)] = &net.Capabilities_CapabilityConstraints{
netCaps.Constraints.PerCapability[uint32(capability)] = &net.Capabilities_CapabilityConstraints{
Models: models,
}
}
Expand All @@ -491,12 +491,11 @@ func CapabilitiesFromNetCapabilities(caps *net.Capabilities) *Capabilities {
return nil
}
coreCaps := &Capabilities{
bitstring: caps.Bitstring,
mandatories: caps.Mandatories,
capacities: make(map[Capability]int),
version: caps.Version,
constraints: Constraints{minVersion: caps.Constraints.GetMinVersion()},
capabilityConstraints: make(CapabilityConstraints),
bitstring: caps.Bitstring,
mandatories: caps.Mandatories,
capacities: make(map[Capability]int),
version: caps.Version,
constraints: Constraints{minVersion: caps.Constraints.GetMinVersion(), perCapability: make(PerCapabilityConstraints)},
}
if caps.Capacities == nil || len(caps.Capacities) == 0 {
// build capacities map if not present (struct received from previous versions)
Expand All @@ -514,13 +513,13 @@ func CapabilitiesFromNetCapabilities(caps *net.Capabilities) *Capabilities {
}
}

for capabilityInt, constraints := range caps.CapabilityConstraints {
for capabilityInt, constraints := range caps.Constraints.PerCapability {
models := make(map[string]*ModelConstraint)
for modelID, modelConstraint := range constraints.Models {
models[modelID] = &ModelConstraint{Warm: modelConstraint.Warm}
}

coreCaps.capabilityConstraints[Capability(capabilityInt)] = &PerCapabilityConstraints{
coreCaps.constraints.perCapability[Capability(capabilityInt)] = &CapabilityConstraints{
Models: models,
}
}
Expand All @@ -529,7 +528,7 @@ func CapabilitiesFromNetCapabilities(caps *net.Capabilities) *Capabilities {
}

func NewCapabilities(caps []Capability, m []Capability) *Capabilities {
c := &Capabilities{capacities: make(map[Capability]int), version: LivepeerVersion, capabilityConstraints: make(CapabilityConstraints)}
c := &Capabilities{capacities: make(map[Capability]int), constraints: Constraints{perCapability: make(PerCapabilityConstraints)}, version: LivepeerVersion}
if len(caps) > 0 {
c.bitstring = NewCapabilityString(caps)
// initialize capacities to 1 by default, mandatory capabilities doesn't have capacities
Expand All @@ -543,13 +542,6 @@ func NewCapabilities(caps []Capability, m []Capability) *Capabilities {
return c
}

func NewCapabilitiesWithConstraints(caps []Capability, m []Capability, constraints Constraints, capabilityConstraints CapabilityConstraints) *Capabilities {
c := NewCapabilities(caps, m)
c.constraints = constraints
c.capabilityConstraints = capabilityConstraints
return c
}

func (cap *Capabilities) AddCapacity(newCaps *Capabilities) {
cap.mutex.Lock()
defer cap.mutex.Unlock()
Expand Down Expand Up @@ -723,6 +715,19 @@ func (bcast *Capabilities) LegacyOnly() bool {
return bcast.bitstring.CompatibleWith(legacyCapabilityString)
}

func (bcast *Capabilities) SetPerCapabilityConstraints(constraints PerCapabilityConstraints) {
if bcast != nil {
bcast.constraints.perCapability = constraints
}
}

func (bcast *Capabilities) PerCapability() PerCapabilityConstraints {
if bcast != nil {
return bcast.constraints.perCapability
}
return nil
}

func (bcast *Capabilities) SetMinVersionConstraint(minVersionConstraint string) {
if bcast != nil {
bcast.constraints.minVersion = minVersionConstraint
Expand Down
30 changes: 16 additions & 14 deletions core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,21 +347,23 @@ func (orch *orchestrator) priceInfo(sender ethcommon.Address, manifestID Manifes
}
} else {
// The base price is the sum of the prices of individual capability + model ID pairs
for cap := range caps.Capacities {
// If the capability does not have constraints (and thus any model constraints) skip it
// because we only price a capability together with a model ID right now
constraints, ok := caps.CapabilityConstraints[cap]
if !ok {
continue
}
for modelID := range constraints.Models {
price := orch.node.GetBasePriceForCap(sender.String(), Capability(cap), modelID)
if price == nil {
price = orch.node.GetBasePriceForCap("default", Capability(cap), modelID)
if caps.Constraints != nil && caps.Constraints.PerCapability != nil {
for cap := range caps.Capacities {
// If the capability does not have constraints (and thus any model constraints) skip it
// because we only price a capability together with a model ID right now
constraints, ok := caps.Constraints.PerCapability[cap]
if !ok {
continue
}

if price != nil {
basePrice.Add(basePrice, price)
for modelID := range constraints.Models {
price := orch.node.GetBasePriceForCap(sender.String(), Capability(cap), modelID)
if price == nil {
price = orch.node.GetBasePriceForCap("default", Capability(cap), modelID)
}

if price != nil {
basePrice.Add(basePrice, price)
}
}
}
}
Expand Down
Loading

0 comments on commit 350556d

Please sign in to comment.