diff --git a/CHANGELOG.md b/CHANGELOG.md index dddf3ad1d..05fe5185c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ # CHANGELOG +## v13.2.0 + +### New Features + +- Added the following functions to replace their versions that don't take a context. + - `adal.InitiateDeviceAuthWithContext()` + - `adal.CheckForUserCompletionWithContext()` + - `adal.WaitForUserCompletionWithContext()` + ## v13.1.0 ### New Features diff --git a/autorest/adal/devicetoken.go b/autorest/adal/devicetoken.go index b38f4c245..914f8af5e 100644 --- a/autorest/adal/devicetoken.go +++ b/autorest/adal/devicetoken.go @@ -24,6 +24,7 @@ package adal */ import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -101,7 +102,14 @@ type deviceToken struct { // InitiateDeviceAuth initiates a device auth flow. It returns a DeviceCode // that can be used with CheckForUserCompletion or WaitForUserCompletion. +// Deprecated: use InitiateDeviceAuthWithContext() instead. func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) { + return InitiateDeviceAuthWithContext(context.Background(), sender, oauthConfig, clientID, resource) +} + +// InitiateDeviceAuthWithContext initiates a device auth flow. It returns a DeviceCode +// that can be used with CheckForUserCompletion or WaitForUserCompletion. +func InitiateDeviceAuthWithContext(ctx context.Context, sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) { v := url.Values{ "client_id": []string{clientID}, "resource": []string{resource}, @@ -117,7 +125,7 @@ func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resour req.ContentLength = int64(len(s)) req.Header.Set(contentType, mimeTypeFormPost) - resp, err := sender.Do(req) + resp, err := sender.Do(req.WithContext(ctx)) if err != nil { return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error()) } @@ -151,7 +159,14 @@ func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resour // CheckForUserCompletion takes a DeviceCode and checks with the Azure AD OAuth endpoint // to see if the device flow has: been completed, timed out, or otherwise failed +// Deprecated: use CheckForUserCompletionWithContext() instead. func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) { + return CheckForUserCompletionWithContext(context.Background(), sender, code) +} + +// CheckForUserCompletionWithContext takes a DeviceCode and checks with the Azure AD OAuth endpoint +// to see if the device flow has: been completed, timed out, or otherwise failed +func CheckForUserCompletionWithContext(ctx context.Context, sender Sender, code *DeviceCode) (*Token, error) { v := url.Values{ "client_id": []string{code.ClientID}, "code": []string{*code.DeviceCode}, @@ -169,7 +184,7 @@ func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) { req.ContentLength = int64(len(s)) req.Header.Set(contentType, mimeTypeFormPost) - resp, err := sender.Do(req) + resp, err := sender.Do(req.WithContext(ctx)) if err != nil { return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error()) } @@ -213,12 +228,19 @@ func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) { // WaitForUserCompletion calls CheckForUserCompletion repeatedly until a token is granted or an error state occurs. // This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'. +// Deprecated: use WaitForUserCompletionWithContext() instead. func WaitForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) { + return WaitForUserCompletionWithContext(context.Background(), sender, code) +} + +// WaitForUserCompletionWithContext calls CheckForUserCompletion repeatedly until a token is granted or an error +// state occurs. This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'. +func WaitForUserCompletionWithContext(ctx context.Context, sender Sender, code *DeviceCode) (*Token, error) { intervalDuration := time.Duration(*code.Interval) * time.Second waitDuration := intervalDuration for { - token, err := CheckForUserCompletion(sender, code) + token, err := CheckForUserCompletionWithContext(ctx, sender, code) if err == nil { return token, nil @@ -237,6 +259,11 @@ func WaitForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) { return nil, fmt.Errorf("%s Error waiting for user to complete device flow. Server told us to slow_down too much", logPrefix) } - time.Sleep(waitDuration) + select { + case <-time.After(waitDuration): + // noop + case <-ctx.Done(): + return nil, ctx.Err() + } } } diff --git a/autorest/adal/devicetoken_test.go b/autorest/adal/devicetoken_test.go index 6beee161f..060749dc9 100644 --- a/autorest/adal/devicetoken_test.go +++ b/autorest/adal/devicetoken_test.go @@ -15,11 +15,13 @@ package adal // limitations under the License. import ( + "context" "encoding/json" "fmt" "net/http" "strings" "testing" + "time" "github.com/Azure/go-autorest/autorest/mocks" ) @@ -328,3 +330,15 @@ func TestDeviceTokenReturnsErrorIfTokenEmptyAndStatusOK(t *testing.T) { t.Fatalf("response body was left open!") } } + +func TestWaitForUserCompletionWithContext(t *testing.T) { + sender := SenderFunc(func(*http.Request) (*http.Response, error) { + return mocks.NewResponseWithContent(`{"error":"authorization_pending"}`), nil + }) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + _, err := WaitForUserCompletionWithContext(ctx, sender, deviceCode()) + if err != context.DeadlineExceeded { + t.Fatalf("adal: got wrong error expected(%s) actual(%s)", context.DeadlineExceeded.Error(), err.Error()) + } +} diff --git a/autorest/version.go b/autorest/version.go index 925d046fa..b73e695ac 100644 --- a/autorest/version.go +++ b/autorest/version.go @@ -19,7 +19,7 @@ import ( "runtime" ) -const number = "v13.1.0" +const number = "v13.2.0" var ( userAgent = fmt.Sprintf("Go/%s (%s-%s) go-autorest/%s", diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 8799e19ef..f33aa0710 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -24,7 +24,7 @@ steps: go get github.com/jstemmer/go-junit-report go get github.com/axw/gocov/gocov go get github.com/AlekSi/gocov-xml - go get -u gopkg.in/matm/v1/gocov-html + go get -u github.com/matm/gocov-html workingDirectory: '$(sdkPath)' displayName: 'Install Dependencies' - script: |