Retake on the refresh token flow
This commit is contained in:
parent
5171f1fc0f
commit
77a4aa1e4e
@ -1,6 +1,6 @@
|
|||||||
import { ApiUtils } from '$lib/server/api-utils';
|
import { ApiUtils } from '$lib/server/api-utils';
|
||||||
import { AccessDenied } from '../error';
|
import { AccessDenied } from '../error';
|
||||||
import { OAuth2AccessTokens, type OAuth2AccessToken } from '../model';
|
import { OAuth2AccessTokens, OAuth2Tokens, type OAuth2AccessToken } from '../model';
|
||||||
|
|
||||||
export class OAuth2BearerController {
|
export class OAuth2BearerController {
|
||||||
static bearerFromRequest = async (
|
static bearerFromRequest = async (
|
||||||
@ -38,10 +38,8 @@ export class OAuth2BearerController {
|
|||||||
|
|
||||||
// Try to fetch access token
|
// Try to fetch access token
|
||||||
const object = await OAuth2AccessTokens.fetchByToken(token);
|
const object = await OAuth2AccessTokens.fetchByToken(token);
|
||||||
if (!object) {
|
if (!object || !OAuth2Tokens.checkTTL(object)) {
|
||||||
throw new AccessDenied('Token not found or has expired');
|
throw new AccessDenied('Token not found or has expired');
|
||||||
} else if (!OAuth2AccessTokens.checkTTL(object)) {
|
|
||||||
throw new AccessDenied('Token is expired');
|
|
||||||
}
|
}
|
||||||
return object;
|
return object;
|
||||||
};
|
};
|
||||||
|
@ -55,7 +55,6 @@ export class OAuth2TokenController {
|
|||||||
// console.debug('Parameter grant_type is', grantType);
|
// console.debug('Parameter grant_type is', grantType);
|
||||||
|
|
||||||
const client = await OAuth2Clients.fetchById(clientId);
|
const client = await OAuth2Clients.fetchById(clientId);
|
||||||
|
|
||||||
if (!client || client.activated === 0) {
|
if (!client || client.activated === 0) {
|
||||||
throw new InvalidClient('Client not found');
|
throw new InvalidClient('Client not found');
|
||||||
}
|
}
|
||||||
@ -67,9 +66,8 @@ export class OAuth2TokenController {
|
|||||||
|
|
||||||
if (!OAuth2Clients.checkGrantType(client, grantType) && grantType !== 'refresh_token') {
|
if (!OAuth2Clients.checkGrantType(client, grantType) && grantType !== 'refresh_token') {
|
||||||
throw new UnauthorizedClient('Invalid grant type for the client');
|
throw new UnauthorizedClient('Invalid grant type for the client');
|
||||||
} else {
|
|
||||||
// console.debug('Grant type check passed');
|
|
||||||
}
|
}
|
||||||
|
// console.debug('Grant type check passed');
|
||||||
|
|
||||||
let tokenResponse: OAuth2TokenResponse = {};
|
let tokenResponse: OAuth2TokenResponse = {};
|
||||||
try {
|
try {
|
||||||
@ -93,6 +91,7 @@ export class OAuth2TokenController {
|
|||||||
} catch (e) {
|
} catch (e) {
|
||||||
return OAuth2Response.error(url, e as OAuth2Error);
|
return OAuth2Response.error(url, e as OAuth2Error);
|
||||||
}
|
}
|
||||||
|
|
||||||
throw new ServerError('Internal error');
|
throw new ServerError('Internal error');
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -47,7 +47,7 @@ export async function authorizationCode(
|
|||||||
throw new InvalidGrant('Code was issued by another client');
|
throw new InvalidGrant('Code was issued by another client');
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!OAuth2Codes.checkTTL(code)) {
|
if (!OAuth2Tokens.checkTTL(code)) {
|
||||||
throw new InvalidGrant('Code has already expired');
|
throw new InvalidGrant('Code has already expired');
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -94,7 +94,7 @@ export async function authorizationCode(
|
|||||||
respObj.refresh_token = await OAuth2RefreshTokens.create(userId, clientId, scope);
|
respObj.refresh_token = await OAuth2RefreshTokens.create(userId, clientId, scope);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error(err);
|
console.error(err);
|
||||||
throw new ServerError('Failed to call refreshToken.create function');
|
throw new ServerError('Failed to call refreshTokens.create function');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -107,7 +107,7 @@ export async function authorizationCode(
|
|||||||
);
|
);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error(err);
|
console.error(err);
|
||||||
throw new ServerError('Failed to call accessToken.create function');
|
throw new ServerError('Failed to call accessTokens.create function');
|
||||||
}
|
}
|
||||||
|
|
||||||
respObj.expires_in = OAuth2Tokens.tokenTtl;
|
respObj.expires_in = OAuth2Tokens.tokenTtl;
|
||||||
@ -117,7 +117,7 @@ export async function authorizationCode(
|
|||||||
await OAuth2Codes.removeByCode(providedCode);
|
await OAuth2Codes.removeByCode(providedCode);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error(err);
|
console.error(err);
|
||||||
throw new ServerError('Failed to call code.removeByCode function');
|
throw new ServerError('Failed to call codes.removeByCode function');
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cleanScope.includes('openid')) {
|
if (cleanScope.includes('openid')) {
|
||||||
|
@ -21,7 +21,6 @@ export async function refreshToken(
|
|||||||
providedToken: string
|
providedToken: string
|
||||||
): Promise<OAuth2TokenResponse> {
|
): Promise<OAuth2TokenResponse> {
|
||||||
let user: User | undefined = undefined;
|
let user: User | undefined = undefined;
|
||||||
let ttl: number | null = null;
|
|
||||||
let refreshToken: OAuth2RefreshToken | undefined = undefined;
|
let refreshToken: OAuth2RefreshToken | undefined = undefined;
|
||||||
let accessToken: OAuth2AccessToken | undefined = undefined;
|
let accessToken: OAuth2AccessToken | undefined = undefined;
|
||||||
|
|
||||||
@ -39,8 +38,8 @@ export async function refreshToken(
|
|||||||
throw new ServerError('Failed to call refreshToken.fetchByToken function');
|
throw new ServerError('Failed to call refreshToken.fetchByToken function');
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!refreshToken) {
|
if (!refreshToken || !OAuth2Tokens.checkTTL(refreshToken)) {
|
||||||
throw new InvalidGrant('Refresh token not found');
|
throw new InvalidGrant('Refresh token not found or it is already expired');
|
||||||
}
|
}
|
||||||
|
|
||||||
if (refreshToken.clientId !== client.id) {
|
if (refreshToken.clientId !== client.id) {
|
||||||
@ -68,31 +67,36 @@ export async function refreshToken(
|
|||||||
throw new ServerError('Failed to call accessToken.fetchByUserIdClientId function');
|
throw new ServerError('Failed to call accessToken.fetchByUserIdClientId function');
|
||||||
}
|
}
|
||||||
|
|
||||||
if (accessToken) {
|
// Remove old token
|
||||||
ttl = OAuth2AccessTokens.getTTL(accessToken);
|
if (accessToken && OAuth2Tokens.checkTTL(accessToken)) {
|
||||||
|
await OAuth2Tokens.remove(accessToken);
|
||||||
if (!ttl) {
|
|
||||||
accessToken = undefined;
|
|
||||||
} else {
|
|
||||||
resObj.access_token = accessToken.token;
|
|
||||||
resObj.expires_in = ttl;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!accessToken) {
|
// Remove old refresh token
|
||||||
try {
|
await OAuth2Tokens.remove(refreshToken);
|
||||||
resObj.access_token = await OAuth2AccessTokens.create(
|
|
||||||
user.id,
|
|
||||||
client.client_id,
|
|
||||||
refreshToken.scope || '',
|
|
||||||
OAuth2Tokens.tokenTtl
|
|
||||||
);
|
|
||||||
} catch (err) {
|
|
||||||
throw new ServerError('Failed to call accessToken.create function');
|
|
||||||
}
|
|
||||||
|
|
||||||
resObj.expires_in = OAuth2Tokens.tokenTtl;
|
try {
|
||||||
|
resObj.refresh_token = await OAuth2RefreshTokens.create(
|
||||||
|
user.id,
|
||||||
|
client.client_id,
|
||||||
|
refreshToken.scope || ''
|
||||||
|
);
|
||||||
|
} catch (err) {
|
||||||
|
throw new ServerError('Failed to call refreshTokens.create function');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
resObj.access_token = await OAuth2AccessTokens.create(
|
||||||
|
user.id,
|
||||||
|
client.client_id,
|
||||||
|
refreshToken.scope || '',
|
||||||
|
OAuth2Tokens.tokenTtl
|
||||||
|
);
|
||||||
|
} catch (err) {
|
||||||
|
throw new ServerError('Failed to call accessTokens.create function');
|
||||||
|
}
|
||||||
|
|
||||||
|
resObj.expires_in = OAuth2Tokens.tokenTtl;
|
||||||
|
|
||||||
return resObj;
|
return resObj;
|
||||||
}
|
}
|
||||||
|
@ -216,6 +216,19 @@ export class OAuth2Clients {
|
|||||||
const filterText = `%${filters?.filter?.toLowerCase()}%`;
|
const filterText = `%${filters?.filter?.toLowerCase()}%`;
|
||||||
const limit = filters?.limit || 20;
|
const limit = filters?.limit || 20;
|
||||||
|
|
||||||
|
const whereQuery = and(
|
||||||
|
!filters?.listAll
|
||||||
|
? or(eq(oauth2Client.ownerId, subject.id), eq(oauth2ClientManager.userId, subject.id))
|
||||||
|
: undefined,
|
||||||
|
filters?.clientId ? eq(oauth2Client.client_id, filters.clientId) : undefined,
|
||||||
|
filters?.filter
|
||||||
|
? or(
|
||||||
|
sql`lower(${oauth2Client.title}) like ${filterText}`,
|
||||||
|
like(oauth2Client.client_id, filterText)
|
||||||
|
)
|
||||||
|
: undefined
|
||||||
|
);
|
||||||
|
|
||||||
// LEFT JOINs below will include more rows than we allow with LIMIT,
|
// LEFT JOINs below will include more rows than we allow with LIMIT,
|
||||||
// so we need to do a subquery with the limiting first.
|
// so we need to do a subquery with the limiting first.
|
||||||
// The LEFT JOIN in the subquery only contributes to the WHERE clause
|
// The LEFT JOIN in the subquery only contributes to the WHERE clause
|
||||||
@ -224,20 +237,7 @@ export class OAuth2Clients {
|
|||||||
.select({ id: oauth2Client.id })
|
.select({ id: oauth2Client.id })
|
||||||
.from(oauth2Client)
|
.from(oauth2Client)
|
||||||
.leftJoin(oauth2ClientManager, eq(oauth2ClientManager.clientId, oauth2Client.id))
|
.leftJoin(oauth2ClientManager, eq(oauth2ClientManager.clientId, oauth2Client.id))
|
||||||
.where(
|
.where(whereQuery)
|
||||||
and(
|
|
||||||
!filters?.listAll
|
|
||||||
? or(eq(oauth2Client.ownerId, subject.id), eq(oauth2ClientManager.userId, subject.id))
|
|
||||||
: undefined,
|
|
||||||
filters?.clientId ? eq(oauth2Client.client_id, filters.clientId) : undefined,
|
|
||||||
filters?.filter
|
|
||||||
? or(
|
|
||||||
sql`lower(${oauth2Client.title}) like ${filterText}`,
|
|
||||||
like(oauth2Client.client_id, filterText)
|
|
||||||
)
|
|
||||||
: undefined
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.groupBy(oauth2Client.id)
|
.groupBy(oauth2Client.id)
|
||||||
.limit(limit)
|
.limit(limit)
|
||||||
.offset(filters?.offset || 0)
|
.offset(filters?.offset || 0)
|
||||||
@ -247,8 +247,9 @@ export class OAuth2Clients {
|
|||||||
.select({
|
.select({
|
||||||
rowCount: count(oauth2Client.id).mapWith(Number)
|
rowCount: count(oauth2Client.id).mapWith(Number)
|
||||||
})
|
})
|
||||||
.from(allowedClients)
|
.from(oauth2Client)
|
||||||
.innerJoin(oauth2Client, eq(allowedClients.id, oauth2Client.id));
|
.leftJoin(oauth2ClientManager, eq(oauth2ClientManager.clientId, oauth2Client.id))
|
||||||
|
.where(whereQuery);
|
||||||
|
|
||||||
const junkList = await DB.drizzle
|
const junkList = await DB.drizzle
|
||||||
.select({
|
.select({
|
||||||
|
@ -112,6 +112,26 @@ export class OAuth2Tokens {
|
|||||||
static async remove(token: OAuth2Token) {
|
static async remove(token: OAuth2Token) {
|
||||||
await DB.drizzle.delete(oauth2Token).where(eq(oauth2Token.id, token.id));
|
await DB.drizzle.delete(oauth2Token).where(eq(oauth2Token.id, token.id));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static async invalidate(token: OAuth2Token) {
|
||||||
|
await DB.drizzle
|
||||||
|
.update(oauth2Token)
|
||||||
|
.set({ expires_at: new Date() })
|
||||||
|
.where(eq(oauth2Token.id, token.id));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check token for expiry
|
||||||
|
* @param token Access token to check
|
||||||
|
* @returns true if still valid
|
||||||
|
*/
|
||||||
|
static checkTTL(token: OAuth2AccessToken): boolean {
|
||||||
|
return new Date().getTime() < new Date(token.expires_at).getTime();
|
||||||
|
}
|
||||||
|
|
||||||
|
static getTTL(token: OAuth2AccessToken): number {
|
||||||
|
return new Date(token.expires_at).getTime() - Date.now();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export class OAuth2Codes {
|
export class OAuth2Codes {
|
||||||
@ -184,10 +204,6 @@ export class OAuth2Codes {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static checkTTL(code: OAuth2Code): boolean {
|
|
||||||
return new Date(code.expires_at).getTime() > Date.now();
|
|
||||||
}
|
|
||||||
|
|
||||||
static getCodeChallenge(code: OAuth2Code) {
|
static getCodeChallenge(code: OAuth2Code) {
|
||||||
return {
|
return {
|
||||||
method: code.code_challenge_method,
|
method: code.code_challenge_method,
|
||||||
@ -241,14 +257,6 @@ export class OAuth2AccessTokens {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
static checkTTL(token: OAuth2AccessToken): boolean {
|
|
||||||
return new Date() < new Date(token.expires_at);
|
|
||||||
}
|
|
||||||
|
|
||||||
static getTTL(token: OAuth2AccessToken): number {
|
|
||||||
return new Date(token.expires_at).getTime() - Date.now();
|
|
||||||
}
|
|
||||||
|
|
||||||
static async fetchByUserIdClientId(
|
static async fetchByUserIdClientId(
|
||||||
userId: number,
|
userId: number,
|
||||||
clientId: string
|
clientId: string
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
</svelte:head>
|
</svelte:head>
|
||||||
|
|
||||||
<TitleRow>
|
<TitleRow>
|
||||||
<h1>{$t('admin.oauth2.title')}</h1>
|
<h1>{$t('admin.oauth2.title')} ({data.meta.rowCount})</h1>
|
||||||
{#if data.createPrivileges}
|
{#if data.createPrivileges}
|
||||||
<a href="oauth2/new">{$t('admin.oauth2.new')}</a>
|
<a href="oauth2/new">{$t('admin.oauth2.new')}</a>
|
||||||
{/if}
|
{/if}
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
<title>{$t('admin.users.title')} - {env.PUBLIC_SITE_NAME} {$t('admin.title')}</title>
|
<title>{$t('admin.users.title')} - {env.PUBLIC_SITE_NAME} {$t('admin.title')}</title>
|
||||||
</svelte:head>
|
</svelte:head>
|
||||||
|
|
||||||
<h1>{$t('admin.users.title')}</h1>
|
<h1>{$t('admin.users.title')} ({data.meta.rowCount})</h1>
|
||||||
|
|
||||||
<ColumnView>
|
<ColumnView>
|
||||||
<form action="" method="get">
|
<form action="" method="get">
|
||||||
|
Loading…
Reference in New Issue
Block a user