zen: return cost

This commit is contained in:
Frank
2026-02-12 00:57:08 -05:00
parent 03de51bd3c
commit d86f24b6b3
6 changed files with 79 additions and 44 deletions

View File

@@ -38,6 +38,7 @@ type RetryOptions = {
excludeProviders: string[]
retryCount: number
}
type BillingSource = "anonymous" | "free" | "byok" | "subscription" | "balance"
export async function handler(
input: APIEvent,
@@ -51,6 +52,7 @@ export async function handler(
type AuthInfo = Awaited<ReturnType<typeof authenticate>>
type ModelInfo = Awaited<ReturnType<typeof validateModel>>
type ProviderInfo = Awaited<ReturnType<typeof selectProvider>>
type CostInfo = ReturnType<typeof calculateCost>
const MAX_FAILOVER_RETRIES = 3
const MAX_429_RETRIES = 3
@@ -139,21 +141,22 @@ export async function handler(
"llm.error.code": res.status,
"llm.error.message": res.statusText,
})
}
// Try another provider => stop retrying if using fallback provider
if (
// ie. openai 404 error: Item with id 'msg_0ead8b004a3b165d0069436a6b6834819896da85b63b196a3f' not found.
res.status !== 404 &&
// ie. cannot change codex model providers mid-session
modelInfo.stickyProvider !== "strict" &&
modelInfo.fallbackProvider &&
providerInfo.id !== modelInfo.fallbackProvider
) {
return retriableRequest({
excludeProviders: [...retry.excludeProviders, providerInfo.id],
retryCount: retry.retryCount + 1,
})
}
// Try another provider => stop retrying if using fallback provider
if (
res.status !== 200 &&
// ie. openai 404 error: Item with id 'msg_0ead8b004a3b165d0069436a6b6834819896da85b63b196a3f' not found.
res.status !== 404 &&
// ie. cannot change codex model providers mid-session
modelInfo.stickyProvider !== "strict" &&
modelInfo.fallbackProvider &&
providerInfo.id !== modelInfo.fallbackProvider
) {
return retriableRequest({
excludeProviders: [...retry.excludeProviders, providerInfo.id],
retryCount: retry.retryCount + 1,
})
}
return { providerInfo, reqBody, res, startTimestamp }
@@ -183,18 +186,25 @@ export async function handler(
// Handle non-streaming response
if (!isStream) {
const responseConverter = createResponseConverter(providerInfo.format, opts.format)
const json = await res.json()
const body = JSON.stringify(responseConverter(json))
const usageInfo = providerInfo.normalizeUsage(json.usage)
const costInfo = calculateCost(modelInfo, usageInfo)
await trialLimiter?.track(usageInfo)
await rateLimiter?.track()
await trackUsage(billingSource, authInfo, modelInfo, providerInfo, usageInfo, costInfo)
await reload(billingSource, authInfo, costInfo)
const responseConverter = createResponseConverter(providerInfo.format, opts.format)
const body = JSON.stringify(
responseConverter({
...json,
cost: calculateOccuredCost(billingSource, costInfo),
}),
)
logger.metric({ response_length: body.length })
logger.debug("RESPONSE: " + body)
dataDumper?.provideResponse(body)
dataDumper?.flush()
const tokensInfo = providerInfo.normalizeUsage(json.usage)
await trialLimiter?.track(tokensInfo)
await rateLimiter?.track()
const costInfo = await trackUsage(authInfo, modelInfo, providerInfo, billingSource, tokensInfo)
await reload(authInfo, costInfo)
return new Response(body, {
status: resStatus,
statusText: res.statusText,
@@ -226,12 +236,16 @@ export async function handler(
dataDumper?.flush()
await rateLimiter?.track()
const usage = usageParser.retrieve()
let cost = "0"
if (usage) {
const tokensInfo = providerInfo.normalizeUsage(usage)
await trialLimiter?.track(tokensInfo)
const costInfo = await trackUsage(authInfo, modelInfo, providerInfo, billingSource, tokensInfo)
await reload(authInfo, costInfo)
const usageInfo = providerInfo.normalizeUsage(usage)
const costInfo = calculateCost(modelInfo, usageInfo)
await trialLimiter?.track(usageInfo)
await trackUsage(billingSource, authInfo, modelInfo, providerInfo, usageInfo, costInfo)
await reload(billingSource, authInfo, costInfo)
cost = calculateOccuredCost(billingSource, costInfo)
}
c.enqueue(encoder.encode(usageParser.buidlCostChunk(cost)))
c.close()
return
}
@@ -283,7 +297,6 @@ export async function handler(
return pump()
},
})
return new Response(stream, {
status: resStatus,
statusText: res.statusText,
@@ -498,9 +511,9 @@ export async function handler(
}
}
function validateBilling(authInfo: AuthInfo, modelInfo: ModelInfo) {
function validateBilling(authInfo: AuthInfo, modelInfo: ModelInfo): BillingSource {
if (!authInfo) return "anonymous"
if (authInfo.provider?.credentials) return "free"
if (authInfo.provider?.credentials) return "byok"
if (authInfo.isFree) return "free"
if (modelInfo.allowAnonymous) return "free"
@@ -613,13 +626,7 @@ export async function handler(
return res
}
async function trackUsage(
authInfo: AuthInfo,
modelInfo: ModelInfo,
providerInfo: ProviderInfo,
billingSource: ReturnType<typeof validateBilling>,
usageInfo: UsageInfo,
) {
function calculateCost(modelInfo: ModelInfo, usageInfo: UsageInfo) {
const { inputTokens, outputTokens, reasoningTokens, cacheReadTokens, cacheWrite5mTokens, cacheWrite1hTokens } =
usageInfo
@@ -657,6 +664,33 @@ export async function handler(
(cacheReadCost ?? 0) +
(cacheWrite5mCost ?? 0) +
(cacheWrite1hCost ?? 0)
return {
totalCostInCent,
inputCost,
outputCost,
reasoningCost,
cacheReadCost,
cacheWrite5mCost,
cacheWrite1hCost,
}
}
function calculateOccuredCost(billingSource: BillingSource, costInfo: CostInfo) {
return billingSource === "balance" ? (costInfo.totalCostInCent / 100).toFixed(8) : "0"
}
async function trackUsage(
billingSource: BillingSource,
authInfo: AuthInfo,
modelInfo: ModelInfo,
providerInfo: ProviderInfo,
usageInfo: UsageInfo,
costInfo: CostInfo,
) {
const { inputTokens, outputTokens, reasoningTokens, cacheReadTokens, cacheWrite5mTokens, cacheWrite1hTokens } =
usageInfo
const { totalCostInCent, inputCost, outputCost, reasoningCost, cacheReadCost, cacheWrite5mCost, cacheWrite1hCost } =
costInfo
logger.metric({
"tokens.input": inputTokens,
@@ -677,7 +711,7 @@ export async function handler(
if (billingSource === "anonymous") return
authInfo = authInfo!
const cost = authInfo.provider?.credentials ? 0 : centsToMicroCents(totalCostInCent)
const cost = centsToMicroCents(totalCostInCent)
await Database.use((db) =>
Promise.all([
db.insert(UsageTable).values({
@@ -772,16 +806,12 @@ export async function handler(
return { costInMicroCents: cost }
}
async function reload(authInfo: AuthInfo, costInfo: Awaited<ReturnType<typeof trackUsage>>) {
if (!authInfo) return
if (authInfo.isFree) return
if (authInfo.provider?.credentials) return
if (authInfo.subscription) return
if (!costInfo) return
async function reload(billingSource: BillingSource, authInfo: AuthInfo, costInfo: CostInfo) {
if (billingSource !== "balance") return
authInfo = authInfo!
const reloadTrigger = centsToMicroCents((authInfo.billing.reloadTrigger ?? Billing.RELOAD_TRIGGER) * 100)
if (authInfo.billing.balance - costInfo.costInMicroCents >= reloadTrigger) return
if (authInfo.billing.balance - costInfo.totalCostInCent >= reloadTrigger) return
if (authInfo.billing.timeReloadLockedTill && authInfo.billing.timeReloadLockedTill > new Date()) return
const lock = await Database.use((tx) =>

View File

@@ -167,6 +167,7 @@ export const anthropicHelper: ProviderHelper = ({ reqModel, providerModel }) =>
}
},
retrieve: () => usage,
buidlCostChunk: (cost: string) => `event: ping\ndata: ${JSON.stringify({ type: "ping", cost })}\n\n`,
}
},
normalizeUsage: (usage: Usage) => ({

View File

@@ -56,6 +56,7 @@ export const googleHelper: ProviderHelper = ({ providerModel }) => ({
usage = json.usageMetadata
},
retrieve: () => usage,
buidlCostChunk: (cost: string) => `data: ${JSON.stringify({ type: "ping", cost })}\n\n`,
}
},
normalizeUsage: (usage: Usage) => {

View File

@@ -54,6 +54,7 @@ export const oaCompatHelper: ProviderHelper = () => ({
usage = json.usage
},
retrieve: () => usage,
buidlCostChunk: (cost: string) => `data: ${JSON.stringify({ choices: [], cost })}\n\n`,
}
},
normalizeUsage: (usage: Usage) => {

View File

@@ -43,6 +43,7 @@ export const openaiHelper: ProviderHelper = () => ({
usage = json.response.usage
},
retrieve: () => usage,
buidlCostChunk: (cost: string) => `event: ping\ndata: ${JSON.stringify({ type: "ping", cost })}\n\n`,
}
},
normalizeUsage: (usage: Usage) => {

View File

@@ -43,6 +43,7 @@ export type ProviderHelper = (input: { reqModel: string; providerModel: string }
createUsageParser: () => {
parse: (chunk: string) => void
retrieve: () => any
buidlCostChunk: (cost: string) => string
}
normalizeUsage: (usage: any) => UsageInfo
}