Skip to content

Commit

Permalink
Add endpoint to get SAML providers for a user.
Browse files Browse the repository at this point in the history
View is combined with user SSO views.

Includes a new version of the view that takes explicit "username" or "email".

OC-4285
  • Loading branch information
jcdyer committed Oct 18, 2018
1 parent 75a739e commit b3521e0
Show file tree
Hide file tree
Showing 5 changed files with 298 additions and 50 deletions.
98 changes: 86 additions & 12 deletions common/djangoapps/third_party_auth/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
import unittest

import ddt
from django.urls import reverse
import six
from django.conf import settings
from django.http import QueryDict
from django.test.utils import override_settings
from django.urls import reverse
from mock import patch
from provider.constants import CONFIDENTIAL
from provider.oauth2.models import Client, AccessToken
from openedx.core.lib.api.permissions import ApiKeyHeaderPermission
from rest_framework.test import APITestCase
from django.conf import settings
from django.test.utils import override_settings
from social_django.models import UserSocialAuth

from student.tests.factories import UserFactory
Expand All @@ -29,6 +30,7 @@
CARL_USERNAME = "carl"
STAFF_USERNAME = "staff"
ADMIN_USERNAME = "admin"
NONEXISTENT_USERNAME = "nobody"
# These users will be created and linked to third party accounts:
LINKED_USERS = (ALICE_USERNAME, STAFF_USERNAME, ADMIN_USERNAME)
PASSWORD = "edx"
Expand Down Expand Up @@ -62,9 +64,10 @@ def setUp(self):
make_staff = (username == STAFF_USERNAME) or make_superuser
user = UserFactory.create(
username=username,
email='{}@example.com'.format(username),
password=PASSWORD,
is_staff=make_staff,
is_superuser=make_superuser
is_superuser=make_superuser,
)
UserSocialAuth.objects.create(
user=user,
Expand All @@ -77,15 +80,13 @@ def setUp(self):
uid='{}:remote_{}'.format(testshib.slug, username),
)
# Create another user not linked to any providers:
UserFactory.create(username=CARL_USERNAME, password=PASSWORD)
UserFactory.create(username=CARL_USERNAME, email='{}@example.com'.format(CARL_USERNAME), password=PASSWORD)


@override_settings(EDX_API_KEY=VALID_API_KEY)
@ddt.ddt
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class UserViewAPITests(TpaAPITestCase):
class UserViewsMixin(object):
"""
Test the Third Party Auth User REST API
Generic TestCase to exercise the v1 and v2 UserViews.
"""

def expected_active(self, username):
Expand Down Expand Up @@ -124,7 +125,7 @@ def expected_active(self, username):
@ddt.unpack
def test_list_connected_providers(self, request_user, target_user, expect_result):
self.client.login(username=request_user, password=PASSWORD)
url = reverse('third_party_auth_users_api', kwargs={'username': target_user})
url = self.make_url({'username': target_user})

response = self.client.get(url)
self.assertEqual(response.status_code, expect_result)
Expand All @@ -140,14 +141,87 @@ def test_list_connected_providers(self, request_user, target_user, expect_result
(None, ALICE_USERNAME, 403),
)
@ddt.unpack
def test_list_connected_providers__withapi_key(self, api_key, target_user, expect_result):
url = reverse('third_party_auth_users_api', kwargs={'username': target_user})
def test_list_connected_providers_with_api_key(self, api_key, target_user, expect_result):
url = self.make_url({'username': target_user})
response = self.client.get(url, HTTP_X_EDX_API_KEY=api_key)
self.assertEqual(response.status_code, expect_result)
if expect_result == 200:
self.assertIn("active", response.data)
self.assertItemsEqual(response.data["active"], self.expected_active(target_user))

@ddt.data(
(True, ALICE_USERNAME, 200, True),
(True, CARL_USERNAME, 200, False),
(False, ALICE_USERNAME, 200, True),
(False, CARL_USERNAME, 403, None),
)
@ddt.unpack
def test_allow_unprivileged_response(self, allow_unprivileged, requesting_user, expect, include_remote_id):
self.client.login(username=requesting_user, password=PASSWORD)
with override_settings(ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY=allow_unprivileged):
url = self.make_url({'username': ALICE_USERNAME})
response = self.client.get(url)
self.assertEqual(response.status_code, expect)
if response.status_code == 200:
self.assertGreater(len(response.data['active']), 0)
for provider_data in response.data['active']:
self.assertEqual(include_remote_id, 'remote_id' in provider_data)

def test_allow_query_by_email(self):
self.client.login(username=ALICE_USERNAME, password=PASSWORD)
url = self.make_url({'email': '{}@example.com'.format(ALICE_USERNAME)})
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertGreater(len(response.data['active']), 0)

def test_throttling(self):
# Default throttle is 10/min. Make 11 requests to verify
throttling_user = UserFactory.create(password=PASSWORD)
self.client.login(username=throttling_user.username, password=PASSWORD)
url = self.make_url({'username': ALICE_USERNAME})
with override_settings(ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY=True):
for _ in range(10):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
response = self.client.get(url)
self.assertEqual(response.status_code, 200)


@override_settings(EDX_API_KEY=VALID_API_KEY)
@ddt.ddt
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class UserViewAPITests(UserViewsMixin, TpaAPITestCase):
"""
Test the Third Party Auth User REST API
"""

def make_url(self, identifier):
"""
Return the view URL, with the identifier provided
"""
return reverse(
'third_party_auth_users_api',
kwargs={'username': identifier.values()[0]}
)


@override_settings(EDX_API_KEY=VALID_API_KEY)
@ddt.ddt
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class UserViewV2APITests(UserViewsMixin, TpaAPITestCase):
"""
Test the Third Party Auth User REST API
"""

def make_url(self, identifier):
"""
Return the view URL, with the identifier provided
"""
return '?'.join([
reverse('third_party_auth_users_api_v2'),
six.moves.urllib.parse.urlencode(identifier)
])


@override_settings(EDX_API_KEY=VALID_API_KEY)
@ddt.ddt
Expand Down
3 changes: 2 additions & 1 deletion common/djangoapps/third_party_auth/api/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from django.conf import settings
from django.conf.urls import url

from .views import UserMappingView, UserView
from .views import UserMappingView, UserView, UserViewV2


PROVIDER_PATTERN = r'(?P<provider_id>[\w.+-]+)(?:\:(?P<idp_slug>[\w.+-]+))?'
Expand All @@ -14,6 +14,7 @@
UserView.as_view(),
name='third_party_auth_users_api',
),
url(r'^v0/users/', UserViewV2.as_view(), name='third_party_auth_users_api_v2'),
url(
r'^v0/providers/{provider_pattern}/users$'.format(provider_pattern=PROVIDER_PATTERN),
UserMappingView.as_view(),
Expand Down
Loading

0 comments on commit b3521e0

Please sign in to comment.