Skip to content

Commit

Permalink
Merge pull request #20 from SkynetLabs/pj/ncmec-client-fu
Browse files Browse the repository at this point in the history
NCMEC F/U
  • Loading branch information
ro-tex authored Apr 4, 2022
2 parents af51ffa + 6907836 commit 8a25a75
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 16 deletions.
26 changes: 13 additions & 13 deletions email/ncmec.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,16 @@ func (c *NCMECClient) status() (reportResponse, error) {
// get is a helper function that executes a GET request on the given endpoint
// with the provided query values. The response will get unmarshaled into the
// given response object.
func (c *NCMECClient) get(endpoint string, query url.Values, headers http.Header, obj interface{}) error {
func (c *NCMECClient) get(endpoint string, query url.Values, headers http.Header, obj interface{}) (err error) {
url := fmt.Sprintf("%s%s", c.staticBaseUri, endpoint)

queryString := query.Encode()
if queryString != "" {
url += "?" + queryString
}

req, err := http.NewRequest(http.MethodGet, url, nil)
var req *http.Request
req, err = http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return errors.AddContext(err, "failed to create request")
}
Expand All @@ -264,13 +265,14 @@ func (c *NCMECClient) get(endpoint string, query url.Values, headers http.Header
for k, v := range headers {
req.Header.Set(k, v[0])
}
res, err := http.DefaultClient.Do(req)
var res *http.Response
res, err = http.DefaultClient.Do(req)
if err != nil {
return err
}
defer func() {
io.Copy(ioutil.Discard, res.Body)
res.Body.Close()
err = errors.Compose(err, res.Body.Close())
}()

// decode the response body
Expand All @@ -279,7 +281,7 @@ func (c *NCMECClient) get(endpoint string, query url.Values, headers http.Header

// post is a helper function that executes a POST request on the given endpoint
// with the provided query values.
func (c *NCMECClient) post(endpoint string, query url.Values, headers http.Header, body io.Reader, obj interface{}) error {
func (c *NCMECClient) post(endpoint string, query url.Values, headers http.Header, body io.Reader, obj interface{}) (err error) {
url := fmt.Sprintf("%s%s", c.staticBaseUri, endpoint)

queryString := query.Encode()
Expand All @@ -288,7 +290,8 @@ func (c *NCMECClient) post(endpoint string, query url.Values, headers http.Heade
}

// create the request
req, err := http.NewRequest(http.MethodPost, url, body)
var req *http.Request
req, err = http.NewRequest(http.MethodPost, url, body)
if err != nil {
return errors.AddContext(err, "failed to create request")
}
Expand All @@ -297,19 +300,16 @@ func (c *NCMECClient) post(endpoint string, query url.Values, headers http.Heade
for k, v := range headers {
req.Header.Set(k, v[0])
}
res, err := http.DefaultClient.Do(req)
var res *http.Response
res, err = http.DefaultClient.Do(req)
if err != nil {
return err
}
defer func() {
io.Copy(ioutil.Discard, res.Body)
res.Body.Close()
err = errors.Compose(err, res.Body.Close())
}()

// decode the response body
err = xml.NewDecoder(res.Body).Decode(obj)
if err != nil {
return err
}
return nil
return xml.NewDecoder(res.Body).Decode(obj)
}
17 changes: 14 additions & 3 deletions email/ncmec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ func TestNCMECClient(t *testing.T) {
if testing.Short() {
t.SkipNow()
}
t.Parallel()

os.Setenv("NCMEC_USERNAME", "")
os.Setenv("NCMEC_PASSWORD", "")
Expand Down Expand Up @@ -58,10 +57,16 @@ func TestNCMECClient(t *testing.T) {

// testFinishReport is a unit test that verifies whether we can finish a report
func testFinishReport(t *testing.T, c *NCMECClient) {
// quickly assert NCMEC is up
res, err := c.status()
if err != nil || res.ResponseCode != ncmecStatusOK {
t.Fatal("NCMEC down")
}

// open a report
now := time.Now().UTC().Add(-time.Hour)
report := newTestReport(now)
res, err := c.openReport(report)
res, err = c.openReport(report)
if err != nil {
t.Fatal(err)
}
Expand All @@ -83,10 +88,16 @@ func testFinishReport(t *testing.T, c *NCMECClient) {

// testOpenReport is a unit test that verifies we can open a report with NCMEC
func testOpenReport(t *testing.T, c *NCMECClient) {
// quickly assert NCMEC is up
res, err := c.status()
if err != nil || res.ResponseCode != ncmecStatusOK {
t.Fatal("NCMEC down")
}

// open a report
now := time.Now().UTC().Add(time.Hour)
report := newTestReport(now)
res, err := c.openReport(report)
res, err = c.openReport(report)
if err != nil {
t.Fatal(err)
}
Expand Down

0 comments on commit 8a25a75

Please sign in to comment.