Skip to content

Commit

Permalink
fix races
Browse files Browse the repository at this point in the history
  • Loading branch information
omrikiei committed Dec 29, 2024
1 parent 619c0b0 commit 9ea51a0
Show file tree
Hide file tree
Showing 10 changed files with 542 additions and 485 deletions.
20 changes: 13 additions & 7 deletions cmd/expose.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ ktunnel expose redis 6379
deploymentAnnotations[parsed[0]] = parsed[1]
}

svc, err := k8s.NewKubeService(KubeContext, Namespace)
if err != nil {
log.Fatalf("Failed to create new kube service: %v", err)
}
podTolerations := make([]apiv1.Toleration, 0, len(PodTolerations))
for _, label := range PodTolerations {
parsed := strings.Split(label, "=")
Expand All @@ -110,15 +114,15 @@ ktunnel expose redis 6379
}

if Force {
err := k8s.TeardownExposedService(Namespace, svcName, &KubeContext, DeploymentOnly)
err := svc.TeardownExposedService(svcName, DeploymentOnly)
if err != nil {
log.Infof("Force delete: Failed deleting k8s objects: %s", err)
}
}

err := k8s.ExposeAsService(
&Namespace,
&svcName,
err = svc.ExposeAsService(
Namespace,
svcName,
port,
Scheme,
ports,
Expand All @@ -134,7 +138,7 @@ ktunnel expose redis 6379
CertFile,
KeyFile,
ServiceType,
&KubeContext,
KubeContext,
ServerCPURequest,
ServerCPULimit,
ServerMemRequest,
Expand All @@ -159,7 +163,7 @@ ktunnel expose redis 6379
}
cancel()
if !Reuse {
err := k8s.TeardownExposedService(Namespace, svcName, &KubeContext, DeploymentOnly)
err := svc.TeardownExposedService(svcName, DeploymentOnly)
if err != nil {
log.Errorf("Failed deleting k8s objects: %s", err)
}
Expand All @@ -171,11 +175,13 @@ ktunnel expose redis 6379
log.Info("waiting for deployment to be ready")
<-readyChan

// Kube Service
kubeService, err := k8s.NewKubeService(KubeContext, Namespace)
// port-Forward
strPort := strconv.FormatInt(int64(port), 10)
stopChan := make(chan struct{}, 1)
// Create a tunnel client for each replica
sourcePorts, err := k8s.PortForward(&Namespace, &svcName, strPort, wg, stopChan, &KubeContext)
sourcePorts, err := kubeService.PortForward(Namespace, svcName, strPort, wg, stopChan)
if err != nil {
log.Fatalf("Failed to run port forwarding: %v", err)
os.Exit(1)
Expand Down
11 changes: 8 additions & 3 deletions cmd/inject.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ ktunnel inject deployment mydeployment 3306 6379
// Inject
deployment := args[0]
readyChan := make(chan bool, 1)
_, err := k8s.InjectSidecar(&Namespace, &deployment, &port, ServerImage, CertFile, KeyFile, readyChan, &KubeContext)
// Kube Service
svc, err := k8s.NewKubeService(KubeContext, Namespace)
if err != nil {
log.Fatalf("failed creating kube service: %v", err)
}
_, err = svc.InjectSidecar(&Namespace, &deployment, &port, ServerImage, CertFile, KeyFile, readyChan, &KubeContext)
if err != nil {
log.Fatalf("failed injecting sidecar: %v", err)
}
Expand All @@ -67,7 +72,7 @@ ktunnel inject deployment mydeployment 3306 6379
wg.Wait()
if eject {
readyChan = make(chan bool, 1)
ok, err := k8s.RemoveSidecar(&Namespace, &deployment, ServerImage, readyChan, &KubeContext)
ok, err := svc.RemoveSidecar(&Namespace, &deployment, ServerImage, readyChan, &KubeContext)
if !ok {
log.Errorf("Failed removing tunnel sidecar; %v", err)
}
Expand All @@ -90,7 +95,7 @@ ktunnel inject deployment mydeployment 3306 6379
// port-Forward
strPort := strconv.FormatInt(int64(port), 10)
// Create a tunnel client for each replica
sourcePorts, err := k8s.PortForward(&Namespace, &deployment, strPort, wg, stopChan, &KubeContext)
sourcePorts, err := svc.PortForward(Namespace, deployment, strPort, wg, stopChan)
if err != nil {
log.Fatalf("Failed to run port forwarding: %v", err)
os.Exit(1)
Expand Down
797 changes: 403 additions & 394 deletions coverage.txt

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions pkg/k8s/cleanup.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

// ResourceTracker keeps track of resources created by ktunnel for cleanup
type ResourceTracker struct {
clients *Clients
namespace string
deployments []string
services []string
Expand All @@ -23,8 +24,9 @@ type ResourceTracker struct {
}

// NewResourceTracker creates a new ResourceTracker for the given namespace
func NewResourceTracker(namespace string) *ResourceTracker {
func NewResourceTracker(namespace string, clients *Clients) *ResourceTracker {
return &ResourceTracker{
clients: clients,
namespace: namespace,
deployments: make([]string, 0),
services: make([]string, 0),
Expand Down Expand Up @@ -101,7 +103,7 @@ func (rt *ResourceTracker) Cleanup(ctx context.Context) error {
wg.Add(1)
go func(name string) {
defer wg.Done()
err := getDeploymentsClient().Delete(ctx, name, metav1.DeleteOptions{})
err := rt.clients.Deployments.Delete(ctx, name, metav1.DeleteOptions{})
if err != nil {
log.Warnf("Failed to delete deployment %s: %v", name, err)
select {
Expand All @@ -119,7 +121,7 @@ func (rt *ResourceTracker) Cleanup(ctx context.Context) error {
wg.Add(1)
go func(name string) {
defer wg.Done()
err := getServicesClient().Delete(ctx, name, metav1.DeleteOptions{})
err := rt.clients.Services.Delete(ctx, name, metav1.DeleteOptions{})
if err != nil {
log.Warnf("Failed to delete service %s: %v", name, err)
select {
Expand Down
31 changes: 28 additions & 3 deletions pkg/k8s/cleanup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,17 @@ import (
)

func TestResourceTracker_AddAndRemove(t *testing.T) {
rt := NewResourceTracker("test-namespace")
namespace := "test-namespace"
fakeClient := testclient.NewSimpleClientset()
// Initialize the clients
deploymentsClient = fakeClient.AppsV1().Deployments(namespace)
svcClient = fakeClient.CoreV1().Services(namespace)

clients := &Clients{
Deployments: deploymentsClient,
Services: svcClient,
}
rt := NewResourceTracker("test-namespace", clients)

// Test adding resources
rt.AddDeployment("test-deployment-1")
Expand Down Expand Up @@ -51,8 +61,13 @@ func TestResourceTracker_Cleanup(t *testing.T) {
deploymentsClient = fakeClient.AppsV1().Deployments(namespace)
svcClient = fakeClient.CoreV1().Services(namespace)

clients := &Clients{
Deployments: deploymentsClient,
Services: svcClient,
}

// Create a resource tracker
rt := NewResourceTracker(namespace)
rt := NewResourceTracker(namespace, clients)
rt.SetTimeout(5 * time.Second)

// Create test resources in the fake client
Expand Down Expand Up @@ -88,7 +103,17 @@ func TestResourceTracker_Cleanup(t *testing.T) {
}

func TestResourceTracker_CleanupTimeout(t *testing.T) {
rt := NewResourceTracker("test-namespace")
namespace := "test-namespace"
fakeClient := testclient.NewSimpleClientset()
// Initialize the clients
deploymentsClient = fakeClient.AppsV1().Deployments(namespace)
svcClient = fakeClient.CoreV1().Services(namespace)

clients := &Clients{
Deployments: deploymentsClient,
Services: svcClient,
}
rt := NewResourceTracker("test-namespace", clients)
rt.SetTimeout(1 * time.Millisecond) // Very short timeout

// Add some resources
Expand Down
14 changes: 14 additions & 0 deletions pkg/k8s/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ var (
svcClient v12.ServiceInterface
)

type Clients struct {
Deployments v1.DeploymentInterface
Pods v12.PodInterface
Services v12.ServiceInterface
}

func getDeploymentsClient() v1.DeploymentInterface {
clientMutex.RLock()
defer clientMutex.RUnlock()
Expand All @@ -27,6 +33,14 @@ func getServicesClient() v12.ServiceInterface {
return svcClient
}

func NewClients(deployments v1.DeploymentInterface, pods v12.PodInterface, services v12.ServiceInterface) *Clients {
return &Clients{
Deployments: deployments,
Pods: pods,
Services: services,
}
}

func setClients(deployments v1.DeploymentInterface, pods v12.PodInterface, services v12.ServiceInterface) {
clientMutex.Lock()
defer clientMutex.Unlock()
Expand Down
70 changes: 48 additions & 22 deletions pkg/k8s/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,38 @@ const (

type ByCreationTime []apiv1.Pod

type KubeService struct {
clients *Clients
config *rest.Config
}

func NewKubeService(kubeCtx, namespace string) (*KubeService, error) {
cfg := GetKubeConfig(kubeCtx)

return &KubeService{
clients: GetClients(cfg, namespace),
config: cfg,
}, nil
}

func GetClients(cfg *rest.Config, namespace string) *Clients {
clientSet, err := kubernetes.NewForConfig(cfg)
if err != nil {
log.Errorf("Failed to get k8s client: %v", err)
os.Exit(1)
}

deploymentsClient = clientSet.AppsV1().Deployments(namespace)
podsClient = clientSet.CoreV1().Pods(namespace)
svcClient = clientSet.CoreV1().Services(namespace)

return &Clients{
Deployments: deploymentsClient,
Pods: podsClient,
Services: svcClient,
}
}

func (a ByCreationTime) Len() int { return len(a) }
func (a ByCreationTime) Less(i, j int) bool {
return a[i].CreationTimestamp.After(a[j].CreationTimestamp.Time)
Expand Down Expand Up @@ -68,7 +100,7 @@ func IsVerbose() bool {
return verbose
}

func getKubeConfig(kubeCtx *string) *rest.Config {
func GetKubeConfig(kubeCtx string) *rest.Config {
configMutex.RLock()
if kubeconfig != nil {
defer configMutex.RUnlock()
Expand All @@ -93,8 +125,8 @@ func getKubeConfig(kubeCtx *string) *rest.Config {
}

var configOverrides *clientcmd.ConfigOverrides
if (kubeCtx) != nil {
configOverrides = &clientcmd.ConfigOverrides{CurrentContext: *kubeCtx}
if (kubeCtx) != "" {
configOverrides = &clientcmd.ConfigOverrides{CurrentContext: kubeCtx}
}

config, err := clientcmd.NewNonInteractiveDeferredLoadingClientConfig(loadingRules, configOverrides).ClientConfig()
Expand All @@ -105,7 +137,7 @@ func getKubeConfig(kubeCtx *string) *rest.Config {
return kubeconfig
}

func getClients(namespace *string, kubeCtx *string) {
func getClients(namespace *string, kubeCtx string) {
clientMutex.Lock()
defer clientMutex.Unlock()

Expand All @@ -114,7 +146,7 @@ func getClients(namespace *string, kubeCtx *string) {
return
}

clientSet, err := kubernetes.NewForConfig(getKubeConfig(kubeCtx))
clientSet, err := kubernetes.NewForConfig(GetKubeConfig(kubeCtx))
if err != nil {
log.Errorf("Failed to get k8s client: %v", err)
os.Exit(1)
Expand All @@ -125,13 +157,10 @@ func getClients(namespace *string, kubeCtx *string) {
svcClient = clientSet.CoreV1().Services(*namespace)
}

func getPodsFilteredByLabel(namespace, kubeCtx, labelSelector *string) (*apiv1.PodList, error) {
getClients(namespace, kubeCtx)
clientMutex.RLock()
defer clientMutex.RUnlock()
pods, err := podsClient.List(
func (k *KubeService) getPodsFilteredByLabel(labelSelector string) (*apiv1.PodList, error) {
pods, err := k.clients.Pods.List(
context.Background(), metav1.ListOptions{
LabelSelector: *labelSelector,
LabelSelector: labelSelector,
},
)
if err != nil {
Expand Down Expand Up @@ -260,13 +289,12 @@ func newService(namespace, name string, ports []apiv1.ServicePort, serviceType a
}
}

func getPodNames(namespace, deploymentName *string, podsPtr *[]string, kubeCtx *string) error {
labelSelector := deploymentNameLabel + "=" + *deploymentName + "," + deploymentInstanceLabel + "=" + *deploymentName
filteredPods, err := getPodsFilteredByLabel(namespace, kubeCtx, &labelSelector)
func (k *KubeService) getPodNames(deploymentName string, pods []string) error {
labelSelector := deploymentNameLabel + "=" + deploymentName + "," + deploymentInstanceLabel + "=" + deploymentName
filteredPods, err := k.getPodsFilteredByLabel(labelSelector)
if err != nil {
return err
}
pods := *podsPtr
matchingPods := ByCreationTime{}
pIndex := 0
for _, p := range filteredPods.Items {
Expand All @@ -286,18 +314,16 @@ func getPodNames(namespace, deploymentName *string, podsPtr *[]string, kubeCtx *
return nil
}

func PortForward(namespace, deploymentName *string, targetPort string, fwdWaitGroup *sync.WaitGroup, stopChan <-chan struct{}, kubecontext *string) (*[]string, error) {
getClients(namespace, kubecontext)

func (k *KubeService) PortForward(namespace, deploymentName string, targetPort string, fwdWaitGroup *sync.WaitGroup, stopChan <-chan struct{}) (*[]string, error) {
clientMutex.RLock()
deployment, err := deploymentsClient.Get(context.Background(), *deploymentName, metav1.GetOptions{})
deployment, err := deploymentsClient.Get(context.Background(), deploymentName, metav1.GetOptions{})
clientMutex.RUnlock()
if err != nil {
return nil, err
}

podNames := make([]string, *deployment.Spec.Replicas)
err = getPodNames(namespace, deploymentName, &podNames, kubecontext)
err = k.getPodNames(deploymentName, podNames)
fwdWaitGroup.Add(int(*deployment.Spec.Replicas))

if err != nil {
Expand All @@ -317,9 +343,9 @@ func PortForward(namespace, deploymentName *string, targetPort string, fwdWaitGr
for i, podName := range podNames {
readyChan := make(chan struct{}, 1)
ports := []string{fmt.Sprintf("%s:%s", sourcePorts[i], targetPort)}
serverURL := getPortForwardURL(getKubeConfig(kubecontext), *namespace, podName)
serverURL := getPortForwardURL(k.config, namespace, podName)

transport, upgrader, err := spdy.RoundTripperFor(getKubeConfig(kubecontext))
transport, upgrader, err := spdy.RoundTripperFor(k.config)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 9ea51a0

Please sign in to comment.