Skip to content

Commit

Permalink
apply some reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
efectn committed Mar 2, 2024
1 parent a140863 commit 6b01572
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 73 deletions.
13 changes: 12 additions & 1 deletion client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,10 @@ func (c *Client) SetRootCertificate(path string) *Client {
c.logger.Panicf("client: %v", err)
}
defer func() {
_ = file.Close() //nolint:errcheck // It is fine to ignore the error here
if err := file.Close(); err != nil {
c.logger.Panicf("client: failed to close file: %v", err)
}

}()

pem, err := io.ReadAll(file)
Expand Down Expand Up @@ -580,6 +583,14 @@ func (c *Client) Reset() {
c.timeout = 0
c.userAgent = ""
c.referer = ""
c.proxyURL = ""
c.retryConfig = nil
c.debug = false

if c.cookieJar != nil {
c.cookieJar.Release()
c.cookieJar = nil
}

c.path.Reset()
c.cookies.Reset()
Expand Down
103 changes: 100 additions & 3 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/gofiber/fiber/v3/internal/tlstest"
"github.com/gofiber/utils/v2"
"github.com/stretchr/testify/require"
"github.com/valyala/bytebufferpool"
)

func startTestServerWithPort(t *testing.T, beforeStarting func(app *fiber.App)) (*fiber.App, string) {
Expand All @@ -30,12 +31,13 @@ func startTestServerWithPort(t *testing.T, beforeStarting func(app *fiber.App))

addrChan := make(chan string)
go func() {
require.NoError(t, app.Listen(":0", fiber.ListenConfig{
err := app.Listen(":0", fiber.ListenConfig{
DisableStartupMessage: true,
ListenerAddrFunc: func(addr net.Addr) {
addrChan <- addr.String()
},
}))
})
require.NoError(t, err)
}()

addr := <-addrChan
Expand All @@ -47,19 +49,28 @@ func Test_Client_Add_Hook(t *testing.T) {

t.Run("add request hooks", func(t *testing.T) {
t.Parallel()

buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)

client := NewClient().AddRequestHook(func(_ *Client, _ *Request) error {
buf.WriteString("hook1")
return nil
})

require.Len(t, client.RequestHook(), 1)

client.AddRequestHook(func(_ *Client, _ *Request) error {
buf.WriteString("hook2")
return nil
}, func(_ *Client, _ *Request) error {
buf.WriteString("hook3")
return nil
})

require.Len(t, client.RequestHook(), 3)

client.builtinRequestHooks[0](client, &Request{})
})

t.Run("add response hooks", func(t *testing.T) {
Expand All @@ -80,6 +91,34 @@ func Test_Client_Add_Hook(t *testing.T) {
})
}

func Test_Client_Add_Hook_CheckOrder(t *testing.T) {
t.Parallel()

buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)

client := NewClient().
AddRequestHook(func(_ *Client, _ *Request) error {
buf.WriteString("hook1")
return nil
}).
AddRequestHook(func(_ *Client, _ *Request) error {
buf.WriteString("hook2")
return nil
}).
AddRequestHook(func(_ *Client, _ *Request) error {
buf.WriteString("hook3")
return nil
})

for _, hook := range client.RequestHook() {
require.NoError(t, hook(client, &Request{}))
}

require.Equal(t, "hook1hook2hook3", buf.String())

}

func Test_Client_Marshal(t *testing.T) {
t.Parallel()

Expand All @@ -95,6 +134,18 @@ func Test_Client_Marshal(t *testing.T) {
require.Equal(t, []byte("hello"), val)
})

t.Run("set json marshal error", func(t *testing.T) {
t.Parallel()
client := NewClient().
SetJSONMarshal(func(_ any) ([]byte, error) {
return nil, errors.New("empty json")
})

val, err := client.JSONMarshal()(nil)
require.Nil(t, val)
require.Equal(t, errors.New("empty json"), err)
})

t.Run("set json unmarshal", func(t *testing.T) {
t.Parallel()
client := NewClient().
Expand All @@ -106,6 +157,17 @@ func Test_Client_Marshal(t *testing.T) {
require.Equal(t, errors.New("empty json"), err)
})

t.Run("set json unmarshal error", func(t *testing.T) {
t.Parallel()
client := NewClient().
SetJSONUnmarshal(func(_ []byte, _ any) error {
return errors.New("empty json")
})

err := client.JSONUnmarshal()(nil, nil)
require.Equal(t, errors.New("empty json"), err)
})

t.Run("set xml marshal", func(t *testing.T) {
t.Parallel()
client := NewClient().
Expand All @@ -118,6 +180,18 @@ func Test_Client_Marshal(t *testing.T) {
require.Equal(t, []byte("hello"), val)
})

t.Run("set xml marshal error", func(t *testing.T) {
t.Parallel()
client := NewClient().
SetXMLMarshal(func(_ any) ([]byte, error) {
return nil, errors.New("empty xml")
})

val, err := client.XMLMarshal()(nil)
require.Nil(t, val)
require.Equal(t, errors.New("empty xml"), err)
})

t.Run("set xml unmarshal", func(t *testing.T) {
t.Parallel()
client := NewClient().
Expand All @@ -128,6 +202,17 @@ func Test_Client_Marshal(t *testing.T) {
err := client.XMLUnmarshal()(nil, nil)
require.Equal(t, errors.New("empty xml"), err)
})

t.Run("set xml unmarshal error", func(t *testing.T) {
t.Parallel()
client := NewClient().
SetXMLUnmarshal(func(_ []byte, _ any) error {
return errors.New("empty xml")
})

err := client.XMLUnmarshal()(nil, nil)
require.Equal(t, errors.New("empty xml"), err)
})
}

func Test_Client_SetBaseURL(t *testing.T) {
Expand All @@ -151,7 +236,7 @@ func Test_Client_Invalid_URL(t *testing.T) {

_, err := NewClient().SetDial(dial).
R().
Get("http://example.com\r\n\r\nGET /\r\n\r\n")
Get("http//example")

require.ErrorIs(t, err, ErrURLFormat)
}
Expand Down Expand Up @@ -663,6 +748,18 @@ func Test_Client_Header(t *testing.T) {
require.Len(t, res, 1)
require.Equal(t, "foo", res[0])
})

t.Run("set header case insensitive", func(t *testing.T) {
t.Parallel()
req := NewClient()
req.SetHeader("foo", "bar").
AddHeader("FOO", "fiber")

res := req.Header("foo")
require.Len(t, res, 2)
require.Equal(t, "bar", res[0])
require.Equal(t, "fiber", res[1])
})
}

func Test_Client_Header_With_Server(t *testing.T) {
Expand Down
9 changes: 5 additions & 4 deletions client/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ func (c *core) execFunc() (*Response, error) {
var err error
go func() {
respv := fasthttp.AcquireResponse()
defer func() {
fasthttp.ReleaseRequest(reqv)
fasthttp.ReleaseResponse(respv)
}()

if cfg != nil {
err = retry.NewExponentialBackoff(*cfg).Retry(func() error {
if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) {
Expand All @@ -105,10 +110,6 @@ func (c *core) execFunc() (*Response, error) {
err = c.client.client.Do(reqv, respv)
}
}
defer func() {
fasthttp.ReleaseRequest(reqv)
fasthttp.ReleaseResponse(respv)
}()

if atomic.CompareAndSwapInt32(&done, 0, 1) {
if err != nil {
Expand Down
Loading

0 comments on commit 6b01572

Please sign in to comment.