Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for JSON string in defaultImage field for overriding pipeline specific images in the mappings #293

Merged
merged 12 commits into from
Jan 28, 2025
Merged
50 changes: 47 additions & 3 deletions worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,57 @@ var containerHostPorts = map[string]string{
}

// Mapping for per pipeline container images.
var defaultBaseImage = "livepeer/ai-runner:latest"
var pipelineToImage = map[string]string{
"segment-anything-2": "livepeer/ai-runner:segment-anything-2",
"text-to-speech": "livepeer/ai-runner:text-to-speech",
"audio-to-text": "livepeer/ai-runner:audio-to-text",
"llm": "livepeer/ai-runner:llm",
}

var livePipelineToImage = map[string]string{
"streamdiffusion": "livepeer/ai-runner:live-app-streamdiffusion",
"comfyui": "livepeer/ai-runner:live-app-comfyui",
"segment_anything_2": "livepeer/ai-runner:live-app-segment_anything_2",
"noop": "livepeer/ai-runner:live-app-noop",
}

// overridePipelineImages updates base and pipeline images with the provided overrides.
func overridePipelineImages(imageOverrides string) error {
if imageOverrides == "" {
return fmt.Errorf("empty string is not a valid image override")
}

// Handle JSON format for multiple pipeline images.
var imageMap map[string]string
if err := json.Unmarshal([]byte(imageOverrides), &imageMap); err == nil {
rickstaa marked this conversation as resolved.
Show resolved Hide resolved
for pipeline, image := range imageMap {
if pipeline == "base" {
defaultBaseImage = image
continue
}

// Check and update the pipeline images.
if _, exists := pipelineToImage[pipeline]; exists {
pipelineToImage[pipeline] = image
} else if _, exists := livePipelineToImage[pipeline]; exists {
livePipelineToImage[pipeline] = image
rickstaa marked this conversation as resolved.
Show resolved Hide resolved
} else {
return fmt.Errorf("can't override docker image for unknown pipeline: %s", pipeline)
}
}
return nil
}

// Check for invalid docker image string.
if strings.ContainsAny(imageOverrides, "{}[]\",") {
return fmt.Errorf("invalid JSON format for image overrides")
}

// Update the base image.
defaultBaseImage = imageOverrides
return nil
}

// DockerClient is an interface for the Docker client, allowing for mocking in tests.
// NOTE: ensure any docker.Client methods used in this package are added.
type DockerClient interface {
Expand Down Expand Up @@ -103,16 +140,23 @@ type DockerManager struct {
mu *sync.Mutex
}

func NewDockerManager(defaultImage string, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) {
func NewDockerManager(imageOverrides string, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) {
ctx, cancel := context.WithTimeout(context.Background(), containerTimeout)
if err := removeExistingContainers(ctx, client); err != nil {
cancel()
return nil, err
}
cancel()

// Override pipeline images if provided.
if imageOverrides != "" {
if err := overridePipelineImages(imageOverrides); err != nil {
return nil, err
}
}

manager := &DockerManager{
defaultImage: defaultImage,
defaultImage: defaultBaseImage,
gpus: gpus,
modelDir: modelDir,
dockerClient: client,
Expand Down
117 changes: 117 additions & 0 deletions worker/docker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,123 @@ func createDockerManager(mockDockerClient *MockDockerClient) *DockerManager {
}
}

// copyMap returns a deep copy of the given map.
func copyMap(m map[string]string) map[string]string {
copy := make(map[string]string)
for k, v := range m {
copy[k] = v
}
return copy
}

func TestOverridePipelineImages(t *testing.T) {
// Store the original values of the maps.
originalDefaultBaseImage := defaultBaseImage
originalPipelineToImage := copyMap(pipelineToImage)
originalLivePipelineToImage := copyMap(livePipelineToImage)

tests := []struct {
name string
inputJSON string
expectedBase string
expectedPipelineImages map[string]string
expectedLiveImages map[string]string
expectError bool
}{
{
name: "ValidPipelineOverrides",
inputJSON: `{"segment-anything-2": "custom-image:1.0", "text-to-speech": "speech-image:2.0"}`,
expectedBase: originalDefaultBaseImage,
expectedPipelineImages: map[string]string{
"segment-anything-2": "custom-image:1.0",
"text-to-speech": "speech-image:2.0",
"audio-to-text": originalPipelineToImage["audio-to-text"],
},
expectedLiveImages: originalLivePipelineToImage,
expectError: false,
},
{
name: "OverrideBaseImage",
inputJSON: "new-base-image:latest",
expectedBase: "new-base-image:latest",
expectedPipelineImages: originalPipelineToImage,
expectedLiveImages: originalLivePipelineToImage,
expectError: false,
},
{
name: "OverrideBaseImageJSON",
inputJSON: `{"base": "new-base-image:latest"}`,
expectedBase: "new-base-image:latest",
expectedPipelineImages: originalPipelineToImage,
expectedLiveImages: originalLivePipelineToImage,
expectError: false,
},
{
name: "EmptyJSON",
inputJSON: `{}`,
expectedBase: originalDefaultBaseImage,
expectedPipelineImages: originalPipelineToImage,
expectedLiveImages: originalLivePipelineToImage,
expectError: false,
},
{
name: "MalformedJSON",
inputJSON: `{"segment-anything-2": "custom-image:1.0"`,
expectedBase: originalDefaultBaseImage,
expectedPipelineImages: originalPipelineToImage,
expectedLiveImages: originalLivePipelineToImage,
expectError: true,
},
{
name: "EmptyString",
inputJSON: "",
expectedBase: originalDefaultBaseImage,
expectedPipelineImages: originalPipelineToImage,
expectedLiveImages: originalLivePipelineToImage,
expectError: true,
},
{
name: "UnknownPipeline",
inputJSON: `{"unknown-pipeline": "unknown-image:latest"}`,
expectedBase: originalDefaultBaseImage,
expectedPipelineImages: originalPipelineToImage,
expectedLiveImages: originalLivePipelineToImage,
expectError: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Register a cleanup function to reset state after the subtest.
t.Cleanup(func() {
defaultBaseImage = originalDefaultBaseImage
pipelineToImage = copyMap(originalPipelineToImage)
livePipelineToImage = copyMap(originalLivePipelineToImage)
})

// Call overridePipelineImages function with the mock data.
err := overridePipelineImages(tt.inputJSON)

if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tt.expectedBase, defaultBaseImage)

// Verify the expected pipeline images.
for pipeline, expectedImage := range tt.expectedPipelineImages {
require.Equal(t, expectedImage, pipelineToImage[pipeline])
}

// Verify the expected live pipeline images.
for livePipeline, expectedImage := range tt.expectedLiveImages {
require.Equal(t, expectedImage, livePipelineToImage[livePipeline])
}
}
})
}
}

func TestNewDockerManager(t *testing.T) {
mockDockerClient := new(MockDockerClient)

Expand Down
4 changes: 2 additions & 2 deletions worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ type Worker struct {
mu *sync.Mutex
}

func NewWorker(defaultImage string, gpus []string, modelDir string) (*Worker, error) {
func NewWorker(imageOverrides string, gpus []string, modelDir string) (*Worker, error) {
dockerClient, err := docker.NewClientWithOpts(docker.FromEnv, docker.WithAPIVersionNegotiation())
if err != nil {
return nil, err
}

manager, err := NewDockerManager(defaultImage, gpus, modelDir, dockerClient)
manager, err := NewDockerManager(imageOverrides, gpus, modelDir, dockerClient)
if err != nil {
return nil, err
}
Expand Down
Loading