Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(go): support http #2639

Merged
merged 2 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ type Configuration struct {
AppID string
ApiKey string

Hosts []string
Hosts []StatefulHost
DefaultHeader map[string]string
UserAgent string
Requester Requester
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,19 @@ const (
)

type Host struct {
scheme string
host string
timeout time.Duration
}

type RetryStrategy struct {
sync.RWMutex
hosts []*StatefulHost
hosts []StatefulHost
writeTimeout time.Duration
readTimeout time.Duration
}

func newRetryStrategy(hosts []*StatefulHost, readTimeout, writeTimeout time.Duration) *RetryStrategy {
func newRetryStrategy(hosts []StatefulHost, readTimeout, writeTimeout time.Duration) *RetryStrategy {
if readTimeout == 0 {
readTimeout = DefaultReadTimeout
}
Expand Down Expand Up @@ -74,7 +75,7 @@ func (s *RetryStrategy) GetTryableHosts(k call.Kind) []Host {

for _, h := range s.hosts {
if !h.isDown && h.accept(k) {
hosts = append(hosts, Host{h.host, time.Duration(h.retryCount+1) * baseTimeout})
hosts = append(hosts, Host{h.scheme, h.host, time.Duration(h.retryCount+1) * baseTimeout})
}
}

Expand All @@ -85,7 +86,7 @@ func (s *RetryStrategy) GetTryableHosts(k call.Kind) []Host {
for _, h := range s.hosts {
if h.accept(k) {
h.reset()
hosts = append(hosts, Host{h.host, time.Duration(h.retryCount+1) * baseTimeout})
hosts = append(hosts, Host{h.scheme, h.host, time.Duration(h.retryCount+1) * baseTimeout})
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@ const (
)

type StatefulHost struct {
scheme string
host string
isDown bool
retryCount int
lastUpdate time.Time
accept func(k call.Kind) bool
}

func NewStatefulHost(host string, accept func(k call.Kind) bool) *StatefulHost {
return &StatefulHost{
func NewStatefulHost(scheme string, host string, accept func(k call.Kind) bool) StatefulHost {
return StatefulHost{
scheme: scheme,
host: host,
isDown: false,
retryCount: 0,
Expand Down
10 changes: 5 additions & 5 deletions clients/algoliasearch-client-go/algolia/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type Transport struct {
}

func New(
hosts []*StatefulHost,
hosts []StatefulHost,
requester Requester,
readTimeout time.Duration,
writeTimeout time.Duration,
Expand Down Expand Up @@ -65,7 +65,7 @@ func (t *Transport) Request(ctx context.Context, req *http.Request, k call.Kind)
// cancelled` error may happen when the body is read.
perRequestCtx, cancel := context.WithTimeout(ctx, h.timeout)
req = req.WithContext(perRequestCtx)
res, err := t.request(req, h.host, h.timeout, t.connectTimeout)
res, err := t.request(req, h, h.timeout, t.connectTimeout)

code := 0
if res != nil {
Expand Down Expand Up @@ -116,9 +116,9 @@ func (t *Transport) Request(ctx context.Context, req *http.Request, k call.Kind)
return nil, nil, errs.ErrNoMoreHostToTry
}

func (t *Transport) request(req *http.Request, host string, timeout time.Duration, connectTimeout time.Duration) (*http.Response, error) {
req.URL.Scheme = "https"
req.URL.Host = host
func (t *Transport) request(req *http.Request, host Host, timeout time.Duration, connectTimeout time.Duration) (*http.Response, error) {
req.URL.Scheme = host.scheme
req.URL.Host = host.host

debug.Display(req)
res, err := t.requester.Request(req, timeout, connectTimeout)
Expand Down
4 changes: 2 additions & 2 deletions clients/algoliasearch-client-go/algolia/transport/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ import (
"math/rand"
)

func Shuffle(hosts []*StatefulHost) []*StatefulHost {
func Shuffle(hosts []StatefulHost) []StatefulHost {
if hosts == nil {
return nil
}
shuffled := make([]*StatefulHost, len(hosts))
shuffled := make([]StatefulHost, len(hosts))
for i, v := range rand.Perm(len(hosts)) {
shuffled[i] = hosts[v]
}
Expand Down
8 changes: 5 additions & 3 deletions templates/go/api.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,10 @@ func (c *APIClient) {{nickname}}WithContext(ctx context.Context, {{#hasParams}}r
{{/returnType}}
)

{{#vendorExtensions}}
requestPath := "{{{path}}}"{{#pathParams}}
requestPath = strings.ReplaceAll(requestPath, {{=<% %>=}}"{<%baseName%>}"<%={{ }}=%>, url.PathEscape(parameterToString(r.{{paramName}}))){{/pathParams}}
requestPath = strings.ReplaceAll(requestPath, {{=<% %>=}}"{<%baseName%>}"<%={{ }}=%>, {{#x-is-custom-request}}parameterToString(r.{{paramName}}){{/x-is-custom-request}}{{^x-is-custom-request}}url.PathEscape(parameterToString(r.{{paramName}})){{/x-is-custom-request}}){{/pathParams}}
{{/vendorExtensions}}

headers := make(map[string]string)
queryParams := url.Values{}
Expand Down Expand Up @@ -287,7 +289,7 @@ func (c *APIClient) {{nickname}}WithContext(ctx context.Context, {{#hasParams}}r
}

var v ErrorBase
err = c.decode(&v, resBody, res.Header.Get("Content-Type"))
err = c.decode(&v, resBody)
if err != nil {
newErr.Message = err.Error()
return {{#returnType}}returnValue, {{/returnType}}newErr
Expand All @@ -297,7 +299,7 @@ func (c *APIClient) {{nickname}}WithContext(ctx context.Context, {{#hasParams}}r
}

{{#returnType}}
err = c.decode(&returnValue, resBody, res.Header.Get("Content-Type"))
err = c.decode(&returnValue, resBody)
if err != nil {
return {{#returnType}}returnValue, {{/returnType}}reportError("cannot decode result: %w", err)
}
Expand Down
77 changes: 36 additions & 41 deletions templates/go/client.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ import (
"github.com/algolia/algoliasearch-client-go/v4/algolia/transport"
)

var jsonCheck = regexp.MustCompile(`(?i:(?:application|text)/(?:vnd\.[^;]+\+)?json)`)

// APIClient manages communication with the {{appName}} API v{{version}}
// In most cases there should be only one, shared, APIClient.
type APIClient struct {
Expand All @@ -51,7 +49,7 @@ return NewClientWithConfig(Configuration{

// NewClientWithConfig creates a new API client with the given configuration to fully customize the client behaviour.
func NewClientWithConfig(cfg Configuration) (*APIClient, error) {
var hosts []*transport.StatefulHost
var hosts []transport.StatefulHost

if cfg.AppID == "" {
return nil, errors.New("`appId` is missing.")
Expand All @@ -65,9 +63,7 @@ func NewClientWithConfig(cfg Configuration) (*APIClient, error) {
}{{/hasRegionalHost}}
hosts = getDefaultHosts({{#hasRegionalHost}}cfg.Region{{/hasRegionalHost}}{{#hostWithAppID}}cfg.AppID{{/hostWithAppID}})
} else {
for _, h := range cfg.Hosts {
hosts = append(hosts, transport.NewStatefulHost(h, call.IsReadWrite))
}
hosts = cfg.Hosts
}
if cfg.Requester == nil {
cfg.Requester = transport.NewDefaultRequester(&cfg.ConnectTimeout)
Expand All @@ -91,34 +87,34 @@ func NewClientWithConfig(cfg Configuration) (*APIClient, error) {
}

{{#hasRegionalHost}}
func getDefaultHosts(r Region) []*transport.StatefulHost {
func getDefaultHosts(r Region) []transport.StatefulHost {
{{#fallbackToAliasHost}}
if r == "" {
return []*transport.StatefulHost{transport.NewStatefulHost("{{{hostWithFallback}}}", call.IsReadWrite)}
return []transport.StatefulHost{transport.NewStatefulHost("https", "{{{hostWithFallback}}}", call.IsReadWrite)}
}{{/fallbackToAliasHost}}

return []*transport.StatefulHost{transport.NewStatefulHost(strings.ReplaceAll("{{{regionalHost}}}", "{region}", string(r)), call.IsReadWrite)}
return []transport.StatefulHost{transport.NewStatefulHost("https", strings.ReplaceAll("{{{regionalHost}}}", "{region}", string(r)), call.IsReadWrite)}
}
{{/hasRegionalHost}}
{{#hostWithAppID}}
func getDefaultHosts(appID string) []*transport.StatefulHost {
hosts := []*transport.StatefulHost{
transport.NewStatefulHost(appID + "-dsn.algolia.net", call.IsRead),
transport.NewStatefulHost(appID + ".algolia.net", call.IsWrite),
func getDefaultHosts(appID string) []transport.StatefulHost {
hosts := []transport.StatefulHost{
transport.NewStatefulHost("https", appID + "-dsn.algolia.net", call.IsRead),
transport.NewStatefulHost("https", appID + ".algolia.net", call.IsWrite),
}
hosts = append(hosts, transport.Shuffle(
[]*transport.StatefulHost{
transport.NewStatefulHost(fmt.Sprintf(appID + "-1.algolianet.com"), call.IsReadWrite),
transport.NewStatefulHost(fmt.Sprintf(appID + "-2.algolianet.com"), call.IsReadWrite),
transport.NewStatefulHost(fmt.Sprintf(appID + "-3.algolianet.com"), call.IsReadWrite),
[]transport.StatefulHost{
transport.NewStatefulHost("https", fmt.Sprintf("%s-1.algolianet.com", appID), call.IsReadWrite),
transport.NewStatefulHost("https", fmt.Sprintf("%s-2.algolianet.com", appID), call.IsReadWrite),
transport.NewStatefulHost("https", fmt.Sprintf("%s-3.algolianet.com", appID), call.IsReadWrite),
},
)...)
return hosts
}
{{/hostWithAppID}}
{{#uniqueHost}}
func getDefaultHosts() []*transport.StatefulHost {
return []*transport.StatefulHost{transport.NewStatefulHost("{{{.}}}", call.IsReadWrite)}
func getDefaultHosts() []transport.StatefulHost {
return []transport.StatefulHost{transport.NewStatefulHost("https", "{{{.}}}", call.IsReadWrite)}
}
{{/uniqueHost}}

Expand Down Expand Up @@ -171,9 +167,7 @@ func (c *APIClient) prepareRequest(
headerParams map[string]string,
queryParams url.Values) (req *http.Request, err error) {

contentType := "application/json"

body, err := setBody(postBody, contentType, c.cfg.Compression)
body, err := setBody(postBody, c.cfg.Compression)
if err != nil {
return nil, fmt.Errorf("failed to set the body: %w", err)
}
Expand Down Expand Up @@ -214,6 +208,8 @@ func (c *APIClient) prepareRequest(
}
}

contentType := "application/json"

// Add the user agent to the request.
req.Header.Add("User-Agent", c.cfg.UserAgent)
req.Header.Add("X-Algolia-Application-Id", c.cfg.AppID)
Expand All @@ -233,29 +229,28 @@ func (c *APIClient) prepareRequest(
return req, nil
}

func (c *APIClient) decode(v any, b []byte, contentType string) error {
func (c *APIClient) decode(v any, b []byte) error {
if len(b) == 0 {
return nil
}
if s, ok := v.(*string); ok {
*s = string(b)
return nil
}
if jsonCheck.MatchString(contentType) {
if actualObj, ok := v.(interface{ GetActualInstance() any }); ok { // oneOf, anyOf schemas
if unmarshalObj, ok := actualObj.(interface{ UnmarshalJSON([]byte) error }); ok { // make sure it has UnmarshalJSON defined
if err := unmarshalObj.UnmarshalJSON(b); err != nil {
return fmt.Errorf("failed to unmarshal one of in response body: %w", err)
}
} else {
return errors.New("Unknown type with GetActualInstance but no unmarshalObj.UnmarshalJSON defined")
}
} else if err := json.Unmarshal(b, v); err != nil { // simple model
return fmt.Errorf("failed to unmarshal response body: %w", err)
}
return nil
}
return errors.New("undefined response type")

if actualObj, ok := v.(interface{ GetActualInstance() any }); ok { // oneOf, anyOf schemas
if unmarshalObj, ok := actualObj.(interface{ UnmarshalJSON([]byte) error }); ok { // make sure it has UnmarshalJSON defined
if err := unmarshalObj.UnmarshalJSON(b); err != nil {
return fmt.Errorf("failed to unmarshal one of in response body: %w", err)
}
} else {
return errors.New("Unknown type with GetActualInstance but no unmarshalObj.UnmarshalJSON defined")
}
} else if err := json.Unmarshal(b, v); err != nil { // simple model
return fmt.Errorf("failed to unmarshal response body: %w", err)
}

return nil
}

// Prevent trying to import "fmt"
Expand All @@ -282,7 +277,7 @@ func validateStruct(v any) error { //nolint:unused
}

// Set request body from an any
func setBody(body any, contentType string, c compression.Compression) (*bytes.Buffer, error) {
func setBody(body any, c compression.Compression) (*bytes.Buffer, error) {
if body == nil {
return nil, nil
}
Expand All @@ -304,7 +299,7 @@ func setBody(body any, contentType string, c compression.Compression) (*bytes.Bu
_, err = bodyBuf.WriteString(s)
} else if s, ok := body.(*string); ok {
_, err = bodyBuf.WriteString(*s)
} else if jsonCheck.MatchString(contentType) {
} else {
err = json.NewEncoder(bodyBuf).Encode(body)
}
}
Expand All @@ -314,7 +309,7 @@ func setBody(body any, contentType string, c compression.Compression) (*bytes.Bu
}

if bodyBuf.Len() == 0 {
return nil, fmt.Errorf("Invalid body type %s\n", contentType)
return nil, errors.New("Invalid body type, or empty body")
}
return bodyBuf, nil
}
Expand Down
25 changes: 13 additions & 12 deletions templates/go/tests/client/suite.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -33,37 +33,38 @@ func create{{#lambda.titlecase}}{{clientPrefix}}{{/lambda.titlecase}}Client(t *t

{{#blocksClient}}
{{#tests}}
// {{testName}}
func Test{{#lambda.titlecase}}{{clientPrefix}}{{testType}}{{/lambda.titlecase}}{{testIndex}}(t *testing.T) {
var err error
{{#autoCreateClient}}
client, echo := create{{#lambda.titlecase}}{{clientPrefix}}{{/lambda.titlecase}}Client(t)
_ = echo
{{/autoCreateClient}}
{{^autoCreateClient}}
echo := &tests.EchoRequester{}
var client *{{clientPrefix}}.APIClient
var cfg {{clientPrefix}}.Configuration
_ = client
{{/autoCreateClient}}
_ = echo
{{#steps}}
{{#isError}}
{{#dynamicTemplate}}{{/dynamicTemplate}}
require.EqualError(t, err, "{{{expectedError}}}")
{{/isError}}
{{^isError}}
require.NoError(t, err)
{{#dynamicTemplate}}{{/dynamicTemplate}}
require.NoError(t, err)
{{#match}}
{{#testUserAgent}}
require.Regexp(t, regexp.MustCompile(`{{{match}}}`), echo.Header.Get("User-Agent"))
{{/testUserAgent}}
{{#testTimeouts}}
require.Equal(t, int64({{{match.parametersWithDataTypeMap.connectTimeout.value}}}), echo.ConnectTimeout.Milliseconds())
require.Equal(t, int64({{{match.parametersWithDataTypeMap.responseTimeout.value}}}), echo.Timeout.Milliseconds())
{{/testTimeouts}}
{{#testHost}}
require.Equal(t, "{{{match}}}", echo.Host)
{{/testHost}}
{{#testUserAgent}}
require.Regexp(t, regexp.MustCompile(`{{{match}}}`), echo.Header.Get("User-Agent"))
{{/testUserAgent}}
{{#testTimeouts}}
require.Equal(t, int64({{{match.parametersWithDataTypeMap.connectTimeout.value}}}), echo.ConnectTimeout.Milliseconds())
require.Equal(t, int64({{{match.parametersWithDataTypeMap.responseTimeout.value}}}), echo.Timeout.Milliseconds())
{{/testTimeouts}}
{{#testHost}}
require.Equal(t, "{{{match}}}", echo.Host)
{{/testHost}}
{{/match}}
{{/isError}}
{{/steps}}
Expand Down
4 changes: 1 addition & 3 deletions templates/go/tests/requests/requests.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ func Test{{#lambda.titlecase}}{{clientPrefix}}{{/lambda.titlecase}}_{{#lambda.ti
{{/requestOptions}})
require.NoError(t, err)

expectedPath, err := url.QueryUnescape("{{{request.path}}}")
require.NoError(t, err)
require.Equal(t, expectedPath, echo.Path)
require.Equal(t, "{{{request.path}}}", echo.Path)
require.Equal(t, "{{{request.method}}}", echo.Method)

{{#request.body}}
Expand Down
2 changes: 1 addition & 1 deletion tests/output/go/tests/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type EchoRequester struct {

func (e *EchoRequester) Request(req *http.Request, timeout time.Duration, connectTimeout time.Duration) (*http.Response, error) {
e.Host = req.URL.Host
e.Path = req.URL.Path
e.Path = req.URL.EscapedPath()
e.Method = req.Method
e.Header = req.Header
e.Query = req.URL.Query()
Expand Down
Loading