Skip to content

Commit

Permalink
feat: Update Preset Model Tags for Release (#578)
Browse files Browse the repository at this point in the history
**Reason for Change**:
Update the Preset Image tags used by Kaito Controller
  • Loading branch information
ishaansehgal99 authored Aug 23, 2024
1 parent 2e967ca commit 5c30038
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 34 deletions.
23 changes: 12 additions & 11 deletions presets/models/falcon/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@ var (
PresetFalcon40BInstructModel = PresetFalcon40BModel + "-instruct"

PresetFalconTagMap = map[string]string{
"Falcon7B": "0.0.5",
"Falcon7BInstruct": "0.0.5",
"Falcon40B": "0.0.6",
"Falcon40BInstruct": "0.0.6",
"Falcon7B": "0.0.6",
"Falcon7BInstruct": "0.0.6",
"Falcon40B": "0.0.7",
"Falcon40BInstruct": "0.0.7",
}

baseCommandPresetFalcon = "python3 metrics_server.py & accelerate launch"
baseCommandPresetFalconInference = "accelerate launch"
baseCommandPresetFalconTuning = "python3 metrics_server.py & accelerate launch"
falconRunParams = map[string]string{
"torch_dtype": "bfloat16",
"pipeline": "text-generation",
Expand All @@ -66,7 +67,7 @@ func (*falcon7b) GetInferenceParameters() *model.PresetParam {
TorchRunParams: inference.DefaultAccelerateParams,
ModelRunParams: falconRunParams,
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetFalcon,
BaseCommand: baseCommandPresetFalconInference,
Tag: PresetFalconTagMap["Falcon7B"],
}
}
Expand All @@ -81,7 +82,7 @@ func (*falcon7b) GetTuningParameters() *model.PresetParam {
TorchRunParams: tuning.DefaultAccelerateParams,
//ModelRunPrams: falconRunTuningParams, // TODO
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetFalcon,
BaseCommand: baseCommandPresetFalconTuning,
Tag: PresetFalconTagMap["Falcon7B"],
TuningPerGPUMemoryRequirement: map[string]int{"qlora": 16},
}
Expand Down Expand Up @@ -109,7 +110,7 @@ func (*falcon7bInst) GetInferenceParameters() *model.PresetParam {
TorchRunParams: inference.DefaultAccelerateParams,
ModelRunParams: falconRunParams,
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetFalcon,
BaseCommand: baseCommandPresetFalconInference,
Tag: PresetFalconTagMap["Falcon7BInstruct"],
}

Expand Down Expand Up @@ -139,7 +140,7 @@ func (*falcon40b) GetInferenceParameters() *model.PresetParam {
TorchRunParams: inference.DefaultAccelerateParams,
ModelRunParams: falconRunParams,
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetFalcon,
BaseCommand: baseCommandPresetFalconInference,
Tag: PresetFalconTagMap["Falcon40B"],
}

Expand All @@ -155,7 +156,7 @@ func (*falcon40b) GetTuningParameters() *model.PresetParam {
TorchRunParams: tuning.DefaultAccelerateParams,
//ModelRunPrams: falconRunTuningParams, // TODO
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetFalcon,
BaseCommand: baseCommandPresetFalconTuning,
Tag: PresetFalconTagMap["Falcon40B"],
}
}
Expand All @@ -181,7 +182,7 @@ func (*falcon40bInst) GetInferenceParameters() *model.PresetParam {
TorchRunParams: inference.DefaultAccelerateParams,
ModelRunParams: falconRunParams,
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetFalcon,
BaseCommand: baseCommandPresetFalconInference,
Tag: PresetFalconTagMap["Falcon40BInstruct"],
}
}
Expand Down
13 changes: 7 additions & 6 deletions presets/models/mistral/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ var (
PresetMistral7BInstructModel = PresetMistral7BModel + "-instruct"

PresetMistralTagMap = map[string]string{
"Mistral7B": "0.0.6",
"Mistral7BInstruct": "0.0.6",
"Mistral7B": "0.0.7",
"Mistral7BInstruct": "0.0.7",
}

baseCommandPresetMistral = "python3 metrics_server.py & accelerate launch"
baseCommandPresetMistralInference = "accelerate launch"
baseCommandPresetMistralTuning = "python3 metrics_server.py & accelerate launch"
mistralRunParams = map[string]string{
"torch_dtype": "bfloat16",
"pipeline": "text-generation",
Expand All @@ -53,7 +54,7 @@ func (*mistral7b) GetInferenceParameters() *model.PresetParam {
TorchRunParams: inference.DefaultAccelerateParams,
ModelRunParams: mistralRunParams,
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetMistral,
BaseCommand: baseCommandPresetMistralInference,
Tag: PresetMistralTagMap["Mistral7B"],
}

Expand All @@ -69,7 +70,7 @@ func (*mistral7b) GetTuningParameters() *model.PresetParam {
//TorchRunParams: tuning.DefaultAccelerateParams,
//ModelRunParams: mistralRunParams,
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetMistral,
BaseCommand: baseCommandPresetMistralTuning,
Tag: PresetMistralTagMap["Mistral7B"],
}
}
Expand All @@ -96,7 +97,7 @@ func (*mistral7bInst) GetInferenceParameters() *model.PresetParam {
TorchRunParams: inference.DefaultAccelerateParams,
ModelRunParams: mistralRunParams,
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetMistral,
BaseCommand: baseCommandPresetMistralInference,
Tag: PresetMistralTagMap["Mistral7BInstruct"],
}

Expand Down
9 changes: 5 additions & 4 deletions presets/models/phi2/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ var (
PresetPhi2Model = "phi-2"

PresetPhiTagMap = map[string]string{
"Phi2": "0.0.4",
"Phi2": "0.0.5",
}

baseCommandPresetPhi = "python3 metrics_server.py & accelerate launch"
baseCommandPresetPhiInference = "accelerate launch"
baseCommandPresetPhiTuning = "python3 metrics_server.py & accelerate launch"
phiRunParams = map[string]string{
"torch_dtype": "float16",
"pipeline": "text-generation",
Expand All @@ -47,7 +48,7 @@ func (*phi2) GetInferenceParameters() *model.PresetParam {
TorchRunParams: inference.DefaultAccelerateParams,
ModelRunParams: phiRunParams,
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetPhi,
BaseCommand: baseCommandPresetPhiInference,
Tag: PresetPhiTagMap["Phi2"],
}
}
Expand All @@ -62,7 +63,7 @@ func (*phi2) GetTuningParameters() *model.PresetParam {
// TorchRunParams: inference.DefaultAccelerateParams,
// ModelRunParams: phiRunParams,
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetPhi,
BaseCommand: baseCommandPresetPhiTuning,
Tag: PresetPhiTagMap["Phi2"],
}
}
Expand Down
27 changes: 14 additions & 13 deletions presets/models/phi3/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ var (
PresetPhi3Medium128kModel = "phi-3-medium-128k-instruct"

PresetPhiTagMap = map[string]string{
"Phi3Mini4kInstruct": "0.0.1",
"Phi3Mini128kInstruct": "0.0.1",
"Phi3Medium4kInstruct": "0.0.1",
"Phi3Medium128kInstruct": "0.0.1",
"Phi3Mini4kInstruct": "0.0.2",
"Phi3Mini128kInstruct": "0.0.2",
"Phi3Medium4kInstruct": "0.0.2",
"Phi3Medium128kInstruct": "0.0.2",
}

baseCommandPresetPhi = "python3 metrics_server.py & accelerate launch"
baseCommandPresetPhiInference = "accelerate launch"
baseCommandPresetPhiTuning = "python3 metrics_server.py & accelerate launch"
phiRunParams = map[string]string{
"torch_dtype": "auto",
"pipeline": "text-generation",
Expand All @@ -66,7 +67,7 @@ func (*phi3Mini4KInst) GetInferenceParameters() *model.PresetParam {
TorchRunParams: inference.DefaultAccelerateParams,
ModelRunParams: phiRunParams,
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetPhi,
BaseCommand: baseCommandPresetPhiInference,
Tag: PresetPhiTagMap["Phi3Mini4kInstruct"],
}
}
Expand All @@ -81,7 +82,7 @@ func (*phi3Mini4KInst) GetTuningParameters() *model.PresetParam {
// TorchRunParams: inference.DefaultAccelerateParams,
// ModelRunParams: phiRunParams,
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetPhi,
BaseCommand: baseCommandPresetPhiTuning,
Tag: PresetPhiTagMap["Phi3Mini4kInstruct"],
}
}
Expand All @@ -105,7 +106,7 @@ func (*phi3Mini128KInst) GetInferenceParameters() *model.PresetParam {
TorchRunParams: inference.DefaultAccelerateParams,
ModelRunParams: phiRunParams,
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetPhi,
BaseCommand: baseCommandPresetPhiInference,
Tag: PresetPhiTagMap["Phi3Mini128kInstruct"],
}
}
Expand All @@ -120,7 +121,7 @@ func (*phi3Mini128KInst) GetTuningParameters() *model.PresetParam {
// TorchRunParams: inference.DefaultAccelerateParams,
// ModelRunParams: phiRunParams,
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetPhi,
BaseCommand: baseCommandPresetPhiTuning,
Tag: PresetPhiTagMap["Phi3Mini128kInstruct"],
}
}
Expand All @@ -144,7 +145,7 @@ func (*Phi3Medium4kInstruct) GetInferenceParameters() *model.PresetParam {
TorchRunParams: inference.DefaultAccelerateParams,
ModelRunParams: phiRunParams,
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetPhi,
BaseCommand: baseCommandPresetPhiInference,
Tag: PresetPhiTagMap["Phi3Medium4kInstruct"],
}
}
Expand All @@ -159,7 +160,7 @@ func (*Phi3Medium4kInstruct) GetTuningParameters() *model.PresetParam {
// TorchRunParams: inference.DefaultAccelerateParams,
// ModelRunParams: phiRunParams,
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetPhi,
BaseCommand: baseCommandPresetPhiTuning,
Tag: PresetPhiTagMap["Phi3Medium4kInstruct"],
}
}
Expand All @@ -183,7 +184,7 @@ func (*Phi3Medium128kInstruct) GetInferenceParameters() *model.PresetParam {
TorchRunParams: inference.DefaultAccelerateParams,
ModelRunParams: phiRunParams,
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetPhi,
BaseCommand: baseCommandPresetPhiInference,
Tag: PresetPhiTagMap["Phi3Medium128kInstruct"],
}
}
Expand All @@ -198,7 +199,7 @@ func (*Phi3Medium128kInstruct) GetTuningParameters() *model.PresetParam {
// TorchRunParams: inference.DefaultAccelerateParams,
// ModelRunParams: phiRunParams,
ReadinessTimeout: time.Duration(30) * time.Minute,
BaseCommand: baseCommandPresetPhi,
BaseCommand: baseCommandPresetPhiTuning,
Tag: PresetPhiTagMap["Phi3Medium128kInstruct"],
}
}
Expand Down

0 comments on commit 5c30038

Please sign in to comment.