From 0366222ca7a11408a6e652c1db1e0e1af15b706c Mon Sep 17 00:00:00 2001 From: Alon Adam Date: Wed, 27 Apr 2022 16:19:04 +0300 Subject: [PATCH] Small refactoring to support 'HttpError'. --- kusto/conn.go | 6 +-- kusto/data/errors/errors.go | 52 +++++++++++++++---------- kusto/ingest/file_options_test.go | 2 +- kusto/ingest/internal/conn/conn.go | 2 +- kusto/ingest/internal/conn/conn_test.go | 3 +- kusto/ingest/streaming.go | 2 +- kusto/test/etoe/etoe_test.go | 2 +- 7 files changed, 39 insertions(+), 30 deletions(-) diff --git a/kusto/conn.go b/kusto/conn.go index 9c0ed048..37927888 100644 --- a/kusto/conn.go +++ b/kusto/conn.go @@ -183,12 +183,8 @@ func (c *conn) execute(ctx context.Context, execType int, db string, query Stmt, return execResp{}, err } - if resp.StatusCode == http.StatusTooManyRequests { - return execResp{}, errors.HTTPErrorCode(op, resp.StatusCode, body, fmt.Sprintf("request got throttled for query %q: ", query.String())) - } - if resp.StatusCode != http.StatusOK { - return execResp{}, errors.HTTP(op, resp.Status, body, fmt.Sprintf("error from Kusto endpoint for query %q: ", query.String())) + return execResp{}, errors.HTTP(op, resp.Status, resp.StatusCode, body, fmt.Sprintf("error from Kusto endpoint for query %q: ", query.String())) } var dec frames.Decoder diff --git a/kusto/data/errors/errors.go b/kusto/data/errors/errors.go index c0664847..79ab2807 100644 --- a/kusto/data/errors/errors.go +++ b/kusto/data/errors/errors.go @@ -18,7 +18,6 @@ import ( "io/ioutil" "net/http" "runtime" - "strconv" "strings" ) @@ -76,8 +75,10 @@ type Error struct { inner *Error } +type KustoError = Error + type HttpError struct { - err Error + KustoError StatusCode int } @@ -224,29 +225,23 @@ func ES(o Op, k Kind, s string, args ...interface{}) *Error { } // HTTP constructs an *Error from an *http.Response and a prefix to the error message. -func HTTP(o Op, status string, body io.ReadCloser, prefix string) *Error { +func HTTP(o Op, status string, statusCode int, body io.ReadCloser, prefix string) *HttpError { bodyBytes, err := ioutil.ReadAll(body) if err != nil { bodyBytes = []byte(fmt.Sprintf("Failed to read body: %v", err)) } - - e := &Error{ - Op: o, - Kind: KHTTPError, - restErrMsg: bodyBytes, - Err: fmt.Errorf("%s(%s):\n%s", prefix, status, string(bodyBytes)), + e := HttpError{ + KustoError: KustoError{ + Op: o, + Kind: KHTTPError, + restErrMsg: bodyBytes, + Err: fmt.Errorf("%s(%s):\n%s", prefix, status, string(bodyBytes)), + }, + StatusCode: statusCode, } - e.UnmarshalREST() - return e -} -func HTTPErrorCode(o Op, status int, body io.ReadCloser, prefix string) *HttpError { - err := HTTP(o, strconv.Itoa(status), body, prefix) - httpError := &HttpError{ - StatusCode: status, - err: *err, - } - return httpError + e.UnmarshalREST() + return &e } // e constructs an Error. You may pass in an Op, Kind, string or error. This will strip an *Error if you @@ -387,5 +382,22 @@ func (e *HttpError) IsThrottled() bool { } func (e *HttpError) Error() string { - return e.err.Error() + return e.KustoError.Error() +} + +func (e *HttpError) Unwrap() error { + if e == nil { + return nil + } + return e.KustoError.Unwrap() +} + +func GetKustoError(err error) (*Error, bool) { + if err, ok := err.(*Error); ok { + return err, true + } + if err, ok := err.(*HttpError); ok { + return &err.KustoError, true + } + return nil, false } diff --git a/kusto/ingest/file_options_test.go b/kusto/ingest/file_options_test.go index 75e5707a..a5216700 100644 --- a/kusto/ingest/file_options_test.go +++ b/kusto/ingest/file_options_test.go @@ -132,7 +132,7 @@ func TestOptions(t *testing.T) { case fromReader: _, err = test.ingestor.FromReader(ctx, bytes.NewReader([]byte{}), test.option) } - if e, ok := err.(*errors.Error); ok { + if e, ok := errors.GetKustoError(err); ok { assert.Equal(t, test.op, e.Op) assert.Equal(t, test.kind, e.Kind) } else { diff --git a/kusto/ingest/internal/conn/conn.go b/kusto/ingest/internal/conn/conn.go index 3835c96d..c195ce25 100644 --- a/kusto/ingest/internal/conn/conn.go +++ b/kusto/ingest/internal/conn/conn.go @@ -167,7 +167,7 @@ func (c *Conn) StreamIngest(ctx context.Context, db, table string, payload io.Re if err != nil { return err } - return errors.HTTP(writeOp, resp.Status, body, "streaming ingest issue") + return errors.HTTP(writeOp, resp.Status, resp.StatusCode, body, "streaming ingest issue") } return nil } diff --git a/kusto/ingest/internal/conn/conn_test.go b/kusto/ingest/internal/conn/conn_test.go index 9d19f772..bca6c4b8 100644 --- a/kusto/ingest/internal/conn/conn_test.go +++ b/kusto/ingest/internal/conn/conn_test.go @@ -196,7 +196,8 @@ func TestStream(t *testing.T) { err = conn.StreamIngest(ctx, db, "table", &payload, properties.JSON, test.mappingName, "") if test.err != nil { - assert.Equal(t, test.err, err.(*errors.Error).Err) + e, _ := errors.GetKustoError(err) + assert.Equal(t, test.err, e.Err) return } else { assert.NoError(t, err) diff --git a/kusto/ingest/streaming.go b/kusto/ingest/streaming.go index 9c00d331..30939a17 100644 --- a/kusto/ingest/streaming.go +++ b/kusto/ingest/streaming.go @@ -125,7 +125,7 @@ func streamImpl(c streamIngestor, ctx context.Context, payload io.Reader, props props.Streaming.ClientRequestId) if err != nil { - if e, ok := err.(*errors.Error); ok { + if e, ok := errors.GetKustoError(err); ok { return nil, e } return nil, errors.E(errors.OpIngestStream, errors.KClientArgs, err) diff --git a/kusto/test/etoe/etoe_test.go b/kusto/test/etoe/etoe_test.go index 451eb979..954e5092 100644 --- a/kusto/test/etoe/etoe_test.go +++ b/kusto/test/etoe/etoe_test.go @@ -1345,7 +1345,7 @@ func TestError(t *testing.T) { kusto.NewParameters().Must(kusto.QueryValues{"tableName": uuid.New().String()}), )) - kustoError, ok := err.(*errors.Error) + kustoError, ok := errors.GetKustoError(err) assert.True(t, ok) assert.Equal(t, errors.OpQuery, kustoError.Op) assert.Equal(t, errors.KHTTPError, kustoError.Kind)