From 77a4aa1e4ef3b40679cf49c9bba3af6ac751142d Mon Sep 17 00:00:00 2001 From: Evert Prants Date: Sat, 8 Jun 2024 10:35:13 +0300 Subject: [PATCH] Retake on the refresh token flow --- src/lib/server/oauth2/controller/bearer.ts | 6 +-- src/lib/server/oauth2/controller/token.ts | 5 +- .../controller/tokens/authorizationCode.ts | 8 +-- .../oauth2/controller/tokens/refreshToken.ts | 52 ++++++++++--------- src/lib/server/oauth2/model/client.ts | 33 ++++++------ src/lib/server/oauth2/model/tokens.ts | 32 +++++++----- src/routes/ssoadmin/oauth2/+page.svelte | 2 +- src/routes/ssoadmin/users/+page.svelte | 2 +- 8 files changed, 75 insertions(+), 65 deletions(-) diff --git a/src/lib/server/oauth2/controller/bearer.ts b/src/lib/server/oauth2/controller/bearer.ts index 4244b63..acbd05f 100644 --- a/src/lib/server/oauth2/controller/bearer.ts +++ b/src/lib/server/oauth2/controller/bearer.ts @@ -1,6 +1,6 @@ import { ApiUtils } from '$lib/server/api-utils'; import { AccessDenied } from '../error'; -import { OAuth2AccessTokens, type OAuth2AccessToken } from '../model'; +import { OAuth2AccessTokens, OAuth2Tokens, type OAuth2AccessToken } from '../model'; export class OAuth2BearerController { static bearerFromRequest = async ( @@ -38,10 +38,8 @@ export class OAuth2BearerController { // Try to fetch access token const object = await OAuth2AccessTokens.fetchByToken(token); - if (!object) { + if (!object || !OAuth2Tokens.checkTTL(object)) { throw new AccessDenied('Token not found or has expired'); - } else if (!OAuth2AccessTokens.checkTTL(object)) { - throw new AccessDenied('Token is expired'); } return object; }; diff --git a/src/lib/server/oauth2/controller/token.ts b/src/lib/server/oauth2/controller/token.ts index 3dcc43c..b67500a 100644 --- a/src/lib/server/oauth2/controller/token.ts +++ b/src/lib/server/oauth2/controller/token.ts @@ -55,7 +55,6 @@ export class OAuth2TokenController { // console.debug('Parameter grant_type is', grantType); const client = await OAuth2Clients.fetchById(clientId); - if (!client || client.activated === 0) { throw new InvalidClient('Client not found'); } @@ -67,9 +66,8 @@ export class OAuth2TokenController { if (!OAuth2Clients.checkGrantType(client, grantType) && grantType !== 'refresh_token') { throw new UnauthorizedClient('Invalid grant type for the client'); - } else { - // console.debug('Grant type check passed'); } + // console.debug('Grant type check passed'); let tokenResponse: OAuth2TokenResponse = {}; try { @@ -93,6 +91,7 @@ export class OAuth2TokenController { } catch (e) { return OAuth2Response.error(url, e as OAuth2Error); } + throw new ServerError('Internal error'); }; } diff --git a/src/lib/server/oauth2/controller/tokens/authorizationCode.ts b/src/lib/server/oauth2/controller/tokens/authorizationCode.ts index 596a9c8..5121d70 100644 --- a/src/lib/server/oauth2/controller/tokens/authorizationCode.ts +++ b/src/lib/server/oauth2/controller/tokens/authorizationCode.ts @@ -47,7 +47,7 @@ export async function authorizationCode( throw new InvalidGrant('Code was issued by another client'); } - if (!OAuth2Codes.checkTTL(code)) { + if (!OAuth2Tokens.checkTTL(code)) { throw new InvalidGrant('Code has already expired'); } } else { @@ -94,7 +94,7 @@ export async function authorizationCode( respObj.refresh_token = await OAuth2RefreshTokens.create(userId, clientId, scope); } catch (err) { console.error(err); - throw new ServerError('Failed to call refreshToken.create function'); + throw new ServerError('Failed to call refreshTokens.create function'); } } @@ -107,7 +107,7 @@ export async function authorizationCode( ); } catch (err) { console.error(err); - throw new ServerError('Failed to call accessToken.create function'); + throw new ServerError('Failed to call accessTokens.create function'); } respObj.expires_in = OAuth2Tokens.tokenTtl; @@ -117,7 +117,7 @@ export async function authorizationCode( await OAuth2Codes.removeByCode(providedCode); } catch (err) { console.error(err); - throw new ServerError('Failed to call code.removeByCode function'); + throw new ServerError('Failed to call codes.removeByCode function'); } if (cleanScope.includes('openid')) { diff --git a/src/lib/server/oauth2/controller/tokens/refreshToken.ts b/src/lib/server/oauth2/controller/tokens/refreshToken.ts index 73e7913..729886d 100644 --- a/src/lib/server/oauth2/controller/tokens/refreshToken.ts +++ b/src/lib/server/oauth2/controller/tokens/refreshToken.ts @@ -21,7 +21,6 @@ export async function refreshToken( providedToken: string ): Promise { let user: User | undefined = undefined; - let ttl: number | null = null; let refreshToken: OAuth2RefreshToken | undefined = undefined; let accessToken: OAuth2AccessToken | undefined = undefined; @@ -39,8 +38,8 @@ export async function refreshToken( throw new ServerError('Failed to call refreshToken.fetchByToken function'); } - if (!refreshToken) { - throw new InvalidGrant('Refresh token not found'); + if (!refreshToken || !OAuth2Tokens.checkTTL(refreshToken)) { + throw new InvalidGrant('Refresh token not found or it is already expired'); } if (refreshToken.clientId !== client.id) { @@ -68,31 +67,36 @@ export async function refreshToken( throw new ServerError('Failed to call accessToken.fetchByUserIdClientId function'); } - if (accessToken) { - ttl = OAuth2AccessTokens.getTTL(accessToken); - - if (!ttl) { - accessToken = undefined; - } else { - resObj.access_token = accessToken.token; - resObj.expires_in = ttl; - } + // Remove old token + if (accessToken && OAuth2Tokens.checkTTL(accessToken)) { + await OAuth2Tokens.remove(accessToken); } - if (!accessToken) { - try { - resObj.access_token = await OAuth2AccessTokens.create( - user.id, - client.client_id, - refreshToken.scope || '', - OAuth2Tokens.tokenTtl - ); - } catch (err) { - throw new ServerError('Failed to call accessToken.create function'); - } + // Remove old refresh token + await OAuth2Tokens.remove(refreshToken); - resObj.expires_in = OAuth2Tokens.tokenTtl; + try { + resObj.refresh_token = await OAuth2RefreshTokens.create( + user.id, + client.client_id, + refreshToken.scope || '' + ); + } catch (err) { + throw new ServerError('Failed to call refreshTokens.create function'); } + try { + resObj.access_token = await OAuth2AccessTokens.create( + user.id, + client.client_id, + refreshToken.scope || '', + OAuth2Tokens.tokenTtl + ); + } catch (err) { + throw new ServerError('Failed to call accessTokens.create function'); + } + + resObj.expires_in = OAuth2Tokens.tokenTtl; + return resObj; } diff --git a/src/lib/server/oauth2/model/client.ts b/src/lib/server/oauth2/model/client.ts index c7f107e..f07c9ce 100644 --- a/src/lib/server/oauth2/model/client.ts +++ b/src/lib/server/oauth2/model/client.ts @@ -216,6 +216,19 @@ export class OAuth2Clients { const filterText = `%${filters?.filter?.toLowerCase()}%`; const limit = filters?.limit || 20; + const whereQuery = and( + !filters?.listAll + ? or(eq(oauth2Client.ownerId, subject.id), eq(oauth2ClientManager.userId, subject.id)) + : undefined, + filters?.clientId ? eq(oauth2Client.client_id, filters.clientId) : undefined, + filters?.filter + ? or( + sql`lower(${oauth2Client.title}) like ${filterText}`, + like(oauth2Client.client_id, filterText) + ) + : undefined + ); + // LEFT JOINs below will include more rows than we allow with LIMIT, // so we need to do a subquery with the limiting first. // The LEFT JOIN in the subquery only contributes to the WHERE clause @@ -224,20 +237,7 @@ export class OAuth2Clients { .select({ id: oauth2Client.id }) .from(oauth2Client) .leftJoin(oauth2ClientManager, eq(oauth2ClientManager.clientId, oauth2Client.id)) - .where( - and( - !filters?.listAll - ? or(eq(oauth2Client.ownerId, subject.id), eq(oauth2ClientManager.userId, subject.id)) - : undefined, - filters?.clientId ? eq(oauth2Client.client_id, filters.clientId) : undefined, - filters?.filter - ? or( - sql`lower(${oauth2Client.title}) like ${filterText}`, - like(oauth2Client.client_id, filterText) - ) - : undefined - ) - ) + .where(whereQuery) .groupBy(oauth2Client.id) .limit(limit) .offset(filters?.offset || 0) @@ -247,8 +247,9 @@ export class OAuth2Clients { .select({ rowCount: count(oauth2Client.id).mapWith(Number) }) - .from(allowedClients) - .innerJoin(oauth2Client, eq(allowedClients.id, oauth2Client.id)); + .from(oauth2Client) + .leftJoin(oauth2ClientManager, eq(oauth2ClientManager.clientId, oauth2Client.id)) + .where(whereQuery); const junkList = await DB.drizzle .select({ diff --git a/src/lib/server/oauth2/model/tokens.ts b/src/lib/server/oauth2/model/tokens.ts index c835bac..d264d6a 100644 --- a/src/lib/server/oauth2/model/tokens.ts +++ b/src/lib/server/oauth2/model/tokens.ts @@ -112,6 +112,26 @@ export class OAuth2Tokens { static async remove(token: OAuth2Token) { await DB.drizzle.delete(oauth2Token).where(eq(oauth2Token.id, token.id)); } + + static async invalidate(token: OAuth2Token) { + await DB.drizzle + .update(oauth2Token) + .set({ expires_at: new Date() }) + .where(eq(oauth2Token.id, token.id)); + } + + /** + * Check token for expiry + * @param token Access token to check + * @returns true if still valid + */ + static checkTTL(token: OAuth2AccessToken): boolean { + return new Date().getTime() < new Date(token.expires_at).getTime(); + } + + static getTTL(token: OAuth2AccessToken): number { + return new Date(token.expires_at).getTime() - Date.now(); + } } export class OAuth2Codes { @@ -184,10 +204,6 @@ export class OAuth2Codes { return true; } - static checkTTL(code: OAuth2Code): boolean { - return new Date(code.expires_at).getTime() > Date.now(); - } - static getCodeChallenge(code: OAuth2Code) { return { method: code.code_challenge_method, @@ -241,14 +257,6 @@ export class OAuth2AccessTokens { }; } - static checkTTL(token: OAuth2AccessToken): boolean { - return new Date() < new Date(token.expires_at); - } - - static getTTL(token: OAuth2AccessToken): number { - return new Date(token.expires_at).getTime() - Date.now(); - } - static async fetchByUserIdClientId( userId: number, clientId: string diff --git a/src/routes/ssoadmin/oauth2/+page.svelte b/src/routes/ssoadmin/oauth2/+page.svelte index 6a18364..87a205c 100644 --- a/src/routes/ssoadmin/oauth2/+page.svelte +++ b/src/routes/ssoadmin/oauth2/+page.svelte @@ -17,7 +17,7 @@ -

{$t('admin.oauth2.title')}

+

{$t('admin.oauth2.title')} ({data.meta.rowCount})

{#if data.createPrivileges} {$t('admin.oauth2.new')} {/if} diff --git a/src/routes/ssoadmin/users/+page.svelte b/src/routes/ssoadmin/users/+page.svelte index 27920a4..f4753bf 100644 --- a/src/routes/ssoadmin/users/+page.svelte +++ b/src/routes/ssoadmin/users/+page.svelte @@ -15,7 +15,7 @@ {$t('admin.users.title')} - {env.PUBLIC_SITE_NAME} {$t('admin.title')} -

{$t('admin.users.title')}

+

{$t('admin.users.title')} ({data.meta.rowCount})