Retake on the refresh token flow
This commit is contained in:
parent
5171f1fc0f
commit
77a4aa1e4e
@ -1,6 +1,6 @@
|
||||
import { ApiUtils } from '$lib/server/api-utils';
|
||||
import { AccessDenied } from '../error';
|
||||
import { OAuth2AccessTokens, type OAuth2AccessToken } from '../model';
|
||||
import { OAuth2AccessTokens, OAuth2Tokens, type OAuth2AccessToken } from '../model';
|
||||
|
||||
export class OAuth2BearerController {
|
||||
static bearerFromRequest = async (
|
||||
@ -38,10 +38,8 @@ export class OAuth2BearerController {
|
||||
|
||||
// Try to fetch access token
|
||||
const object = await OAuth2AccessTokens.fetchByToken(token);
|
||||
if (!object) {
|
||||
if (!object || !OAuth2Tokens.checkTTL(object)) {
|
||||
throw new AccessDenied('Token not found or has expired');
|
||||
} else if (!OAuth2AccessTokens.checkTTL(object)) {
|
||||
throw new AccessDenied('Token is expired');
|
||||
}
|
||||
return object;
|
||||
};
|
||||
|
@ -55,7 +55,6 @@ export class OAuth2TokenController {
|
||||
// console.debug('Parameter grant_type is', grantType);
|
||||
|
||||
const client = await OAuth2Clients.fetchById(clientId);
|
||||
|
||||
if (!client || client.activated === 0) {
|
||||
throw new InvalidClient('Client not found');
|
||||
}
|
||||
@ -67,9 +66,8 @@ export class OAuth2TokenController {
|
||||
|
||||
if (!OAuth2Clients.checkGrantType(client, grantType) && grantType !== 'refresh_token') {
|
||||
throw new UnauthorizedClient('Invalid grant type for the client');
|
||||
} else {
|
||||
// console.debug('Grant type check passed');
|
||||
}
|
||||
// console.debug('Grant type check passed');
|
||||
|
||||
let tokenResponse: OAuth2TokenResponse = {};
|
||||
try {
|
||||
@ -93,6 +91,7 @@ export class OAuth2TokenController {
|
||||
} catch (e) {
|
||||
return OAuth2Response.error(url, e as OAuth2Error);
|
||||
}
|
||||
|
||||
throw new ServerError('Internal error');
|
||||
};
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ export async function authorizationCode(
|
||||
throw new InvalidGrant('Code was issued by another client');
|
||||
}
|
||||
|
||||
if (!OAuth2Codes.checkTTL(code)) {
|
||||
if (!OAuth2Tokens.checkTTL(code)) {
|
||||
throw new InvalidGrant('Code has already expired');
|
||||
}
|
||||
} else {
|
||||
@ -94,7 +94,7 @@ export async function authorizationCode(
|
||||
respObj.refresh_token = await OAuth2RefreshTokens.create(userId, clientId, scope);
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
throw new ServerError('Failed to call refreshToken.create function');
|
||||
throw new ServerError('Failed to call refreshTokens.create function');
|
||||
}
|
||||
}
|
||||
|
||||
@ -107,7 +107,7 @@ export async function authorizationCode(
|
||||
);
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
throw new ServerError('Failed to call accessToken.create function');
|
||||
throw new ServerError('Failed to call accessTokens.create function');
|
||||
}
|
||||
|
||||
respObj.expires_in = OAuth2Tokens.tokenTtl;
|
||||
@ -117,7 +117,7 @@ export async function authorizationCode(
|
||||
await OAuth2Codes.removeByCode(providedCode);
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
throw new ServerError('Failed to call code.removeByCode function');
|
||||
throw new ServerError('Failed to call codes.removeByCode function');
|
||||
}
|
||||
|
||||
if (cleanScope.includes('openid')) {
|
||||
|
@ -21,7 +21,6 @@ export async function refreshToken(
|
||||
providedToken: string
|
||||
): Promise<OAuth2TokenResponse> {
|
||||
let user: User | undefined = undefined;
|
||||
let ttl: number | null = null;
|
||||
let refreshToken: OAuth2RefreshToken | undefined = undefined;
|
||||
let accessToken: OAuth2AccessToken | undefined = undefined;
|
||||
|
||||
@ -39,8 +38,8 @@ export async function refreshToken(
|
||||
throw new ServerError('Failed to call refreshToken.fetchByToken function');
|
||||
}
|
||||
|
||||
if (!refreshToken) {
|
||||
throw new InvalidGrant('Refresh token not found');
|
||||
if (!refreshToken || !OAuth2Tokens.checkTTL(refreshToken)) {
|
||||
throw new InvalidGrant('Refresh token not found or it is already expired');
|
||||
}
|
||||
|
||||
if (refreshToken.clientId !== client.id) {
|
||||
@ -68,31 +67,36 @@ export async function refreshToken(
|
||||
throw new ServerError('Failed to call accessToken.fetchByUserIdClientId function');
|
||||
}
|
||||
|
||||
if (accessToken) {
|
||||
ttl = OAuth2AccessTokens.getTTL(accessToken);
|
||||
|
||||
if (!ttl) {
|
||||
accessToken = undefined;
|
||||
} else {
|
||||
resObj.access_token = accessToken.token;
|
||||
resObj.expires_in = ttl;
|
||||
}
|
||||
// Remove old token
|
||||
if (accessToken && OAuth2Tokens.checkTTL(accessToken)) {
|
||||
await OAuth2Tokens.remove(accessToken);
|
||||
}
|
||||
|
||||
if (!accessToken) {
|
||||
try {
|
||||
resObj.access_token = await OAuth2AccessTokens.create(
|
||||
user.id,
|
||||
client.client_id,
|
||||
refreshToken.scope || '',
|
||||
OAuth2Tokens.tokenTtl
|
||||
);
|
||||
} catch (err) {
|
||||
throw new ServerError('Failed to call accessToken.create function');
|
||||
}
|
||||
// Remove old refresh token
|
||||
await OAuth2Tokens.remove(refreshToken);
|
||||
|
||||
resObj.expires_in = OAuth2Tokens.tokenTtl;
|
||||
try {
|
||||
resObj.refresh_token = await OAuth2RefreshTokens.create(
|
||||
user.id,
|
||||
client.client_id,
|
||||
refreshToken.scope || ''
|
||||
);
|
||||
} catch (err) {
|
||||
throw new ServerError('Failed to call refreshTokens.create function');
|
||||
}
|
||||
|
||||
try {
|
||||
resObj.access_token = await OAuth2AccessTokens.create(
|
||||
user.id,
|
||||
client.client_id,
|
||||
refreshToken.scope || '',
|
||||
OAuth2Tokens.tokenTtl
|
||||
);
|
||||
} catch (err) {
|
||||
throw new ServerError('Failed to call accessTokens.create function');
|
||||
}
|
||||
|
||||
resObj.expires_in = OAuth2Tokens.tokenTtl;
|
||||
|
||||
return resObj;
|
||||
}
|
||||
|
@ -216,6 +216,19 @@ export class OAuth2Clients {
|
||||
const filterText = `%${filters?.filter?.toLowerCase()}%`;
|
||||
const limit = filters?.limit || 20;
|
||||
|
||||
const whereQuery = and(
|
||||
!filters?.listAll
|
||||
? or(eq(oauth2Client.ownerId, subject.id), eq(oauth2ClientManager.userId, subject.id))
|
||||
: undefined,
|
||||
filters?.clientId ? eq(oauth2Client.client_id, filters.clientId) : undefined,
|
||||
filters?.filter
|
||||
? or(
|
||||
sql`lower(${oauth2Client.title}) like ${filterText}`,
|
||||
like(oauth2Client.client_id, filterText)
|
||||
)
|
||||
: undefined
|
||||
);
|
||||
|
||||
// LEFT JOINs below will include more rows than we allow with LIMIT,
|
||||
// so we need to do a subquery with the limiting first.
|
||||
// The LEFT JOIN in the subquery only contributes to the WHERE clause
|
||||
@ -224,20 +237,7 @@ export class OAuth2Clients {
|
||||
.select({ id: oauth2Client.id })
|
||||
.from(oauth2Client)
|
||||
.leftJoin(oauth2ClientManager, eq(oauth2ClientManager.clientId, oauth2Client.id))
|
||||
.where(
|
||||
and(
|
||||
!filters?.listAll
|
||||
? or(eq(oauth2Client.ownerId, subject.id), eq(oauth2ClientManager.userId, subject.id))
|
||||
: undefined,
|
||||
filters?.clientId ? eq(oauth2Client.client_id, filters.clientId) : undefined,
|
||||
filters?.filter
|
||||
? or(
|
||||
sql`lower(${oauth2Client.title}) like ${filterText}`,
|
||||
like(oauth2Client.client_id, filterText)
|
||||
)
|
||||
: undefined
|
||||
)
|
||||
)
|
||||
.where(whereQuery)
|
||||
.groupBy(oauth2Client.id)
|
||||
.limit(limit)
|
||||
.offset(filters?.offset || 0)
|
||||
@ -247,8 +247,9 @@ export class OAuth2Clients {
|
||||
.select({
|
||||
rowCount: count(oauth2Client.id).mapWith(Number)
|
||||
})
|
||||
.from(allowedClients)
|
||||
.innerJoin(oauth2Client, eq(allowedClients.id, oauth2Client.id));
|
||||
.from(oauth2Client)
|
||||
.leftJoin(oauth2ClientManager, eq(oauth2ClientManager.clientId, oauth2Client.id))
|
||||
.where(whereQuery);
|
||||
|
||||
const junkList = await DB.drizzle
|
||||
.select({
|
||||
|
@ -112,6 +112,26 @@ export class OAuth2Tokens {
|
||||
static async remove(token: OAuth2Token) {
|
||||
await DB.drizzle.delete(oauth2Token).where(eq(oauth2Token.id, token.id));
|
||||
}
|
||||
|
||||
static async invalidate(token: OAuth2Token) {
|
||||
await DB.drizzle
|
||||
.update(oauth2Token)
|
||||
.set({ expires_at: new Date() })
|
||||
.where(eq(oauth2Token.id, token.id));
|
||||
}
|
||||
|
||||
/**
|
||||
* Check token for expiry
|
||||
* @param token Access token to check
|
||||
* @returns true if still valid
|
||||
*/
|
||||
static checkTTL(token: OAuth2AccessToken): boolean {
|
||||
return new Date().getTime() < new Date(token.expires_at).getTime();
|
||||
}
|
||||
|
||||
static getTTL(token: OAuth2AccessToken): number {
|
||||
return new Date(token.expires_at).getTime() - Date.now();
|
||||
}
|
||||
}
|
||||
|
||||
export class OAuth2Codes {
|
||||
@ -184,10 +204,6 @@ export class OAuth2Codes {
|
||||
return true;
|
||||
}
|
||||
|
||||
static checkTTL(code: OAuth2Code): boolean {
|
||||
return new Date(code.expires_at).getTime() > Date.now();
|
||||
}
|
||||
|
||||
static getCodeChallenge(code: OAuth2Code) {
|
||||
return {
|
||||
method: code.code_challenge_method,
|
||||
@ -241,14 +257,6 @@ export class OAuth2AccessTokens {
|
||||
};
|
||||
}
|
||||
|
||||
static checkTTL(token: OAuth2AccessToken): boolean {
|
||||
return new Date() < new Date(token.expires_at);
|
||||
}
|
||||
|
||||
static getTTL(token: OAuth2AccessToken): number {
|
||||
return new Date(token.expires_at).getTime() - Date.now();
|
||||
}
|
||||
|
||||
static async fetchByUserIdClientId(
|
||||
userId: number,
|
||||
clientId: string
|
||||
|
@ -17,7 +17,7 @@
|
||||
</svelte:head>
|
||||
|
||||
<TitleRow>
|
||||
<h1>{$t('admin.oauth2.title')}</h1>
|
||||
<h1>{$t('admin.oauth2.title')} ({data.meta.rowCount})</h1>
|
||||
{#if data.createPrivileges}
|
||||
<a href="oauth2/new">{$t('admin.oauth2.new')}</a>
|
||||
{/if}
|
||||
|
@ -15,7 +15,7 @@
|
||||
<title>{$t('admin.users.title')} - {env.PUBLIC_SITE_NAME} {$t('admin.title')}</title>
|
||||
</svelte:head>
|
||||
|
||||
<h1>{$t('admin.users.title')}</h1>
|
||||
<h1>{$t('admin.users.title')} ({data.meta.rowCount})</h1>
|
||||
|
||||
<ColumnView>
|
||||
<form action="" method="get">
|
||||
|
Loading…
Reference in New Issue
Block a user