fix: token refresh locking (#613)

* fix: kv use tls explicit

* fix: token refresh locking

* remove logs

* compile fix

* compile fix

---------

Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
This commit is contained in:
Igor Monadical
2025-09-05 23:03:24 -04:00
committed by GitHub
parent 08d88ec349
commit 7f5a4c9ddc
9 changed files with 159 additions and 62 deletions

View File

@@ -8,10 +8,11 @@ import { assertCustomSession, CustomSession } from "./types";
import { Session } from "next-auth"; import { Session } from "next-auth";
import { SessionAutoRefresh } from "./SessionAutoRefresh"; import { SessionAutoRefresh } from "./SessionAutoRefresh";
import { REFRESH_ACCESS_TOKEN_ERROR } from "./auth"; import { REFRESH_ACCESS_TOKEN_ERROR } from "./auth";
import { assertExists } from "./utils";
type AuthContextType = ( type AuthContextType = (
| { status: "loading" } | { status: "loading" }
| { status: "refreshing" } | { status: "refreshing"; user: CustomSession["user"] }
| { status: "unauthenticated"; error?: string } | { status: "unauthenticated"; error?: string }
| { | {
status: "authenticated"; status: "authenticated";
@@ -41,7 +42,10 @@ export function AuthProvider({ children }: { children: React.ReactNode }) {
return { status }; return { status };
} }
case true: { case true: {
return { status: "refreshing" as const }; return {
status: "refreshing" as const,
user: assertExists(customSession).user,
};
} }
default: { default: {
const _: never = sessionIsHere; const _: never = sessionIsHere;

View File

@@ -15,6 +15,7 @@ const REFRESH_BEFORE = REFRESH_ACCESS_TOKEN_BEFORE;
export function SessionAutoRefresh({ children }) { export function SessionAutoRefresh({ children }) {
const auth = useAuth(); const auth = useAuth();
const accessTokenExpires = const accessTokenExpires =
auth.status === "authenticated" ? auth.accessTokenExpires : null; auth.status === "authenticated" ? auth.accessTokenExpires : null;
@@ -23,17 +24,16 @@ export function SessionAutoRefresh({ children }) {
// and not too slow (debuggable) // and not too slow (debuggable)
const INTERVAL_REFRESH_MS = 5000; const INTERVAL_REFRESH_MS = 5000;
const interval = setInterval(() => { const interval = setInterval(() => {
if (accessTokenExpires !== null) { if (accessTokenExpires === null) return;
const timeLeft = accessTokenExpires - Date.now(); const timeLeft = accessTokenExpires - Date.now();
if (timeLeft < REFRESH_BEFORE) { if (timeLeft < REFRESH_BEFORE) {
auth auth
.update() .update()
.then(() => {}) .then(() => {})
.catch((e) => { .catch((e) => {
// note: 401 won't be considered error here // note: 401 won't be considered error here
console.error("error refreshing auth token", e); console.error("error refreshing auth token", e);
}); });
}
} }
}, INTERVAL_REFRESH_MS); }, INTERVAL_REFRESH_MS);

View File

@@ -2,7 +2,11 @@ import { AuthOptions } from "next-auth";
import AuthentikProvider from "next-auth/providers/authentik"; import AuthentikProvider from "next-auth/providers/authentik";
import type { JWT } from "next-auth/jwt"; import type { JWT } from "next-auth/jwt";
import { JWTWithAccessToken, CustomSession } from "./types"; import { JWTWithAccessToken, CustomSession } from "./types";
import { assertExists, assertExistsAndNonEmptyString } from "./utils"; import {
assertExists,
assertExistsAndNonEmptyString,
assertNotExists,
} from "./utils";
import { import {
REFRESH_ACCESS_TOKEN_BEFORE, REFRESH_ACCESS_TOKEN_BEFORE,
REFRESH_ACCESS_TOKEN_ERROR, REFRESH_ACCESS_TOKEN_ERROR,
@@ -12,14 +16,10 @@ import {
setTokenCache, setTokenCache,
deleteTokenCache, deleteTokenCache,
} from "./redisTokenCache"; } from "./redisTokenCache";
import { tokenCacheRedis } from "./redisClient"; import { tokenCacheRedis, redlock } from "./redisClient";
import { isBuildPhase } from "./next"; import { isBuildPhase } from "./next";
// REFRESH_ACCESS_TOKEN_BEFORE because refresh is based on access token expiration (imagine we cache it 30 days)
const TOKEN_CACHE_TTL = REFRESH_ACCESS_TOKEN_BEFORE; const TOKEN_CACHE_TTL = REFRESH_ACCESS_TOKEN_BEFORE;
const refreshLocks = new Map<string, Promise<JWTWithAccessToken>>();
const CLIENT_ID = !isBuildPhase const CLIENT_ID = !isBuildPhase
? assertExistsAndNonEmptyString(process.env.AUTHENTIK_CLIENT_ID) ? assertExistsAndNonEmptyString(process.env.AUTHENTIK_CLIENT_ID)
: "noop"; : "noop";
@@ -45,31 +45,48 @@ export const authOptions: AuthOptions = {
}, },
callbacks: { callbacks: {
async jwt({ token, account, user }) { async jwt({ token, account, user }) {
const KEY = `token:${token.sub}`; if (account && !account.access_token) {
await deleteTokenCache(tokenCacheRedis, `token:${token.sub}`);
}
if (account && user) { if (account && user) {
// called only on first login // called only on first login
// XXX account.expires_in used in example is not defined for authentik backend, but expires_at is // XXX account.expires_in used in example is not defined for authentik backend, but expires_at is
const expiresAtS = assertExists(account.expires_at); if (account.access_token) {
const expiresAtMs = expiresAtS * 1000; const expiresAtS = assertExists(account.expires_at);
if (!account.access_token) { const expiresAtMs = expiresAtS * 1000;
await deleteTokenCache(tokenCacheRedis, KEY);
} else {
const jwtToken: JWTWithAccessToken = { const jwtToken: JWTWithAccessToken = {
...token, ...token,
accessToken: account.access_token, accessToken: account.access_token,
accessTokenExpires: expiresAtMs, accessTokenExpires: expiresAtMs,
refreshToken: account.refresh_token, refreshToken: account.refresh_token,
}; };
await setTokenCache(tokenCacheRedis, KEY, { if (jwtToken.error) {
token: jwtToken, await deleteTokenCache(tokenCacheRedis, `token:${token.sub}`);
timestamp: Date.now(), } else {
}); assertNotExists(
return jwtToken; jwtToken.error,
`panic! trying to cache token with error in jwt: ${jwtToken.error}`,
);
await setTokenCache(tokenCacheRedis, `token:${token.sub}`, {
token: jwtToken,
timestamp: Date.now(),
});
return jwtToken;
}
} }
} }
const currentToken = await getTokenCache(tokenCacheRedis, KEY); const currentToken = await getTokenCache(
tokenCacheRedis,
`token:${token.sub}`,
);
console.debug(
"currentToken from cache",
JSON.stringify(currentToken, null, 2),
"will be returned?",
currentToken && Date.now() < currentToken.token.accessTokenExpires,
);
if (currentToken && Date.now() < currentToken.token.accessTokenExpires) { if (currentToken && Date.now() < currentToken.token.accessTokenExpires) {
return currentToken.token; return currentToken.token;
} }
@@ -97,20 +114,22 @@ export const authOptions: AuthOptions = {
async function lockedRefreshAccessToken( async function lockedRefreshAccessToken(
token: JWT, token: JWT,
): Promise<JWTWithAccessToken> { ): Promise<JWTWithAccessToken> {
const lockKey = `${token.sub}-refresh`; const lockKey = `${token.sub}-lock`;
const existingRefresh = refreshLocks.get(lockKey); return redlock
if (existingRefresh) { .using([lockKey], 10000, async () => {
return await existingRefresh;
}
const refreshPromise = (async () => {
try {
const cached = await getTokenCache(tokenCacheRedis, `token:${token.sub}`); const cached = await getTokenCache(tokenCacheRedis, `token:${token.sub}`);
if (cached)
console.debug(
"received cached token. to delete?",
Date.now() - cached.timestamp > TOKEN_CACHE_TTL,
);
else console.debug("no cached token received");
if (cached) { if (cached) {
if (Date.now() - cached.timestamp > TOKEN_CACHE_TTL) { if (Date.now() - cached.timestamp > TOKEN_CACHE_TTL) {
await deleteTokenCache(tokenCacheRedis, `token:${token.sub}`); await deleteTokenCache(tokenCacheRedis, `token:${token.sub}`);
} else if (Date.now() < cached.token.accessTokenExpires) { } else if (Date.now() < cached.token.accessTokenExpires) {
console.debug("returning cached token", cached.token);
return cached.token; return cached.token;
} }
} }
@@ -118,19 +137,35 @@ async function lockedRefreshAccessToken(
const currentToken = cached?.token || (token as JWTWithAccessToken); const currentToken = cached?.token || (token as JWTWithAccessToken);
const newToken = await refreshAccessToken(currentToken); const newToken = await refreshAccessToken(currentToken);
console.debug("current token during refresh", currentToken);
console.debug("new token during refresh", newToken);
if (newToken.error) {
await deleteTokenCache(tokenCacheRedis, `token:${token.sub}`);
return newToken;
}
assertNotExists(
newToken.error,
`panic! trying to cache token with error during refresh: ${newToken.error}`,
);
await setTokenCache(tokenCacheRedis, `token:${token.sub}`, { await setTokenCache(tokenCacheRedis, `token:${token.sub}`, {
token: newToken, token: newToken,
timestamp: Date.now(), timestamp: Date.now(),
}); });
return newToken; return newToken;
} finally { })
setTimeout(() => refreshLocks.delete(lockKey), 100); .catch((e) => {
} console.error("error refreshing token", e);
})(); deleteTokenCache(tokenCacheRedis, `token:${token.sub}`).catch((e) => {
console.error("error deleting errored token", e);
refreshLocks.set(lockKey, refreshPromise); });
return refreshPromise; return {
...token,
error: REFRESH_ACCESS_TOKEN_ERROR,
} as JWTWithAccessToken;
});
} }
async function refreshAccessToken(token: JWT): Promise<JWTWithAccessToken> { async function refreshAccessToken(token: JWT): Promise<JWTWithAccessToken> {

View File

@@ -1,20 +1,29 @@
import Redis from "ioredis"; import Redis from "ioredis";
import { isBuildPhase } from "./next"; import { isBuildPhase } from "./next";
import Redlock, { ResourceLockedError } from "redlock";
export type RedisClient = Pick<Redis, "get" | "setex" | "del">; export type RedisClient = Pick<Redis, "get" | "setex" | "del">;
export type RedlockClient = {
using: <T>(
keys: string | string[],
ttl: number,
cb: () => Promise<T>,
) => Promise<T>;
};
const KV_USE_TLS = process.env.KV_USE_TLS const KV_USE_TLS = process.env.KV_USE_TLS
? process.env.KV_USE_TLS === "true" ? process.env.KV_USE_TLS === "true"
: undefined; : undefined;
let redisClient: Redis | null = null;
const getRedisClient = (): RedisClient => { const getRedisClient = (): RedisClient => {
if (redisClient) return redisClient;
const redisUrl = process.env.KV_URL; const redisUrl = process.env.KV_URL;
if (!redisUrl) { if (!redisUrl) {
throw new Error("KV_URL environment variable is required"); throw new Error("KV_URL environment variable is required");
} }
const redis = new Redis(redisUrl, { redisClient = new Redis(redisUrl, {
maxRetriesPerRequest: 3, maxRetriesPerRequest: 3,
lazyConnect: true,
...(KV_USE_TLS === true ...(KV_USE_TLS === true
? { ? {
tls: {}, tls: {},
@@ -22,18 +31,11 @@ const getRedisClient = (): RedisClient => {
: {}), : {}),
}); });
redis.on("error", (error) => { redisClient.on("error", (error) => {
console.error("Redis error:", error); console.error("Redis error:", error);
}); });
// not necessary but will indicate redis config errors by failfast at startup return redisClient;
// happens only once; after that connection is allowed to die and the lib is assumed to be able to restore it eventually
redis.connect().catch((e) => {
console.error("Failed to connect to Redis:", e);
process.exit(1);
});
return redis;
}; };
// next.js buildtime usage - we want to isolate next.js "build" time concepts here // next.js buildtime usage - we want to isolate next.js "build" time concepts here
@@ -52,4 +54,25 @@ const noopClient: RedisClient = (() => {
del: noopDel, del: noopDel,
}; };
})(); })();
const noopRedlock: RedlockClient = {
using: <T>(resource: string | string[], ttl: number, cb: () => Promise<T>) =>
cb(),
};
export const redlock: RedlockClient = isBuildPhase
? noopRedlock
: (() => {
const r = new Redlock([getRedisClient()], {});
r.on("error", (error) => {
if (error instanceof ResourceLockedError) {
return;
}
// Log all other errors.
console.error(error);
});
return r;
})();
export const tokenCacheRedis = isBuildPhase ? noopClient : getRedisClient(); export const tokenCacheRedis = isBuildPhase ? noopClient : getRedisClient();

View File

@@ -9,7 +9,6 @@ const TokenCacheEntrySchema = z.object({
accessToken: z.string(), accessToken: z.string(),
accessTokenExpires: z.number(), accessTokenExpires: z.number(),
refreshToken: z.string().optional(), refreshToken: z.string().optional(),
error: z.string().optional(),
}), }),
timestamp: z.number(), timestamp: z.number(),
}); });
@@ -46,14 +45,15 @@ export async function getTokenCache(
} }
} }
const TTL_SECONDS = 30 * 24 * 60 * 60;
export async function setTokenCache( export async function setTokenCache(
redis: KV, redis: KV,
key: string, key: string,
value: TokenCacheEntry, value: TokenCacheEntry,
): Promise<void> { ): Promise<void> {
const encodedValue = TokenCacheEntryCodec.encode(value); const encodedValue = TokenCacheEntryCodec.encode(value);
const ttlSeconds = Math.floor(REFRESH_ACCESS_TOKEN_BEFORE / 1000); await redis.setex(key, TTL_SECONDS, encodedValue);
await redis.setex(key, ttlSeconds, encodedValue);
} }
export async function deleteTokenCache(redis: KV, key: string): Promise<void> { export async function deleteTokenCache(redis: KV, key: string): Promise<void> {

View File

@@ -2,6 +2,7 @@ import { useAuth } from "./AuthProvider";
export const useUserName = (): string | null | undefined => { export const useUserName = (): string | null | undefined => {
const auth = useAuth(); const auth = useAuth();
if (auth.status !== "authenticated") return undefined; if (auth.status !== "authenticated" && auth.status !== "refreshing")
return undefined;
return auth.user?.name || null; return auth.user?.name || null;
}; };

View File

@@ -158,6 +158,17 @@ export const assertExists = <T>(
return value; return value;
}; };
export const assertNotExists = <T>(
value: T | null | undefined,
err?: string,
): void => {
if (value !== null && value !== undefined) {
throw new Error(
`Assertion failed: ${err ?? "value is not null or undefined"}`,
);
}
};
export const assertExistsAndNonEmptyString = ( export const assertExistsAndNonEmptyString = (
value: string | null | undefined, value: string | null | undefined,
): NonEmptyString => ): NonEmptyString =>

View File

@@ -45,6 +45,7 @@
"react-markdown": "^9.0.0", "react-markdown": "^9.0.0",
"react-qr-code": "^2.0.12", "react-qr-code": "^2.0.12",
"react-select-search": "^4.1.7", "react-select-search": "^4.1.7",
"redlock": "5.0.0-beta.2",
"sass": "^1.63.6", "sass": "^1.63.6",
"simple-peer": "^9.11.1", "simple-peer": "^9.11.1",
"tailwindcss": "^3.3.2", "tailwindcss": "^3.3.2",

22
www/pnpm-lock.yaml generated
View File

@@ -106,6 +106,9 @@ importers:
react-select-search: react-select-search:
specifier: ^4.1.7 specifier: ^4.1.7
version: 4.1.8(prop-types@15.8.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) version: 4.1.8(prop-types@15.8.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
redlock:
specifier: 5.0.0-beta.2
version: 5.0.0-beta.2
sass: sass:
specifier: ^1.63.6 specifier: ^1.63.6
version: 1.90.0 version: 1.90.0
@@ -6566,6 +6569,12 @@ packages:
sass: sass:
optional: true optional: true
node-abort-controller@3.1.1:
resolution:
{
integrity: sha512-AGK2yQKIjRuqnc6VkX2Xj5d+QW8xZ87pa1UK6yA6ouUyuxfHuMP6umE5QK7UmTeOAymo+Zx1Fxiuw9rVx8taHQ==,
}
node-addon-api@7.1.1: node-addon-api@7.1.1:
resolution: resolution:
{ {
@@ -7433,6 +7442,13 @@ packages:
} }
engines: { node: ">=4" } engines: { node: ">=4" }
redlock@5.0.0-beta.2:
resolution:
{
integrity: sha512-2RDWXg5jgRptDrB1w9O/JgSZC0j7y4SlaXnor93H/UJm/QyDiFgBKNtrh0TI6oCXqYSaSoXxFh6Sd3VtYfhRXw==,
}
engines: { node: ">=12" }
redux-thunk@3.1.0: redux-thunk@3.1.0:
resolution: resolution:
{ {
@@ -13812,6 +13828,8 @@ snapshots:
- "@babel/core" - "@babel/core"
- babel-plugin-macros - babel-plugin-macros
node-abort-controller@3.1.1: {}
node-addon-api@7.1.1: node-addon-api@7.1.1:
optional: true optional: true
@@ -14290,6 +14308,10 @@ snapshots:
dependencies: dependencies:
redis-errors: 1.2.0 redis-errors: 1.2.0
redlock@5.0.0-beta.2:
dependencies:
node-abort-controller: 3.1.1
redux-thunk@3.1.0(redux@5.0.1): redux-thunk@3.1.0(redux@5.0.1):
dependencies: dependencies:
redux: 5.0.1 redux: 5.0.1