Skip to content

Commit

Permalink
Realized the option of specifying the local address of zgrab2
Browse files Browse the repository at this point in the history
  • Loading branch information
Baoxd123 committed Feb 3, 2025
1 parent 35d17d9 commit 6db27b7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
10 changes: 10 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type Config struct {
ConnectionsPerHost int `long:"connections-per-host" default:"1" description:"Number of times to connect to each host (results in more output)"`
ReadLimitPerHost int `long:"read-limit-per-host" default:"96" description:"Maximum total kilobytes to read for a single host (default 96kb)"`
Prometheus string `long:"prometheus" description:"Address to use for Prometheus server (e.g. localhost:8080). If empty, Prometheus is disabled."`
LocalAddrStr string `long:"local-addr" description:"Local source address for outgoing connections (e.g. 192.168.10.2:0, port is required even if it's 0)"`
CustomDNS string `long:"dns" description:"Address of a custom DNS server for lookups. Default port is 53."`
Multiple MultipleCommand `command:"multiple" description:"Multiple module actions"`
inputFile *os.File
Expand Down Expand Up @@ -100,6 +101,15 @@ func validateFrameworkConfiguration() {
}
runtime.GOMAXPROCS(config.GOMAXPROCS)

// Parse and validate the local address if specified
if config.LocalAddrStr != "" {
var err error
config.localAddr, err = net.ResolveTCPAddr("tcp", config.LocalAddrStr)
if err != nil {
log.Fatalf("could not resolve local address %s: %v", config.LocalAddrStr, err)
}
}

//validate/start prometheus
if config.Prometheus != "" {
go func() {
Expand Down
11 changes: 10 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,13 @@ func NewTimeoutConnection(ctx context.Context, conn net.Conn, timeout, readTimeo
func DialTimeoutConnectionEx(proto string, target string, dialTimeout, sessionTimeout, readTimeout, writeTimeout time.Duration, bytesReadLimit int) (net.Conn, error) {
var conn net.Conn
var err error
dialer := &net.Dialer{
Timeout: dialTimeout,
}
if config.localAddr != nil {
dialer.LocalAddr = config.localAddr
}
conn, err = dialer.Dial(proto, target)
if dialTimeout > 0 {
conn, err = net.DialTimeout(proto, target, dialTimeout)
} else {
Expand Down Expand Up @@ -300,7 +307,9 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
d.Dialer.KeepAlive = d.Timeout

// Copy over the source IP if set, or nil
d.Dialer.LocalAddr = config.localAddr
if config.localAddr != nil {
d.Dialer.LocalAddr = config.localAddr
}

dialContext, cancelDial := context.WithTimeout(ctx, d.Dialer.Timeout)
defer cancelDial()
Expand Down

0 comments on commit 6db27b7

Please sign in to comment.