Skip to content

Commit

Permalink
implement Functional Options Pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
gi8lino committed Oct 17, 2024
1 parent 4861957 commit 2b59367
Show file tree
Hide file tree
Showing 8 changed files with 429 additions and 320 deletions.
11 changes: 6 additions & 5 deletions cmd/portpatrol/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,23 @@ func run(ctx context.Context, args []string, output io.Writer) error {
if err != nil {
return fmt.Errorf("configuration error: %w", err)
}
f.Version = version

logger := logger.SetupLogger(f, output)

checkers, err := flags.ParseChecker(f.Targets, f.DefaultCheckInterval)
checkers, err := flags.ParseTargets(f.Targets, f.DefaultCheckInterval)
if err != nil {
return fmt.Errorf("configuration error: %w", err)
return fmt.Errorf("parse error: %w", err)
}

logger := logger.SetupLogger(f, output)

eg, ctx := errgroup.WithContext(ctx)

for _, chk := range checkers {
checker := chk // Capture loop variable
eg.Go(func() error {
err := wait.WaitUntilReady(ctx, checker.Interval, checker.Checker, logger)
if err != nil {
return fmt.Errorf("checker '%s' failed: %w", checker.Checker.String(), err)
return fmt.Errorf("checker '%s' failed: %w", checker.Checker.Name(), err)
}
return nil
})
Expand Down
50 changes: 30 additions & 20 deletions internal/checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"context"
"fmt"
"strings"
"time"
)

// CheckType represents the type of check to perform.
type CheckType int

const (
Expand All @@ -14,13 +16,24 @@ const (
ICMP // ICMP represents a check using the ICMP protocol (ping).
)

const defaultCheckInterval time.Duration = 1 * time.Second

// String returns the string representation of the CheckType.
func (c CheckType) String() string {
return [...]string{"TCP", "HTTP", "ICMP"}[c]
}

type CheckerConfig interface {
// Marker interface; no methods required
// Option defines a functional option for configuring a Checker.
type Option interface {
apply(Checker)
}

// OptionFunc is a function that applies an Option to a Checker.
type OptionFunc func(Checker)

// apply calls the OptionFunc with the given Checker.
func (f OptionFunc) apply(c Checker) {
f(c)
}

// Checker defines an interface for performing various types of checks, such as TCP, HTTP, or ICMP.
Expand All @@ -29,30 +42,27 @@ type Checker interface {
// Check performs a check and returns an error if the check fails.
Check(ctx context.Context) error

// String returns the name of the checker.
String() string
// Name returns the name of the checker.
Name() string
}

func NewChecker(checkType CheckType, name, address string, config CheckerConfig) (Checker, error) {
// NewChecker creates a new Checker based on the specified CheckType, name, address, and options.
func NewChecker(checkType CheckType, name, address string, opts ...Option) (Checker, error) {
switch checkType {
case HTTP:
httpConfig, ok := config.(HTTPCheckerConfig)
if !ok {
return nil, fmt.Errorf("invalid config for HTTP checker")
}
return NewHTTPChecker(name, address, httpConfig)
return newHTTPChecker(name, address, opts...)
case TCP:
tcpConfig, ok := config.(TCPCheckerConfig)
if !ok {
return nil, fmt.Errorf("invalid config for TCP checker")
}
return NewTCPChecker(name, address, tcpConfig)
// The "tcp://" prefix is used to identify the check type and is not needed for further processing,
// so it must be removed before passing the address to other functions.
address = strings.TrimPrefix(address, "tcp://")

return newTCPChecker(name, address, opts...)
case ICMP:
icmpConfig, ok := config.(ICMPCheckerConfig)
if !ok {
return nil, fmt.Errorf("invalid config for ICMP checker")
}
return NewICMPChecker(name, address, icmpConfig)
// The "icmp://" prefix is used to identify the check type and is not needed for further processing,
// so it must be removed before passing the address to other functions.
address = strings.TrimPrefix(address, "icmp://")

return newICMPChecker(name, address, opts...)
default:
return nil, fmt.Errorf("unsupported check type: %d", checkType)
}
Expand Down
138 changes: 81 additions & 57 deletions internal/checker/http_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,85 +9,58 @@ import (
)

const (
defaultHTTPDialTimeout time.Duration = 1 * time.Second
defaultHTTPMethod = http.MethodGet
defaultHTTPTimeout time.Duration = 1 * time.Second
defaultHTTPMethod string = http.MethodGet
defaultHTTPSkipTLSVerify bool = false
)

var defaultHTTPExpectedStatusCodes = []int{200} // Slices cannot be constants
var defaultHTTPExpectedStatusCodes = []int{200}

// HTTPChecker implements the Checker interface for HTTP checks.
type HTTPChecker struct {
name string // The name of the checker.
address string // The address of the target.
expectedStatusCodes []int // The expected status codes.
method string // The HTTP method to use.
headers map[string]string // The HTTP headers to include in the request.
allowDuplicateHeaders bool // Whether to allow duplicate headers.
skipTLSVerify bool // Whether to skip TLS verification.
dialTimeout time.Duration // The timeout for dialing the target.

client *http.Client // The HTTP client to use for the request.
name string
address string
method string
headers map[string]string
expectedStatusCodes []int
skipTLSVerify bool
timeout time.Duration
client *http.Client
}

type HTTPCheckerConfig struct {
Interval time.Duration
Method string
Headers map[string]string
ExpectedStatusCodes []int
SkipTLSVerify bool
Timeout time.Duration
}

// String returns the name of the checker.
func (c *HTTPChecker) String() string {
// Name returns the name of the checker.
func (c *HTTPChecker) Name() string {
return c.name
}

// NewHTTPChecker creates a new HTTPChecker with default values and applies any provided options.
func NewHTTPChecker(name, address string, cfg HTTPCheckerConfig) (Checker, error) {
// Set defaults if necessary
method := cfg.Method
if method == "" {
method = defaultHTTPMethod
}

expectedStatusCodes := cfg.ExpectedStatusCodes
if len(expectedStatusCodes) == 0 {
expectedStatusCodes = defaultHTTPExpectedStatusCodes
}

headers := cfg.Headers
if headers == nil {
headers = make(map[string]string)
// newHTTPChecker creates a new HTTPChecker with functional options.
func newHTTPChecker(name, address string, opts ...Option) (*HTTPChecker, error) {
checker := &HTTPChecker{
name: name,
address: address,
method: defaultHTTPMethod,
headers: make(map[string]string),
expectedStatusCodes: defaultHTTPExpectedStatusCodes,
skipTLSVerify: defaultHTTPSkipTLSVerify,
timeout: defaultHTTPTimeout,
}

timeout := cfg.Timeout
if timeout == 0 {
timeout = defaultHTTPDialTimeout
// Apply options
for _, opt := range opts {
opt.apply(checker)
}

// Initialize the HTTP client
client := &http.Client{
Timeout: timeout,
checker.client = &http.Client{
Timeout: checker.timeout,
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: cfg.SkipTLSVerify,
InsecureSkipVerify: checker.skipTLSVerify,
},
},
}

checker := &HTTPChecker{
name: name,
address: address,
method: method,
expectedStatusCodes: expectedStatusCodes,
headers: headers,
skipTLSVerify: cfg.SkipTLSVerify,
dialTimeout: timeout,
client: client,
}

return checker, nil
}

Expand Down Expand Up @@ -120,3 +93,54 @@ func (c *HTTPChecker) Check(ctx context.Context) error {

return fmt.Errorf("unexpected status code: got %d, expected one of %v", resp.StatusCode, c.expectedStatusCodes)
}

// WithHTTPMethod sets the HTTP method for the HTTPChecker.
func WithHTTPMethod(method string) Option {
return OptionFunc(func(c Checker) {
if httpChecker, ok := c.(*HTTPChecker); ok {
httpChecker.method = method
}
})
}

// WithHTTPHeaders sets the HTTP headers for the HTTPChecker.
func WithHTTPHeaders(headers map[string]string) Option {
return OptionFunc(func(c Checker) {
if httpChecker, ok := c.(*HTTPChecker); ok {
for key, value := range headers {
httpChecker.headers[key] = value
}
}
})
}

// WithExpectedStatusCodes sets the expected status codes for the HTTPChecker.
func WithExpectedStatusCodes(codes []int) Option {
return OptionFunc(func(c Checker) {
if httpChecker, ok := c.(*HTTPChecker); ok {
if len(codes) > 0 {
httpChecker.expectedStatusCodes = codes
}
}
})
}

// WithHTTPSkipTLSVerify sets the TLS verification flag for the HTTPChecker.
func WithHTTPSkipTLSVerify(skip bool) Option {
return OptionFunc(func(c Checker) {
if httpChecker, ok := c.(*HTTPChecker); ok {
httpChecker.skipTLSVerify = skip
}
})
}

// WithHTTPTimeout sets the timeout for the HTTPChecker.
func WithHTTPTimeout(timeout time.Duration) Option {
return OptionFunc(func(c Checker) {
if httpChecker, ok := c.(*HTTPChecker); ok {
if timeout > 0 {
httpChecker.timeout = timeout
}
}
})
}
76 changes: 41 additions & 35 deletions internal/checker/icmp_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"net"
"os"
"strings"
"sync/atomic"
"time"
)
Expand All @@ -15,55 +14,40 @@ const (
defaultICMPWriteTimeout time.Duration = 1 * time.Second
)

type ICMPCheckerConfig struct {
Interval time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
}

// ICMPChecker implements the Checker interface for ICMP checks.
type ICMPChecker struct {
name string // The name of the checker.
address string // The address of the target.
readTimeout time.Duration // The timeout for reading the ICMP reply.
writeTimeout time.Duration // The timeout for writing the ICMP request.

protocol Protocol // The protocol (ICMPv4 or ICMPv6) used for the check.
name string
address string
readTimeout time.Duration
writeTimeout time.Duration
protocol Protocol
}

// Name returns the name of the checker.
func (c *ICMPChecker) String() string {
func (c *ICMPChecker) Name() string {
return c.name
}

// NewICMPChecker initializes a new ICMPChecker with the given parameters.
func NewICMPChecker(name, address string, cfg ICMPCheckerConfig) (Checker, error) {
// The "icmp://" prefix is used to identify the check type and is not needed for further processing,
// so it must be removed before passing the address to other functions.
address = strings.TrimPrefix(address, "icmp://")

readTimeout := cfg.ReadTimeout
if readTimeout == 0 {
readTimeout = defaultICMPReadTimeout
// newICMPChecker initializes a new ICMPChecker with functional options.
func newICMPChecker(name, address string, opts ...Option) (*ICMPChecker, error) {
checker := &ICMPChecker{
name: name,
address: address,
readTimeout: defaultICMPReadTimeout,
writeTimeout: defaultICMPWriteTimeout,
}

writeTimeout := cfg.WriteTimeout
if writeTimeout == 0 {
writeTimeout = defaultICMPWriteTimeout
// Apply options
for _, opt := range opts {
opt.apply(checker)
}

protocol, err := newProtocol(address)
// Initialize protocol based on address
protocol, err := newProtocol(checker.address)
if err != nil {
return nil, fmt.Errorf("failed to create ICMP protocol: %w", err)
}

checker := &ICMPChecker{
name: name,
address: address,
readTimeout: readTimeout,
writeTimeout: writeTimeout,
protocol: protocol,
}
checker.protocol = protocol

return checker, nil
}
Expand Down Expand Up @@ -121,3 +105,25 @@ func (c *ICMPChecker) Check(ctx context.Context) error {

return nil
}

// WithICMPReadTimeout sets the read timeout for the ICMPChecker.
func WithICMPReadTimeout(timeout time.Duration) Option {
return OptionFunc(func(c Checker) {
if icmpChecker, ok := c.(*ICMPChecker); ok {
if timeout > 0 {
icmpChecker.readTimeout = timeout
}
}
})
}

// WithICMPWriteTimeout sets the write timeout for the ICMPChecker.
func WithICMPWriteTimeout(timeout time.Duration) Option {
return OptionFunc(func(c Checker) {
if icmpChecker, ok := c.(*ICMPChecker); ok {
if timeout > 0 {
icmpChecker.writeTimeout = timeout
}
}
})
}
Loading

0 comments on commit 2b59367

Please sign in to comment.