Skip to content

Commit

Permalink
Add support for validating the org_name claim [SDK-4414] (#782)
Browse files Browse the repository at this point in the history
  • Loading branch information
Widcket authored Jul 13, 2023
1 parent 9c3254d commit 2c89b61
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 38 deletions.
43 changes: 36 additions & 7 deletions Auth0/ClaimValidators.swift
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ struct IDTokenAuthTimeValidator: JWTValidator {
}
}

struct IDTokenOrgIdValidator: JWTValidator {
struct IDTokenOrgIDValidator: JWTValidator {
enum ValidationError: Auth0Error {
case missingOrgId
case mismatchedOrgId(actual: String, expected: String)
Expand All @@ -263,16 +263,45 @@ struct IDTokenOrgIdValidator: JWTValidator {
}
}

private let expectedOrganization: String
private let expectedOrgID: String

init(organization: String) {
self.expectedOrganization = organization
init(orgID: String) {
self.expectedOrgID = orgID
}

func validate(_ jwt: JWT) -> Auth0Error? {
guard let actualOrganization = jwt.claim(name: "org_id").string else { return ValidationError.missingOrgId }
guard actualOrganization == expectedOrganization else {
return ValidationError.mismatchedOrgId(actual: actualOrganization, expected: expectedOrganization)
guard let actualOrgID = jwt.claim(name: "org_id").string else { return ValidationError.missingOrgId }
guard actualOrgID == expectedOrgID else {
return ValidationError.mismatchedOrgId(actual: actualOrgID, expected: expectedOrgID)
}
return nil
}
}

struct IDTokenOrgNameValidator: JWTValidator {
enum ValidationError: Auth0Error {
case missingOrgName
case mismatchedOrgName(actual: String, expected: String)

var debugDescription: String {
switch self {
case .missingOrgName: return "Organization Name (org_name) claim must be a string present in the ID token"
case .mismatchedOrgName(let actual, let expected):
return "Organization Name (org_name) claim value mismatch in the ID token; expected (\(expected)), found (\(actual))"
}
}
}

private let expectedOrgName: String

init(orgName: String) {
self.expectedOrgName = orgName
}

func validate(_ jwt: JWT) -> Auth0Error? {
guard let actualOrgName = jwt.claim(name: "org_name").string else { return ValidationError.missingOrgName }
guard actualOrgName.caseInsensitiveCompare(expectedOrgName) == .orderedSame else {
return ValidationError.mismatchedOrgName(actual: actualOrgName, expected: expectedOrgName)
}
return nil
}
Expand Down
6 changes: 5 additions & 1 deletion Auth0/IDTokenValidator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ func validate(idToken: String,
claimValidators.append(IDTokenAuthTimeValidator(leeway: context.leeway, maxAge: maxAge))
}
if let organization = context.organization {
claimValidators.append(IDTokenOrgIdValidator(organization: organization))
if organization.starts(with: "org_") {
claimValidators.append(IDTokenOrgIDValidator(orgID: organization))
} else {
claimValidators.append(IDTokenOrgNameValidator(orgName: organization))
}
}
let validator = IDTokenValidator(signatureValidator: signatureValidator ?? IDTokenSignatureValidator(context: context),
claimsValidator: claimsValidator ?? IDTokenClaimsValidator(validators: claimValidators),
Expand Down
81 changes: 65 additions & 16 deletions Auth0Tests/ClaimValidatorsSpec.swift
Original file line number Diff line number Diff line change
Expand Up @@ -414,47 +414,96 @@ class ClaimValidatorsSpec: IDTokenValidatorBaseSpec {

}

describe("organization validation") {
describe("organization id validation") {

var organizationValidator: IDTokenOrgIdValidator!
let expectedOrganization = "abc1234"
var orgIDValidator: IDTokenOrgIDValidator!
let expectedOrgID = "org_abc1234"

beforeEach {
organizationValidator = IDTokenOrgIdValidator(organization: expectedOrganization)
orgIDValidator = IDTokenOrgIDValidator(orgID: expectedOrgID)
}

context("missing org_id") {
it("should return nil if org_id is present") {
let jwt = generateJWT(organization: expectedOrganization)
let jwt = generateJWT(orgID: expectedOrgID)

expect(organizationValidator.validate(jwt)).to(beNil())
expect(orgIDValidator.validate(jwt)).to(beNil())
}

it("should return an error if org_id is missing") {
let jwt = generateJWT(organization: nil)
let expectedError = IDTokenOrgIdValidator.ValidationError.missingOrgId
let result = organizationValidator.validate(jwt)
let jwt = generateJWT(orgID: nil)
let expectedError = IDTokenOrgIDValidator.ValidationError.missingOrgId
let result = orgIDValidator.validate(jwt)

expect(result).to(matchError(expectedError))
expect(result?.localizedDescription).to(equal(expectedError.localizedDescription))
}
}

context("mismatched org_id") {
it("should return an error if org_id does not match the request organization") {
let organization = "xyz6789"
let jwt = generateJWT(organization: organization)
let expectedError = IDTokenOrgIdValidator.ValidationError.mismatchedOrgId(actual: organization,
expected: expectedOrganization)
let result = organizationValidator.validate(jwt)
it("should return an error if org_id does not match the request organization id") {
let orgID = "org_xyz6789"
let jwt = generateJWT(orgID: orgID)
let expectedError = IDTokenOrgIDValidator.ValidationError.mismatchedOrgId(actual: orgID,
expected: expectedOrgID)
let result = orgIDValidator.validate(jwt)

expect(result).to(matchError(expectedError))
expect(result?.localizedDescription).to(equal(expectedError.localizedDescription))
}
}

}


describe("organization name validation") {

var orgNameValidator: IDTokenOrgNameValidator!
let expectedOrgName = "abc1234"

beforeEach {
orgNameValidator = IDTokenOrgNameValidator(orgName: expectedOrgName)
}

context("missing org_name") {
it("should return nil if org_name is present") {
let jwt = generateJWT(orgName: expectedOrgName)

expect(orgNameValidator.validate(jwt)).to(beNil())
}

it("should return an error if org_name is missing") {
let jwt = generateJWT(orgName: nil)
let expectedError = IDTokenOrgNameValidator.ValidationError.missingOrgName
let result = orgNameValidator.validate(jwt)

expect(result).to(matchError(expectedError))
expect(result?.localizedDescription).to(equal(expectedError.localizedDescription))
}
}

context("mismatched org_name") {
it("should return an error if org_name does not match the request organization name") {
let orgName = "xyz6789"
let jwt = generateJWT(orgName: orgName)
let expectedError = IDTokenOrgNameValidator.ValidationError.mismatchedOrgName(actual: orgName,
expected: expectedOrgName)
let result = orgNameValidator.validate(jwt)

expect(result).to(matchError(expectedError))
expect(result?.localizedDescription).to(equal(expectedError.localizedDescription))
}
}

it("should perform a case insensitive compare") {
let orgName = "aBc1234"
let expectedOrgName = "AbC1234"
let jwt = generateJWT(orgName: orgName)
orgNameValidator = IDTokenOrgNameValidator(orgName: expectedOrgName)

expect(orgNameValidator.validate(jwt)).to(beNil())
}

}
}

}
20 changes: 13 additions & 7 deletions Auth0Tests/Generators.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ private func generateJWTPayload(iss: String?,
nonce: String?,
maxAge: Int?,
authTime: Date?,
organization: String?) -> String {
orgID: String?,
orgName: String?) -> String {
var bodyDict: [String: Any] = [:]

if let iss = iss {
Expand Down Expand Up @@ -84,10 +85,14 @@ private func generateJWTPayload(iss: String?,
bodyDict["nonce"] = nonce
}

if let organization = organization {
bodyDict["org_id"] = organization
if let orgID = orgID {
bodyDict["org_id"] = orgID
}


if let orgName = orgName {
bodyDict["org_name"] = orgName
}

return encodeJWTPart(from: bodyDict)
}

Expand All @@ -102,7 +107,8 @@ func generateJWT(alg: String = JWTAlgorithm.rs256.rawValue,
nonce: String? = "a1b2c3d4e5",
maxAge: Int? = nil,
authTime: Date? = nil,
organization: String? = nil,
orgID: String? = nil,
orgName: String? = nil,
signature: String? = nil) -> JWT {
let header = generateJWTHeader(alg: alg, kid: kid)
let body = generateJWTPayload(iss: iss,
Expand All @@ -114,7 +120,8 @@ func generateJWT(alg: String = JWTAlgorithm.rs256.rawValue,
nonce: nonce,
maxAge: maxAge,
authTime: authTime,
organization: organization)
orgID: orgID,
orgName: orgName)

let signableParts = "\(header).\(body)"
var signaturePart = ""
Expand All @@ -128,7 +135,6 @@ func generateJWT(alg: String = JWTAlgorithm.rs256.rawValue,
signaturePart = (data! as Data).a0_encodeBase64URLSafe()!
}


return try! decode(jwt: "\(signableParts).\(signaturePart)")
}

Expand Down
73 changes: 69 additions & 4 deletions Auth0Tests/IDTokenValidatorSpec.swift
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,16 @@ class IDTokenValidatorSpec: IDTokenValidatorBaseSpec {
}
}

it("should validate a token with an organization") {
let organization = "abc1234"
let jwt = generateJWT(aud: aud, azp: nil, nonce: nil, maxAge: nil, authTime: nil, organization: organization)
it("should validate a token with an organization ID") {
let orgID = "org_abc1234"
let jwt = generateJWT(aud: aud, azp: nil, nonce: nil, maxAge: nil, authTime: nil, orgID: orgID)
let context = IDTokenValidatorContext(issuer: validatorContext.issuer,
audience: aud[0],
jwksRequest: validatorContext.jwksRequest,
leeway: validatorContext.leeway,
maxAge: nil,
nonce: nil,
organization: organization)
organization: orgID)

await waitUntil { done in
validate(idToken: jwt.string,
Expand All @@ -236,6 +236,71 @@ class IDTokenValidatorSpec: IDTokenValidatorBaseSpec {
}
}
}

it("should validate a token with an organization name") {
let orgName = "abc1234"
let jwt = generateJWT(aud: aud, azp: nil, nonce: nil, maxAge: nil, authTime: nil, orgName: orgName)
let context = IDTokenValidatorContext(issuer: validatorContext.issuer,
audience: aud[0],
jwksRequest: validatorContext.jwksRequest,
leeway: validatorContext.leeway,
maxAge: nil,
nonce: nil,
organization: orgName)

await waitUntil { done in
validate(idToken: jwt.string,
with: context,
signatureValidator: mockSignatureValidator) { error in
expect(error).to(beNil())
done()
}
}
}

it("should expect an organization ID instead of an organization name") {
let orgID = "org_abc1234"
let jwt = generateJWT(aud: aud, azp: nil, nonce: nil, maxAge: nil, authTime: nil, orgName: orgID)
let context = IDTokenValidatorContext(issuer: validatorContext.issuer,
audience: aud[0],
jwksRequest: validatorContext.jwksRequest,
leeway: validatorContext.leeway,
maxAge: nil,
nonce: nil,
organization: orgID)
let expectedError = IDTokenOrgIDValidator.ValidationError.missingOrgId

await waitUntil { done in
validate(idToken: jwt.string,
with: context,
signatureValidator: mockSignatureValidator) { error in
expect(error).to(matchError(expectedError))
done()
}
}
}

it("should expect an organization name instead of an organization ID") {
let orgName = "abc1234"
let jwt = generateJWT(aud: aud, azp: nil, nonce: nil, maxAge: nil, authTime: nil, orgID: orgName)
let context = IDTokenValidatorContext(issuer: validatorContext.issuer,
audience: aud[0],
jwksRequest: validatorContext.jwksRequest,
leeway: validatorContext.leeway,
maxAge: nil,
nonce: nil,
organization: orgName)
let expectedError = IDTokenOrgNameValidator.ValidationError.missingOrgName

await waitUntil { done in
validate(idToken: jwt.string,
with: context,
signatureValidator: mockSignatureValidator) { error in
expect(error).to(matchError(expectedError))
done()
}
}
}
}

}
Expand Down
6 changes: 3 additions & 3 deletions EXAMPLES.md
Original file line number Diff line number Diff line change
Expand Up @@ -1212,7 +1212,7 @@ Auth0
```swift
Auth0
.webAuth()
.organization("YOUR_AUTH0_ORGANIZATION_ID")
.organization("YOUR_AUTH0_ORGANIZATION_NAME_OR_ID")
.start { result in
switch result {
case .success(let credentials):
Expand All @@ -1230,7 +1230,7 @@ Auth0
do {
let credentials = try await Auth0
.webAuth()
.organization("YOUR_AUTH0_ORGANIZATION_ID")
.organization("YOUR_AUTH0_ORGANIZATION_NAME_OR_ID")
.start()
print("Obtained credentials: \(credentials)")
} catch {
Expand All @@ -1245,7 +1245,7 @@ do {
```swift
Auth0
.webAuth()
.organization("YOUR_AUTH0_ORGANIZATION_ID")
.organization("YOUR_AUTH0_ORGANIZATION_NAME_OR_ID")
.start()
.sink(receiveCompletion: { completion in
if case .failure(let error) = completion {
Expand Down

0 comments on commit 2c89b61

Please sign in to comment.