fix: token refresh locking (#613)

* fix: kv use tls explicit

* fix: token refresh locking

* remove logs

* compile fix

* compile fix

---------

Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
This commit is contained in:
Igor Monadical
2025-09-05 23:03:24 -04:00
committed by GitHub
parent 08d88ec349
commit 7f5a4c9ddc
9 changed files with 159 additions and 62 deletions

View File

@@ -8,10 +8,11 @@ import { assertCustomSession, CustomSession } from "./types";
import { Session } from "next-auth";
import { SessionAutoRefresh } from "./SessionAutoRefresh";
import { REFRESH_ACCESS_TOKEN_ERROR } from "./auth";
import { assertExists } from "./utils";
type AuthContextType = (
| { status: "loading" }
| { status: "refreshing" }
| { status: "refreshing"; user: CustomSession["user"] }
| { status: "unauthenticated"; error?: string }
| {
status: "authenticated";
@@ -41,7 +42,10 @@ export function AuthProvider({ children }: { children: React.ReactNode }) {
return { status };
}
case true: {
return { status: "refreshing" as const };
return {
status: "refreshing" as const,
user: assertExists(customSession).user,
};
}
default: {
const _: never = sessionIsHere;

View File

@@ -15,6 +15,7 @@ const REFRESH_BEFORE = REFRESH_ACCESS_TOKEN_BEFORE;
export function SessionAutoRefresh({ children }) {
const auth = useAuth();
const accessTokenExpires =
auth.status === "authenticated" ? auth.accessTokenExpires : null;
@@ -23,17 +24,16 @@ export function SessionAutoRefresh({ children }) {
// and not too slow (debuggable)
const INTERVAL_REFRESH_MS = 5000;
const interval = setInterval(() => {
if (accessTokenExpires !== null) {
const timeLeft = accessTokenExpires - Date.now();
if (timeLeft < REFRESH_BEFORE) {
auth
.update()
.then(() => {})
.catch((e) => {
// note: 401 won't be considered error here
console.error("error refreshing auth token", e);
});
}
if (accessTokenExpires === null) return;
const timeLeft = accessTokenExpires - Date.now();
if (timeLeft < REFRESH_BEFORE) {
auth
.update()
.then(() => {})
.catch((e) => {
// note: 401 won't be considered error here
console.error("error refreshing auth token", e);
});
}
}, INTERVAL_REFRESH_MS);

View File

@@ -2,7 +2,11 @@ import { AuthOptions } from "next-auth";
import AuthentikProvider from "next-auth/providers/authentik";
import type { JWT } from "next-auth/jwt";
import { JWTWithAccessToken, CustomSession } from "./types";
import { assertExists, assertExistsAndNonEmptyString } from "./utils";
import {
assertExists,
assertExistsAndNonEmptyString,
assertNotExists,
} from "./utils";
import {
REFRESH_ACCESS_TOKEN_BEFORE,
REFRESH_ACCESS_TOKEN_ERROR,
@@ -12,14 +16,10 @@ import {
setTokenCache,
deleteTokenCache,
} from "./redisTokenCache";
import { tokenCacheRedis } from "./redisClient";
import { tokenCacheRedis, redlock } from "./redisClient";
import { isBuildPhase } from "./next";
// REFRESH_ACCESS_TOKEN_BEFORE because refresh is based on access token expiration (imagine we cache it 30 days)
const TOKEN_CACHE_TTL = REFRESH_ACCESS_TOKEN_BEFORE;
const refreshLocks = new Map<string, Promise<JWTWithAccessToken>>();
const CLIENT_ID = !isBuildPhase
? assertExistsAndNonEmptyString(process.env.AUTHENTIK_CLIENT_ID)
: "noop";
@@ -45,31 +45,48 @@ export const authOptions: AuthOptions = {
},
callbacks: {
async jwt({ token, account, user }) {
const KEY = `token:${token.sub}`;
if (account && !account.access_token) {
await deleteTokenCache(tokenCacheRedis, `token:${token.sub}`);
}
if (account && user) {
// called only on first login
// XXX account.expires_in used in example is not defined for authentik backend, but expires_at is
const expiresAtS = assertExists(account.expires_at);
const expiresAtMs = expiresAtS * 1000;
if (!account.access_token) {
await deleteTokenCache(tokenCacheRedis, KEY);
} else {
if (account.access_token) {
const expiresAtS = assertExists(account.expires_at);
const expiresAtMs = expiresAtS * 1000;
const jwtToken: JWTWithAccessToken = {
...token,
accessToken: account.access_token,
accessTokenExpires: expiresAtMs,
refreshToken: account.refresh_token,
};
await setTokenCache(tokenCacheRedis, KEY, {
token: jwtToken,
timestamp: Date.now(),
});
return jwtToken;
if (jwtToken.error) {
await deleteTokenCache(tokenCacheRedis, `token:${token.sub}`);
} else {
assertNotExists(
jwtToken.error,
`panic! trying to cache token with error in jwt: ${jwtToken.error}`,
);
await setTokenCache(tokenCacheRedis, `token:${token.sub}`, {
token: jwtToken,
timestamp: Date.now(),
});
return jwtToken;
}
}
}
const currentToken = await getTokenCache(tokenCacheRedis, KEY);
const currentToken = await getTokenCache(
tokenCacheRedis,
`token:${token.sub}`,
);
console.debug(
"currentToken from cache",
JSON.stringify(currentToken, null, 2),
"will be returned?",
currentToken && Date.now() < currentToken.token.accessTokenExpires,
);
if (currentToken && Date.now() < currentToken.token.accessTokenExpires) {
return currentToken.token;
}
@@ -97,20 +114,22 @@ export const authOptions: AuthOptions = {
async function lockedRefreshAccessToken(
token: JWT,
): Promise<JWTWithAccessToken> {
const lockKey = `${token.sub}-refresh`;
const lockKey = `${token.sub}-lock`;
const existingRefresh = refreshLocks.get(lockKey);
if (existingRefresh) {
return await existingRefresh;
}
const refreshPromise = (async () => {
try {
return redlock
.using([lockKey], 10000, async () => {
const cached = await getTokenCache(tokenCacheRedis, `token:${token.sub}`);
if (cached)
console.debug(
"received cached token. to delete?",
Date.now() - cached.timestamp > TOKEN_CACHE_TTL,
);
else console.debug("no cached token received");
if (cached) {
if (Date.now() - cached.timestamp > TOKEN_CACHE_TTL) {
await deleteTokenCache(tokenCacheRedis, `token:${token.sub}`);
} else if (Date.now() < cached.token.accessTokenExpires) {
console.debug("returning cached token", cached.token);
return cached.token;
}
}
@@ -118,19 +137,35 @@ async function lockedRefreshAccessToken(
const currentToken = cached?.token || (token as JWTWithAccessToken);
const newToken = await refreshAccessToken(currentToken);
console.debug("current token during refresh", currentToken);
console.debug("new token during refresh", newToken);
if (newToken.error) {
await deleteTokenCache(tokenCacheRedis, `token:${token.sub}`);
return newToken;
}
assertNotExists(
newToken.error,
`panic! trying to cache token with error during refresh: ${newToken.error}`,
);
await setTokenCache(tokenCacheRedis, `token:${token.sub}`, {
token: newToken,
timestamp: Date.now(),
});
return newToken;
} finally {
setTimeout(() => refreshLocks.delete(lockKey), 100);
}
})();
refreshLocks.set(lockKey, refreshPromise);
return refreshPromise;
})
.catch((e) => {
console.error("error refreshing token", e);
deleteTokenCache(tokenCacheRedis, `token:${token.sub}`).catch((e) => {
console.error("error deleting errored token", e);
});
return {
...token,
error: REFRESH_ACCESS_TOKEN_ERROR,
} as JWTWithAccessToken;
});
}
async function refreshAccessToken(token: JWT): Promise<JWTWithAccessToken> {

View File

@@ -1,20 +1,29 @@
import Redis from "ioredis";
import { isBuildPhase } from "./next";
import Redlock, { ResourceLockedError } from "redlock";
export type RedisClient = Pick<Redis, "get" | "setex" | "del">;
export type RedlockClient = {
using: <T>(
keys: string | string[],
ttl: number,
cb: () => Promise<T>,
) => Promise<T>;
};
const KV_USE_TLS = process.env.KV_USE_TLS
? process.env.KV_USE_TLS === "true"
: undefined;
let redisClient: Redis | null = null;
const getRedisClient = (): RedisClient => {
if (redisClient) return redisClient;
const redisUrl = process.env.KV_URL;
if (!redisUrl) {
throw new Error("KV_URL environment variable is required");
}
const redis = new Redis(redisUrl, {
redisClient = new Redis(redisUrl, {
maxRetriesPerRequest: 3,
lazyConnect: true,
...(KV_USE_TLS === true
? {
tls: {},
@@ -22,18 +31,11 @@ const getRedisClient = (): RedisClient => {
: {}),
});
redis.on("error", (error) => {
redisClient.on("error", (error) => {
console.error("Redis error:", error);
});
// not necessary but will indicate redis config errors by failfast at startup
// happens only once; after that connection is allowed to die and the lib is assumed to be able to restore it eventually
redis.connect().catch((e) => {
console.error("Failed to connect to Redis:", e);
process.exit(1);
});
return redis;
return redisClient;
};
// next.js buildtime usage - we want to isolate next.js "build" time concepts here
@@ -52,4 +54,25 @@ const noopClient: RedisClient = (() => {
del: noopDel,
};
})();
const noopRedlock: RedlockClient = {
using: <T>(resource: string | string[], ttl: number, cb: () => Promise<T>) =>
cb(),
};
export const redlock: RedlockClient = isBuildPhase
? noopRedlock
: (() => {
const r = new Redlock([getRedisClient()], {});
r.on("error", (error) => {
if (error instanceof ResourceLockedError) {
return;
}
// Log all other errors.
console.error(error);
});
return r;
})();
export const tokenCacheRedis = isBuildPhase ? noopClient : getRedisClient();

View File

@@ -9,7 +9,6 @@ const TokenCacheEntrySchema = z.object({
accessToken: z.string(),
accessTokenExpires: z.number(),
refreshToken: z.string().optional(),
error: z.string().optional(),
}),
timestamp: z.number(),
});
@@ -46,14 +45,15 @@ export async function getTokenCache(
}
}
const TTL_SECONDS = 30 * 24 * 60 * 60;
export async function setTokenCache(
redis: KV,
key: string,
value: TokenCacheEntry,
): Promise<void> {
const encodedValue = TokenCacheEntryCodec.encode(value);
const ttlSeconds = Math.floor(REFRESH_ACCESS_TOKEN_BEFORE / 1000);
await redis.setex(key, ttlSeconds, encodedValue);
await redis.setex(key, TTL_SECONDS, encodedValue);
}
export async function deleteTokenCache(redis: KV, key: string): Promise<void> {

View File

@@ -2,6 +2,7 @@ import { useAuth } from "./AuthProvider";
export const useUserName = (): string | null | undefined => {
const auth = useAuth();
if (auth.status !== "authenticated") return undefined;
if (auth.status !== "authenticated" && auth.status !== "refreshing")
return undefined;
return auth.user?.name || null;
};

View File

@@ -158,6 +158,17 @@ export const assertExists = <T>(
return value;
};
export const assertNotExists = <T>(
value: T | null | undefined,
err?: string,
): void => {
if (value !== null && value !== undefined) {
throw new Error(
`Assertion failed: ${err ?? "value is not null or undefined"}`,
);
}
};
export const assertExistsAndNonEmptyString = (
value: string | null | undefined,
): NonEmptyString =>

View File

@@ -45,6 +45,7 @@
"react-markdown": "^9.0.0",
"react-qr-code": "^2.0.12",
"react-select-search": "^4.1.7",
"redlock": "5.0.0-beta.2",
"sass": "^1.63.6",
"simple-peer": "^9.11.1",
"tailwindcss": "^3.3.2",

22
www/pnpm-lock.yaml generated
View File

@@ -106,6 +106,9 @@ importers:
react-select-search:
specifier: ^4.1.7
version: 4.1.8(prop-types@15.8.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
redlock:
specifier: 5.0.0-beta.2
version: 5.0.0-beta.2
sass:
specifier: ^1.63.6
version: 1.90.0
@@ -6566,6 +6569,12 @@ packages:
sass:
optional: true
node-abort-controller@3.1.1:
resolution:
{
integrity: sha512-AGK2yQKIjRuqnc6VkX2Xj5d+QW8xZ87pa1UK6yA6ouUyuxfHuMP6umE5QK7UmTeOAymo+Zx1Fxiuw9rVx8taHQ==,
}
node-addon-api@7.1.1:
resolution:
{
@@ -7433,6 +7442,13 @@ packages:
}
engines: { node: ">=4" }
redlock@5.0.0-beta.2:
resolution:
{
integrity: sha512-2RDWXg5jgRptDrB1w9O/JgSZC0j7y4SlaXnor93H/UJm/QyDiFgBKNtrh0TI6oCXqYSaSoXxFh6Sd3VtYfhRXw==,
}
engines: { node: ">=12" }
redux-thunk@3.1.0:
resolution:
{
@@ -13812,6 +13828,8 @@ snapshots:
- "@babel/core"
- babel-plugin-macros
node-abort-controller@3.1.1: {}
node-addon-api@7.1.1:
optional: true
@@ -14290,6 +14308,10 @@ snapshots:
dependencies:
redis-errors: 1.2.0
redlock@5.0.0-beta.2:
dependencies:
node-abort-controller: 3.1.1
redux-thunk@3.1.0(redux@5.0.1):
dependencies:
redux: 5.0.1