diff --git a/src/components/Content.ts b/src/components/Content.ts index f11dbfce..e1e2a97f 100644 --- a/src/components/Content.ts +++ b/src/components/Content.ts @@ -5,9 +5,8 @@ import { when } from "lit/directives/when.js"; import { localized } from "@lit/localize"; import { StateController } from "@lit-app/state"; import { portalState } from "../state/portal-state"; -import { getCurrentAccessToken } from "../utils/storage"; +import { tokenState } from "../state/token-state"; import { theme } from "../utils/theme"; -import { tokenMatchesScope } from "../utils/token"; @customElement("bkd-content") @localized() @@ -95,7 +94,7 @@ export class Content extends LitElement { } render() { - if (!tokenMatchesScope(getCurrentAccessToken(), portalState.app.scope)) { + if (tokenState.scope !== portalState.app.scope) { // Token scope does not match current app, wait for correct // token to be activated in component to avoid requests // resulting in 403 due to unsufficient rights. diff --git a/src/components/Header/SubstitutionsToggle.ts b/src/components/Header/SubstitutionsToggle.ts index b48cb2fa..b03a2b28 100644 --- a/src/components/Header/SubstitutionsToggle.ts +++ b/src/components/Header/SubstitutionsToggle.ts @@ -7,12 +7,11 @@ import caretIcon from "../../assets/icons/caret.svg?raw"; import closeSmallIcon from "../../assets/icons/close-small.svg?raw"; import substitutionIcon from "../../assets/icons/substitution.svg?raw"; import { DropdownController } from "../../controllers/dropdown"; +import { tokenState } from "../../state/token-state.ts"; import { Substitution, fetchCurrentSubstitutions } from "../../utils/fetch"; import { buildUrl } from "../../utils/routing.ts"; -import { getCurrentAccessToken } from "../../utils/storage"; import { submit } from "../../utils/submit"; import { theme } from "../../utils/theme"; -import { getTokenPayload } from "../../utils/token"; import { SubstitutionsDropdown } from "./SubstitutionsDropdown"; @customElement("bkd-substitutions-toggle") @@ -133,11 +132,7 @@ export class SubstitutionsToggle extends LitElement { } private getActiveSubstitutionId(): number | null { - const token = getCurrentAccessToken(); - if (!token) return null; - - const { substitutionId } = getTokenPayload(token); - return substitutionId ?? null; + return tokenState.accessTokenPayload?.substitutionId ?? null; } private toggle(event: Event) { @@ -181,7 +176,7 @@ export class SubstitutionsToggle extends LitElement { private redirect(url: string): void { submit("POST", url, { - access_token: getCurrentAccessToken() ?? "", + access_token: tokenState.accessToken ?? "", redirect_uri: buildUrl("home"), }); } diff --git a/src/components/Portal.ts b/src/components/Portal.ts index 6deb34a7..43cb3d66 100644 --- a/src/components/Portal.ts +++ b/src/components/Portal.ts @@ -5,6 +5,7 @@ import { localized } from "@lit/localize"; import { StateController, Unsubscribe } from "@lit-app/state"; import { settings } from "../settings"; import { portalState } from "../state/portal-state"; +import { tokenState } from "../state/token-state.ts"; import { activateTokenForScope, createOAuthClient, @@ -14,16 +15,16 @@ import { import { getInitialLocale } from "../utils/locale"; import { getNavigationItemByAppPath } from "../utils/navigation"; import { getHash, getScopeFromUrl, updateHash } from "../utils/routing"; -import { getCurrentAccessToken } from "../utils/storage"; import { customProperties, fontFaces, registerLightDomStyles, theme, } from "../utils/theme"; -import { tokenMatchesScope } from "../utils/token.ts"; +import { initializeTokenRenewal } from "../utils/token-renewal.ts"; const oAuthClient = createOAuthClient(); +initializeTokenRenewal(oAuthClient); const authReady = (async function () { // Start Authorization Code Flow with PKCE @@ -86,7 +87,7 @@ export class Portal extends LitElement { // has no access to the navigation item from the redirect URL, // hence is redirected to home (see // https://github.com/bkd-mba-fbi/evento-portal/issues/106). - if (!tokenMatchesScope(getCurrentAccessToken(), portalState.app.scope)) { + if (tokenState.scope !== portalState.app.scope) { activateTokenForScope( oAuthClient, portalState.app.scope, @@ -126,11 +127,6 @@ export class Portal extends LitElement { window.removeEventListener("hashchange", this.handleHashChange); } - private isAuthenticated(): boolean { - const token = getCurrentAccessToken(); - return Boolean(token); - } - /** * Update the document title based on the current state */ @@ -201,7 +197,7 @@ export class Portal extends LitElement { render() { return html` ${when( - this.authReady && this.isAuthenticated(), + this.authReady && tokenState.authenticated, () => html` diff --git a/src/state/portal-state.ts b/src/state/portal-state.ts index d8fd3d3b..93c56acf 100644 --- a/src/state/portal-state.ts +++ b/src/state/portal-state.ts @@ -11,8 +11,8 @@ import { fetchInstanceName, fetchUserAccessInfo } from "../utils/fetch"; import { getInitialLocale, getLocale, updateLocale } from "../utils/locale"; import { filterAllowed, getApp, getNavigationItem } from "../utils/navigation"; import { cleanupQueryParams, updateQueryParam } from "../utils/routing"; -import { getCurrentAccessToken, storeLocale } from "../utils/storage"; -import { getTokenPayload } from "../utils/token"; +import { storeLocale } from "../utils/storage"; +import { tokenState } from "./token-state"; export const LOCALE_QUERY_PARAM = "locale"; export const NAV_ITEM_QUERY_PARAM = "module"; @@ -182,10 +182,9 @@ export class PortalState extends State { } private updateNavigation(): void { - const token = getCurrentAccessToken(); - if (!token) return; + const { instanceId } = tokenState; + if (!instanceId) return; - const { instanceId } = getTokenPayload(token); this.navigation = filterAllowed( settings.navigation, instanceId, @@ -237,16 +236,14 @@ export class PortalState extends State { } private async loadRolesAndPermissions(): Promise { - const token = getCurrentAccessToken(); - if (!token) return; + if (!tokenState.authenticated) return; const { roles, permissions } = await fetchUserAccessInfo(); this.rolesAndPermissions = [...roles, ...permissions]; } private async loadInstanceName(): Promise { - const token = getCurrentAccessToken(); - if (!token) return; + if (!tokenState.authenticated) return; const instanceName = await fetchInstanceName(); this.instanceName = [msg("Evento"), instanceName] diff --git a/src/state/token-state.ts b/src/state/token-state.ts new file mode 100644 index 00000000..2d998f20 --- /dev/null +++ b/src/state/token-state.ts @@ -0,0 +1,152 @@ +import { + getCurrentAccessToken, + getRefreshToken, + resetAllTokens, + storeAccessToken, + storeCurrentAccessToken, + storeRefreshToken, +} from "../utils/storage"; +import { TokenPayload, getTokenPayload, isTokenExpired } from "../utils/token"; + +type Subscriber = (token: TokenPayload | null) => void; +type Unsubscribe = () => void; + +export class TokenState { + private _refreshToken = getRefreshToken(); + private _refreshTokenPayload: TokenPayload | null = null; + private _accessToken = getCurrentAccessToken(); + private _accessTokenPayload: TokenPayload | null = null; + private refreshTokenSubscribers: Subscriber[] = []; + private accessTokenSubscribers: Subscriber[] = []; + + constructor() { + this.afterRefreshTokenUpdate(this.refreshToken, false); + this.afterAccessTokenUpdate(this.accessToken, false); + } + + get refreshToken() { + return this._refreshToken; + } + set refreshToken(refreshToken: string | null) { + this._refreshToken = refreshToken; + this.afterRefreshTokenUpdate(refreshToken); + } + + get refreshTokenPayload() { + return this._refreshTokenPayload; + } + + get accessToken() { + return this._accessToken; + } + set accessToken(accessToken: string | null) { + this._accessToken = accessToken; + this.afterAccessTokenUpdate(accessToken); + } + + get accessTokenPayload() { + return this._accessTokenPayload; + } + + get authenticated(): boolean { + return Boolean(this.accessToken); + } + + get scope(): string | null { + return this.accessTokenPayload?.scope ?? null; + } + + get locale(): string | null { + return this.accessTokenPayload?.locale ?? null; + } + + get instanceId(): string | null { + return this.accessTokenPayload?.instanceId ?? null; + } + + isRefreshTokenExpired(): boolean { + return isTokenExpired(this._refreshTokenPayload); + } + + resetAllTokens(): void { + this._refreshToken = null; + this._refreshTokenPayload = null; + + this._accessToken = null; + this._accessTokenPayload = null; + + resetAllTokens(); + } + + onRefreshTokenUpdate(callback: Subscriber): Unsubscribe { + this.refreshTokenSubscribers.push(callback); + + // Initially call with current value + callback(this.refreshTokenPayload); + + return () => { + const index = this.refreshTokenSubscribers.findIndex( + (s) => s === callback, + ); + this.refreshTokenSubscribers.splice(index, 1); + }; + } + + onAccessTokenUpdate(callback: Subscriber): Unsubscribe { + this.accessTokenSubscribers.push(callback); + + // Initially call with current value + callback(this.accessTokenPayload); + + return () => { + const index = this.accessTokenSubscribers.findIndex( + (s) => s === callback, + ); + this.accessTokenSubscribers.splice(index, 1); + }; + } + + private afterRefreshTokenUpdate( + refreshToken: string | null, + store = true, + ): void { + this._refreshTokenPayload = refreshToken + ? getTokenPayload(refreshToken) + : null; + + if (refreshToken && store) { + storeRefreshToken(refreshToken); + } + + this.notifyRefreshTokenSubscribers(); + } + + private afterAccessTokenUpdate( + accessToken: string | null, + store = true, + ): void { + const payload = accessToken ? getTokenPayload(accessToken) : null; + this._accessTokenPayload = payload; + + if (accessToken && payload && store) { + storeAccessToken(payload.scope, accessToken); + storeCurrentAccessToken(accessToken); + } + + this.notifyAccessTokenSubscribers(); + } + + private notifyRefreshTokenSubscribers(): void { + this.refreshTokenSubscribers.forEach((callback) => + callback(this.refreshTokenPayload), + ); + } + + private notifyAccessTokenSubscribers(): void { + this.accessTokenSubscribers.forEach((callback) => + callback(this.accessTokenPayload), + ); + } +} + +export const tokenState = new TokenState(); diff --git a/src/utils/auth.ts b/src/utils/auth.ts index 193a8fd1..409710e4 100644 --- a/src/utils/auth.ts +++ b/src/utils/auth.ts @@ -7,26 +7,17 @@ import { import { generateQueryString } from "@badgateway/oauth2-client/dist/client"; import { getCodeChallenge } from "@badgateway/oauth2-client/dist/client/authorization-code"; import { LOCALE_QUERY_PARAM, portalState } from "../state/portal-state"; +import { tokenState } from "../state/token-state"; import { log } from "./logging"; import { consumeLoginState, getAccessToken, - getCurrentAccessToken, getInstance, - getRefreshToken, - resetAllTokens, - storeCurrentAccessToken, storeInstance, storeLoginState, - storeToken, } from "./storage"; -import { getTokenPayload, isTokenExpired, isValidToken } from "./token"; -import { - clearTokenRenewalTimers, - renewAccessTokenOnExpiration, - renewRefreshTokenOnExpiration, - renewTokenOnExpiration, -} from "./token-renewal"; +import { isValidToken } from "./token"; +import { clearTokenRenewalTimers } from "./token-renewal"; const envSettings = window.eventoPortal.settings; @@ -73,7 +64,7 @@ export async function ensureAuthenticated( if (loginResult) { // Successfully logged in log("Successfully logged in"); - handleLoginResult(client, loginResult, loginState); + handleLoginResult(loginResult, loginState); return; } @@ -81,7 +72,7 @@ export async function ensureAuthenticated( if (substitutionResult) { // Started or stopped substitution log("Successfully started or stopped substitution"); - handleSubstitutionResult(client, substitutionResult); + handleSubstitutionResult(substitutionResult); return; } @@ -109,16 +100,13 @@ export async function activateTokenForScope( ): Promise { log(`Activate token for scope "${scope}" and locale "${locale}"`); - const refreshToken = getRefreshToken(); - if (!refreshToken || isTokenExpired(refreshToken)) { + if (tokenState.isRefreshTokenExpired()) { // Not authenticated or refresh token expired, redirect to login log("Not authenticated or refresh token expired, redirect to login"); return redirect(client, scope, locale, loginUrl); - } else { - renewRefreshTokenOnExpiration(client, refreshToken); } - const currentAccessToken = getCurrentAccessToken(); + const currentAccessToken = tokenState.accessToken; const cachedAccessToken = getAccessToken(scope); if (isValidToken(currentAccessToken, scope, locale)) { @@ -126,14 +114,12 @@ export async function activateTokenForScope( log( `Current token for scope "${scope}" and locale "${locale}" already set`, ); - renewAccessTokenOnExpiration(client, currentAccessToken); } else if (isValidToken(cachedAccessToken, scope, locale)) { // Token for scope/locale cached, set as current log( `Token for scope "${scope}" and locale "${locale}" cached, set as current`, ); - storeCurrentAccessToken(cachedAccessToken); - renewAccessTokenOnExpiration(client, cachedAccessToken); + tokenState.accessToken = cachedAccessToken; } else { // No token for scope/locale present or half expired, redirect for // token fetch/refresh @@ -148,8 +134,8 @@ export async function logout(client: OAuth2Client): Promise { const instance = getInstance(); if (!instance) throw new Error("No instance available"); - const token = getCurrentAccessToken(); - if (!token) return; + const { accessToken, scope, locale } = tokenState; + if (!accessToken || !scope || !locale) return; // Logout & reset tokens try { @@ -157,7 +143,7 @@ export async function logout(client: OAuth2Client): Promise { client, `${envSettings.oAuthPrefix}/Authorization/${instance}/Logout`, { - access_token: token, + access_token: accessToken, }, ); } catch (e) { @@ -166,11 +152,10 @@ export async function logout(client: OAuth2Client): Promise { throw e; } } finally { - resetAllTokens(); + tokenState.resetAllTokens(); clearTokenRenewalTimers(); // Redirect to login with scope/locale of current token - const { scope, locale } = getTokenPayload(token); await redirect(client, scope, locale, loginUrl); } } @@ -249,13 +234,12 @@ export const refreshUrl: RedirectUrlBuilder = async ( const [codeChallengeMethod, codeChallenge] = await getCodeChallenge(codeVerifier); - const refreshToken = getRefreshToken(); // url.searchParams.set("clientId", client.settings.clientId); url.searchParams.set("redirectUrl", redirectUri); url.searchParams.set("culture_info", locale); url.searchParams.set("application_scope", scope); - url.searchParams.set("refresh_token", refreshToken ?? ""); + url.searchParams.set("refresh_token", tokenState.refreshToken ?? ""); url.searchParams.set("response_type", "code"); url.searchParams.set("code_challenge_method", codeChallengeMethod); url.searchParams.set("code_challenge", codeChallenge); @@ -284,21 +268,21 @@ async function getTokenAfterLogin( } function handleLoginResult( - client: OAuth2Client, - token: OAuth2Token, + { refreshToken, accessToken }: OAuth2Token, loginState: { codeVerifier: string; redirectUri?: string; } | null, ): void { - const { accessToken } = token; - const { scope, instanceId } = getTokenPayload(accessToken); - storeToken(scope, token); - storeCurrentAccessToken(accessToken); - renewTokenOnExpiration(client, token); + tokenState.refreshToken = refreshToken; + tokenState.accessToken = accessToken; // Remember the chosen instance for later logins - storeInstance(instanceId); + const instanceId = tokenState.accessTokenPayload?.instanceId; + if (instanceId) { + // TODO: move to TokenState as well? + storeInstance(instanceId); + } if (loginState?.redirectUri) { portalState.navigate(new URL(loginState.redirectUri)); @@ -330,15 +314,10 @@ function getTokenAfterSubstitutionRedirect(): OAuth2Token | null { return null; } -function handleSubstitutionResult( - client: OAuth2Client, - token: OAuth2Token, -): void { - const { accessToken } = token; - const { scope } = getTokenPayload(accessToken); - storeToken(scope, token); - storeCurrentAccessToken(accessToken); - renewTokenOnExpiration(client, token); +function handleSubstitutionResult(token: OAuth2Token): void { + const { refreshToken, accessToken } = token; + tokenState.refreshToken = refreshToken; + tokenState.accessToken = accessToken; // Remove sensitive information from URL const url = new URL(document.location.href); diff --git a/src/utils/fetch.ts b/src/utils/fetch.ts index 517d3d82..314acc16 100644 --- a/src/utils/fetch.ts +++ b/src/utils/fetch.ts @@ -1,4 +1,4 @@ -import { getCurrentAccessToken } from "./storage"; +import { tokenState } from "../state/token-state"; const envSettings = window.eventoPortal.settings; @@ -49,13 +49,13 @@ async function fetchApi( url: string | URL, { method = "GET" } = {}, ): Promise { - const token = getCurrentAccessToken(); - if (!token) { + const { accessToken } = tokenState; + if (!accessToken) { throw new Error("No token available"); } const headers = new Headers({ - "CLX-Authorization": `token_type=urn:ietf:params:oauth:token-type:jwt-bearer, access_token=${token}`, + "CLX-Authorization": `token_type=urn:ietf:params:oauth:token-type:jwt-bearer, access_token=${accessToken}`, "Content-Type": "application/json", }); diff --git a/src/utils/storage.ts b/src/utils/storage.ts index 8b0e19b9..f33f8778 100644 --- a/src/utils/storage.ts +++ b/src/utils/storage.ts @@ -1,4 +1,3 @@ -import { OAuth2Token } from "@badgateway/oauth2-client"; import { getTokenPayload } from "./token"; const INSTANCE_KEY = "bkdInstance"; @@ -30,14 +29,16 @@ export function getRefreshToken(): string | null { return localStorage.getItem(REFRESH_TOKEN_KEY); } -export function storeToken(scope: string, token: OAuth2Token): void { - const { refreshToken, accessToken } = token; - localStorage.setItem(`${ACCESS_TOKEN_KEY}_${scope}`, accessToken); +export function storeRefreshToken(refreshToken: string | null): void { if (refreshToken) { localStorage.setItem(REFRESH_TOKEN_KEY, refreshToken); } } +export function storeAccessToken(scope: string, accessToken: string): void { + localStorage.setItem(`${ACCESS_TOKEN_KEY}_${scope}`, accessToken); +} + export function resetAllTokens(): void { new Array(localStorage.length).fill(undefined).forEach((_, i) => { const key = localStorage.key(i); diff --git a/src/utils/token-renewal.ts b/src/utils/token-renewal.ts index 1f2d94a3..6f36c05e 100644 --- a/src/utils/token-renewal.ts +++ b/src/utils/token-renewal.ts @@ -1,8 +1,8 @@ -import { OAuth2Client, OAuth2Token } from "@badgateway/oauth2-client"; +import { OAuth2Client } from "@badgateway/oauth2-client"; +import { tokenState } from "../state/token-state"; import { loginUrl, redirect, refreshUrl } from "./auth"; import { log, logLazy } from "./logging"; -import { getCurrentAccessToken } from "./storage"; -import { getTokenExpireIn, getTokenPayload } from "./token"; +import { TokenPayload, getTokenExpireIn } from "./token"; enum TokenType { Refresh = "refresh", @@ -17,42 +17,40 @@ const expirationTimers: Record< access: undefined, }; -export function renewTokenOnExpiration( - client: OAuth2Client, - token: OAuth2Token, -): void { - const { refreshToken, accessToken } = token; - if (refreshToken) { - renewRefreshTokenOnExpiration(client, refreshToken); - } - renewAccessTokenOnExpiration(client, accessToken); +export function initializeTokenRenewal(client: OAuth2Client): void { + tokenState.onRefreshTokenUpdate((refreshToken) => + renewRefreshTokenOnExpiration(client, refreshToken), + ); + tokenState.onAccessTokenUpdate((accessToken) => + renewAccessTokenOnExpiration(client, accessToken), + ); } export function renewRefreshTokenOnExpiration( client: OAuth2Client, - refreshToken: string, + refreshToken: TokenPayload | null, ): void { onExpiration(TokenType.Refresh, refreshToken, () => { // Get the scope of the "current" access token at the time the // refresh token expires, since the user may have switched scopes // in the meantime - const accessToken = getCurrentAccessToken(); - if (!accessToken) { - return; - } + const accessToken = tokenState.accessTokenPayload; + if (!accessToken) return; log(`Refresh token expired, redirect to login`); - const { scope, locale } = getTokenPayload(accessToken); + const { scope, locale } = accessToken; redirect(client, scope, locale, loginUrl); }); } export function renewAccessTokenOnExpiration( client: OAuth2Client, - accessToken: string, + accessToken: TokenPayload | null, ): void { - const { scope, locale } = getTokenPayload(accessToken); onExpiration(TokenType.Access, accessToken, () => { + if (!accessToken) return; + + const { scope, locale } = accessToken; log( `Access token for scope "${scope}" and locale "${locale}" expired, redirect for token fetch/refresh`, ); @@ -74,20 +72,22 @@ export function clearTokenRenewalTimers(): void { */ function onExpiration( type: TokenType, - token: string, + token: TokenPayload | null, callback: () => void, ): void { if (expirationTimers[type]) { clearTimeout(expirationTimers[type]); } - expirationTimers[type] = setTimeout(callback, getTokenExpireIn(token)); - logLazy(() => { - const { expirationTime } = getTokenPayload(token); - const expirationDate = new Date(); - expirationDate.setTime(expirationTime * 1000); - return `Scheduled ${type} token expiration timeout in ${Math.floor( - getTokenExpireIn(token) / 1000 / 60, - )} minutes (at ${expirationDate})`; - }); + if (token) { + expirationTimers[type] = setTimeout(callback, getTokenExpireIn(token)); + logLazy(() => { + const { expirationTime } = token; + const expirationDate = new Date(); + expirationDate.setTime(expirationTime * 1000); + return `Scheduled ${type} token expiration timeout in ${Math.floor( + getTokenExpireIn(token) / 1000 / 60, + )} minutes (at ${expirationDate})`; + }); + } } diff --git a/src/utils/token.ts b/src/utils/token.ts index 9a0c1f34..9263afa3 100644 --- a/src/utils/token.ts +++ b/src/utils/token.ts @@ -59,25 +59,18 @@ export function isValidToken( ); } -export function isTokenExpired(token: string | null): boolean { +export function isTokenExpired(token: TokenPayload | null): boolean { if (!token) return true; - const { expirationTime } = getTokenPayload(token); + const { expirationTime } = token; const now = Math.floor(Date.now() / 1000); return expirationTime < now; } -export function isTokenHalfExpired(token: string | null): boolean; -export function isTokenHalfExpired(payload: TokenPayload | null): boolean; -export function isTokenHalfExpired( - tokenOrPayload: string | TokenPayload | null, -): boolean { - if (!tokenOrPayload) return true; +export function isTokenHalfExpired(token: TokenPayload | null): boolean { + if (!token) return true; - const { issueTime, expirationTime } = - typeof tokenOrPayload === "string" - ? getTokenPayload(tokenOrPayload) - : tokenOrPayload; + const { issueTime, expirationTime } = token; const validFor = expirationTime - issueTime; const now = Math.floor(Date.now() / 1000); @@ -88,22 +81,11 @@ export function isTokenHalfExpired( * Returns the time (in milliseconds) the token will expire from now (0 * if already expired). */ -export function getTokenExpireIn(token: string): number { - const { expirationTime } = getTokenPayload(token); +export function getTokenExpireIn(token: TokenPayload): number { + const { expirationTime } = token; return Math.max(expirationTime * 1000 - Date.now(), 0); } -/** - * Returns whether the given token matches the given scope. - */ -export function tokenMatchesScope( - token: string | null, - scope: string, -): boolean { - const tokenScope = token && getTokenPayload(token).scope; - return tokenScope === scope; -} - function parseTokenPayload(token: string): RawTokenPayload { const base64Url = token.split(".")[1]; const base64 = base64Url.replace("-", "+").replace("_", "/");