Skip to content

Commit

Permalink
use context in runner (#305)
Browse files Browse the repository at this point in the history
* use context in runner

* fix copy-paste error
  • Loading branch information
arriven authored Mar 12, 2022
1 parent bca677d commit b9ac6b6
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 75 deletions.
9 changes: 6 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,11 @@ func main() {

go utils.CheckCountry([]string{"Ukraine"})

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

if prometheusOn {
go metrics.ExportPrometheusMetrics(context.Background(), prometheusPushGateways)
go metrics.ExportPrometheusMetrics(ctx, prometheusPushGateways)
}

r, err := runner.New(&runner.Config{
Expand All @@ -132,8 +135,8 @@ func main() {
signal.Notify(sigs, syscall.SIGTERM)
<-sigs
log.Println("Terminating")
r.Stop()
cancel()
}()

r.Run()
r.Run(ctx)
}
4 changes: 2 additions & 2 deletions src/core/packetgen/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ import (
)

// RandomPayload returns a byte slice to spoof ip packets with random payload in specified length
func RandomPayload(length int) []byte {
func RandomPayload(length int) string {
payload := make([]byte, length)
rand.Read(payload)
return payload
return string(payload)
}

// RandomIP returns a random ip to spoof packets
Expand Down
3 changes: 1 addition & 2 deletions src/jobs/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,13 @@ func fastHTTPJob(ctx context.Context, globalConfig GlobalConfig, args Args, debu

req := fasthttp.AcquireRequest()
defer fasthttp.ReleaseRequest(req)

log.Printf("Attacking %v", jobConfig.Request["path"])
for jobConfig.Next(ctx) {
var requestConfig http.RequestConfig
if err := utils.Decode(requestTpl.Execute(ctx), &requestConfig); err != nil {
log.Printf("Error executing request template: %v", err)
return nil, err
}
log.Printf("Sent single http request to %v", requestConfig.Path)
dataSize := http.InitRequest(requestConfig, req)

trafficMonitor.Add(uint64(dataSize))
Expand Down
3 changes: 1 addition & 2 deletions src/jobs/rawnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package jobs

import (
"context"
"encoding/json"
"fmt"
"log"
"net"
Expand All @@ -21,7 +20,7 @@ type rawNetJobConfig struct {
BasicJobConfig

Address string
Body json.RawMessage
Body string
}

func tcpJob(ctx context.Context, globalConfig GlobalConfig, args Args, debug bool) (data interface{}, err error) {
Expand Down
118 changes: 52 additions & 66 deletions src/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"context"
"log"
"strings"
"sync"
"time"

"github.com/google/uuid"
Expand Down Expand Up @@ -37,8 +36,6 @@ type Runner struct {
currentRawConfig []byte // currently applied config

debug bool

stop chan interface{}
}

// New runner according to the config
Expand All @@ -51,101 +48,90 @@ func New(cfg *Config, debug bool) (*Runner, error) {
configFormat: cfg.Format,

debug: debug,

stop: make(chan interface{}),
}, nil
}

// Run the runner and block until Stop() is called
func (r *Runner) Run() {
func (r *Runner) Run(ctx context.Context) {
clientID := uuid.New()
refreshTimer := time.NewTicker(r.refreshTimeout)
defer refreshTimer.Stop()

var (
stop bool
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
)

for !stop {
for {
if cfg, raw := config.Update(r.configPaths, r.currentRawConfig, r.backupConfig, r.configFormat); cfg != nil {
if cancel != nil {
cancel()
}

ctx, cancel = context.WithCancel(context.Background())

var jobInstancesCount int

for i := range cfg.Jobs {
if len(cfg.Jobs[i].Filter) != 0 && strings.TrimSpace(templates.ParseAndExecute(cfg.Jobs[i].Filter, clientID.ID())) != "true" {
log.Println("There is a filter defined for a job but this client doesn't pass it - skip the job")
continue
}
cancel = r.runJobs(ctx, cfg, clientID)

job := jobs.Get(cfg.Jobs[i].Type)
if job == nil {
log.Printf("Unknown job %q", cfg.Jobs[i].Type)
r.currentRawConfig = raw
}

continue
}
// Wait for refresh timer or stop signal
select {
case <-refreshTimer.C:
case <-ctx.Done():
if cancel != nil {
cancel()
}
return
}

if cfg.Jobs[i].Count < 1 {
cfg.Jobs[i].Count = 1
}
if r.config.Global.ScaleFactor > 0 {
cfg.Jobs[i].Count = cfg.Jobs[i].Count * r.config.Global.ScaleFactor
}
cfgMap := make(map[string]interface{})
err := utils.Decode(cfg.Jobs[i], &cfgMap)
if err != nil {
log.Fatal("failed to encode cfg map")
}
ctx := context.WithValue(ctx, templates.ContextKey("config"), cfgMap)
dumpMetrics(clientID.String(), r.debug)
}
}

for j := 0; j < cfg.Jobs[i].Count; j++ {
wg.Add(1)
func (r *Runner) runJobs(ctx context.Context, cfg *config.Config, clientID uuid.UUID) (cancel context.CancelFunc) {
ctx, cancel = context.WithCancel(ctx)

go func(i int) {
_, err := job(ctx, r.config.Global, cfg.Jobs[i].Args, r.debug)
if err != nil {
log.Println("error running job:", err)
}
wg.Done()
}(i)
var jobInstancesCount int

jobInstancesCount++
}
}
for i := range cfg.Jobs {
if len(cfg.Jobs[i].Filter) != 0 && strings.TrimSpace(templates.ParseAndExecute(cfg.Jobs[i].Filter, clientID.ID())) != "true" {
log.Println("There is a filter defined for a job but this client doesn't pass it - skip the job")
continue
}

r.currentRawConfig = raw
job := jobs.Get(cfg.Jobs[i].Type)
if job == nil {
log.Printf("Unknown job %q", cfg.Jobs[i].Type)

log.Printf("%d job instances (re)started", jobInstancesCount)
continue
}

// Wait for refresh timer or stop signal
select {
case <-refreshTimer.C:
case <-r.stop:
refreshTimer.Stop()

stop = true
if cfg.Jobs[i].Count < 1 {
cfg.Jobs[i].Count = 1
}
if r.config.Global.ScaleFactor > 0 {
cfg.Jobs[i].Count = cfg.Jobs[i].Count * r.config.Global.ScaleFactor
}
cfgMap := make(map[string]interface{})
err := utils.Decode(cfg.Jobs[i], &cfgMap)
if err != nil {
log.Fatal("failed to encode cfg map")
}
ctx := context.WithValue(ctx, templates.ContextKey("config"), cfgMap)

dumpMetrics(clientID.String(), r.debug)
}
for j := 0; j < cfg.Jobs[i].Count; j++ {
go func(i int) {
_, err := job(ctx, r.config.Global, cfg.Jobs[i].Args, r.debug)
if err != nil {
log.Println("error running job:", err)
}
}(i)

if cancel != nil {
cancel()
jobInstancesCount++
}
}

wg.Wait()
log.Printf("%d job instances (re)started", jobInstancesCount)
return cancel
}

// Stop runner asynchronously
func (r *Runner) Stop() { close(r.stop) }

func dumpMetrics(clientID string, debug bool) {
defer func() {
if err := recover(); err != nil {
Expand Down

0 comments on commit b9ac6b6

Please sign in to comment.