feat: retry parts (#3369)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import { streamText, type ModelMessage, LoadAPIKeyError } from "ai"
|
||||
import { streamText, type ModelMessage, LoadAPIKeyError, type StreamTextResult, type Tool as AITool } from "ai"
|
||||
import { Session } from "."
|
||||
import { Identifier } from "../id/id"
|
||||
import { Instance } from "../project/instance"
|
||||
@@ -14,8 +14,8 @@ import { Flag } from "../flag/flag"
|
||||
import { Token } from "../util/token"
|
||||
import { Log } from "../util/log"
|
||||
import { SessionLock } from "./lock"
|
||||
import { NamedError } from "../util/error"
|
||||
import { ProviderTransform } from "@/provider/transform"
|
||||
import { SessionRetry } from "./retry"
|
||||
|
||||
export namespace SessionCompaction {
|
||||
const log = Log.create({ service: "session.compaction" })
|
||||
@@ -41,6 +41,7 @@ export namespace SessionCompaction {
|
||||
|
||||
export const PRUNE_MINIMUM = 20_000
|
||||
export const PRUNE_PROTECT = 40_000
|
||||
const MAX_RETRIES = 10
|
||||
|
||||
// goes backwards through parts until there are 40_000 tokens worth of tool
|
||||
// calls. then erases output of previous tool calls. idea is to throw away old
|
||||
@@ -142,8 +143,10 @@ export namespace SessionCompaction {
|
||||
},
|
||||
})) as MessageV2.TextPart
|
||||
|
||||
const stream = streamText({
|
||||
maxRetries: 10,
|
||||
const doStream = () =>
|
||||
streamText({
|
||||
// set to 0, we handle loop
|
||||
maxRetries: 0,
|
||||
model: model.language,
|
||||
providerOptions: ProviderTransform.providerOptions(model.npm, model.providerID, model.info.options),
|
||||
abortSignal: signal,
|
||||
@@ -172,6 +175,12 @@ export namespace SessionCompaction {
|
||||
],
|
||||
})
|
||||
|
||||
// TODO: reduce duplication between compaction.ts & prompt.ts
|
||||
const process = async (
|
||||
stream: StreamTextResult<Record<string, AITool>, never>,
|
||||
retries: { count: number; max: number },
|
||||
) => {
|
||||
let shouldRetry = false
|
||||
try {
|
||||
for await (const value of stream.fullStream) {
|
||||
signal.throwIfAborted()
|
||||
@@ -212,42 +221,95 @@ export namespace SessionCompaction {
|
||||
log.error("compaction error", {
|
||||
error: e,
|
||||
})
|
||||
switch (true) {
|
||||
case e instanceof DOMException && e.name === "AbortError":
|
||||
msg.error = new MessageV2.AbortedError(
|
||||
{ message: e.message },
|
||||
{
|
||||
cause: e,
|
||||
const error = MessageV2.fromError(e, { providerID: input.providerID })
|
||||
if (retries.count < retries.max && MessageV2.APIError.isInstance(error) && error.data.isRetryable) {
|
||||
shouldRetry = true
|
||||
await Session.updatePart({
|
||||
id: Identifier.ascending("part"),
|
||||
messageID: msg.id,
|
||||
sessionID: msg.sessionID,
|
||||
type: "retry",
|
||||
attempt: retries.count + 1,
|
||||
time: {
|
||||
created: Date.now(),
|
||||
},
|
||||
).toObject()
|
||||
break
|
||||
case MessageV2.OutputLengthError.isInstance(e):
|
||||
msg.error = e
|
||||
break
|
||||
case LoadAPIKeyError.isInstance(e):
|
||||
msg.error = new MessageV2.AuthError(
|
||||
{
|
||||
providerID: model.providerID,
|
||||
message: e.message,
|
||||
},
|
||||
{ cause: e },
|
||||
).toObject()
|
||||
break
|
||||
case e instanceof Error:
|
||||
msg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
|
||||
break
|
||||
default:
|
||||
msg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
|
||||
}
|
||||
error,
|
||||
})
|
||||
} else {
|
||||
msg.error = error
|
||||
Bus.publish(Session.Event.Error, {
|
||||
sessionID: input.sessionID,
|
||||
sessionID: msg.sessionID,
|
||||
error: msg.error,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const parts = await Session.getParts(msg.id)
|
||||
return {
|
||||
info: msg,
|
||||
parts,
|
||||
shouldRetry,
|
||||
}
|
||||
}
|
||||
|
||||
let stream = doStream()
|
||||
let result = await process(stream, {
|
||||
count: 0,
|
||||
max: MAX_RETRIES,
|
||||
})
|
||||
if (result.shouldRetry) {
|
||||
for (let retry = 1; retry < MAX_RETRIES; retry++) {
|
||||
const lastRetryPart = result.parts.findLast((p) => p.type === "retry")
|
||||
|
||||
if (lastRetryPart) {
|
||||
const delayMs = SessionRetry.getRetryDelayInMs(lastRetryPart.error, retry)
|
||||
|
||||
log.info("retrying with backoff", {
|
||||
attempt: retry,
|
||||
delayMs,
|
||||
})
|
||||
|
||||
const stop = await SessionRetry.sleep(delayMs, signal)
|
||||
.then(() => false)
|
||||
.catch((error) => {
|
||||
if (error instanceof DOMException && error.name === "AbortError") {
|
||||
const err = new MessageV2.AbortedError(
|
||||
{ message: error.message },
|
||||
{
|
||||
cause: error,
|
||||
},
|
||||
).toObject()
|
||||
result.info.error = err
|
||||
Bus.publish(Session.Event.Error, {
|
||||
sessionID: result.info.sessionID,
|
||||
error: result.info.error,
|
||||
})
|
||||
return true
|
||||
}
|
||||
throw error
|
||||
})
|
||||
|
||||
if (stop) break
|
||||
}
|
||||
|
||||
stream = doStream()
|
||||
result = await process(stream, {
|
||||
count: retry,
|
||||
max: MAX_RETRIES,
|
||||
})
|
||||
if (!result.shouldRetry) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msg.time.completed = Date.now()
|
||||
|
||||
if (!msg.error || MessageV2.AbortedError.isInstance(msg.error)) {
|
||||
if (
|
||||
!msg.error ||
|
||||
(MessageV2.AbortedError.isInstance(msg.error) &&
|
||||
result.parts.some((part) => part.type === "text" && part.text.length > 0))
|
||||
) {
|
||||
msg.summary = true
|
||||
Bus.publish(Event.Compacted, {
|
||||
sessionID: input.sessionID,
|
||||
@@ -257,7 +319,7 @@ export namespace SessionCompaction {
|
||||
|
||||
return {
|
||||
info: msg,
|
||||
parts: [part],
|
||||
parts: result.parts,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ import z from "zod/v4"
|
||||
import { Bus } from "../bus"
|
||||
import { NamedError } from "../util/error"
|
||||
import { Message } from "./message"
|
||||
import { convertToModelMessages, type ModelMessage, type UIMessage } from "ai"
|
||||
import { APICallError, convertToModelMessages, LoadAPIKeyError, type ModelMessage, type UIMessage } from "ai"
|
||||
import { Identifier } from "../id/id"
|
||||
import { LSP } from "../lsp"
|
||||
import { Snapshot } from "@/snapshot"
|
||||
@@ -18,6 +18,17 @@ export namespace MessageV2 {
|
||||
message: z.string(),
|
||||
}),
|
||||
)
|
||||
export const APIError = NamedError.create(
|
||||
"APIError",
|
||||
z.object({
|
||||
message: z.string(),
|
||||
statusCode: z.number().optional(),
|
||||
isRetryable: z.boolean(),
|
||||
responseHeaders: z.record(z.string(), z.string()).optional(),
|
||||
responseBody: z.string().optional(),
|
||||
}),
|
||||
)
|
||||
export type APIError = z.infer<typeof APIError.Schema>
|
||||
|
||||
const PartBase = z.object({
|
||||
id: z.string(),
|
||||
@@ -130,6 +141,18 @@ export namespace MessageV2 {
|
||||
})
|
||||
export type AgentPart = z.infer<typeof AgentPart>
|
||||
|
||||
export const RetryPart = PartBase.extend({
|
||||
type: z.literal("retry"),
|
||||
attempt: z.number(),
|
||||
error: APIError.Schema,
|
||||
time: z.object({
|
||||
created: z.number(),
|
||||
}),
|
||||
}).meta({
|
||||
ref: "RetryPart",
|
||||
})
|
||||
export type RetryPart = z.infer<typeof RetryPart>
|
||||
|
||||
export const StepStartPart = PartBase.extend({
|
||||
type: z.literal("step-start"),
|
||||
snapshot: z.string().optional(),
|
||||
@@ -265,6 +288,7 @@ export namespace MessageV2 {
|
||||
SnapshotPart,
|
||||
PatchPart,
|
||||
AgentPart,
|
||||
RetryPart,
|
||||
])
|
||||
.meta({
|
||||
ref: "Part",
|
||||
@@ -283,6 +307,7 @@ export namespace MessageV2 {
|
||||
NamedError.Unknown.Schema,
|
||||
OutputLengthError.Schema,
|
||||
AbortedError.Schema,
|
||||
APIError.Schema,
|
||||
])
|
||||
.optional(),
|
||||
system: z.string().array(),
|
||||
@@ -610,4 +635,41 @@ export namespace MessageV2 {
|
||||
if (i === -1) return msgs.slice()
|
||||
return msgs.slice(i)
|
||||
}
|
||||
|
||||
export function fromError(e: unknown, ctx: { providerID: string }) {
|
||||
switch (true) {
|
||||
case e instanceof DOMException && e.name === "AbortError":
|
||||
return new MessageV2.AbortedError(
|
||||
{ message: e.message },
|
||||
{
|
||||
cause: e,
|
||||
},
|
||||
).toObject()
|
||||
case MessageV2.OutputLengthError.isInstance(e):
|
||||
return e
|
||||
case LoadAPIKeyError.isInstance(e):
|
||||
return new MessageV2.AuthError(
|
||||
{
|
||||
providerID: ctx.providerID,
|
||||
message: e.message,
|
||||
},
|
||||
{ cause: e },
|
||||
).toObject()
|
||||
case APICallError.isInstance(e):
|
||||
return new MessageV2.APIError(
|
||||
{
|
||||
message: e.message,
|
||||
statusCode: e.statusCode,
|
||||
isRetryable: e.isRetryable,
|
||||
responseHeaders: e.responseHeaders,
|
||||
responseBody: e.responseBody,
|
||||
},
|
||||
{ cause: e },
|
||||
).toObject()
|
||||
case e instanceof Error:
|
||||
return new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
|
||||
default:
|
||||
return new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ import {
|
||||
tool,
|
||||
wrapLanguageModel,
|
||||
type StreamTextResult,
|
||||
LoadAPIKeyError,
|
||||
stepCountIs,
|
||||
jsonSchema,
|
||||
} from "ai"
|
||||
@@ -28,6 +27,7 @@ import { Bus } from "../bus"
|
||||
import { ProviderTransform } from "../provider/transform"
|
||||
import { SystemPrompt } from "./system"
|
||||
import { Plugin } from "../plugin"
|
||||
import { SessionRetry } from "./retry"
|
||||
|
||||
import PROMPT_PLAN from "../session/prompt/plan.txt"
|
||||
import BUILD_SWITCH from "../session/prompt/build-switch.txt"
|
||||
@@ -44,7 +44,6 @@ import { TaskTool } from "../tool/task"
|
||||
import { FileTime } from "../file/time"
|
||||
import { Permission } from "../permission"
|
||||
import { Snapshot } from "../snapshot"
|
||||
import { NamedError } from "../util/error"
|
||||
import { ulid } from "ulid"
|
||||
import { spawn } from "child_process"
|
||||
import { Command } from "../command"
|
||||
@@ -55,6 +54,7 @@ import { MessageSummary } from "./summary"
|
||||
export namespace SessionPrompt {
|
||||
const log = Log.create({ service: "session.prompt" })
|
||||
export const OUTPUT_TOKEN_MAX = 32_000
|
||||
const MAX_RETRIES = 10
|
||||
|
||||
export const Event = {
|
||||
Idle: Bus.event(
|
||||
@@ -240,7 +240,8 @@ export namespace SessionPrompt {
|
||||
await using _ = defer(async () => {
|
||||
await processor.end()
|
||||
})
|
||||
const stream = streamText({
|
||||
const doStream = () =>
|
||||
streamText({
|
||||
onError(error) {
|
||||
log.error("stream error", {
|
||||
error,
|
||||
@@ -274,7 +275,8 @@ export namespace SessionPrompt {
|
||||
"x-opencode-request": userMsg.info.id,
|
||||
}
|
||||
: undefined,
|
||||
maxRetries: 10,
|
||||
// set to 0, we handle loop
|
||||
maxRetries: 0,
|
||||
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
|
||||
maxOutputTokens: ProviderTransform.maxOutputTokens(
|
||||
model.providerID,
|
||||
@@ -326,7 +328,57 @@ export namespace SessionPrompt {
|
||||
],
|
||||
}),
|
||||
})
|
||||
const result = await processor.process(stream)
|
||||
|
||||
let stream = doStream()
|
||||
let result = await processor.process(stream, {
|
||||
count: 0,
|
||||
max: MAX_RETRIES,
|
||||
})
|
||||
if (result.shouldRetry) {
|
||||
for (let retry = 1; retry < MAX_RETRIES; retry++) {
|
||||
const lastRetryPart = result.parts.findLast((p) => p.type === "retry")
|
||||
|
||||
if (lastRetryPart) {
|
||||
const delayMs = SessionRetry.getRetryDelayInMs(lastRetryPart.error, retry)
|
||||
|
||||
log.info("retrying with backoff", {
|
||||
attempt: retry,
|
||||
delayMs,
|
||||
})
|
||||
|
||||
const stop = await SessionRetry.sleep(delayMs, abort.signal)
|
||||
.then(() => false)
|
||||
.catch((error) => {
|
||||
if (error instanceof DOMException && error.name === "AbortError") {
|
||||
const err = new MessageV2.AbortedError(
|
||||
{ message: error.message },
|
||||
{
|
||||
cause: error,
|
||||
},
|
||||
).toObject()
|
||||
result.info.error = err
|
||||
Bus.publish(Session.Event.Error, {
|
||||
sessionID: result.info.sessionID,
|
||||
error: result.info.error,
|
||||
})
|
||||
return true
|
||||
}
|
||||
throw error
|
||||
})
|
||||
|
||||
if (stop) break
|
||||
}
|
||||
|
||||
stream = doStream()
|
||||
result = await processor.process(stream, {
|
||||
count: retry,
|
||||
max: MAX_RETRIES,
|
||||
})
|
||||
if (!result.shouldRetry) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
await processor.end()
|
||||
|
||||
const queued = state().queued.get(input.sessionID) ?? []
|
||||
@@ -959,9 +1011,10 @@ export namespace SessionPrompt {
|
||||
partFromToolCall(toolCallID: string) {
|
||||
return toolcalls[toolCallID]
|
||||
},
|
||||
async process(stream: StreamTextResult<Record<string, AITool>, never>) {
|
||||
async process(stream: StreamTextResult<Record<string, AITool>, never>, retries: { count: number; max: number }) {
|
||||
log.info("process")
|
||||
if (!assistantMsg) throw new Error("call next() first before processing")
|
||||
let shouldRetry = false
|
||||
try {
|
||||
let currentText: MessageV2.TextPart | undefined
|
||||
let reasoningMap: Record<string, MessageV2.ReasoningPart> = {}
|
||||
@@ -1314,38 +1367,28 @@ export namespace SessionPrompt {
|
||||
log.error("process", {
|
||||
error: e,
|
||||
})
|
||||
switch (true) {
|
||||
case e instanceof DOMException && e.name === "AbortError":
|
||||
assistantMsg.error = new MessageV2.AbortedError(
|
||||
{ message: e.message },
|
||||
{
|
||||
cause: e,
|
||||
const error = MessageV2.fromError(e, { providerID: input.providerID })
|
||||
if (retries.count < retries.max && MessageV2.APIError.isInstance(error) && error.data.isRetryable) {
|
||||
shouldRetry = true
|
||||
await Session.updatePart({
|
||||
id: Identifier.ascending("part"),
|
||||
messageID: assistantMsg.id,
|
||||
sessionID: assistantMsg.sessionID,
|
||||
type: "retry",
|
||||
attempt: retries.count + 1,
|
||||
time: {
|
||||
created: Date.now(),
|
||||
},
|
||||
).toObject()
|
||||
break
|
||||
case MessageV2.OutputLengthError.isInstance(e):
|
||||
assistantMsg.error = e
|
||||
break
|
||||
case LoadAPIKeyError.isInstance(e):
|
||||
assistantMsg.error = new MessageV2.AuthError(
|
||||
{
|
||||
providerID: input.providerID,
|
||||
message: e.message,
|
||||
},
|
||||
{ cause: e },
|
||||
).toObject()
|
||||
break
|
||||
case e instanceof Error:
|
||||
assistantMsg.error = new NamedError.Unknown({ message: e.toString() }, { cause: e }).toObject()
|
||||
break
|
||||
default:
|
||||
assistantMsg.error = new NamedError.Unknown({ message: JSON.stringify(e) }, { cause: e })
|
||||
}
|
||||
error,
|
||||
})
|
||||
} else {
|
||||
assistantMsg.error = error
|
||||
Bus.publish(Session.Event.Error, {
|
||||
sessionID: assistantMsg.sessionID,
|
||||
error: assistantMsg.error,
|
||||
})
|
||||
}
|
||||
}
|
||||
const p = await Session.getParts(assistantMsg.id)
|
||||
for (const part of p) {
|
||||
if (part.type === "tool" && part.state.status !== "completed" && part.state.status !== "error") {
|
||||
@@ -1363,9 +1406,11 @@ export namespace SessionPrompt {
|
||||
})
|
||||
}
|
||||
}
|
||||
if (!shouldRetry) {
|
||||
assistantMsg.time.completed = Date.now()
|
||||
}
|
||||
await Session.updateMessage(assistantMsg)
|
||||
return { info: assistantMsg, parts: p, blocked }
|
||||
return { info: assistantMsg, parts: p, blocked, shouldRetry }
|
||||
},
|
||||
}
|
||||
return result
|
||||
|
||||
57
packages/opencode/src/session/retry.ts
Normal file
57
packages/opencode/src/session/retry.ts
Normal file
@@ -0,0 +1,57 @@
|
||||
import { MessageV2 } from "./message-v2"
|
||||
|
||||
export namespace SessionRetry {
|
||||
export const RETRY_INITIAL_DELAY = 2000
|
||||
export const RETRY_BACKOFF_FACTOR = 2
|
||||
|
||||
export async function sleep(ms: number, signal: AbortSignal): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const timeout = setTimeout(resolve, ms)
|
||||
signal.addEventListener(
|
||||
"abort",
|
||||
() => {
|
||||
clearTimeout(timeout)
|
||||
reject(new DOMException("Aborted", "AbortError"))
|
||||
},
|
||||
{ once: true },
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
export function getRetryDelayInMs(error: MessageV2.APIError, attempt: number): number {
|
||||
const base = RETRY_INITIAL_DELAY * Math.pow(RETRY_BACKOFF_FACTOR, attempt - 1)
|
||||
const headers = error.data.responseHeaders
|
||||
if (!headers) return base
|
||||
|
||||
const retryAfterMs = headers["retry-after-ms"]
|
||||
if (retryAfterMs) {
|
||||
const parsed = Number.parseFloat(retryAfterMs)
|
||||
const normalized = normalizeDelay({ base, candidate: parsed })
|
||||
if (normalized != null) return normalized
|
||||
}
|
||||
|
||||
const retryAfter = headers["retry-after"]
|
||||
if (!retryAfter) return base
|
||||
|
||||
const seconds = Number.parseFloat(retryAfter)
|
||||
if (!Number.isNaN(seconds)) {
|
||||
const normalized = normalizeDelay({ base, candidate: seconds * 1000 })
|
||||
if (normalized != null) return normalized
|
||||
return base
|
||||
}
|
||||
|
||||
const dateMs = Date.parse(retryAfter) - Date.now()
|
||||
const normalized = normalizeDelay({ base, candidate: dateMs })
|
||||
if (normalized != null) return normalized
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
function normalizeDelay(input: { base: number; candidate: number }): number | undefined {
|
||||
if (Number.isNaN(input.candidate)) return undefined
|
||||
if (input.candidate < 0) return undefined
|
||||
if (input.candidate < 60_000) return input.candidate
|
||||
if (input.candidate < input.base) return input.candidate
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
47
packages/opencode/test/session/retry.test.ts
Normal file
47
packages/opencode/test/session/retry.test.ts
Normal file
@@ -0,0 +1,47 @@
|
||||
import { describe, expect, test } from "bun:test"
|
||||
import { SessionRetry } from "../../src/session/retry"
|
||||
import { MessageV2 } from "../../src/session/message-v2"
|
||||
|
||||
function apiError(headers?: Record<string, string>): MessageV2.APIError {
|
||||
return new MessageV2.APIError({
|
||||
message: "boom",
|
||||
isRetryable: true,
|
||||
responseHeaders: headers,
|
||||
}).toObject() as MessageV2.APIError
|
||||
}
|
||||
|
||||
describe("session.retry.getRetryDelayInMs", () => {
|
||||
test("doubles delay on each attempt when headers missing", () => {
|
||||
const error = apiError()
|
||||
const delays = Array.from({ length: 7 }, (_, index) => SessionRetry.getRetryDelayInMs(error, index + 1))
|
||||
expect(delays).toStrictEqual([2000, 4000, 8000, 16000, 32000, 64000, 128000])
|
||||
})
|
||||
|
||||
test("prefers retry-after-ms when shorter than exponential", () => {
|
||||
const error = apiError({ "retry-after-ms": "1500" })
|
||||
expect(SessionRetry.getRetryDelayInMs(error, 4)).toBe(1500)
|
||||
})
|
||||
|
||||
test("uses retry-after seconds when reasonable", () => {
|
||||
const error = apiError({ "retry-after": "30" })
|
||||
expect(SessionRetry.getRetryDelayInMs(error, 3)).toBe(30000)
|
||||
})
|
||||
|
||||
test("falls back to exponential when server delay is long", () => {
|
||||
const error = apiError({ "retry-after": "120" })
|
||||
expect(SessionRetry.getRetryDelayInMs(error, 2)).toBe(4000)
|
||||
})
|
||||
|
||||
test("accepts http-date retry-after values", () => {
|
||||
const date = new Date(Date.now() + 20000).toUTCString()
|
||||
const error = apiError({ "retry-after": date })
|
||||
const delay = SessionRetry.getRetryDelayInMs(error, 1)
|
||||
expect(delay).toBeGreaterThanOrEqual(19000)
|
||||
expect(delay).toBeLessThanOrEqual(20000)
|
||||
})
|
||||
|
||||
test("ignores invalid retry hints", () => {
|
||||
const error = apiError({ "retry-after": "not-a-number" })
|
||||
expect(SessionRetry.getRetryDelayInMs(error, 1)).toBe(2000)
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user