Skip to content

Commit

Permalink
Merge pull request #1613 from Shelf-nu/sso-multi-domain
Browse files Browse the repository at this point in the history
feat: multi domain support for SSO
  • Loading branch information
DonKoko authored Jan 30, 2025
2 parents c6e3c0f + 5aab26f commit 0d311c8
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 114 deletions.
2 changes: 1 addition & 1 deletion app/database/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ enum OrganizationRoles {
model SsoDetails {
id String @id @default(cuid())
// The domain of the organization
// The domains of the organization. Comma separated for multiple domains
domain String
organizations Organization[]
Expand Down
101 changes: 71 additions & 30 deletions app/modules/organization/service.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,45 +65,59 @@ export const getOrganizationByUserId = async ({
}
};

export const getOrganizationsBySsoDomain = async (domain: string) => {
/**
* Gets organizations that use the email domain for SSO
* Supports multiple domains per organization via comma-separated domain strings
* @param emailDomain - Email domain to check
* @returns Array of organizations that use this domain for SSO
*/
export async function getOrganizationsBySsoDomain(emailDomain: string) {
try {
const orgs = await db.organization
.findMany({
// We dont throw as we need to handle the case where no organization is found for the domain in the app logic
where: {
ssoDetails: {
is: {
domain: domain,
if (!emailDomain) {
throw new ShelfError({
cause: null,
message: "Email domain is required",
additionalData: { emailDomain },
label: "SSO",
});
}

// Query for organizations where the domain field contains the email domain
const organizations = await db.organization.findMany({
where: {
ssoDetails: {
isNot: null,
},
AND: [
{
ssoDetails: {
domain: {
contains: emailDomain,
},
},
},
type: "TEAM",
},
include: {
ssoDetails: true,
},
})
.catch((cause) => {
throw new ShelfError({
cause,
title: "Organization not found",
message:
"It looks like the organization you're trying to log in to is not found. Please contact our support team to get access to your organization.",
additionalData: { domain },
label,
});
});
],
},
include: {
ssoDetails: true,
},
});

return orgs;
// Filter to ensure exact domain matches
return organizations.filter((org) =>
org.ssoDetails?.domain
? emailMatchesDomains(emailDomain, org.ssoDetails.domain)
: false
);
} catch (cause) {
throw new ShelfError({
cause,
message:
"Something went wrong with fetching the organizations related to your domain",
additionalData: { domain },
label,
message: "Failed to get organizations by SSO domain",
additionalData: { emailDomain },
label: "SSO",
});
}
};
}

export async function createOrganization({
name,
Expand Down Expand Up @@ -362,3 +376,30 @@ export async function toggleOrganizationSso({
});
}
}

/**
* Utility function to parse and validate domains from a comma-separated string
* @param domainsString - Comma-separated string of domains
* @returns Array of cleaned domain strings
*/
export function parseDomains(domainsString: string): string[] {
return domainsString
.split(",")
.map((domain) => domain.trim().toLowerCase())
.filter(Boolean);
}

/**
* Checks if a given email matches any of the provided comma-separated domains
* @param email - Email address to check
* @param domainsString - Comma-separated string of domains
* @returns boolean indicating if email matches any domain
*/
export function emailMatchesDomains(
emailDomain: string,
domainsString: string | null
): boolean {
if (!emailDomain || !domainsString) return false;
const domains = parseDomains(domainsString);
return domains.includes(emailDomain.toLowerCase());
}
112 changes: 48 additions & 64 deletions app/modules/user/service.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,8 @@ export async function createUserFromSSO(
) {
try {
const { email, userId } = authSession;

const { firstName, lastName, groups } = userData;
const domain = email.split("@")[1];
const emailDomain = email.split("@")[1];

// Create user with personal workspace - all users get this now
const user = await createUser({
Expand All @@ -233,45 +232,40 @@ export async function createUserFromSSO(
isSSO: true,
});

// Find organizations that use this email domain for SSO
const organizations = await getOrganizationsBySsoDomain(domain);

// No organizations using this domain is perfectly valid for Pure SSO
// User can still log in and will be able to access workspaces through invites
if (organizations.length > 0) {
// Process SCIM access for organizations that have group mappings
for (const org of organizations) {
const { ssoDetails } = org;
if (!ssoDetails) continue;

// Check if this organization uses SCIM (has group mappings)
const hasGroupMappings = !!(
ssoDetails.adminGroupId ||
ssoDetails.baseUserGroupId ||
ssoDetails.selfServiceGroupId
);
// Find organizations that match this domain - handles multiple domains per org
const organizations = await getOrganizationsBySsoDomain(emailDomain);

if (hasGroupMappings) {
const role = getRoleFromGroupId(ssoDetails, groups);
if (role) {
await createUserOrgAssociation(db, {
userId: user.id,
organizationIds: [org.id],
roles: [role],
});
// For each matching organization, handle SCIM access if configured
for (const org of organizations) {
const { ssoDetails } = org;
if (!ssoDetails) continue;

await createTeamMember({
name: `${firstName} ${lastName}`,
organizationId: org.id,
userId,
});
}
// Check if this organization uses SCIM (has group mappings)
const hasGroupMappings = !!(
ssoDetails.adminGroupId ||
ssoDetails.baseUserGroupId ||
ssoDetails.selfServiceGroupId
);

if (hasGroupMappings) {
const role = getRoleFromGroupId(ssoDetails, groups);
if (role) {
await createUserOrgAssociation(db, {
userId: user.id,
organizationIds: [org.id],
roles: [role],
});

await createTeamMember({
name: `${firstName} ${lastName}`,
organizationId: org.id,
userId,
});
}
}
}

// Return the user and org (if any SCIM orgs exist)
// For pure SSO with no org mappings, org will be null
// Return the user and first matching org (if any)
return { user, org: organizations[0] || null };
} catch (cause: any) {
throw new ShelfError({
Expand Down Expand Up @@ -383,21 +377,7 @@ async function handleSCIMTransition(

/**
* Updates an existing SSO user on subsequent logins.
* Handles both Pure SSO and SCIM SSO scenarios:
*
* For Pure SSO users:
* - Updates their name if changed in IDP
* - Maintains their personal workspace and manual workspace invites
*
* For SCIM SSO users:
* - Updates their name if changed in IDP
* - Updates their workspace access based on current IDP group membership
* - Maintains their personal workspace regardless of group membership
*
* @param authSession - The authentication session from Supabase
* @param existingUser - The existing user record from our database
* @param userData - Updated user data from SSO provider
* @returns Object containing updated user and org (if any SCIM orgs exist)
* Handles both Pure SSO and SCIM SSO scenarios for multiple domains.
*/
export async function updateUserFromSSO(
authSession: AuthSession,
Expand All @@ -416,8 +396,7 @@ export async function updateUserFromSSO(
}> {
const { email, userId } = authSession;
const { firstName, lastName, groups } = userData;
const domain = email.split("@")[1];
const transitions: UserOrgTransition[] = [];
const emailDomain = email.split("@")[1];

try {
let user = existingUser;
Expand All @@ -431,21 +410,27 @@ export async function updateUserFromSSO(
});
}

const domainOrganizations = await getOrganizationsBySsoDomain(domain);
// Find organizations that match this user's email domain
// getOrganizationsBySsoDomain now handles multiple domains per org
const domainOrganizations = await getOrganizationsBySsoDomain(emailDomain);
const existingUserOrganizations = user.userOrganizations;

const transitions: UserOrgTransition[] = [];

for (const org of domainOrganizations) {
const { ssoDetails } = org;
if (!ssoDetails) continue;

// Check if this organization uses SCIM (has group mappings)
const hasGroupMappings = !!(
ssoDetails.adminGroupId ||
ssoDetails.baseUserGroupId ||
ssoDetails.selfServiceGroupId
);

if (hasGroupMappings) {
// Get desired role based on user's groups
const desiredRole = getRoleFromGroupId(ssoDetails, groups);
// Find if user already has access to this org
const existingOrgAccess = existingUserOrganizations.find(
(uo) => uo.organization.id === org.id
);
Expand All @@ -460,27 +445,27 @@ export async function updateUserFromSSO(
);
transitions.push(transition);
} else if (desiredRole) {
// First create the org association
// User doesn't have access but should - grant it
await createUserOrgAssociation(db, {
userId: user.id,
organizationIds: [org.id],
roles: [desiredRole],
});

// Create team member for the new organization access
await createTeamMember({
name: `${firstName} ${lastName}`,
organizationId: org.id,
userId,
});

// Then handle it as a transition from no roles to the desired role
const transition = await handleSCIMTransition(
transitions.push({
userId,
org,
[], // No previous roles
desiredRole
);
transitions.push(transition);
organizationId: org.id,
previousRoles: [],
newRole: desiredRole,
transitionType: "ACCESS_GRANTED",
});
}
}
}
Expand All @@ -506,8 +491,7 @@ export async function updateUserFromSSO(
additionalData: {
email,
userId,
domain,
transitions,
domain: emailDomain,
},
label: "SSO",
});
Expand Down
21 changes: 16 additions & 5 deletions app/routes/_layout+/admin-dashboard+/org.$organizationId.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ import { ShelfError, makeShelfError } from "~/utils/error";
import { isFormProcessing } from "~/utils/form";
import { getParams, data, error, parseData } from "~/utils/http.server";
import { extractCSVDataFromContentImport } from "~/utils/import.server";
import { isValidDomain } from "~/utils/misc";
import { requireAdmin } from "~/utils/roles.server";
import { validateDomains } from "~/utils/sso.server";

export const loader = async ({ context, params }: LoaderFunctionArgs) => {
const authSession = context.getSession();
Expand Down Expand Up @@ -117,10 +117,21 @@ export const action = async ({
selfServiceGroupId: z.string(),
domain: z
.string()
.transform((email) => email.toLowerCase())
.refine(isValidDomain, () => ({
message: "Please enter a valid domain name",
})),
.transform((domains) => domains.toLowerCase())
.transform((domains, ctx) => {
try {
return validateDomains(domains).join(", ");
} catch (error) {
ctx.addIssue({
code: z.ZodIssueCode.custom,
message:
error instanceof Error
? error.message
: "Invalid domains",
});
return z.NEVER;
}
}),
})
);

Expand Down
Loading

0 comments on commit 0d311c8

Please sign in to comment.