Retake on the refresh token flow

This commit is contained in:
Evert Prants 2024-06-08 10:35:13 +03:00
parent 5171f1fc0f
commit 77a4aa1e4e
Signed by: evert
GPG Key ID: 1688DA83D222D0B5
8 changed files with 75 additions and 65 deletions

View File

@ -1,6 +1,6 @@
import { ApiUtils } from '$lib/server/api-utils'; import { ApiUtils } from '$lib/server/api-utils';
import { AccessDenied } from '../error'; import { AccessDenied } from '../error';
import { OAuth2AccessTokens, type OAuth2AccessToken } from '../model'; import { OAuth2AccessTokens, OAuth2Tokens, type OAuth2AccessToken } from '../model';
export class OAuth2BearerController { export class OAuth2BearerController {
static bearerFromRequest = async ( static bearerFromRequest = async (
@ -38,10 +38,8 @@ export class OAuth2BearerController {
// Try to fetch access token // Try to fetch access token
const object = await OAuth2AccessTokens.fetchByToken(token); const object = await OAuth2AccessTokens.fetchByToken(token);
if (!object) { if (!object || !OAuth2Tokens.checkTTL(object)) {
throw new AccessDenied('Token not found or has expired'); throw new AccessDenied('Token not found or has expired');
} else if (!OAuth2AccessTokens.checkTTL(object)) {
throw new AccessDenied('Token is expired');
} }
return object; return object;
}; };

View File

@ -55,7 +55,6 @@ export class OAuth2TokenController {
// console.debug('Parameter grant_type is', grantType); // console.debug('Parameter grant_type is', grantType);
const client = await OAuth2Clients.fetchById(clientId); const client = await OAuth2Clients.fetchById(clientId);
if (!client || client.activated === 0) { if (!client || client.activated === 0) {
throw new InvalidClient('Client not found'); throw new InvalidClient('Client not found');
} }
@ -67,9 +66,8 @@ export class OAuth2TokenController {
if (!OAuth2Clients.checkGrantType(client, grantType) && grantType !== 'refresh_token') { if (!OAuth2Clients.checkGrantType(client, grantType) && grantType !== 'refresh_token') {
throw new UnauthorizedClient('Invalid grant type for the client'); throw new UnauthorizedClient('Invalid grant type for the client');
} else {
// console.debug('Grant type check passed');
} }
// console.debug('Grant type check passed');
let tokenResponse: OAuth2TokenResponse = {}; let tokenResponse: OAuth2TokenResponse = {};
try { try {
@ -93,6 +91,7 @@ export class OAuth2TokenController {
} catch (e) { } catch (e) {
return OAuth2Response.error(url, e as OAuth2Error); return OAuth2Response.error(url, e as OAuth2Error);
} }
throw new ServerError('Internal error'); throw new ServerError('Internal error');
}; };
} }

View File

@ -47,7 +47,7 @@ export async function authorizationCode(
throw new InvalidGrant('Code was issued by another client'); throw new InvalidGrant('Code was issued by another client');
} }
if (!OAuth2Codes.checkTTL(code)) { if (!OAuth2Tokens.checkTTL(code)) {
throw new InvalidGrant('Code has already expired'); throw new InvalidGrant('Code has already expired');
} }
} else { } else {
@ -94,7 +94,7 @@ export async function authorizationCode(
respObj.refresh_token = await OAuth2RefreshTokens.create(userId, clientId, scope); respObj.refresh_token = await OAuth2RefreshTokens.create(userId, clientId, scope);
} catch (err) { } catch (err) {
console.error(err); console.error(err);
throw new ServerError('Failed to call refreshToken.create function'); throw new ServerError('Failed to call refreshTokens.create function');
} }
} }
@ -107,7 +107,7 @@ export async function authorizationCode(
); );
} catch (err) { } catch (err) {
console.error(err); console.error(err);
throw new ServerError('Failed to call accessToken.create function'); throw new ServerError('Failed to call accessTokens.create function');
} }
respObj.expires_in = OAuth2Tokens.tokenTtl; respObj.expires_in = OAuth2Tokens.tokenTtl;
@ -117,7 +117,7 @@ export async function authorizationCode(
await OAuth2Codes.removeByCode(providedCode); await OAuth2Codes.removeByCode(providedCode);
} catch (err) { } catch (err) {
console.error(err); console.error(err);
throw new ServerError('Failed to call code.removeByCode function'); throw new ServerError('Failed to call codes.removeByCode function');
} }
if (cleanScope.includes('openid')) { if (cleanScope.includes('openid')) {

View File

@ -21,7 +21,6 @@ export async function refreshToken(
providedToken: string providedToken: string
): Promise<OAuth2TokenResponse> { ): Promise<OAuth2TokenResponse> {
let user: User | undefined = undefined; let user: User | undefined = undefined;
let ttl: number | null = null;
let refreshToken: OAuth2RefreshToken | undefined = undefined; let refreshToken: OAuth2RefreshToken | undefined = undefined;
let accessToken: OAuth2AccessToken | undefined = undefined; let accessToken: OAuth2AccessToken | undefined = undefined;
@ -39,8 +38,8 @@ export async function refreshToken(
throw new ServerError('Failed to call refreshToken.fetchByToken function'); throw new ServerError('Failed to call refreshToken.fetchByToken function');
} }
if (!refreshToken) { if (!refreshToken || !OAuth2Tokens.checkTTL(refreshToken)) {
throw new InvalidGrant('Refresh token not found'); throw new InvalidGrant('Refresh token not found or it is already expired');
} }
if (refreshToken.clientId !== client.id) { if (refreshToken.clientId !== client.id) {
@ -68,31 +67,36 @@ export async function refreshToken(
throw new ServerError('Failed to call accessToken.fetchByUserIdClientId function'); throw new ServerError('Failed to call accessToken.fetchByUserIdClientId function');
} }
if (accessToken) { // Remove old token
ttl = OAuth2AccessTokens.getTTL(accessToken); if (accessToken && OAuth2Tokens.checkTTL(accessToken)) {
await OAuth2Tokens.remove(accessToken);
if (!ttl) {
accessToken = undefined;
} else {
resObj.access_token = accessToken.token;
resObj.expires_in = ttl;
}
} }
if (!accessToken) { // Remove old refresh token
try { await OAuth2Tokens.remove(refreshToken);
resObj.access_token = await OAuth2AccessTokens.create(
user.id,
client.client_id,
refreshToken.scope || '',
OAuth2Tokens.tokenTtl
);
} catch (err) {
throw new ServerError('Failed to call accessToken.create function');
}
resObj.expires_in = OAuth2Tokens.tokenTtl; try {
resObj.refresh_token = await OAuth2RefreshTokens.create(
user.id,
client.client_id,
refreshToken.scope || ''
);
} catch (err) {
throw new ServerError('Failed to call refreshTokens.create function');
} }
try {
resObj.access_token = await OAuth2AccessTokens.create(
user.id,
client.client_id,
refreshToken.scope || '',
OAuth2Tokens.tokenTtl
);
} catch (err) {
throw new ServerError('Failed to call accessTokens.create function');
}
resObj.expires_in = OAuth2Tokens.tokenTtl;
return resObj; return resObj;
} }

View File

@ -216,6 +216,19 @@ export class OAuth2Clients {
const filterText = `%${filters?.filter?.toLowerCase()}%`; const filterText = `%${filters?.filter?.toLowerCase()}%`;
const limit = filters?.limit || 20; const limit = filters?.limit || 20;
const whereQuery = and(
!filters?.listAll
? or(eq(oauth2Client.ownerId, subject.id), eq(oauth2ClientManager.userId, subject.id))
: undefined,
filters?.clientId ? eq(oauth2Client.client_id, filters.clientId) : undefined,
filters?.filter
? or(
sql`lower(${oauth2Client.title}) like ${filterText}`,
like(oauth2Client.client_id, filterText)
)
: undefined
);
// LEFT JOINs below will include more rows than we allow with LIMIT, // LEFT JOINs below will include more rows than we allow with LIMIT,
// so we need to do a subquery with the limiting first. // so we need to do a subquery with the limiting first.
// The LEFT JOIN in the subquery only contributes to the WHERE clause // The LEFT JOIN in the subquery only contributes to the WHERE clause
@ -224,20 +237,7 @@ export class OAuth2Clients {
.select({ id: oauth2Client.id }) .select({ id: oauth2Client.id })
.from(oauth2Client) .from(oauth2Client)
.leftJoin(oauth2ClientManager, eq(oauth2ClientManager.clientId, oauth2Client.id)) .leftJoin(oauth2ClientManager, eq(oauth2ClientManager.clientId, oauth2Client.id))
.where( .where(whereQuery)
and(
!filters?.listAll
? or(eq(oauth2Client.ownerId, subject.id), eq(oauth2ClientManager.userId, subject.id))
: undefined,
filters?.clientId ? eq(oauth2Client.client_id, filters.clientId) : undefined,
filters?.filter
? or(
sql`lower(${oauth2Client.title}) like ${filterText}`,
like(oauth2Client.client_id, filterText)
)
: undefined
)
)
.groupBy(oauth2Client.id) .groupBy(oauth2Client.id)
.limit(limit) .limit(limit)
.offset(filters?.offset || 0) .offset(filters?.offset || 0)
@ -247,8 +247,9 @@ export class OAuth2Clients {
.select({ .select({
rowCount: count(oauth2Client.id).mapWith(Number) rowCount: count(oauth2Client.id).mapWith(Number)
}) })
.from(allowedClients) .from(oauth2Client)
.innerJoin(oauth2Client, eq(allowedClients.id, oauth2Client.id)); .leftJoin(oauth2ClientManager, eq(oauth2ClientManager.clientId, oauth2Client.id))
.where(whereQuery);
const junkList = await DB.drizzle const junkList = await DB.drizzle
.select({ .select({

View File

@ -112,6 +112,26 @@ export class OAuth2Tokens {
static async remove(token: OAuth2Token) { static async remove(token: OAuth2Token) {
await DB.drizzle.delete(oauth2Token).where(eq(oauth2Token.id, token.id)); await DB.drizzle.delete(oauth2Token).where(eq(oauth2Token.id, token.id));
} }
static async invalidate(token: OAuth2Token) {
await DB.drizzle
.update(oauth2Token)
.set({ expires_at: new Date() })
.where(eq(oauth2Token.id, token.id));
}
/**
* Check token for expiry
* @param token Access token to check
* @returns true if still valid
*/
static checkTTL(token: OAuth2AccessToken): boolean {
return new Date().getTime() < new Date(token.expires_at).getTime();
}
static getTTL(token: OAuth2AccessToken): number {
return new Date(token.expires_at).getTime() - Date.now();
}
} }
export class OAuth2Codes { export class OAuth2Codes {
@ -184,10 +204,6 @@ export class OAuth2Codes {
return true; return true;
} }
static checkTTL(code: OAuth2Code): boolean {
return new Date(code.expires_at).getTime() > Date.now();
}
static getCodeChallenge(code: OAuth2Code) { static getCodeChallenge(code: OAuth2Code) {
return { return {
method: code.code_challenge_method, method: code.code_challenge_method,
@ -241,14 +257,6 @@ export class OAuth2AccessTokens {
}; };
} }
static checkTTL(token: OAuth2AccessToken): boolean {
return new Date() < new Date(token.expires_at);
}
static getTTL(token: OAuth2AccessToken): number {
return new Date(token.expires_at).getTime() - Date.now();
}
static async fetchByUserIdClientId( static async fetchByUserIdClientId(
userId: number, userId: number,
clientId: string clientId: string

View File

@ -17,7 +17,7 @@
</svelte:head> </svelte:head>
<TitleRow> <TitleRow>
<h1>{$t('admin.oauth2.title')}</h1> <h1>{$t('admin.oauth2.title')} ({data.meta.rowCount})</h1>
{#if data.createPrivileges} {#if data.createPrivileges}
<a href="oauth2/new">{$t('admin.oauth2.new')}</a> <a href="oauth2/new">{$t('admin.oauth2.new')}</a>
{/if} {/if}

View File

@ -15,7 +15,7 @@
<title>{$t('admin.users.title')} - {env.PUBLIC_SITE_NAME} {$t('admin.title')}</title> <title>{$t('admin.users.title')} - {env.PUBLIC_SITE_NAME} {$t('admin.title')}</title>
</svelte:head> </svelte:head>
<h1>{$t('admin.users.title')}</h1> <h1>{$t('admin.users.title')} ({data.meta.rowCount})</h1>
<ColumnView> <ColumnView>
<form action="" method="get"> <form action="" method="get">