diff --git a/email/ncmec.go b/email/ncmec.go index 438f3f7..c4808e4 100644 --- a/email/ncmec.go +++ b/email/ncmec.go @@ -247,7 +247,7 @@ func (c *NCMECClient) status() (reportResponse, error) { // get is a helper function that executes a GET request on the given endpoint // with the provided query values. The response will get unmarshaled into the // given response object. -func (c *NCMECClient) get(endpoint string, query url.Values, headers http.Header, obj interface{}) error { +func (c *NCMECClient) get(endpoint string, query url.Values, headers http.Header, obj interface{}) (err error) { url := fmt.Sprintf("%s%s", c.staticBaseUri, endpoint) queryString := query.Encode() @@ -255,7 +255,8 @@ func (c *NCMECClient) get(endpoint string, query url.Values, headers http.Header url += "?" + queryString } - req, err := http.NewRequest(http.MethodGet, url, nil) + var req *http.Request + req, err = http.NewRequest(http.MethodGet, url, nil) if err != nil { return errors.AddContext(err, "failed to create request") } @@ -264,13 +265,14 @@ func (c *NCMECClient) get(endpoint string, query url.Values, headers http.Header for k, v := range headers { req.Header.Set(k, v[0]) } - res, err := http.DefaultClient.Do(req) + var res *http.Response + res, err = http.DefaultClient.Do(req) if err != nil { return err } defer func() { io.Copy(ioutil.Discard, res.Body) - res.Body.Close() + err = errors.Compose(err, res.Body.Close()) }() // decode the response body @@ -279,7 +281,7 @@ func (c *NCMECClient) get(endpoint string, query url.Values, headers http.Header // post is a helper function that executes a POST request on the given endpoint // with the provided query values. -func (c *NCMECClient) post(endpoint string, query url.Values, headers http.Header, body io.Reader, obj interface{}) error { +func (c *NCMECClient) post(endpoint string, query url.Values, headers http.Header, body io.Reader, obj interface{}) (err error) { url := fmt.Sprintf("%s%s", c.staticBaseUri, endpoint) queryString := query.Encode() @@ -288,7 +290,8 @@ func (c *NCMECClient) post(endpoint string, query url.Values, headers http.Heade } // create the request - req, err := http.NewRequest(http.MethodPost, url, body) + var req *http.Request + req, err = http.NewRequest(http.MethodPost, url, body) if err != nil { return errors.AddContext(err, "failed to create request") } @@ -297,19 +300,16 @@ func (c *NCMECClient) post(endpoint string, query url.Values, headers http.Heade for k, v := range headers { req.Header.Set(k, v[0]) } - res, err := http.DefaultClient.Do(req) + var res *http.Response + res, err = http.DefaultClient.Do(req) if err != nil { return err } defer func() { io.Copy(ioutil.Discard, res.Body) - res.Body.Close() + err = errors.Compose(err, res.Body.Close()) }() // decode the response body - err = xml.NewDecoder(res.Body).Decode(obj) - if err != nil { - return err - } - return nil + return xml.NewDecoder(res.Body).Decode(obj) } diff --git a/email/ncmec_test.go b/email/ncmec_test.go index 8ab4153..fd59385 100644 --- a/email/ncmec_test.go +++ b/email/ncmec_test.go @@ -13,7 +13,6 @@ func TestNCMECClient(t *testing.T) { if testing.Short() { t.SkipNow() } - t.Parallel() os.Setenv("NCMEC_USERNAME", "") os.Setenv("NCMEC_PASSWORD", "") @@ -58,10 +57,16 @@ func TestNCMECClient(t *testing.T) { // testFinishReport is a unit test that verifies whether we can finish a report func testFinishReport(t *testing.T, c *NCMECClient) { + // quickly assert NCMEC is up + res, err := c.status() + if err != nil || res.ResponseCode != ncmecStatusOK { + t.Fatal("NCMEC down") + } + // open a report now := time.Now().UTC().Add(-time.Hour) report := newTestReport(now) - res, err := c.openReport(report) + res, err = c.openReport(report) if err != nil { t.Fatal(err) } @@ -83,10 +88,16 @@ func testFinishReport(t *testing.T, c *NCMECClient) { // testOpenReport is a unit test that verifies we can open a report with NCMEC func testOpenReport(t *testing.T, c *NCMECClient) { + // quickly assert NCMEC is up + res, err := c.status() + if err != nil || res.ResponseCode != ncmecStatusOK { + t.Fatal("NCMEC down") + } + // open a report now := time.Now().UTC().Add(time.Hour) report := newTestReport(now) - res, err := c.openReport(report) + res, err = c.openReport(report) if err != nil { t.Fatal(err) }