Skip to content

Commit

Permalink
more ptr utils + errkit improvements (#573)
Browse files Browse the repository at this point in the history
* more ptr utils

* errkit: stick to slog standards + format improvements
  • Loading branch information
tarunKoyalwar authored Dec 2, 2024
1 parent cebafa1 commit 7ba513a
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 54 deletions.
2 changes: 1 addition & 1 deletion env/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func ExpandWithEnv(variables ...*string) {

// EnvType is a type that can be used as a type for environment variables.
type EnvType interface {
~string | ~int | ~bool | ~float64 | time.Duration
~string | ~int | ~bool | ~float64 | time.Duration | ~rune
}

// GetEnvOrDefault returns the value of the environment variable or the default value if the variable is not set.
Expand Down
199 changes: 153 additions & 46 deletions errkit/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ import (
"errors"
"fmt"
"log/slog"
"runtime"
"strconv"
"strings"
"time"

"github.com/projectdiscovery/utils/env"
"golang.org/x/exp/maps"
)

const (
Expand All @@ -24,38 +26,78 @@ const (
DelimMultiLine = "\n - "
// MultiLinePrefix is the prefix used for multiline errors
MultiLineErrPrefix = "the following errors occurred:"
// Space is the identifier used for indentation
Space = " "
)

var (
// MaxErrorDepth is the maximum depth of errors to be unwrapped or maintained
// all errors beyond this depth will be ignored
MaxErrorDepth = env.GetEnvOrDefault("MAX_ERROR_DEPTH", 3)
// ErrorSeperator is the seperator used to join errors
ErrorSeperator = env.GetEnvOrDefault("ERROR_SEPERATOR", "; ")
// FieldSeperator
ErrFieldSeparator = env.GetEnvOrDefault("ERR_FIELD_SEPERATOR", Space)
// ErrChainSeperator
ErrChainSeperator = env.GetEnvOrDefault("ERR_CHAIN_SEPERATOR", DelimSemiColon)
// EnableTimestamp controls whether error timestamps are included
EnableTimestamp = env.GetEnvOrDefault("ENABLE_ERR_TIMESTAMP", false)
// EnableTrace controls whether error stack traces are included
EnableTrace = env.GetEnvOrDefault("ENABLE_ERR_TRACE", false)
)

// ErrorX is a custom error type that can handle all known types of errors
// wrapping and joining strategies including custom ones and it supports error class
// which can be shown to client/users in more meaningful way
type ErrorX struct {
kind ErrKind
attrs map[string]slog.Attr
errs []error
uniqErrs map[string]struct{}
kind ErrKind
record *slog.Record
source *slog.Source
errs []error
}

func (e *ErrorX) init(skipStack ...int) {
// initializes if necessary
if e.record == nil {
e.record = &slog.Record{}
if EnableTimestamp {
e.record.Time = time.Now()
}
if EnableTrace {
// get fn name
var pcs [1]uintptr
// skip [runtime.Callers, ErrorX.init, parent]
skip := 3
if len(skipStack) > 0 {
skip = skipStack[0]
}
runtime.Callers(skip, pcs[:])
pc := pcs[0]
fs := runtime.CallersFrames([]uintptr{pc})
f, _ := fs.Next()
e.source = &slog.Source{
Function: f.Function,
File: f.File,
Line: f.Line,
}
}
}
}

// append is internal method to append given
// error to error slice , it removes duplicates
// earlier it used map which causes more allocations that necessary
func (e *ErrorX) append(errs ...error) {
if e.uniqErrs == nil {
e.uniqErrs = make(map[string]struct{})
}
for _, err := range errs {
if _, ok := e.uniqErrs[err.Error()]; ok {
continue
for _, nerr := range errs {
found := false
new:
for _, oerr := range e.errs {
if oerr.Error() == nerr.Error() {
found = true
break new
}
}
if !found {
e.errs = append(e.errs, nerr)
}
e.uniqErrs[err.Error()] = struct{}{}
e.errs = append(e.errs, err)
}
}

Expand All @@ -71,8 +113,11 @@ func (e ErrorX) MarshalJSON() ([]byte, error) {
"kind": e.kind.String(),
"errors": tmp,
}
if len(e.attrs) > 0 {
m["attrs"] = slog.GroupValue(maps.Values(e.attrs)...)
if e.record != nil && e.record.NumAttrs() > 0 {
m["attrs"] = slog.GroupValue(e.Attrs()...)
}
if e.source != nil {
m["source"] = e.source
}
return json.Marshal(m)
}
Expand All @@ -84,10 +129,15 @@ func (e *ErrorX) Errors() []error {

// Attrs returns all attributes associated with the error
func (e *ErrorX) Attrs() []slog.Attr {
if e.attrs == nil {
if e.record == nil || e.record.NumAttrs() == 0 {
return nil
}
return maps.Values(e.attrs)
values := []slog.Attr{}
e.record.Attrs(func(a slog.Attr) bool {
values = append(values, a)
return true
})
return values
}

// Build returns the object as error interface
Expand All @@ -103,6 +153,7 @@ func (e *ErrorX) Unwrap() []error {
// Is checks if current error contains given error
func (e *ErrorX) Is(err error) bool {
x := &ErrorX{}
x.init()
parseError(x, err)
// even one submatch is enough
for _, orig := range e.errs {
Expand All @@ -118,20 +169,26 @@ func (e *ErrorX) Is(err error) bool {
// Error returns the error string
func (e *ErrorX) Error() string {
var sb strings.Builder
if e.kind != nil && e.kind.String() != "" {
sb.WriteString("errKind=")
sb.WriteString(e.kind.String())
sb.WriteString(" ")
}
if len(e.attrs) > 0 {
sb.WriteString(slog.GroupValue(maps.Values(e.attrs)...).String())
sb.WriteString(" ")
sb.WriteString("cause=")
sb.WriteString(strconv.Quote(e.errs[0].Error()))
if e.record != nil && e.record.NumAttrs() > 0 {
values := []string{}
e.record.Attrs(func(a slog.Attr) bool {
values = append(values, a.String())
return true
})
sb.WriteString(Space)
sb.WriteString(strings.Join(values, " "))
}
for _, err := range e.errs {
sb.WriteString(err.Error())
sb.WriteString(ErrorSeperator)
if len(e.errs) > 1 {
chain := []string{}
for _, value := range e.errs[1:] {
chain = append(chain, strings.TrimSpace(value.Error()))
}
sb.WriteString(Space)
sb.WriteString("chain=" + strconv.Quote(strings.Join(chain, ErrChainSeperator)))
}
return strings.TrimSuffix(sb.String(), ErrorSeperator)
return sb.String()
}

// Cause return the original error that caused this without any wrapping
Expand All @@ -158,28 +215,65 @@ func FromError(err error) *ErrorX {
return nil
}
nucleiErr := &ErrorX{}
nucleiErr.init()
parseError(nucleiErr, err)
return nucleiErr
}

// New creates a new error with the given message
func New(format string, args ...interface{}) *ErrorX {
// it follows slog pattern of adding and expects in the same way
//
// Example:
//
// this is correct (√)
// errkit.New("this is a nuclei error","address",host)
//
// this is not readable/recommended (x)
// errkit.New("this is a nuclei error",slog.String("address",host))
//
// this is wrong (x)
// errkit.New("this is a nuclei error %s",host)
func New(msg string, args ...interface{}) *ErrorX {
e := &ErrorX{}
e.append(fmt.Errorf(format, args...))
e.init()
if len(args) > 0 {
e.record.Add(args...)
}
e.append(errors.New(msg))
return e
}

// Msgf adds a message to the error
// it follows slog pattern of adding and expects in the same way
//
// Example:
//
// this is correct (√)
// myError.Msgf("dial error","network","tcp")
//
// this is not readable/recommended (x)
// myError.Msgf(slog.String("address",host))
//
// this is wrong (x)
// myError.Msgf("this is a nuclei error %s",host)
func (e *ErrorX) Msgf(format string, args ...interface{}) {
if e == nil {
return
}
if len(args) == 0 {
e.append(errors.New(format))
}
e.append(fmt.Errorf(format, args...))
}

// SetClass sets the class of the error
// if underlying error class was already set, then it is given preference
// when generating final error msg
//
// Example:
//
// this is correct (√)
// myError.SetKind(errkit.ErrKindNetworkPermanent)
func (e *ErrorX) SetKind(kind ErrKind) *ErrorX {
if e.kind == nil {
e.kind = kind
Expand All @@ -189,23 +283,30 @@ func (e *ErrorX) SetKind(kind ErrKind) *ErrorX {
return e
}

// ResetKind resets the error class of the error
//
// Example:
//
// myError.ResetKind()
func (e *ErrorX) ResetKind() *ErrorX {
e.kind = nil
return e
}

// Deprecated: use Attrs instead
//
// SetAttr sets additional attributes to a given error
// it only adds unique attributes and ignores duplicates
// Note: only key is checked for uniqueness
//
// Example:
//
// this is correct (√)
// myError.SetAttr(slog.String("address",host))
func (e *ErrorX) SetAttr(s ...slog.Attr) *ErrorX {
e.init()
for _, attr := range s {
if e.attrs == nil {
e.attrs = make(map[string]slog.Attr)
}
// check if this exists
if _, ok := e.attrs[attr.Key]; !ok && len(e.attrs) < MaxErrorDepth {
e.attrs[attr.Key] = attr
}
e.record.Add(attr)
}
return e
}
Expand All @@ -217,6 +318,7 @@ func parseError(to *ErrorX, err error) {
}
if to == nil {
to = &ErrorX{}
to.init(4)
}
if len(to.errs) >= MaxErrorDepth {
return
Expand All @@ -225,6 +327,17 @@ func parseError(to *ErrorX, err error) {
switch v := err.(type) {
case *ErrorX:
to.append(v.errs...)
if to.record == nil {
to.record = v.record
} else {
v.record.Attrs(func(a slog.Attr) bool {
to.record.Add(a)
return true
})
}
if to.source == nil {
to.source = v.source
}
to.kind = CombineErrKinds(to.kind, v.kind)
case JoinedError:
foundAny := false
Expand Down Expand Up @@ -283,9 +396,3 @@ func parseError(to *ErrorX, err error) {
}
}
}

// WrappedError is implemented by errors that are wrapped
type WrappedError interface {
// Unwrap returns the underlying error
Unwrap() error
}
12 changes: 12 additions & 0 deletions errkit/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,15 @@ func TestMarshalError(t *testing.T) {
require.NoError(t, err, "expected to be able to marshal the error")
require.Equal(t, `{"errors":["port closed or filtered","this is a wrapped error"],"kind":"network-permanent-error"}`, string(marshalled))
}

func TestErrorString(t *testing.T) {
var x error = New("i/o timeout")
x = With(x, "ip", "10.0.0.1", "port", 80)
x = WithMessage(x, "tcp dial error")
x = Append(x, errors.New("some other error"))

require.Equal(t,
`cause="i/o timeout" ip=10.0.0.1 port=80 chain="tcp dial error; some other error"`,
x.Error(),
)
}
14 changes: 8 additions & 6 deletions errkit/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,19 +193,21 @@ func IsNetworkPermanentErr(err error) bool {
return isNetworkPermanentErr(x)
}

// WithAttr wraps error with given attributes
// With adds extra attributes to the error
//
// err = errkit.WithAttr(err,slog.Any("resource",domain))
func WithAttr(err error, attrs ...slog.Attr) error {
// err = errkit.With(err,"resource",domain)
func With(err error, args ...any) error {
if err == nil {
return nil
}
if len(attrs) == 0 {
if len(args) == 0 {
return err
}
x := &ErrorX{}
x.init()
parseError(x, err)
return x.SetAttr(attrs...)
x.record.Add(args...)
return x
}

// GetAttr returns all attributes of given error if it has any
Expand Down Expand Up @@ -271,7 +273,7 @@ func GetAttrValue(err error, key string) slog.Value {
}
x := &ErrorX{}
parseError(x, err)
for _, attr := range x.attrs {
for _, attr := range x.Attrs() {
if attr.Key == key {
return attr.Value
}
Expand Down
6 changes: 6 additions & 0 deletions errkit/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,9 @@ type ComparableError interface {
// Is checks if current error contains given error
Is(err error) bool
}

// WrappedError is implemented by errors that are wrapped
type WrappedError interface {
// Unwrap returns the underlying error
Unwrap() error
}
Loading

0 comments on commit 7ba513a

Please sign in to comment.