From 0b56d1ec09ddd56e61a18cf79990797b6cbf77f5 Mon Sep 17 00:00:00 2001
From: Mostafa Moradian <mostafa@gatewayd.io>
Date: Mon, 7 Oct 2024 21:24:24 +0200
Subject: [PATCH 1/3] Check for nil or empty result

---
 act/registry.go | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/act/registry.go b/act/registry.go
index 68fe97ab..e3a0b9c6 100644
--- a/act/registry.go
+++ b/act/registry.go
@@ -407,6 +407,10 @@ func runActionWithTimeout(
 
 // RunAll run all the actions in the outputs and returns the end result.
 func (r *Registry) RunAll(result map[string]any) map[string]any {
+	if result == nil || len(result) == 0 {
+		return result
+	}
+
 	if _, exists := result[sdkAct.Outputs]; !exists {
 		r.Logger.Debug().Msg("Outputs key is not present, returning the result as-is")
 		return result

From 2e498e91b434d7a0c3ad0ef3518f10b20f338633 Mon Sep 17 00:00:00 2001
From: Mostafa Moradian <mostafa@gatewayd.io>
Date: Mon, 7 Oct 2024 21:25:09 +0200
Subject: [PATCH 2/3] Use the result of notification hooks by running actions

---
 cmd/run.go        | 33 +++++++++++++++++++++++++++------
 network/proxy.go  | 16 +++++++++++++---
 network/server.go | 47 ++++++++++++++++++++++++++++++++++++++---------
 3 files changed, 78 insertions(+), 18 deletions(-)

diff --git a/cmd/run.go b/cmd/run.go
index 98a96061..bf992467 100644
--- a/cmd/run.go
+++ b/cmd/run.go
@@ -109,7 +109,7 @@ func StopGracefully(
 		defer cancel()
 
 		//nolint:contextcheck
-		_, err := pluginRegistry.Run(
+		result, err := pluginRegistry.Run(
 			pluginTimeoutCtx,
 			map[string]any{"signal": currentSignal},
 			v1.HookName_HOOK_NAME_ON_SIGNAL,
@@ -118,6 +118,9 @@ func StopGracefully(
 			logger.Error().Err(err).Msg("Failed to run OnSignal hooks")
 			span.RecordError(err)
 		}
+		if result != nil {
+			_ = pluginRegistry.ActRegistry.RunAll(result)
+		}
 	}
 
 	logger.Info().Msg("GatewayD is shutting down")
@@ -434,6 +437,9 @@ var runCmd = &cobra.Command{
 			logger.Error().Err(err).Msg("Failed to run OnConfigLoaded hooks")
 			span.RecordError(err)
 		}
+		if updatedGlobalConfig != nil {
+			updatedGlobalConfig = pluginRegistry.ActRegistry.RunAll(updatedGlobalConfig)
+		}
 
 		// If the config was modified by the plugins, merge it with the one loaded from the file.
 		// Only global configuration is merged, which means that plugins cannot modify the plugin
@@ -606,12 +612,15 @@ var runCmd = &cobra.Command{
 		defer cancel()
 
 		if data, ok := conf.GlobalKoanf.Get("loggers").(map[string]any); ok {
-			_, err = pluginRegistry.Run(
+			result, err := pluginRegistry.Run(
 				pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_LOGGER)
 			if err != nil {
 				logger.Error().Err(err).Msg("Failed to run OnNewLogger hooks")
 				span.RecordError(err)
 			}
+			if result != nil {
+				_ = pluginRegistry.ActRegistry.RunAll(result)
+			}
 		} else {
 			logger.Error().Msg("Failed to get loggers from config")
 		}
@@ -767,12 +776,15 @@ var runCmd = &cobra.Command{
 							"backoffMultiplier":  clientConfig.BackoffMultiplier,
 							"disableBackoffCaps": clientConfig.DisableBackoffCaps,
 						}
-						_, err := pluginRegistry.Run(
+						result, err := pluginRegistry.Run(
 							pluginTimeoutCtx, clientCfg, v1.HookName_HOOK_NAME_ON_NEW_CLIENT)
 						if err != nil {
 							logger.Error().Err(err).Msg("Failed to run OnNewClient hooks")
 							span.RecordError(err)
 						}
+						if result != nil {
+							_ = pluginRegistry.ActRegistry.RunAll(result)
+						}
 
 						err = pools[configGroupName][configBlockName].Put(client.ID, client)
 						if err != nil {
@@ -822,7 +834,7 @@ var runCmd = &cobra.Command{
 					context.Background(), conf.Plugin.Timeout)
 				defer cancel()
 
-				_, err = pluginRegistry.Run(
+				result, err := pluginRegistry.Run(
 					pluginTimeoutCtx,
 					map[string]any{"name": configBlockName, "size": currentPoolSize},
 					v1.HookName_HOOK_NAME_ON_NEW_POOL)
@@ -830,6 +842,9 @@ var runCmd = &cobra.Command{
 					logger.Error().Err(err).Msg("Failed to run OnNewPool hooks")
 					span.RecordError(err)
 				}
+				if result != nil {
+					_ = pluginRegistry.ActRegistry.RunAll(result)
+				}
 			}
 		}
 
@@ -877,12 +892,15 @@ var runCmd = &cobra.Command{
 				defer cancel()
 
 				if data, ok := conf.GlobalKoanf.Get("proxies").(map[string]any); ok {
-					_, err = pluginRegistry.Run(
+					result, err := pluginRegistry.Run(
 						pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_PROXY)
 					if err != nil {
 						logger.Error().Err(err).Msg("Failed to run OnNewProxy hooks")
 						span.RecordError(err)
 					}
+					if result != nil {
+						_ = pluginRegistry.ActRegistry.RunAll(result)
+					}
 				} else {
 					logger.Error().Msg("Failed to get proxy from config")
 				}
@@ -948,12 +966,15 @@ var runCmd = &cobra.Command{
 			defer cancel()
 
 			if data, ok := conf.GlobalKoanf.Get("servers").(map[string]any); ok {
-				_, err = pluginRegistry.Run(
+				result, err := pluginRegistry.Run(
 					pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_SERVER)
 				if err != nil {
 					logger.Error().Err(err).Msg("Failed to run OnNewServer hooks")
 					span.RecordError(err)
 				}
+				if result != nil {
+					_ = pluginRegistry.ActRegistry.RunAll(result)
+				}
 			} else {
 				logger.Error().Msg("Failed to get the servers configuration")
 			}
diff --git a/network/proxy.go b/network/proxy.go
index cdf9b71f..6d60b310 100644
--- a/network/proxy.go
+++ b/network/proxy.go
@@ -446,7 +446,7 @@ func (pr *Proxy) PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.Gate
 	defer cancel()
 
 	// Run the OnTrafficToServer hooks.
-	_, err = pr.PluginRegistry.Run(
+	result, err = pr.PluginRegistry.Run(
 		pluginTimeoutCtx,
 		trafficData(
 			conn.Conn(),
@@ -463,8 +463,11 @@ func (pr *Proxy) PassThroughToServer(conn *ConnWrapper, stack *Stack) *gerr.Gate
 		pr.Logger.Error().Err(err).Msg("Error running hook")
 		span.RecordError(err)
 	}
-	span.AddEvent("Ran the OnTrafficToServer hooks")
+	if result != nil {
+		_ = pr.PluginRegistry.ActRegistry.RunAll(result)
+	}
 
+	span.AddEvent("Ran the OnTrafficToServer hooks")
 	metrics.ProxyPassThroughsToServer.WithLabelValues(pr.GetGroupName(), pr.GetBlockName()).Inc()
 
 	return nil
@@ -558,6 +561,9 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate
 		pr.Logger.Error().Err(err).Msg("Error running hook")
 		span.RecordError(err)
 	}
+	if result != nil {
+		result = pr.PluginRegistry.ActRegistry.RunAll(result)
+	}
 	span.AddEvent("Ran the OnTrafficFromServer hooks")
 
 	// If the hook modified the response, use the modified response.
@@ -575,7 +581,7 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate
 	pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), pr.PluginTimeout)
 	defer cancel()
 
-	_, err = pr.PluginRegistry.Run(
+	result, err = pr.PluginRegistry.Run(
 		pluginTimeoutCtx,
 		trafficData(
 			conn.Conn(),
@@ -597,6 +603,10 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate
 		pr.Logger.Error().Err(err).Msg("Error running hook")
 		span.RecordError(err)
 	}
+	if result != nil {
+		_ = pr.PluginRegistry.ActRegistry.RunAll(result)
+	}
+	span.AddEvent("Ran the OnTrafficToClient hooks")
 
 	if errVerdict != nil {
 		span.RecordError(errVerdict)
diff --git a/network/server.go b/network/server.go
index 60af37ac..a6e98b88 100644
--- a/network/server.go
+++ b/network/server.go
@@ -98,7 +98,7 @@ func (s *Server) OnBoot() Action {
 	pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.PluginTimeout)
 	defer cancel()
 	// Run the OnBooting hooks.
-	_, err := s.PluginRegistry.Run(
+	result, err := s.PluginRegistry.Run(
 		pluginTimeoutCtx,
 		map[string]any{"status": fmt.Sprint(s.Status)},
 		v1.HookName_HOOK_NAME_ON_BOOTING)
@@ -106,6 +106,9 @@ func (s *Server) OnBoot() Action {
 		s.Logger.Error().Err(err).Msg("Failed to run OnBooting hook")
 		span.RecordError(err)
 	}
+	if result != nil {
+		_ = s.PluginRegistry.ActRegistry.RunAll(result)
+	}
 	span.AddEvent("Ran the OnBooting hooks")
 
 	// Set the server status to running.
@@ -117,7 +120,7 @@ func (s *Server) OnBoot() Action {
 	pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), s.PluginTimeout)
 	defer cancel()
 
-	_, err = s.PluginRegistry.Run(
+	result, err = s.PluginRegistry.Run(
 		pluginTimeoutCtx,
 		map[string]any{"status": fmt.Sprint(s.Status)},
 		v1.HookName_HOOK_NAME_ON_BOOTED)
@@ -125,6 +128,9 @@ func (s *Server) OnBoot() Action {
 		s.Logger.Error().Err(err).Msg("Failed to run OnBooted hook")
 		span.RecordError(err)
 	}
+	if result != nil {
+		_ = s.PluginRegistry.ActRegistry.RunAll(result)
+	}
 	span.AddEvent("Ran the OnBooted hooks")
 
 	s.Logger.Debug().Msg("GatewayD booted")
@@ -150,12 +156,15 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) {
 			"remote": RemoteAddr(conn.Conn()),
 		},
 	}
-	_, err := s.PluginRegistry.Run(
+	result, err := s.PluginRegistry.Run(
 		pluginTimeoutCtx, onOpeningData, v1.HookName_HOOK_NAME_ON_OPENING)
 	if err != nil {
 		s.Logger.Error().Err(err).Msg("Failed to run OnOpening hook")
 		span.RecordError(err)
 	}
+	if result != nil {
+		_ = s.PluginRegistry.ActRegistry.RunAll(result)
+	}
 	span.AddEvent("Ran the OnOpening hooks")
 
 	// Attempt to retrieve the next proxy.
@@ -195,12 +204,15 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) {
 			"remote": RemoteAddr(conn.Conn()),
 		},
 	}
-	_, err = s.PluginRegistry.Run(
+	result, err = s.PluginRegistry.Run(
 		pluginTimeoutCtx, onOpenedData, v1.HookName_HOOK_NAME_ON_OPENED)
 	if err != nil {
 		s.Logger.Error().Err(err).Msg("Failed to run OnOpened hook")
 		span.RecordError(err)
 	}
+	if result != nil {
+		_ = s.PluginRegistry.ActRegistry.RunAll(result)
+	}
 	span.AddEvent("Ran the OnOpened hooks")
 
 	metrics.ClientConnections.WithLabelValues(s.GroupName, proxy.GetBlockName()).Inc()
@@ -231,12 +243,15 @@ func (s *Server) OnClose(conn *ConnWrapper, err error) Action {
 	if err != nil {
 		data["error"] = err.Error()
 	}
-	_, gatewaydErr := s.PluginRegistry.Run(
+	result, gatewaydErr := s.PluginRegistry.Run(
 		pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_CLOSING)
 	if gatewaydErr != nil {
 		s.Logger.Error().Err(gatewaydErr).Msg("Failed to run OnClosing hook")
 		span.RecordError(gatewaydErr)
 	}
+	if result != nil {
+		_ = s.PluginRegistry.ActRegistry.RunAll(result)
+	}
 	span.AddEvent("Ran the OnClosing hooks")
 
 	// Shutdown the server if there are no more connections and the server is stopped.
@@ -291,12 +306,15 @@ func (s *Server) OnClose(conn *ConnWrapper, err error) Action {
 	if err != nil {
 		data["error"] = err.Error()
 	}
-	_, gatewaydErr = s.PluginRegistry.Run(
+	result, gatewaydErr = s.PluginRegistry.Run(
 		pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_CLOSED)
 	if gatewaydErr != nil {
 		s.Logger.Error().Err(gatewaydErr).Msg("Failed to run OnClosed hook")
 		span.RecordError(gatewaydErr)
 	}
+	if result != nil {
+		_ = s.PluginRegistry.ActRegistry.RunAll(result)
+	}
 	span.AddEvent("Ran the OnClosed hooks")
 
 	metrics.ClientConnections.WithLabelValues(s.GroupName, proxy.GetBlockName()).Dec()
@@ -320,12 +338,15 @@ func (s *Server) OnTraffic(conn *ConnWrapper, stopConnection chan struct{}) Acti
 			"remote": RemoteAddr(conn.Conn()),
 		},
 	}
-	_, err := s.PluginRegistry.Run(
+	result, err := s.PluginRegistry.Run(
 		pluginTimeoutCtx, onTrafficData, v1.HookName_HOOK_NAME_ON_TRAFFIC)
 	if err != nil {
 		s.Logger.Error().Err(err).Msg("Failed to run OnTraffic hook")
 		span.RecordError(err)
 	}
+	if result != nil {
+		_ = s.PluginRegistry.ActRegistry.RunAll(result)
+	}
 	span.AddEvent("Ran the OnTraffic hooks")
 
 	stack := NewStack()
@@ -391,7 +412,7 @@ func (s *Server) OnShutdown() {
 	pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.PluginTimeout)
 	defer cancel()
 	// Run the OnShutdown hooks.
-	_, err := s.PluginRegistry.Run(
+	result, err := s.PluginRegistry.Run(
 		pluginTimeoutCtx,
 		map[string]any{"connections": s.CountConnections()},
 		v1.HookName_HOOK_NAME_ON_SHUTDOWN)
@@ -399,6 +420,9 @@ func (s *Server) OnShutdown() {
 		s.Logger.Error().Err(err).Msg("Failed to run OnShutdown hook")
 		span.RecordError(err)
 	}
+	if result != nil {
+		_ = s.PluginRegistry.ActRegistry.RunAll(result)
+	}
 	span.AddEvent("Ran the OnShutdown hooks")
 
 	// Shutdown proxies.
@@ -424,7 +448,7 @@ func (s *Server) OnTick() (time.Duration, Action) {
 	pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), s.PluginTimeout)
 	defer cancel()
 	// Run the OnTick hooks.
-	_, err := s.PluginRegistry.Run(
+	result, err := s.PluginRegistry.Run(
 		pluginTimeoutCtx,
 		map[string]any{"connections": s.CountConnections()},
 		v1.HookName_HOOK_NAME_ON_TICK)
@@ -432,6 +456,9 @@ func (s *Server) OnTick() (time.Duration, Action) {
 		s.Logger.Error().Err(err).Msg("Failed to run OnTick hook")
 		span.RecordError(err)
 	}
+	if result != nil {
+		_ = s.PluginRegistry.ActRegistry.RunAll(result)
+	}
 	span.AddEvent("Ran the OnTick hooks")
 
 	// TODO: Investigate whether to move schedulers here or not
@@ -474,6 +501,8 @@ func (s *Server) Run() *gerr.GatewayDError {
 	span.AddEvent("Ran the OnRun hooks")
 
 	if result != nil {
+		_ = s.PluginRegistry.ActRegistry.RunAll(result)
+
 		if errMsg, ok := result["error"].(string); ok && errMsg != "" {
 			s.Logger.Error().Str("error", errMsg).Msg("Error in hook")
 		}

From dc5759a4893a5182c0b8bd474da07249a0a80e7c Mon Sep 17 00:00:00 2001
From: Mostafa Moradian <mostafa@gatewayd.io>
Date: Mon, 7 Oct 2024 21:51:55 +0200
Subject: [PATCH 3/3] Fix linter errors

---
 act/registry.go | 2 +-
 cmd/run.go      | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/act/registry.go b/act/registry.go
index e3a0b9c6..04b093a9 100644
--- a/act/registry.go
+++ b/act/registry.go
@@ -407,7 +407,7 @@ func runActionWithTimeout(
 
 // RunAll run all the actions in the outputs and returns the end result.
 func (r *Registry) RunAll(result map[string]any) map[string]any {
-	if result == nil || len(result) == 0 {
+	if len(result) == 0 {
 		return result
 	}
 
diff --git a/cmd/run.go b/cmd/run.go
index bf992467..e8007f01 100644
--- a/cmd/run.go
+++ b/cmd/run.go
@@ -119,7 +119,7 @@ func StopGracefully(
 			span.RecordError(err)
 		}
 		if result != nil {
-			_ = pluginRegistry.ActRegistry.RunAll(result)
+			_ = pluginRegistry.ActRegistry.RunAll(result) //nolint:contextcheck
 		}
 	}