Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions packages/__tests__/cost/modelCostFromRegistry.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -436,4 +436,81 @@ describe("modelCostBreakdownFromRegistry", () => {
}
});
});

describe("cost override", () => {
const usage: ModelUsage = { input: 1000, output: 500 };

it("overrides both input and output per-token cost", () => {
const breakdown = modelCostBreakdownFromRegistry({
modelUsage: usage,
providerModelId: "gpt-4o",
provider: "openai" as ModelProviderName,
costOverride: { inputCostPerToken: 0.00001, outputCostPerToken: 0.00002 },
});

expect(breakdown).not.toBeNull();
if (breakdown) {
expect(breakdown.inputCost).toBe(1000 * 0.00001);
expect(breakdown.outputCost).toBe(500 * 0.00002);
expect(breakdown.totalCost).toBe(1000 * 0.00001 + 500 * 0.00002);
}
});

it("applies a partial override per field, falling back to the registry rate", () => {
const inputOnly = modelCostBreakdownFromRegistry({
modelUsage: usage,
providerModelId: "gpt-4o",
provider: "openai" as ModelProviderName,
costOverride: { inputCostPerToken: 0.00001 },
});
expect(inputOnly).not.toBeNull();
if (inputOnly) {
// input overridden; output keeps the gpt-4o registry rate (0.01 / 1K)
expect(inputOnly.inputCost).toBe(1000 * 0.00001);
expect(inputOnly.outputCost).toBe(500 * 0.00001);
}

const outputOnly = modelCostBreakdownFromRegistry({
modelUsage: usage,
providerModelId: "gpt-4o",
provider: "openai" as ModelProviderName,
costOverride: { outputCostPerToken: 0.00002 },
});
expect(outputOnly).not.toBeNull();
if (outputOnly) {
// output overridden; input keeps the gpt-4o registry rate (0.0025 / 1K)
expect(outputOnly.inputCost).toBe(1000 * 0.0000025);
expect(outputOnly.outputCost).toBe(500 * 0.00002);
}
});

it("treats a zero override as free, not as unset", () => {
const breakdown = modelCostBreakdownFromRegistry({
modelUsage: usage,
providerModelId: "gpt-4o",
provider: "openai" as ModelProviderName,
costOverride: { inputCostPerToken: 0, outputCostPerToken: 0 },
});

expect(breakdown).not.toBeNull();
if (breakdown) {
expect(breakdown.inputCost).toBe(0);
expect(breakdown.outputCost).toBe(0);
expect(breakdown.totalCost).toBe(0);
}
});

it("uses registry pricing when no override is provided", () => {
const breakdown = modelCostBreakdownFromRegistry({
modelUsage: usage,
providerModelId: "gpt-4o",
provider: "openai" as ModelProviderName,
});

expect(breakdown).not.toBeNull();
if (breakdown) {
expect(breakdown.totalCost).toBe(0.0075);
}
});
});
});
6 changes: 4 additions & 2 deletions packages/cost/costCalc.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { costOfPrompt } from "./index";
import type { ModelUsage } from "./usage/types";
import type { ModelProviderName } from "./models/providers";
import { calculateModelCostBreakdown, CostBreakdown } from "./models/calculate-cost";
import { calculateModelCostBreakdown, CostBreakdown, CostOverride } from "./models/calculate-cost";

// since costs in clickhouse are multiplied by the multiplier
// divide to get real cost in USD in dollars
Expand Down Expand Up @@ -53,13 +53,15 @@ export function modelCostBreakdownFromRegistry(params: {
provider: ModelProviderName;
providerModelId: string;
requestCount?: number;
costOverride?: CostOverride;
}): CostBreakdown | null {
const breakdown = calculateModelCostBreakdown({
modelUsage: params.modelUsage,
providerModelId: params.providerModelId,
provider: params.provider,
requestCount: params.requestCount,
costOverride: params.costOverride,
});

return breakdown;
}
12 changes: 9 additions & 3 deletions packages/cost/models/calculate-cost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,19 @@ function getThresholdValueFunction(provider: ModelProviderName): (usage: ModelUs
}
}

export interface CostOverride {
inputCostPerToken?: number;
outputCostPerToken?: number;
}

export function calculateModelCostBreakdown(params: {
modelUsage: ModelUsage;
providerModelId: string;
provider: ModelProviderName;
requestCount?: number;
costOverride?: CostOverride;
}): CostBreakdown | null {
const { modelUsage, providerModelId, provider, requestCount = 1 } = params;
const { modelUsage, providerModelId, provider, requestCount = 1, costOverride } = params;

const configResult = registry.getModelProviderConfigByProviderModelId(
providerModelId,
Expand Down Expand Up @@ -205,7 +211,7 @@ export function calculateModelCostBreakdown(params: {
};

const inputPricing = getPricingTier(preprocessedPricing, getThresholdValue(modelUsage, "inputCost"));
breakdown.inputCost = modelUsage.input * inputPricing.input;
breakdown.inputCost = modelUsage.input * (costOverride?.inputCostPerToken ?? inputPricing.input);

if (modelUsage.cacheDetails) {
if (modelUsage.cacheDetails.cachedInput > 0) {
Expand All @@ -229,7 +235,7 @@ export function calculateModelCostBreakdown(params: {
}

const outputPricing = getPricingTier(preprocessedPricing, getThresholdValue(modelUsage, "outputCost"));
breakdown.outputCost = modelUsage.output * outputPricing.output;
breakdown.outputCost = modelUsage.output * (costOverride?.outputCostPerToken ?? outputPricing.output);

if (modelUsage.thinking) {
const thinkingRate = basePricing.thinking ?? basePricing.output;
Expand Down