Retake on the refresh token flow

This commit is contained in:
Evert Prants 2024-06-08 10:35:13 +03:00
parent 5171f1fc0f
commit 77a4aa1e4e
Signed by: evert
GPG Key ID: 1688DA83D222D0B5
8 changed files with 75 additions and 65 deletions

View File

@ -1,6 +1,6 @@
import { ApiUtils } from '$lib/server/api-utils';
import { AccessDenied } from '../error';
import { OAuth2AccessTokens, type OAuth2AccessToken } from '../model';
import { OAuth2AccessTokens, OAuth2Tokens, type OAuth2AccessToken } from '../model';
export class OAuth2BearerController {
static bearerFromRequest = async (
@ -38,10 +38,8 @@ export class OAuth2BearerController {
// Try to fetch access token
const object = await OAuth2AccessTokens.fetchByToken(token);
if (!object) {
if (!object || !OAuth2Tokens.checkTTL(object)) {
throw new AccessDenied('Token not found or has expired');
} else if (!OAuth2AccessTokens.checkTTL(object)) {
throw new AccessDenied('Token is expired');
}
return object;
};

View File

@ -55,7 +55,6 @@ export class OAuth2TokenController {
// console.debug('Parameter grant_type is', grantType);
const client = await OAuth2Clients.fetchById(clientId);
if (!client || client.activated === 0) {
throw new InvalidClient('Client not found');
}
@ -67,9 +66,8 @@ export class OAuth2TokenController {
if (!OAuth2Clients.checkGrantType(client, grantType) && grantType !== 'refresh_token') {
throw new UnauthorizedClient('Invalid grant type for the client');
} else {
// console.debug('Grant type check passed');
}
// console.debug('Grant type check passed');
let tokenResponse: OAuth2TokenResponse = {};
try {
@ -93,6 +91,7 @@ export class OAuth2TokenController {
} catch (e) {
return OAuth2Response.error(url, e as OAuth2Error);
}
throw new ServerError('Internal error');
};
}

View File

@ -47,7 +47,7 @@ export async function authorizationCode(
throw new InvalidGrant('Code was issued by another client');
}
if (!OAuth2Codes.checkTTL(code)) {
if (!OAuth2Tokens.checkTTL(code)) {
throw new InvalidGrant('Code has already expired');
}
} else {
@ -94,7 +94,7 @@ export async function authorizationCode(
respObj.refresh_token = await OAuth2RefreshTokens.create(userId, clientId, scope);
} catch (err) {
console.error(err);
throw new ServerError('Failed to call refreshToken.create function');
throw new ServerError('Failed to call refreshTokens.create function');
}
}
@ -107,7 +107,7 @@ export async function authorizationCode(
);
} catch (err) {
console.error(err);
throw new ServerError('Failed to call accessToken.create function');
throw new ServerError('Failed to call accessTokens.create function');
}
respObj.expires_in = OAuth2Tokens.tokenTtl;
@ -117,7 +117,7 @@ export async function authorizationCode(
await OAuth2Codes.removeByCode(providedCode);
} catch (err) {
console.error(err);
throw new ServerError('Failed to call code.removeByCode function');
throw new ServerError('Failed to call codes.removeByCode function');
}
if (cleanScope.includes('openid')) {

View File

@ -21,7 +21,6 @@ export async function refreshToken(
providedToken: string
): Promise<OAuth2TokenResponse> {
let user: User | undefined = undefined;
let ttl: number | null = null;
let refreshToken: OAuth2RefreshToken | undefined = undefined;
let accessToken: OAuth2AccessToken | undefined = undefined;
@ -39,8 +38,8 @@ export async function refreshToken(
throw new ServerError('Failed to call refreshToken.fetchByToken function');
}
if (!refreshToken) {
throw new InvalidGrant('Refresh token not found');
if (!refreshToken || !OAuth2Tokens.checkTTL(refreshToken)) {
throw new InvalidGrant('Refresh token not found or it is already expired');
}
if (refreshToken.clientId !== client.id) {
@ -68,31 +67,36 @@ export async function refreshToken(
throw new ServerError('Failed to call accessToken.fetchByUserIdClientId function');
}
if (accessToken) {
ttl = OAuth2AccessTokens.getTTL(accessToken);
if (!ttl) {
accessToken = undefined;
} else {
resObj.access_token = accessToken.token;
resObj.expires_in = ttl;
}
// Remove old token
if (accessToken && OAuth2Tokens.checkTTL(accessToken)) {
await OAuth2Tokens.remove(accessToken);
}
if (!accessToken) {
try {
resObj.access_token = await OAuth2AccessTokens.create(
user.id,
client.client_id,
refreshToken.scope || '',
OAuth2Tokens.tokenTtl
);
} catch (err) {
throw new ServerError('Failed to call accessToken.create function');
}
// Remove old refresh token
await OAuth2Tokens.remove(refreshToken);
resObj.expires_in = OAuth2Tokens.tokenTtl;
try {
resObj.refresh_token = await OAuth2RefreshTokens.create(
user.id,
client.client_id,
refreshToken.scope || ''
);
} catch (err) {
throw new ServerError('Failed to call refreshTokens.create function');
}
try {
resObj.access_token = await OAuth2AccessTokens.create(
user.id,
client.client_id,
refreshToken.scope || '',
OAuth2Tokens.tokenTtl
);
} catch (err) {
throw new ServerError('Failed to call accessTokens.create function');
}
resObj.expires_in = OAuth2Tokens.tokenTtl;
return resObj;
}

View File

@ -216,6 +216,19 @@ export class OAuth2Clients {
const filterText = `%${filters?.filter?.toLowerCase()}%`;
const limit = filters?.limit || 20;
const whereQuery = and(
!filters?.listAll
? or(eq(oauth2Client.ownerId, subject.id), eq(oauth2ClientManager.userId, subject.id))
: undefined,
filters?.clientId ? eq(oauth2Client.client_id, filters.clientId) : undefined,
filters?.filter
? or(
sql`lower(${oauth2Client.title}) like ${filterText}`,
like(oauth2Client.client_id, filterText)
)
: undefined
);
// LEFT JOINs below will include more rows than we allow with LIMIT,
// so we need to do a subquery with the limiting first.
// The LEFT JOIN in the subquery only contributes to the WHERE clause
@ -224,20 +237,7 @@ export class OAuth2Clients {
.select({ id: oauth2Client.id })
.from(oauth2Client)
.leftJoin(oauth2ClientManager, eq(oauth2ClientManager.clientId, oauth2Client.id))
.where(
and(
!filters?.listAll
? or(eq(oauth2Client.ownerId, subject.id), eq(oauth2ClientManager.userId, subject.id))
: undefined,
filters?.clientId ? eq(oauth2Client.client_id, filters.clientId) : undefined,
filters?.filter
? or(
sql`lower(${oauth2Client.title}) like ${filterText}`,
like(oauth2Client.client_id, filterText)
)
: undefined
)
)
.where(whereQuery)
.groupBy(oauth2Client.id)
.limit(limit)
.offset(filters?.offset || 0)
@ -247,8 +247,9 @@ export class OAuth2Clients {
.select({
rowCount: count(oauth2Client.id).mapWith(Number)
})
.from(allowedClients)
.innerJoin(oauth2Client, eq(allowedClients.id, oauth2Client.id));
.from(oauth2Client)
.leftJoin(oauth2ClientManager, eq(oauth2ClientManager.clientId, oauth2Client.id))
.where(whereQuery);
const junkList = await DB.drizzle
.select({

View File

@ -112,6 +112,26 @@ export class OAuth2Tokens {
static async remove(token: OAuth2Token) {
await DB.drizzle.delete(oauth2Token).where(eq(oauth2Token.id, token.id));
}
static async invalidate(token: OAuth2Token) {
await DB.drizzle
.update(oauth2Token)
.set({ expires_at: new Date() })
.where(eq(oauth2Token.id, token.id));
}
/**
* Check token for expiry
* @param token Access token to check
* @returns true if still valid
*/
static checkTTL(token: OAuth2AccessToken): boolean {
return new Date().getTime() < new Date(token.expires_at).getTime();
}
static getTTL(token: OAuth2AccessToken): number {
return new Date(token.expires_at).getTime() - Date.now();
}
}
export class OAuth2Codes {
@ -184,10 +204,6 @@ export class OAuth2Codes {
return true;
}
static checkTTL(code: OAuth2Code): boolean {
return new Date(code.expires_at).getTime() > Date.now();
}
static getCodeChallenge(code: OAuth2Code) {
return {
method: code.code_challenge_method,
@ -241,14 +257,6 @@ export class OAuth2AccessTokens {
};
}
static checkTTL(token: OAuth2AccessToken): boolean {
return new Date() < new Date(token.expires_at);
}
static getTTL(token: OAuth2AccessToken): number {
return new Date(token.expires_at).getTime() - Date.now();
}
static async fetchByUserIdClientId(
userId: number,
clientId: string

View File

@ -17,7 +17,7 @@
</svelte:head>
<TitleRow>
<h1>{$t('admin.oauth2.title')}</h1>
<h1>{$t('admin.oauth2.title')} ({data.meta.rowCount})</h1>
{#if data.createPrivileges}
<a href="oauth2/new">{$t('admin.oauth2.new')}</a>
{/if}

View File

@ -15,7 +15,7 @@
<title>{$t('admin.users.title')} - {env.PUBLIC_SITE_NAME} {$t('admin.title')}</title>
</svelte:head>
<h1>{$t('admin.users.title')}</h1>
<h1>{$t('admin.users.title')} ({data.meta.rowCount})</h1>
<ColumnView>
<form action="" method="get">