Skip to content
This repository has been archived by the owner on Feb 12, 2025. It is now read-only.

Add ability to cancel device code flow auth #474

Merged
merged 3 commits into from
Oct 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# CHANGELOG

## v13.2.0

### New Features

- Added the following functions to replace their versions that don't take a context.
- `adal.InitiateDeviceAuthWithContext()`
- `adal.CheckForUserCompletionWithContext()`
- `adal.WaitForUserCompletionWithContext()`

## v13.1.0

### New Features
Expand Down
35 changes: 31 additions & 4 deletions autorest/adal/devicetoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ package adal
*/

import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -101,7 +102,14 @@ type deviceToken struct {

// InitiateDeviceAuth initiates a device auth flow. It returns a DeviceCode
// that can be used with CheckForUserCompletion or WaitForUserCompletion.
// Deprecated: use InitiateDeviceAuthWithContext() instead.
func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) {
return InitiateDeviceAuthWithContext(context.Background(), sender, oauthConfig, clientID, resource)
}

// InitiateDeviceAuthWithContext initiates a device auth flow. It returns a DeviceCode
// that can be used with CheckForUserCompletion or WaitForUserCompletion.
func InitiateDeviceAuthWithContext(ctx context.Context, sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) {
v := url.Values{
"client_id": []string{clientID},
"resource": []string{resource},
Expand All @@ -117,7 +125,7 @@ func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resour

req.ContentLength = int64(len(s))
req.Header.Set(contentType, mimeTypeFormPost)
resp, err := sender.Do(req)
resp, err := sender.Do(req.WithContext(ctx))
if err != nil {
return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error())
}
Expand Down Expand Up @@ -151,7 +159,14 @@ func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resour

// CheckForUserCompletion takes a DeviceCode and checks with the Azure AD OAuth endpoint
// to see if the device flow has: been completed, timed out, or otherwise failed
// Deprecated: use CheckForUserCompletionWithContext() instead.
func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
return CheckForUserCompletionWithContext(context.Background(), sender, code)
}

// CheckForUserCompletionWithContext takes a DeviceCode and checks with the Azure AD OAuth endpoint
// to see if the device flow has: been completed, timed out, or otherwise failed
func CheckForUserCompletionWithContext(ctx context.Context, sender Sender, code *DeviceCode) (*Token, error) {
v := url.Values{
"client_id": []string{code.ClientID},
"code": []string{*code.DeviceCode},
Expand All @@ -169,7 +184,7 @@ func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {

req.ContentLength = int64(len(s))
req.Header.Set(contentType, mimeTypeFormPost)
resp, err := sender.Do(req)
resp, err := sender.Do(req.WithContext(ctx))
if err != nil {
return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error())
}
Expand Down Expand Up @@ -213,12 +228,19 @@ func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {

// WaitForUserCompletion calls CheckForUserCompletion repeatedly until a token is granted or an error state occurs.
// This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'.
// Deprecated: use WaitForUserCompletionWithContext() instead.
func WaitForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
return WaitForUserCompletionWithContext(context.Background(), sender, code)
}

// WaitForUserCompletionWithContext calls CheckForUserCompletion repeatedly until a token is granted or an error
// state occurs. This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'.
func WaitForUserCompletionWithContext(ctx context.Context, sender Sender, code *DeviceCode) (*Token, error) {
intervalDuration := time.Duration(*code.Interval) * time.Second
waitDuration := intervalDuration

for {
token, err := CheckForUserCompletion(sender, code)
token, err := CheckForUserCompletionWithContext(ctx, sender, code)

if err == nil {
return token, nil
Expand All @@ -237,6 +259,11 @@ func WaitForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
return nil, fmt.Errorf("%s Error waiting for user to complete device flow. Server told us to slow_down too much", logPrefix)
}

time.Sleep(waitDuration)
select {
case <-time.After(waitDuration):
// noop
case <-ctx.Done():
return nil, ctx.Err()
}
}
}
14 changes: 14 additions & 0 deletions autorest/adal/devicetoken_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ package adal
// limitations under the License.

import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"testing"
"time"

"github.com/Azure/go-autorest/autorest/mocks"
)
Expand Down Expand Up @@ -328,3 +330,15 @@ func TestDeviceTokenReturnsErrorIfTokenEmptyAndStatusOK(t *testing.T) {
t.Fatalf("response body was left open!")
}
}

func TestWaitForUserCompletionWithContext(t *testing.T) {
sender := SenderFunc(func(*http.Request) (*http.Response, error) {
return mocks.NewResponseWithContent(`{"error":"authorization_pending"}`), nil
})
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
_, err := WaitForUserCompletionWithContext(ctx, sender, deviceCode())
if err != context.DeadlineExceeded {
t.Fatalf("adal: got wrong error expected(%s) actual(%s)", context.DeadlineExceeded.Error(), err.Error())
}
}
2 changes: 1 addition & 1 deletion autorest/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
"runtime"
)

const number = "v13.1.0"
const number = "v13.2.0"

var (
userAgent = fmt.Sprintf("Go/%s (%s-%s) go-autorest/%s",
Expand Down
2 changes: 1 addition & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ steps:
go get github.com/jstemmer/go-junit-report
go get github.com/axw/gocov/gocov
go get github.com/AlekSi/gocov-xml
go get -u gopkg.in/matm/v1/gocov-html
go get -u github.com/matm/gocov-html
workingDirectory: '$(sdkPath)'
displayName: 'Install Dependencies'
- script: |
Expand Down