Skip to content

Commit

Permalink
fix: Bump plugins tags (#500)
Browse files Browse the repository at this point in the history
**Reason for Change**:
Bump the image preset plugins tags
  • Loading branch information
ishaansehgal99 authored Jul 9, 2024
1 parent 814f247 commit d601ae8
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 50 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/kind-cluster/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def update_model(model_name, model_commit):
# run_command(f"rm -rf {os.path.join(git_files_path, 'lfs')}")
except Exception as e:
print(f"An error occurred: {e}")
exit(1)
finally:
# Change back to the original directory
os.chdir(start_dir)
Expand All @@ -93,6 +94,7 @@ def download_new_model(model_name, model_url):
shutil.move(os.path.join(weights_path, ".git"), git_files_path)
except Exception as e:
print(f"An error occurred: {e}")
exit(1)
finally:
os.chdir(start_dir)

Expand Down
8 changes: 4 additions & 4 deletions presets/models/falcon/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ var (
PresetFalcon40BInstructModel = PresetFalcon40BModel + "-instruct"

PresetFalconTagMap = map[string]string{
"Falcon7B": "0.0.4",
"Falcon7BInstruct": "0.0.4",
"Falcon40B": "0.0.5",
"Falcon40BInstruct": "0.0.5",
"Falcon7B": "0.0.5",
"Falcon7BInstruct": "0.0.5",
"Falcon40B": "0.0.6",
"Falcon40BInstruct": "0.0.6",
}

baseCommandPresetFalcon = "accelerate launch"
Expand Down
4 changes: 2 additions & 2 deletions presets/models/mistral/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ var (
PresetMistral7BInstructModel = PresetMistral7BModel + "-instruct"

PresetMistralTagMap = map[string]string{
"Mistral7B": "0.0.4",
"Mistral7BInstruct": "0.0.4",
"Mistral7B": "0.0.5",
"Mistral7BInstruct": "0.0.5",
}

baseCommandPresetMistral = "accelerate launch"
Expand Down
2 changes: 1 addition & 1 deletion presets/models/phi2/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ var (
PresetPhi2Model = "phi-2"

PresetPhiTagMap = map[string]string{
"Phi2": "0.0.3",
"Phi2": "0.0.4",
}

baseCommandPresetPhi = "accelerate launch"
Expand Down
95 changes: 53 additions & 42 deletions test/e2e/preset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,18 @@ import (
"log"
"math/rand"
"os"
"path/filepath"
"strconv"
"strings"
"time"

batchv1 "k8s.io/api/batch/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/clientcmd"

"github.com/aws/karpenter-core/pkg/apis/v1alpha5"
kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
"github.com/azure/kaito/test/e2e/utils"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/samber/lo"
appsv1 "k8s.io/api/apps/v1"
batchv1 "k8s.io/api/batch/v1"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand Down Expand Up @@ -453,24 +448,7 @@ func validateTuningResource(workspaceObj *kaitov1alpha1.Workspace) {
}

func validateACRTuningResultsUploaded(workspaceObj *kaitov1alpha1.Workspace, jobName string) {
var config *rest.Config
var err error

if os.Getenv("KUBERNETES_SERVICE_HOST") != "" && os.Getenv("KUBERNETES_SERVICE_PORT") != "" {
config, err = rest.InClusterConfig()
if err != nil {
log.Fatalf("Failed to get in-cluster config: %v", err)
}
} else {
// Use kubeconfig file for local development
kubeconfig := filepath.Join(os.Getenv("HOME"), ".kube", "config")
config, err = clientcmd.BuildConfigFromFlags("", kubeconfig)
if err != nil {
log.Fatalf("Failed to load kubeconfig: %v", err)
}
}

coreClient, err := kubernetes.NewForConfig(config)
coreClient, err := utils.GetK8sConfig()
if err != nil {
log.Fatalf("Failed to create core client: %v", err)
}
Expand Down Expand Up @@ -555,6 +533,31 @@ func deleteWorkspace(workspaceObj *kaitov1alpha1.Workspace) error {
return nil
}

func printPodLogsOnFailure(namespace, labelSelector string) {
coreClient, err := utils.GetK8sConfig()
if err != nil {
log.Printf("Failed to create core client: %v", err)
}
pods, err := coreClient.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{
LabelSelector: labelSelector,
})
if err != nil {
log.Printf("Failed to list pods: %v", err)
return
}

for _, pod := range pods.Items {
for _, container := range pod.Spec.Containers {
logs, err := utils.GetPodLogs(coreClient, namespace, pod.Name, container.Name)
if err != nil {
log.Printf("Failed to get logs from pod %s, container %s: %v", pod.Name, container.Name, err)
} else {
fmt.Printf("Logs from pod %s, container %s:\n%s\n", pod.Name, container.Name, string(logs))
}
}
}
}

var runLlama13B bool
var aiModelsRegistry string
var aiModelsRegistrySecret string
Expand All @@ -569,25 +572,34 @@ var _ = Describe("Workspace Preset", func() {
loadModelVersions()
})

It("should create a mistral workspace with preset public mode successfully", func() {
numOfNode := 1
workspaceObj := createMistralWorkspaceWithPresetPublicMode(numOfNode)

defer cleanupResources(workspaceObj)
time.Sleep(30 * time.Second)

validateMachineCreation(workspaceObj, numOfNode)
validateResourceStatus(workspaceObj)

time.Sleep(30 * time.Second)

validateAssociatedService(workspaceObj)

validateInferenceResource(workspaceObj, int32(numOfNode), false)

validateWorkspaceReadiness(workspaceObj)
AfterEach(func() {
if CurrentSpecReport().Failed() {
printPodLogsOnFailure(namespaceName, "") // The Preset Pod
printPodLogsOnFailure("kaito-workspace", "") // The Kaito Workspace Pod
printPodLogsOnFailure("gpu-provisioner", "") // The gpu-provisioner Pod
Fail("Fail threshold reached")
}
})

//It("should create a mistral workspace with preset public mode successfully", func() {
// numOfNode := 1
// workspaceObj := createMistralWorkspaceWithPresetPublicMode(numOfNode)
//
// defer cleanupResources(workspaceObj)
// time.Sleep(30 * time.Second)
//
// validateMachineCreation(workspaceObj, numOfNode)
// validateResourceStatus(workspaceObj)
//
// time.Sleep(30 * time.Second)
//
// validateAssociatedService(workspaceObj)
//
// validateInferenceResource(workspaceObj, int32(numOfNode), false)
//
// validateWorkspaceReadiness(workspaceObj)
//})

It("should create a Phi-2 workspace with preset public mode successfully", func() {
numOfNode := 1
workspaceObj := createPhi2WorkspaceWithPresetPublicMode(numOfNode)
Expand Down Expand Up @@ -729,7 +741,6 @@ var _ = Describe("Workspace Preset", func() {

time.Sleep(30 * time.Second)

// TODO: Need to check if tuning job uploaded to ACR
validateTuningResource(workspaceObj)

validateACRTuningResultsUploaded(workspaceObj, jobName)
Expand Down
36 changes: 35 additions & 1 deletion test/e2e/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ import (
"fmt"
"io"
"io/ioutil"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/clientcmd"
"log"
"math/rand"
"os"
"path/filepath"
"strings"
"time"

Expand Down Expand Up @@ -81,8 +85,38 @@ func GetPodNameForJob(coreClient *kubernetes.Clientset, namespace, jobName strin
return podList.Items[0].Name, nil
}

func GetK8sConfig() (*kubernetes.Clientset, error) {
var config *rest.Config
var err error

if os.Getenv("KUBERNETES_SERVICE_HOST") != "" && os.Getenv("KUBERNETES_SERVICE_PORT") != "" {
config, err = rest.InClusterConfig()
if err != nil {
log.Fatalf("Failed to get in-cluster config: %v", err)
}
} else {
// Use kubeconfig file for local development
kubeconfig := filepath.Join(os.Getenv("HOME"), ".kube", "config")
config, err = clientcmd.BuildConfigFromFlags("", kubeconfig)
if err != nil {
log.Fatalf("Failed to load kubeconfig: %v", err)
}
}

coreClient, err := kubernetes.NewForConfig(config)
if err != nil {
log.Fatalf("Failed to create core client: %v", err)
}
return coreClient, err
}

func GetPodLogs(coreClient *kubernetes.Clientset, namespace, podName, containerName string) (string, error) {
req := coreClient.CoreV1().Pods(namespace).GetLogs(podName, &v1.PodLogOptions{Container: containerName})
options := &v1.PodLogOptions{}
if containerName != "" {
options.Container = containerName
}

req := coreClient.CoreV1().Pods(namespace).GetLogs(podName, options)
logs, err := req.Stream(context.Background())
if err != nil {
return "", err
Expand Down

0 comments on commit d601ae8

Please sign in to comment.