fix: ensure variants also work for completely custom models (#6481)
Co-authored-by: Daniel Smolsky <dannysmo@gmail.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import z from "zod"
|
||||
import fuzzysort from "fuzzysort"
|
||||
import { Config } from "../config/config"
|
||||
import { mapValues, mergeDeep, sortBy } from "remeda"
|
||||
import { mapValues, mergeDeep, omit, pickBy, sortBy } from "remeda"
|
||||
import { NoSuchModelError, type Provider as SDK } from "ai"
|
||||
import { Log } from "../util/log"
|
||||
import { BunProc } from "../bun"
|
||||
@@ -405,16 +405,6 @@ export namespace Provider {
|
||||
},
|
||||
}
|
||||
|
||||
export const Variant = z
|
||||
.object({
|
||||
disabled: z.boolean(),
|
||||
})
|
||||
.catchall(z.any())
|
||||
.meta({
|
||||
ref: "Variant",
|
||||
})
|
||||
export type Variant = z.infer<typeof Variant>
|
||||
|
||||
export const Model = z
|
||||
.object({
|
||||
id: z.string(),
|
||||
@@ -478,7 +468,7 @@ export namespace Provider {
|
||||
options: z.record(z.string(), z.any()),
|
||||
headers: z.record(z.string(), z.string()),
|
||||
release_date: z.string(),
|
||||
variants: z.record(z.string(), Variant).optional(),
|
||||
variants: z.record(z.string(), z.record(z.string(), z.any())).optional(),
|
||||
})
|
||||
.meta({
|
||||
ref: "Model",
|
||||
@@ -561,7 +551,7 @@ export namespace Provider {
|
||||
variants: {},
|
||||
}
|
||||
|
||||
m.variants = mapValues(ProviderTransform.variants(m), (v) => ({ disabled: false, ...v }))
|
||||
m.variants = mapValues(ProviderTransform.variants(m), (v) => v)
|
||||
|
||||
return m
|
||||
}
|
||||
@@ -697,7 +687,13 @@ export namespace Provider {
|
||||
headers: mergeDeep(existingModel?.headers ?? {}, model.headers ?? {}),
|
||||
family: model.family ?? existingModel?.family ?? "",
|
||||
release_date: model.release_date ?? existingModel?.release_date ?? "",
|
||||
variants: {},
|
||||
}
|
||||
const merged = mergeDeep(ProviderTransform.variants(parsedModel), model.variants ?? {})
|
||||
parsedModel.variants = mapValues(
|
||||
pickBy(merged, (v) => !v.disabled),
|
||||
(v) => omit(v, ["disabled"]),
|
||||
)
|
||||
parsed.models[modelID] = parsedModel
|
||||
}
|
||||
database[providerID] = parsed
|
||||
@@ -822,6 +818,16 @@ export namespace Provider {
|
||||
(configProvider?.whitelist && !configProvider.whitelist.includes(modelID))
|
||||
)
|
||||
delete provider.models[modelID]
|
||||
|
||||
// Filter out disabled variants from config
|
||||
const configVariants = configProvider?.models?.[modelID]?.variants
|
||||
if (configVariants && model.variants) {
|
||||
const merged = mergeDeep(model.variants, configVariants)
|
||||
model.variants = mapValues(
|
||||
pickBy(merged, (v) => !v.disabled),
|
||||
(v) => omit(v, ["disabled"]),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if (Object.keys(provider.models).length === 0) {
|
||||
|
||||
Reference in New Issue
Block a user