Skip to content

Commit

Permalink
fix(go): support http (#2639)
Browse files Browse the repository at this point in the history
  • Loading branch information
millotp authored Jan 30, 2024
1 parent 17814f1 commit 366c994
Show file tree
Hide file tree
Showing 10 changed files with 73 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ type Configuration struct {
AppID string
ApiKey string

Hosts []string
Hosts []StatefulHost
DefaultHeader map[string]string
UserAgent string
Requester Requester
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,19 @@ const (
)

type Host struct {
scheme string
host string
timeout time.Duration
}

type RetryStrategy struct {
sync.RWMutex
hosts []*StatefulHost
hosts []StatefulHost
writeTimeout time.Duration
readTimeout time.Duration
}

func newRetryStrategy(hosts []*StatefulHost, readTimeout, writeTimeout time.Duration) *RetryStrategy {
func newRetryStrategy(hosts []StatefulHost, readTimeout, writeTimeout time.Duration) *RetryStrategy {
if readTimeout == 0 {
readTimeout = DefaultReadTimeout
}
Expand Down Expand Up @@ -74,7 +75,7 @@ func (s *RetryStrategy) GetTryableHosts(k call.Kind) []Host {

for _, h := range s.hosts {
if !h.isDown && h.accept(k) {
hosts = append(hosts, Host{h.host, time.Duration(h.retryCount+1) * baseTimeout})
hosts = append(hosts, Host{h.scheme, h.host, time.Duration(h.retryCount+1) * baseTimeout})
}
}

Expand All @@ -85,7 +86,7 @@ func (s *RetryStrategy) GetTryableHosts(k call.Kind) []Host {
for _, h := range s.hosts {
if h.accept(k) {
h.reset()
hosts = append(hosts, Host{h.host, time.Duration(h.retryCount+1) * baseTimeout})
hosts = append(hosts, Host{h.scheme, h.host, time.Duration(h.retryCount+1) * baseTimeout})
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@ const (
)

type StatefulHost struct {
scheme string
host string
isDown bool
retryCount int
lastUpdate time.Time
accept func(k call.Kind) bool
}

func NewStatefulHost(host string, accept func(k call.Kind) bool) *StatefulHost {
return &StatefulHost{
func NewStatefulHost(scheme string, host string, accept func(k call.Kind) bool) StatefulHost {
return StatefulHost{
scheme: scheme,
host: host,
isDown: false,
retryCount: 0,
Expand Down
10 changes: 5 additions & 5 deletions clients/algoliasearch-client-go/algolia/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type Transport struct {
}

func New(
hosts []*StatefulHost,
hosts []StatefulHost,
requester Requester,
readTimeout time.Duration,
writeTimeout time.Duration,
Expand Down Expand Up @@ -65,7 +65,7 @@ func (t *Transport) Request(ctx context.Context, req *http.Request, k call.Kind)
// cancelled` error may happen when the body is read.
perRequestCtx, cancel := context.WithTimeout(ctx, h.timeout)
req = req.WithContext(perRequestCtx)
res, err := t.request(req, h.host, h.timeout, t.connectTimeout)
res, err := t.request(req, h, h.timeout, t.connectTimeout)

code := 0
if res != nil {
Expand Down Expand Up @@ -116,9 +116,9 @@ func (t *Transport) Request(ctx context.Context, req *http.Request, k call.Kind)
return nil, nil, errs.ErrNoMoreHostToTry
}

func (t *Transport) request(req *http.Request, host string, timeout time.Duration, connectTimeout time.Duration) (*http.Response, error) {
req.URL.Scheme = "https"
req.URL.Host = host
func (t *Transport) request(req *http.Request, host Host, timeout time.Duration, connectTimeout time.Duration) (*http.Response, error) {
req.URL.Scheme = host.scheme
req.URL.Host = host.host

debug.Display(req)
res, err := t.requester.Request(req, timeout, connectTimeout)
Expand Down
4 changes: 2 additions & 2 deletions clients/algoliasearch-client-go/algolia/transport/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ import (
"math/rand"
)

func Shuffle(hosts []*StatefulHost) []*StatefulHost {
func Shuffle(hosts []StatefulHost) []StatefulHost {
if hosts == nil {
return nil
}
shuffled := make([]*StatefulHost, len(hosts))
shuffled := make([]StatefulHost, len(hosts))
for i, v := range rand.Perm(len(hosts)) {
shuffled[i] = hosts[v]
}
Expand Down
8 changes: 5 additions & 3 deletions templates/go/api.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,10 @@ func (c *APIClient) {{nickname}}WithContext(ctx context.Context, {{#hasParams}}r
{{/returnType}}
)

{{#vendorExtensions}}
requestPath := "{{{path}}}"{{#pathParams}}
requestPath = strings.ReplaceAll(requestPath, {{=<% %>=}}"{<%baseName%>}"<%={{ }}=%>, url.PathEscape(parameterToString(r.{{paramName}}))){{/pathParams}}
requestPath = strings.ReplaceAll(requestPath, {{=<% %>=}}"{<%baseName%>}"<%={{ }}=%>, {{#x-is-custom-request}}parameterToString(r.{{paramName}}){{/x-is-custom-request}}{{^x-is-custom-request}}url.PathEscape(parameterToString(r.{{paramName}})){{/x-is-custom-request}}){{/pathParams}}
{{/vendorExtensions}}

headers := make(map[string]string)
queryParams := url.Values{}
Expand Down Expand Up @@ -287,7 +289,7 @@ func (c *APIClient) {{nickname}}WithContext(ctx context.Context, {{#hasParams}}r
}

var v ErrorBase
err = c.decode(&v, resBody, res.Header.Get("Content-Type"))
err = c.decode(&v, resBody)
if err != nil {
newErr.Message = err.Error()
return {{#returnType}}returnValue, {{/returnType}}newErr
Expand All @@ -297,7 +299,7 @@ func (c *APIClient) {{nickname}}WithContext(ctx context.Context, {{#hasParams}}r
}

{{#returnType}}
err = c.decode(&returnValue, resBody, res.Header.Get("Content-Type"))
err = c.decode(&returnValue, resBody)
if err != nil {
return {{#returnType}}returnValue, {{/returnType}}reportError("cannot decode result: %w", err)
}
Expand Down
77 changes: 36 additions & 41 deletions templates/go/client.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ import (
"github.com/algolia/algoliasearch-client-go/v4/algolia/transport"
)

var jsonCheck = regexp.MustCompile(`(?i:(?:application|text)/(?:vnd\.[^;]+\+)?json)`)

// APIClient manages communication with the {{appName}} API v{{version}}
// In most cases there should be only one, shared, APIClient.
type APIClient struct {
Expand All @@ -51,7 +49,7 @@ return NewClientWithConfig(Configuration{

// NewClientWithConfig creates a new API client with the given configuration to fully customize the client behaviour.
func NewClientWithConfig(cfg Configuration) (*APIClient, error) {
var hosts []*transport.StatefulHost
var hosts []transport.StatefulHost
if cfg.AppID == "" {
return nil, errors.New("`appId` is missing.")
Expand All @@ -65,9 +63,7 @@ func NewClientWithConfig(cfg Configuration) (*APIClient, error) {
}{{/hasRegionalHost}}
hosts = getDefaultHosts({{#hasRegionalHost}}cfg.Region{{/hasRegionalHost}}{{#hostWithAppID}}cfg.AppID{{/hostWithAppID}})
} else {
for _, h := range cfg.Hosts {
hosts = append(hosts, transport.NewStatefulHost(h, call.IsReadWrite))
}
hosts = cfg.Hosts
}
if cfg.Requester == nil {
cfg.Requester = transport.NewDefaultRequester(&cfg.ConnectTimeout)
Expand All @@ -91,34 +87,34 @@ func NewClientWithConfig(cfg Configuration) (*APIClient, error) {
}

{{#hasRegionalHost}}
func getDefaultHosts(r Region) []*transport.StatefulHost {
func getDefaultHosts(r Region) []transport.StatefulHost {
{{#fallbackToAliasHost}}
if r == "" {
return []*transport.StatefulHost{transport.NewStatefulHost("{{{hostWithFallback}}}", call.IsReadWrite)}
return []transport.StatefulHost{transport.NewStatefulHost("https", "{{{hostWithFallback}}}", call.IsReadWrite)}
}{{/fallbackToAliasHost}}

return []*transport.StatefulHost{transport.NewStatefulHost(strings.ReplaceAll("{{{regionalHost}}}", "{region}", string(r)), call.IsReadWrite)}
return []transport.StatefulHost{transport.NewStatefulHost("https", strings.ReplaceAll("{{{regionalHost}}}", "{region}", string(r)), call.IsReadWrite)}
}
{{/hasRegionalHost}}
{{#hostWithAppID}}
func getDefaultHosts(appID string) []*transport.StatefulHost {
hosts := []*transport.StatefulHost{
transport.NewStatefulHost(appID + "-dsn.algolia.net", call.IsRead),
transport.NewStatefulHost(appID + ".algolia.net", call.IsWrite),
func getDefaultHosts(appID string) []transport.StatefulHost {
hosts := []transport.StatefulHost{
transport.NewStatefulHost("https", appID + "-dsn.algolia.net", call.IsRead),
transport.NewStatefulHost("https", appID + ".algolia.net", call.IsWrite),
}
hosts = append(hosts, transport.Shuffle(
[]*transport.StatefulHost{
transport.NewStatefulHost(fmt.Sprintf(appID + "-1.algolianet.com"), call.IsReadWrite),
transport.NewStatefulHost(fmt.Sprintf(appID + "-2.algolianet.com"), call.IsReadWrite),
transport.NewStatefulHost(fmt.Sprintf(appID + "-3.algolianet.com"), call.IsReadWrite),
[]transport.StatefulHost{
transport.NewStatefulHost("https", fmt.Sprintf("%s-1.algolianet.com", appID), call.IsReadWrite),
transport.NewStatefulHost("https", fmt.Sprintf("%s-2.algolianet.com", appID), call.IsReadWrite),
transport.NewStatefulHost("https", fmt.Sprintf("%s-3.algolianet.com", appID), call.IsReadWrite),
},
)...)
return hosts
}
{{/hostWithAppID}}
{{#uniqueHost}}
func getDefaultHosts() []*transport.StatefulHost {
return []*transport.StatefulHost{transport.NewStatefulHost("{{{.}}}", call.IsReadWrite)}
func getDefaultHosts() []transport.StatefulHost {
return []transport.StatefulHost{transport.NewStatefulHost("https", "{{{.}}}", call.IsReadWrite)}
}
{{/uniqueHost}}

Expand Down Expand Up @@ -171,9 +167,7 @@ func (c *APIClient) prepareRequest(
headerParams map[string]string,
queryParams url.Values) (req *http.Request, err error) {
contentType := "application/json"
body, err := setBody(postBody, contentType, c.cfg.Compression)
body, err := setBody(postBody, c.cfg.Compression)
if err != nil {
return nil, fmt.Errorf("failed to set the body: %w", err)
}
Expand Down Expand Up @@ -214,6 +208,8 @@ func (c *APIClient) prepareRequest(
}
}

contentType := "application/json"

// Add the user agent to the request.
req.Header.Add("User-Agent", c.cfg.UserAgent)
req.Header.Add("X-Algolia-Application-Id", c.cfg.AppID)
Expand All @@ -233,29 +229,28 @@ func (c *APIClient) prepareRequest(
return req, nil
}

func (c *APIClient) decode(v any, b []byte, contentType string) error {
func (c *APIClient) decode(v any, b []byte) error {
if len(b) == 0 {
return nil
}
if s, ok := v.(*string); ok {
*s = string(b)
return nil
}
if jsonCheck.MatchString(contentType) {
if actualObj, ok := v.(interface{ GetActualInstance() any }); ok { // oneOf, anyOf schemas
if unmarshalObj, ok := actualObj.(interface{ UnmarshalJSON([]byte) error }); ok { // make sure it has UnmarshalJSON defined
if err := unmarshalObj.UnmarshalJSON(b); err != nil {
return fmt.Errorf("failed to unmarshal one of in response body: %w", err)
}
} else {
return errors.New("Unknown type with GetActualInstance but no unmarshalObj.UnmarshalJSON defined")
}
} else if err := json.Unmarshal(b, v); err != nil { // simple model
return fmt.Errorf("failed to unmarshal response body: %w", err)
}
return nil
}
return errors.New("undefined response type")

if actualObj, ok := v.(interface{ GetActualInstance() any }); ok { // oneOf, anyOf schemas
if unmarshalObj, ok := actualObj.(interface{ UnmarshalJSON([]byte) error }); ok { // make sure it has UnmarshalJSON defined
if err := unmarshalObj.UnmarshalJSON(b); err != nil {
return fmt.Errorf("failed to unmarshal one of in response body: %w", err)
}
} else {
return errors.New("Unknown type with GetActualInstance but no unmarshalObj.UnmarshalJSON defined")
}
} else if err := json.Unmarshal(b, v); err != nil { // simple model
return fmt.Errorf("failed to unmarshal response body: %w", err)
}

return nil
}

// Prevent trying to import "fmt"
Expand All @@ -282,7 +277,7 @@ func validateStruct(v any) error { //nolint:unused
}

// Set request body from an any
func setBody(body any, contentType string, c compression.Compression) (*bytes.Buffer, error) {
func setBody(body any, c compression.Compression) (*bytes.Buffer, error) {
if body == nil {
return nil, nil
}
Expand All @@ -304,7 +299,7 @@ func setBody(body any, contentType string, c compression.Compression) (*bytes.Bu
_, err = bodyBuf.WriteString(s)
} else if s, ok := body.(*string); ok {
_, err = bodyBuf.WriteString(*s)
} else if jsonCheck.MatchString(contentType) {
} else {
err = json.NewEncoder(bodyBuf).Encode(body)
}
}
Expand All @@ -314,7 +309,7 @@ func setBody(body any, contentType string, c compression.Compression) (*bytes.Bu
}

if bodyBuf.Len() == 0 {
return nil, fmt.Errorf("Invalid body type %s\n", contentType)
return nil, errors.New("Invalid body type, or empty body")
}
return bodyBuf, nil
}
Expand Down
25 changes: 13 additions & 12 deletions templates/go/tests/client/suite.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -33,37 +33,38 @@ func create{{#lambda.titlecase}}{{clientPrefix}}{{/lambda.titlecase}}Client(t *t

{{#blocksClient}}
{{#tests}}
// {{testName}}
func Test{{#lambda.titlecase}}{{clientPrefix}}{{testType}}{{/lambda.titlecase}}{{testIndex}}(t *testing.T) {
var err error
{{#autoCreateClient}}
client, echo := create{{#lambda.titlecase}}{{clientPrefix}}{{/lambda.titlecase}}Client(t)
_ = echo
{{/autoCreateClient}}
{{^autoCreateClient}}
echo := &tests.EchoRequester{}
var client *{{clientPrefix}}.APIClient
var cfg {{clientPrefix}}.Configuration
_ = client
{{/autoCreateClient}}
_ = echo
{{#steps}}
{{#isError}}
{{#dynamicTemplate}}{{/dynamicTemplate}}
require.EqualError(t, err, "{{{expectedError}}}")
{{/isError}}
{{^isError}}
require.NoError(t, err)
{{#dynamicTemplate}}{{/dynamicTemplate}}
require.NoError(t, err)
{{#match}}
{{#testUserAgent}}
require.Regexp(t, regexp.MustCompile(`{{{match}}}`), echo.Header.Get("User-Agent"))
{{/testUserAgent}}
{{#testTimeouts}}
require.Equal(t, int64({{{match.parametersWithDataTypeMap.connectTimeout.value}}}), echo.ConnectTimeout.Milliseconds())
require.Equal(t, int64({{{match.parametersWithDataTypeMap.responseTimeout.value}}}), echo.Timeout.Milliseconds())
{{/testTimeouts}}
{{#testHost}}
require.Equal(t, "{{{match}}}", echo.Host)
{{/testHost}}
{{#testUserAgent}}
require.Regexp(t, regexp.MustCompile(`{{{match}}}`), echo.Header.Get("User-Agent"))
{{/testUserAgent}}
{{#testTimeouts}}
require.Equal(t, int64({{{match.parametersWithDataTypeMap.connectTimeout.value}}}), echo.ConnectTimeout.Milliseconds())
require.Equal(t, int64({{{match.parametersWithDataTypeMap.responseTimeout.value}}}), echo.Timeout.Milliseconds())
{{/testTimeouts}}
{{#testHost}}
require.Equal(t, "{{{match}}}", echo.Host)
{{/testHost}}
{{/match}}
{{/isError}}
{{/steps}}
Expand Down
4 changes: 1 addition & 3 deletions templates/go/tests/requests/requests.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ func Test{{#lambda.titlecase}}{{clientPrefix}}{{/lambda.titlecase}}_{{#lambda.ti
{{/requestOptions}})
require.NoError(t, err)

expectedPath, err := url.QueryUnescape("{{{request.path}}}")
require.NoError(t, err)
require.Equal(t, expectedPath, echo.Path)
require.Equal(t, "{{{request.path}}}", echo.Path)
require.Equal(t, "{{{request.method}}}", echo.Method)

{{#request.body}}
Expand Down
2 changes: 1 addition & 1 deletion tests/output/go/tests/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type EchoRequester struct {

func (e *EchoRequester) Request(req *http.Request, timeout time.Duration, connectTimeout time.Duration) (*http.Response, error) {
e.Host = req.URL.Host
e.Path = req.URL.Path
e.Path = req.URL.EscapedPath()
e.Method = req.Method
e.Header = req.Header
e.Query = req.URL.Query()
Expand Down

0 comments on commit 366c994

Please sign in to comment.