diff --git a/packages/__tests__/cost/modelCostFromRegistry.test.ts b/packages/__tests__/cost/modelCostFromRegistry.test.ts index 73aa4f6602..4f2ab987ce 100644 --- a/packages/__tests__/cost/modelCostFromRegistry.test.ts +++ b/packages/__tests__/cost/modelCostFromRegistry.test.ts @@ -436,4 +436,81 @@ describe("modelCostBreakdownFromRegistry", () => { } }); }); + + describe("cost override", () => { + const usage: ModelUsage = { input: 1000, output: 500 }; + + it("overrides both input and output per-token cost", () => { + const breakdown = modelCostBreakdownFromRegistry({ + modelUsage: usage, + providerModelId: "gpt-4o", + provider: "openai" as ModelProviderName, + costOverride: { inputCostPerToken: 0.00001, outputCostPerToken: 0.00002 }, + }); + + expect(breakdown).not.toBeNull(); + if (breakdown) { + expect(breakdown.inputCost).toBe(1000 * 0.00001); + expect(breakdown.outputCost).toBe(500 * 0.00002); + expect(breakdown.totalCost).toBe(1000 * 0.00001 + 500 * 0.00002); + } + }); + + it("applies a partial override per field, falling back to the registry rate", () => { + const inputOnly = modelCostBreakdownFromRegistry({ + modelUsage: usage, + providerModelId: "gpt-4o", + provider: "openai" as ModelProviderName, + costOverride: { inputCostPerToken: 0.00001 }, + }); + expect(inputOnly).not.toBeNull(); + if (inputOnly) { + // input overridden; output keeps the gpt-4o registry rate (0.01 / 1K) + expect(inputOnly.inputCost).toBe(1000 * 0.00001); + expect(inputOnly.outputCost).toBe(500 * 0.00001); + } + + const outputOnly = modelCostBreakdownFromRegistry({ + modelUsage: usage, + providerModelId: "gpt-4o", + provider: "openai" as ModelProviderName, + costOverride: { outputCostPerToken: 0.00002 }, + }); + expect(outputOnly).not.toBeNull(); + if (outputOnly) { + // output overridden; input keeps the gpt-4o registry rate (0.0025 / 1K) + expect(outputOnly.inputCost).toBe(1000 * 0.0000025); + expect(outputOnly.outputCost).toBe(500 * 0.00002); + } + }); + + it("treats a zero override as free, not as unset", () => { + const breakdown = modelCostBreakdownFromRegistry({ + modelUsage: usage, + providerModelId: "gpt-4o", + provider: "openai" as ModelProviderName, + costOverride: { inputCostPerToken: 0, outputCostPerToken: 0 }, + }); + + expect(breakdown).not.toBeNull(); + if (breakdown) { + expect(breakdown.inputCost).toBe(0); + expect(breakdown.outputCost).toBe(0); + expect(breakdown.totalCost).toBe(0); + } + }); + + it("uses registry pricing when no override is provided", () => { + const breakdown = modelCostBreakdownFromRegistry({ + modelUsage: usage, + providerModelId: "gpt-4o", + provider: "openai" as ModelProviderName, + }); + + expect(breakdown).not.toBeNull(); + if (breakdown) { + expect(breakdown.totalCost).toBe(0.0075); + } + }); + }); }); \ No newline at end of file diff --git a/packages/cost/costCalc.ts b/packages/cost/costCalc.ts index 404e31423f..a6c88722f8 100644 --- a/packages/cost/costCalc.ts +++ b/packages/cost/costCalc.ts @@ -1,7 +1,7 @@ import { costOfPrompt } from "./index"; import type { ModelUsage } from "./usage/types"; import type { ModelProviderName } from "./models/providers"; -import { calculateModelCostBreakdown, CostBreakdown } from "./models/calculate-cost"; +import { calculateModelCostBreakdown, CostBreakdown, CostOverride } from "./models/calculate-cost"; // since costs in clickhouse are multiplied by the multiplier // divide to get real cost in USD in dollars @@ -53,13 +53,15 @@ export function modelCostBreakdownFromRegistry(params: { provider: ModelProviderName; providerModelId: string; requestCount?: number; + costOverride?: CostOverride; }): CostBreakdown | null { const breakdown = calculateModelCostBreakdown({ modelUsage: params.modelUsage, providerModelId: params.providerModelId, provider: params.provider, requestCount: params.requestCount, + costOverride: params.costOverride, }); - + return breakdown; } diff --git a/packages/cost/models/calculate-cost.ts b/packages/cost/models/calculate-cost.ts index 860396a924..7adcec37c9 100644 --- a/packages/cost/models/calculate-cost.ts +++ b/packages/cost/models/calculate-cost.ts @@ -166,13 +166,19 @@ function getThresholdValueFunction(provider: ModelProviderName): (usage: ModelUs } } +export interface CostOverride { + inputCostPerToken?: number; + outputCostPerToken?: number; +} + export function calculateModelCostBreakdown(params: { modelUsage: ModelUsage; providerModelId: string; provider: ModelProviderName; requestCount?: number; + costOverride?: CostOverride; }): CostBreakdown | null { - const { modelUsage, providerModelId, provider, requestCount = 1 } = params; + const { modelUsage, providerModelId, provider, requestCount = 1, costOverride } = params; const configResult = registry.getModelProviderConfigByProviderModelId( providerModelId, @@ -205,7 +211,7 @@ export function calculateModelCostBreakdown(params: { }; const inputPricing = getPricingTier(preprocessedPricing, getThresholdValue(modelUsage, "inputCost")); - breakdown.inputCost = modelUsage.input * inputPricing.input; + breakdown.inputCost = modelUsage.input * (costOverride?.inputCostPerToken ?? inputPricing.input); if (modelUsage.cacheDetails) { if (modelUsage.cacheDetails.cachedInput > 0) { @@ -229,7 +235,7 @@ export function calculateModelCostBreakdown(params: { } const outputPricing = getPricingTier(preprocessedPricing, getThresholdValue(modelUsage, "outputCost")); - breakdown.outputCost = modelUsage.output * outputPricing.output; + breakdown.outputCost = modelUsage.output * (costOverride?.outputCostPerToken ?? outputPricing.output); if (modelUsage.thinking) { const thinkingRate = basePricing.thinking ?? basePricing.output;