diff --git a/worker/docker.go b/worker/docker.go index 9e64a4af..263e18d9 100644 --- a/worker/docker.go +++ b/worker/docker.go @@ -56,14 +56,14 @@ var containerHostPorts = map[string]string{ "live-video-to-video": "8900", } -// Mapping for per pipeline container images. +// Default pipeline container image mapping to use if no overrides are provided. +var defaultBaseImage = "livepeer/ai-runner:latest" var pipelineToImage = map[string]string{ "segment-anything-2": "livepeer/ai-runner:segment-anything-2", "text-to-speech": "livepeer/ai-runner:text-to-speech", "audio-to-text": "livepeer/ai-runner:audio-to-text", "llm": "livepeer/ai-runner:llm", } - var livePipelineToImage = map[string]string{ "streamdiffusion": "livepeer/ai-runner:live-app-streamdiffusion", "comfyui": "livepeer/ai-runner:live-app-comfyui", @@ -71,6 +71,12 @@ var livePipelineToImage = map[string]string{ "noop": "livepeer/ai-runner:live-app-noop", } +type ImageOverrides struct { + Default string `json:"default"` + Batch map[string]string `json:"batch"` + Live map[string]string `json:"live"` +} + // DockerClient is an interface for the Docker client, allowing for mocking in tests. // NOTE: ensure any docker.Client methods used in this package are added. type DockerClient interface { @@ -91,9 +97,9 @@ var _ DockerClient = (*docker.Client)(nil) var dockerWaitUntilRunningFunc = dockerWaitUntilRunning type DockerManager struct { - defaultImage string - gpus []string - modelDir string + gpus []string + modelDir string + overrides ImageOverrides dockerClient DockerClient // gpu ID => container name @@ -103,7 +109,7 @@ type DockerManager struct { mu *sync.Mutex } -func NewDockerManager(defaultImage string, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) { +func NewDockerManager(overrides ImageOverrides, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) { ctx, cancel := context.WithTimeout(context.Background(), containerTimeout) if err := removeExistingContainers(ctx, client); err != nil { cancel() @@ -112,9 +118,9 @@ func NewDockerManager(defaultImage string, gpus []string, modelDir string, clien cancel() manager := &DockerManager{ - defaultImage: defaultImage, gpus: gpus, modelDir: modelDir, + overrides: overrides, dockerClient: client, gpuContainers: make(map[string]string), containers: make(map[string]*RunnerContainer), @@ -215,17 +221,24 @@ func (m *DockerManager) returnContainer(rc *RunnerContainer) { func (m *DockerManager) getContainerImageName(pipeline, modelID string) (string, error) { if pipeline == "live-video-to-video" { // We currently use the model ID as the live pipeline name for legacy reasons. - if image, ok := livePipelineToImage[modelID]; ok { + if image, ok := m.overrides.Live[modelID]; ok { + return image, nil + } else if image, ok := livePipelineToImage[modelID]; ok { return image, nil } return "", fmt.Errorf("no container image found for live pipeline %s", modelID) } - if image, ok := pipelineToImage[pipeline]; ok { + if image, ok := m.overrides.Batch[pipeline]; ok { + return image, nil + } else if image, ok := pipelineToImage[pipeline]; ok { return image, nil } - return m.defaultImage, nil + if m.overrides.Default != "" { + return m.overrides.Default, nil + } + return defaultBaseImage, nil } // HasCapacity checks if an unused managed container exists or if a GPU is available for a new container. diff --git a/worker/docker_test.go b/worker/docker_test.go index f615ecb3..8698e8f5 100644 --- a/worker/docker_test.go +++ b/worker/docker_test.go @@ -96,9 +96,9 @@ func NewMockServer() *MockServer { // createDockerManager creates a DockerManager with a mock DockerClient. func createDockerManager(mockDockerClient *MockDockerClient) *DockerManager { return &DockerManager{ - defaultImage: "default-image", gpus: []string{"gpu0"}, modelDir: "/models", + overrides: ImageOverrides{Default: "default-image"}, dockerClient: mockDockerClient, gpuContainers: make(map[string]string), containers: make(map[string]*RunnerContainer), @@ -110,10 +110,10 @@ func TestNewDockerManager(t *testing.T) { mockDockerClient := new(MockDockerClient) createAndVerifyManager := func() *DockerManager { - manager, err := NewDockerManager("default-image", []string{"gpu0"}, "/models", mockDockerClient) + manager, err := NewDockerManager(ImageOverrides{Default: "default-image"}, []string{"gpu0"}, "/models", mockDockerClient) require.NoError(t, err) require.NotNil(t, manager) - require.Equal(t, "default-image", manager.defaultImage) + require.Equal(t, "default-image", manager.overrides.Default) require.Equal(t, []string{"gpu0"}, manager.gpus) require.Equal(t, "/models", manager.modelDir) require.Equal(t, mockDockerClient, manager.dockerClient) @@ -301,10 +301,11 @@ func TestDockerManager_returnContainer(t *testing.T) { func TestDockerManager_getContainerImageName(t *testing.T) { mockDockerClient := new(MockDockerClient) - manager := createDockerManager(mockDockerClient) + dockerManager := createDockerManager(mockDockerClient) tests := []struct { name string + setup func(*DockerManager, *MockDockerClient) pipeline string modelID string expectedImage string @@ -312,6 +313,7 @@ func TestDockerManager_getContainerImageName(t *testing.T) { }{ { name: "live-video-to-video with valid modelID", + setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {}, pipeline: "live-video-to-video", modelID: "streamdiffusion", expectedImage: "livepeer/ai-runner:live-app-streamdiffusion", @@ -319,12 +321,14 @@ func TestDockerManager_getContainerImageName(t *testing.T) { }, { name: "live-video-to-video with invalid modelID", + setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {}, pipeline: "live-video-to-video", modelID: "invalid-model", expectError: true, }, { name: "valid pipeline", + setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {}, pipeline: "text-to-speech", modelID: "", expectedImage: "livepeer/ai-runner:text-to-speech", @@ -332,16 +336,95 @@ func TestDockerManager_getContainerImageName(t *testing.T) { }, { name: "invalid pipeline", + setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {}, pipeline: "invalid-pipeline", modelID: "", expectedImage: "default-image", expectError: false, }, + { + name: "override default image", + setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) { + dockerManager.overrides = ImageOverrides{ + Default: "custom-image", + } + }, + pipeline: "", + modelID: "", + expectedImage: "custom-image", + expectError: false, + }, + { + name: "override batch image", + setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) { + dockerManager.overrides = ImageOverrides{ + Batch: map[string]string{ + "text-to-speech": "custom-image", + }, + } + }, + pipeline: "text-to-speech", + modelID: "", + expectedImage: "custom-image", + expectError: false, + }, + { + name: "override live image", + setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) { + dockerManager.overrides = ImageOverrides{ + Live: map[string]string{ + "streamdiffusion": "custom-image", + }, + } + }, + pipeline: "live-video-to-video", + modelID: "streamdiffusion", + expectedImage: "custom-image", + expectError: false, + }, + { + name: "non-overridden batch image", + setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) { + dockerManager.overrides = ImageOverrides{ + Default: "default-image", + Batch: map[string]string{ + "text-to-speech": "custom-batch-image", + }, + Live: map[string]string{ + "streamdiffusion": "custom-live-image", + }, + } + }, + pipeline: "audio-to-text", + modelID: "", + expectedImage: "livepeer/ai-runner:audio-to-text", + expectError: false, + }, + { + name: "non-overridden live image", + setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) { + dockerManager.overrides = ImageOverrides{ + Default: "default-image", + Batch: map[string]string{ + "text-to-speech": "custom-batch-image", + }, + Live: map[string]string{ + "streamdiffusion": "custom-live-image", + }, + } + }, + pipeline: "live-video-to-video", + modelID: "comfyui", + expectedImage: "livepeer/ai-runner:live-app-comfyui", + expectError: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - image, err := manager.getContainerImageName(tt.pipeline, tt.modelID) + tt.setup(dockerManager, mockDockerClient) + + image, err := dockerManager.getContainerImageName(tt.pipeline, tt.modelID) if tt.expectError { require.Error(t, err) require.Equal(t, fmt.Sprintf("no container image found for live pipeline %s", tt.modelID), err.Error()) @@ -500,7 +583,7 @@ func TestDockerManager_createContainer(t *testing.T) { dockerManager.gpus = []string{gpu} dockerManager.gpuContainers = make(map[string]string) dockerManager.containers = make(map[string]*RunnerContainer) - dockerManager.defaultImage = containerImage + dockerManager.overrides.Default = containerImage mockDockerClient.On("ContainerCreate", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(container.CreateResponse{ID: containerID}, nil) mockDockerClient.On("ContainerStart", mock.Anything, containerID, mock.Anything).Return(nil) diff --git a/worker/worker.go b/worker/worker.go index 726485a2..aefc4934 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -51,13 +51,13 @@ type Worker struct { mu *sync.Mutex } -func NewWorker(defaultImage string, gpus []string, modelDir string) (*Worker, error) { +func NewWorker(imageOverrides ImageOverrides, gpus []string, modelDir string) (*Worker, error) { dockerClient, err := docker.NewClientWithOpts(docker.FromEnv, docker.WithAPIVersionNegotiation()) if err != nil { return nil, err } - manager, err := NewDockerManager(defaultImage, gpus, modelDir, dockerClient) + manager, err := NewDockerManager(imageOverrides, gpus, modelDir, dockerClient) if err != nil { return nil, err }