diff --git a/packages/console/app/src/routes/zen/util/handler.ts b/packages/console/app/src/routes/zen/util/handler.ts index 9646cacd0..37708944d 100644 --- a/packages/console/app/src/routes/zen/util/handler.ts +++ b/packages/console/app/src/routes/zen/util/handler.ts @@ -38,6 +38,7 @@ type RetryOptions = { excludeProviders: string[] retryCount: number } +type BillingSource = "anonymous" | "free" | "byok" | "subscription" | "balance" export async function handler( input: APIEvent, @@ -51,6 +52,7 @@ export async function handler( type AuthInfo = Awaited> type ModelInfo = Awaited> type ProviderInfo = Awaited> + type CostInfo = ReturnType const MAX_FAILOVER_RETRIES = 3 const MAX_429_RETRIES = 3 @@ -139,21 +141,22 @@ export async function handler( "llm.error.code": res.status, "llm.error.message": res.statusText, }) + } - // Try another provider => stop retrying if using fallback provider - if ( - // ie. openai 404 error: Item with id 'msg_0ead8b004a3b165d0069436a6b6834819896da85b63b196a3f' not found. - res.status !== 404 && - // ie. cannot change codex model providers mid-session - modelInfo.stickyProvider !== "strict" && - modelInfo.fallbackProvider && - providerInfo.id !== modelInfo.fallbackProvider - ) { - return retriableRequest({ - excludeProviders: [...retry.excludeProviders, providerInfo.id], - retryCount: retry.retryCount + 1, - }) - } + // Try another provider => stop retrying if using fallback provider + if ( + res.status !== 200 && + // ie. openai 404 error: Item with id 'msg_0ead8b004a3b165d0069436a6b6834819896da85b63b196a3f' not found. + res.status !== 404 && + // ie. cannot change codex model providers mid-session + modelInfo.stickyProvider !== "strict" && + modelInfo.fallbackProvider && + providerInfo.id !== modelInfo.fallbackProvider + ) { + return retriableRequest({ + excludeProviders: [...retry.excludeProviders, providerInfo.id], + retryCount: retry.retryCount + 1, + }) } return { providerInfo, reqBody, res, startTimestamp } @@ -183,18 +186,25 @@ export async function handler( // Handle non-streaming response if (!isStream) { - const responseConverter = createResponseConverter(providerInfo.format, opts.format) const json = await res.json() - const body = JSON.stringify(responseConverter(json)) + const usageInfo = providerInfo.normalizeUsage(json.usage) + const costInfo = calculateCost(modelInfo, usageInfo) + await trialLimiter?.track(usageInfo) + await rateLimiter?.track() + await trackUsage(billingSource, authInfo, modelInfo, providerInfo, usageInfo, costInfo) + await reload(billingSource, authInfo, costInfo) + + const responseConverter = createResponseConverter(providerInfo.format, opts.format) + const body = JSON.stringify( + responseConverter({ + ...json, + cost: calculateOccuredCost(billingSource, costInfo), + }), + ) logger.metric({ response_length: body.length }) logger.debug("RESPONSE: " + body) dataDumper?.provideResponse(body) dataDumper?.flush() - const tokensInfo = providerInfo.normalizeUsage(json.usage) - await trialLimiter?.track(tokensInfo) - await rateLimiter?.track() - const costInfo = await trackUsage(authInfo, modelInfo, providerInfo, billingSource, tokensInfo) - await reload(authInfo, costInfo) return new Response(body, { status: resStatus, statusText: res.statusText, @@ -226,12 +236,16 @@ export async function handler( dataDumper?.flush() await rateLimiter?.track() const usage = usageParser.retrieve() + let cost = "0" if (usage) { - const tokensInfo = providerInfo.normalizeUsage(usage) - await trialLimiter?.track(tokensInfo) - const costInfo = await trackUsage(authInfo, modelInfo, providerInfo, billingSource, tokensInfo) - await reload(authInfo, costInfo) + const usageInfo = providerInfo.normalizeUsage(usage) + const costInfo = calculateCost(modelInfo, usageInfo) + await trialLimiter?.track(usageInfo) + await trackUsage(billingSource, authInfo, modelInfo, providerInfo, usageInfo, costInfo) + await reload(billingSource, authInfo, costInfo) + cost = calculateOccuredCost(billingSource, costInfo) } + c.enqueue(encoder.encode(usageParser.buidlCostChunk(cost))) c.close() return } @@ -283,7 +297,6 @@ export async function handler( return pump() }, }) - return new Response(stream, { status: resStatus, statusText: res.statusText, @@ -498,9 +511,9 @@ export async function handler( } } - function validateBilling(authInfo: AuthInfo, modelInfo: ModelInfo) { + function validateBilling(authInfo: AuthInfo, modelInfo: ModelInfo): BillingSource { if (!authInfo) return "anonymous" - if (authInfo.provider?.credentials) return "free" + if (authInfo.provider?.credentials) return "byok" if (authInfo.isFree) return "free" if (modelInfo.allowAnonymous) return "free" @@ -613,13 +626,7 @@ export async function handler( return res } - async function trackUsage( - authInfo: AuthInfo, - modelInfo: ModelInfo, - providerInfo: ProviderInfo, - billingSource: ReturnType, - usageInfo: UsageInfo, - ) { + function calculateCost(modelInfo: ModelInfo, usageInfo: UsageInfo) { const { inputTokens, outputTokens, reasoningTokens, cacheReadTokens, cacheWrite5mTokens, cacheWrite1hTokens } = usageInfo @@ -657,6 +664,33 @@ export async function handler( (cacheReadCost ?? 0) + (cacheWrite5mCost ?? 0) + (cacheWrite1hCost ?? 0) + return { + totalCostInCent, + inputCost, + outputCost, + reasoningCost, + cacheReadCost, + cacheWrite5mCost, + cacheWrite1hCost, + } + } + + function calculateOccuredCost(billingSource: BillingSource, costInfo: CostInfo) { + return billingSource === "balance" ? (costInfo.totalCostInCent / 100).toFixed(8) : "0" + } + + async function trackUsage( + billingSource: BillingSource, + authInfo: AuthInfo, + modelInfo: ModelInfo, + providerInfo: ProviderInfo, + usageInfo: UsageInfo, + costInfo: CostInfo, + ) { + const { inputTokens, outputTokens, reasoningTokens, cacheReadTokens, cacheWrite5mTokens, cacheWrite1hTokens } = + usageInfo + const { totalCostInCent, inputCost, outputCost, reasoningCost, cacheReadCost, cacheWrite5mCost, cacheWrite1hCost } = + costInfo logger.metric({ "tokens.input": inputTokens, @@ -677,7 +711,7 @@ export async function handler( if (billingSource === "anonymous") return authInfo = authInfo! - const cost = authInfo.provider?.credentials ? 0 : centsToMicroCents(totalCostInCent) + const cost = centsToMicroCents(totalCostInCent) await Database.use((db) => Promise.all([ db.insert(UsageTable).values({ @@ -772,16 +806,12 @@ export async function handler( return { costInMicroCents: cost } } - async function reload(authInfo: AuthInfo, costInfo: Awaited>) { - if (!authInfo) return - if (authInfo.isFree) return - if (authInfo.provider?.credentials) return - if (authInfo.subscription) return - - if (!costInfo) return + async function reload(billingSource: BillingSource, authInfo: AuthInfo, costInfo: CostInfo) { + if (billingSource !== "balance") return + authInfo = authInfo! const reloadTrigger = centsToMicroCents((authInfo.billing.reloadTrigger ?? Billing.RELOAD_TRIGGER) * 100) - if (authInfo.billing.balance - costInfo.costInMicroCents >= reloadTrigger) return + if (authInfo.billing.balance - costInfo.totalCostInCent >= reloadTrigger) return if (authInfo.billing.timeReloadLockedTill && authInfo.billing.timeReloadLockedTill > new Date()) return const lock = await Database.use((tx) => diff --git a/packages/console/app/src/routes/zen/util/provider/anthropic.ts b/packages/console/app/src/routes/zen/util/provider/anthropic.ts index 7081e980d..e2803459e 100644 --- a/packages/console/app/src/routes/zen/util/provider/anthropic.ts +++ b/packages/console/app/src/routes/zen/util/provider/anthropic.ts @@ -167,6 +167,7 @@ export const anthropicHelper: ProviderHelper = ({ reqModel, providerModel }) => } }, retrieve: () => usage, + buidlCostChunk: (cost: string) => `event: ping\ndata: ${JSON.stringify({ type: "ping", cost })}\n\n`, } }, normalizeUsage: (usage: Usage) => ({ diff --git a/packages/console/app/src/routes/zen/util/provider/google.ts b/packages/console/app/src/routes/zen/util/provider/google.ts index f6f7d6e19..ecf3b2d4d 100644 --- a/packages/console/app/src/routes/zen/util/provider/google.ts +++ b/packages/console/app/src/routes/zen/util/provider/google.ts @@ -56,6 +56,7 @@ export const googleHelper: ProviderHelper = ({ providerModel }) => ({ usage = json.usageMetadata }, retrieve: () => usage, + buidlCostChunk: (cost: string) => `data: ${JSON.stringify({ type: "ping", cost })}\n\n`, } }, normalizeUsage: (usage: Usage) => { diff --git a/packages/console/app/src/routes/zen/util/provider/openai-compatible.ts b/packages/console/app/src/routes/zen/util/provider/openai-compatible.ts index ce97a34d9..046bf8f0c 100644 --- a/packages/console/app/src/routes/zen/util/provider/openai-compatible.ts +++ b/packages/console/app/src/routes/zen/util/provider/openai-compatible.ts @@ -54,6 +54,7 @@ export const oaCompatHelper: ProviderHelper = () => ({ usage = json.usage }, retrieve: () => usage, + buidlCostChunk: (cost: string) => `data: ${JSON.stringify({ choices: [], cost })}\n\n`, } }, normalizeUsage: (usage: Usage) => { diff --git a/packages/console/app/src/routes/zen/util/provider/openai.ts b/packages/console/app/src/routes/zen/util/provider/openai.ts index f4d7699e9..db2dfa521 100644 --- a/packages/console/app/src/routes/zen/util/provider/openai.ts +++ b/packages/console/app/src/routes/zen/util/provider/openai.ts @@ -43,6 +43,7 @@ export const openaiHelper: ProviderHelper = () => ({ usage = json.response.usage }, retrieve: () => usage, + buidlCostChunk: (cost: string) => `event: ping\ndata: ${JSON.stringify({ type: "ping", cost })}\n\n`, } }, normalizeUsage: (usage: Usage) => { diff --git a/packages/console/app/src/routes/zen/util/provider/provider.ts b/packages/console/app/src/routes/zen/util/provider/provider.ts index bbf54f4f9..5f8b631cf 100644 --- a/packages/console/app/src/routes/zen/util/provider/provider.ts +++ b/packages/console/app/src/routes/zen/util/provider/provider.ts @@ -43,6 +43,7 @@ export type ProviderHelper = (input: { reqModel: string; providerModel: string } createUsageParser: () => { parse: (chunk: string) => void retrieve: () => any + buidlCostChunk: (cost: string) => string } normalizeUsage: (usage: any) => UsageInfo }