diff --git a/.changeset/gemini-provider-tools.md b/.changeset/gemini-provider-tools.md new file mode 100644 index 000000000..3b1093432 --- /dev/null +++ b/.changeset/gemini-provider-tools.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents-plugin-google': minor +--- + +Add Gemini provider tools for Google Search, Google Maps, URL context, File Search, code execution, and Vertex RAG retrieval, and serialize them from `ToolContext` for Google LLM and realtime sessions. diff --git a/.changeset/list-syntax-toolcontext.md b/.changeset/list-syntax-toolcontext.md index f50fa8216..5e854a813 100644 --- a/.changeset/list-syntax-toolcontext.md +++ b/.changeset/list-syntax-toolcontext.md @@ -2,4 +2,15 @@ '@livekit/agents': minor --- -**BREAKING**: `Agent({ tools })` and `agent.updateTools()` now accept a flat list `(FunctionTool | ProviderDefinedTool | Toolset)[]` instead of a `Record` map, and `llm.tool({ ... })` requires a `name` field. `ToolContext` is now a Python-parity class with `functionTools` / `providerTools` / `toolsets` accessors, plus `flatten()`, `hasTool(name)`, `getFunctionTool(name)`, `updateTools()`, `copy()`, and `equals()`. To match the Python reference, registering two **different** function-tool instances under the same `name` now throws `duplicate function name: ` instead of silently overriding the earlier entry; passing the **same instance** twice is a no-op. `agent.toolCtx` returns a defensive copy so callers can no longer mutate the agent's internal state. `LLM.chat({ toolCtx })` accepts either a `ToolContext` instance or a raw `(FunctionTool | ProviderDefinedTool | Toolset)[]` array (`ToolCtxInput`) and normalizes it internally, so callers don't have to construct a `ToolContext` themselves. +**BREAKING**: `Agent({ tools })` and `agent.updateTools()` now accept a flat list `(FunctionTool | ProviderTool | Toolset)[]` instead of a `Record` map, and `llm.tool({ ... })` requires a `name` field. `ToolContext` is now a Python-parity class with `functionTools` / `providerTools` / `toolsets` accessors, plus `flatten()`, `hasTool(id)`, `getFunctionTool(id)`, `updateTools()`, `copy()`, and `equals()`. To match the Python reference, registering two **different** function-tool instances under the same `name` now throws `duplicate function name: ` instead of silently overriding the earlier entry; passing the **same instance** twice is a no-op. `agent.toolCtx` returns a defensive copy so callers can no longer mutate the agent's internal state. `LLM.chat({ toolCtx })` accepts either a `ToolContext` instance or a raw `(FunctionTool | ProviderTool | Toolset)[]` array (`ToolCtxInput`) and normalizes it internally, so callers don't have to construct a `ToolContext` themselves. + +Tools also expose an `id: string` field on the base `Tool` interface (parity with Python's `Tool.id` property): for `FunctionTool` it mirrors `name`, for `ProviderTool` it is the provider tool id. `ToolContext` keys and equality now use `tool.id` consistently. + +**BREAKING**: Provider tools are now modeled to match Python's `ProviderTool`: + +- `ProviderDefinedTool` is renamed to `ProviderTool`, and `isProviderDefinedTool` is renamed to `isProviderTool`. +- `ProviderTool` is now an **abstract class** (Python parity). Plugins must subclass it (`class WebSearch extends ProviderTool { ... }`) to attach provider-specific fields and serializers; bare `new ProviderTool(...)` is rejected at compile time. +- The `tool({ id })` factory overload is removed; `tool({ ... })` only creates function tools now. Construct provider tools by instantiating a `ProviderTool` subclass. +- The `ToolType` literal for provider tools is renamed from `'provider-defined'` to `'provider'`. + +`Toolset` now carries a `TOOLSET_SYMBOL` marker and is detected via a new `isToolset()` guard (consistent with `isFunctionTool` / `isProviderTool`). Existing `instanceof Toolset` checks still work, but symbol-based detection is preferred for cross-realm safety. diff --git a/.changeset/openai-provider-tools.md b/.changeset/openai-provider-tools.md new file mode 100644 index 000000000..8e793a935 --- /dev/null +++ b/.changeset/openai-provider-tools.md @@ -0,0 +1,5 @@ +--- +'@livekit/agents-plugin-openai': minor +--- + +Add OpenAI Responses provider tools for web search, file search, and code interpreter. diff --git a/agents/src/llm/chat_context.test.ts b/agents/src/llm/chat_context.test.ts index 350dcd1cd..101a9fe46 100644 --- a/agents/src/llm/chat_context.test.ts +++ b/agents/src/llm/chat_context.test.ts @@ -19,7 +19,7 @@ import { isInstructions, renderInstructions, } from './chat_context.js'; -import { ToolContext, tool } from './tool_context.js'; +import { ProviderTool, ToolContext, tool } from './tool_context.js'; initializeLogger({ pretty: false, level: 'error' }); @@ -1498,7 +1498,8 @@ describe('ChatContext.copy with toolCtx filter', () => { }); it('keeps provider-tool calls when the ToolContext holds a matching provider tool id', () => { - const provider = tool({ id: 'code_runner', config: {} }); + class CodeRunner extends ProviderTool {} + const provider = new CodeRunner({ id: 'code_runner' }); const ctx = new ChatContext([ FunctionCall.create({ callId: 'p1', name: 'code_runner', args: '{}' }), FunctionCall.create({ callId: 'p2', name: 'other', args: '{}' }), diff --git a/agents/src/llm/index.ts b/agents/src/llm/index.ts index bbb77a2fa..0c7a42679 100644 --- a/agents/src/llm/index.ts +++ b/agents/src/llm/index.ts @@ -4,8 +4,10 @@ export { handoff, isFunctionTool, - isProviderDefinedTool, + isProviderTool, isTool, + isToolset, + ProviderTool, tool, ToolContext, ToolError, @@ -14,7 +16,6 @@ export { toToolContext, type AgentHandoff, type FunctionTool, - type ProviderDefinedTool, type Tool, type ToolCalledEvent, type ToolChoice, diff --git a/agents/src/llm/llm.ts b/agents/src/llm/llm.ts index 0c05bbb2d..828e134e9 100644 --- a/agents/src/llm/llm.ts +++ b/agents/src/llm/llm.ts @@ -98,7 +98,7 @@ export abstract class LLM extends (EventEmitter as new () => TypedEmitter { @@ -448,8 +448,20 @@ describe('tool() name requirement', () => { }); expect(t.name).toBe('doStuff'); }); + + it('exposes id mirroring the function tool name', () => { + const t = tool({ + name: 'doStuff', + description: 'd', + execute: async () => 'x', + }); + expect(t.id).toBe('doStuff'); + expect(t.id).toBe(t.name); + }); }); +class TestProviderTool extends ProviderTool {} + describe('ToolContext', () => { const makeFn = (name: string) => tool({ @@ -497,7 +509,7 @@ describe('ToolContext', () => { it('separates provider tools from function tools', () => { const fnA = makeFn('a'); - const provider = tool({ id: 'code', config: { language: 'python' } }); + const provider = new TestProviderTool({ id: 'code' }); const ctx = new ToolContext([fnA, provider]); expect(ctx.functionTools).toEqual({ a: fnA }); @@ -537,7 +549,7 @@ describe('ToolContext', () => { it('equals() is reflexive', () => { const a = makeFn('a'); - const provider = tool({ id: 'code', config: { language: 'python' } }); + const provider = new TestProviderTool({ id: 'code' }); const ctx = new ToolContext([a, provider]); expect(ctx.equals(ctx)).toBe(true); }); @@ -547,22 +559,22 @@ describe('ToolContext', () => { // that hold the same provider-tool identities in different order are still equal so // realtime-session / preemptive-generation reuse fast paths are not invalidated. const a = makeFn('a'); - const p1 = tool({ id: 'code', config: { language: 'python' } }); - const p2 = tool({ id: 'browser', config: {} }); + const p1 = new TestProviderTool({ id: 'code' }); + const p2 = new TestProviderTool({ id: 'browser' }); expect(new ToolContext([a, p1, p2]).equals(new ToolContext([a, p2, p1]))).toBe(true); }); it('equals() supports contexts with only provider tools', () => { - const p1 = tool({ id: 'code', config: {} }); - const p2 = tool({ id: 'browser', config: {} }); + const p1 = new TestProviderTool({ id: 'code' }); + const p2 = new TestProviderTool({ id: 'browser' }); expect(new ToolContext([p1, p2]).equals(new ToolContext([p1, p2]))).toBe(true); - const p3 = tool({ id: 'code', config: {} }); // distinct identity, same id + const p3 = new TestProviderTool({ id: 'code' }); // distinct identity, same id expect(new ToolContext([p1]).equals(new ToolContext([p3]))).toBe(false); }); it('hasTool() matches function tools by name and provider tools by id', () => { const a = makeFn('a'); - const provider = tool({ id: 'code_runner', config: {} }); + const provider = new TestProviderTool({ id: 'code_runner' }); const ctx = new ToolContext([a, provider]); expect(ctx.hasTool('a')).toBe(true); @@ -574,7 +586,7 @@ describe('ToolContext', () => { // Matches Python's `flatten()`: list(self._fnc_tools_map.values()) + self._provider_tools. const a = makeFn('a'); const b = makeFn('b'); - const provider = tool({ id: 'code', config: {} }); + const provider = new TestProviderTool({ id: 'code' }); const ctx = new ToolContext([b, provider, a]); expect(ctx.flatten()).toEqual([b, a, provider]); diff --git a/agents/src/llm/tool_context.ts b/agents/src/llm/tool_context.ts index f2153a487..aa2ec3cdc 100644 --- a/agents/src/llm/tool_context.ts +++ b/agents/src/llm/tool_context.ts @@ -12,7 +12,8 @@ import { isZodObjectSchema, isZodSchema } from './zod-utils.js'; const TOOL_SYMBOL = Symbol('tool'); const FUNCTION_TOOL_SYMBOL = Symbol('function_tool'); -const PROVIDER_DEFINED_TOOL_SYMBOL = Symbol('provider_defined_tool'); +const PROVIDER_TOOL_SYMBOL = Symbol('provider_tool'); +const TOOLSET_SYMBOL = Symbol('toolset'); const TOOL_ERROR_SYMBOL = Symbol('tool_error'); const HANDOFF_SYMBOL = Symbol('handoff'); @@ -57,7 +58,7 @@ export type InferToolInput = T extends { _output: infer O } ? O : any; // eslint-disable-line @typescript-eslint/no-explicit-any -- Fallback type for JSON Schema objects without type inference -export type ToolType = 'function' | 'provider-defined'; +export type ToolType = 'function' | 'provider'; export type ToolChoice = | 'auto' @@ -136,28 +137,32 @@ export type ToolExecuteFunction< export interface Tool { /** * The type of the tool. - * @internal Either user-defined core tool or provider-defined tool. + * @internal Either user-defined function tool or provider-side tool. */ type: ToolType; + /** + * Stable identifier used to key the tool inside a `ToolContext`. For function tools this + * mirrors `name`; for provider tools this is the provider tool id. + */ + id: string; + [TOOL_SYMBOL]: true; } -// TODO(AJS-112): support provider-defined tools -export interface ProviderDefinedTool extends Tool { - type: 'provider-defined'; +// TODO(AJS-112): support provider tools +export abstract class ProviderTool implements Tool { + readonly type = 'provider' as const; - /** - * The ID of the tool. - */ - id: string; + readonly id: string; - /** - * The configuration of the tool. - */ - config: Record; + readonly [TOOL_SYMBOL] = true as const; - [PROVIDER_DEFINED_TOOL_SYMBOL]: true; + readonly [PROVIDER_TOOL_SYMBOL] = true as const; + + constructor({ id }: { id: string }) { + this.id = id; + } } export interface FunctionTool< @@ -169,7 +174,7 @@ export interface FunctionTool< /** * The name of the tool. Used to identify it inside a `ToolContext` and exposed to the LLM - * as the function name to call. + * as the function name to call. Also surfaced as the inherited `Tool.id`. */ name: string; @@ -213,8 +218,11 @@ export interface ToolCompletedEvent { */ export class Toolset { readonly #id: string; + readonly #tools: Tool[]; + readonly [TOOLSET_SYMBOL] = true as const; + constructor({ id, tools }: { id: string; tools: readonly Tool[] }) { this.#id = id; this.#tools = [...tools]; @@ -257,13 +265,13 @@ export function toToolContext( // eslint-disable-next-line @typescript-eslint/no-explicit-any -- ToolContext entries accept any function-tool parameter/result types export type ToolContextEntry = // eslint-disable-next-line @typescript-eslint/no-explicit-any - FunctionTool | ProviderDefinedTool | Toolset; + FunctionTool | ProviderTool | Toolset; export class ToolContext { private _tools: ToolContextEntry[] = []; // eslint-disable-next-line @typescript-eslint/no-explicit-any -- ToolContext stores generic function tools private _functionToolsMap: Map> = new Map(); - private _providerTools: ProviderDefinedTool[] = []; + private _providerTools: ProviderTool[] = []; private _toolsets: Toolset[] = []; constructor(tools: readonly ToolContextEntry[] = []) { @@ -281,7 +289,7 @@ export class ToolContext { } /** A copy of all provider tools in the tool context, including those in tool sets. */ - get providerTools(): ProviderDefinedTool[] { + get providerTools(): ProviderTool[] { return this._providerTools; } @@ -303,15 +311,15 @@ export class ToolContext { } // eslint-disable-next-line @typescript-eslint/no-explicit-any -- Generic registry over any parameter/result types - getFunctionTool(name: string): FunctionTool | undefined { - return this._functionToolsMap.get(name); + getFunctionTool(id: string): FunctionTool | undefined { + return this._functionToolsMap.get(id); } - hasTool(name: string): boolean { - if (this._functionToolsMap.has(name)) { + hasTool(id: string): boolean { + if (this._functionToolsMap.has(id)) { return true; } - return this._providerTools.some((tool) => tool.id === name); + return this._providerTools.some((tool) => tool.id === id); } updateTools(tools: readonly ToolContextEntry[]): void { @@ -322,7 +330,7 @@ export class ToolContext { // eslint-disable-next-line @typescript-eslint/no-explicit-any -- accepts any tool shape const addTool = (tool: any): void => { - if (tool instanceof Toolset) { + if (isToolset(tool)) { for (const inner of tool.tools) { addTool(inner); } @@ -330,20 +338,20 @@ export class ToolContext { return; } - if (isProviderDefinedTool(tool)) { + if (isProviderTool(tool)) { this._providerTools.push(tool); return; } if (isFunctionTool(tool)) { - const existing = this._functionToolsMap.get(tool.name); + const existing = this._functionToolsMap.get(tool.id); if (existing !== undefined) { if (existing !== tool) { - throw new Error(`duplicate function name: ${tool.name}`); + throw new Error(`duplicate function name: ${tool.id}`); } return; // same instance, skip } - this._functionToolsMap.set(tool.name, tool); + this._functionToolsMap.set(tool.id, tool); return; } @@ -363,14 +371,17 @@ export class ToolContext { if (this._functionToolsMap.size !== other._functionToolsMap.size) { return false; } - for (const [name, tool] of this._functionToolsMap) { - if (other._functionToolsMap.get(name) !== tool) { + + for (const [id, tool] of this._functionToolsMap) { + if (other._functionToolsMap.get(id) !== tool) { return false; } } + if (this._providerTools.length !== other._providerTools.length) { return false; } + // Provider tools compare as identity sets to match Python's `set(id(t) for t in ...)` // semantics — order is not significant. const otherProviderIds = new Set(other._providerTools); @@ -379,9 +390,11 @@ export class ToolContext { return false; } } + if (this._toolsets.length !== other._toolsets.length) { return false; } + const otherToolsets = new Set(other._toolsets); for (const ts of this._toolsets) { if (!otherToolsets.has(ts)) { @@ -445,63 +458,36 @@ export function tool({ flags?: number; }): FunctionTool, UserData, Result>; -/** - * Create a provider-defined tool. - * - * @param id - The ID of the tool. - * @param config - The configuration of the tool. - */ -export function tool({ - id, - config, -}: { - id: string; - config: Record; -}): ProviderDefinedTool; - // eslint-disable-next-line @typescript-eslint/no-explicit-any export function tool(tool: any): any { - if (tool.execute !== undefined) { - if (typeof tool.name !== 'string' || tool.name.length === 0) { - throw new Error('tool({ name, ... }) requires a non-empty name'); - } - - // Default parameters to z.object({}) if not provided - const parameters = tool.parameters ?? z.object({}); - - // if parameters is a Zod schema, ensure it's an object schema - if (isZodSchema(parameters) && !isZodObjectSchema(parameters)) { - throw new Error('Tool parameters must be a Zod object schema (z.object(...))'); - } + if (typeof tool.name !== 'string' || tool.name.length === 0) { + throw new Error('tool({ name, ... }) requires a non-empty name'); + } - // Ensure parameters is either a Zod schema or a plain object (JSON schema) - if (!isZodSchema(parameters) && !(typeof parameters === 'object')) { - throw new Error('Tool parameters must be a Zod object schema or a raw JSON schema'); - } + // Default parameters to z.object({}) if not provided + const parameters = tool.parameters ?? z.object({}); - return { - type: 'function', - name: tool.name, - description: tool.description, - parameters, - execute: tool.execute, - flags: tool.flags ?? ToolFlag.NONE, - [TOOL_SYMBOL]: true, - [FUNCTION_TOOL_SYMBOL]: true, - }; + // if parameters is a Zod schema, ensure it's an object schema + if (isZodSchema(parameters) && !isZodObjectSchema(parameters)) { + throw new Error('Tool parameters must be a Zod object schema (z.object(...))'); } - if (tool.config !== undefined && tool.id !== undefined) { - return { - type: 'provider-defined', - id: tool.id, - config: tool.config, - [TOOL_SYMBOL]: true, - [PROVIDER_DEFINED_TOOL_SYMBOL]: true, - }; + // Ensure parameters is either a Zod schema or a plain object (JSON schema) + if (!isZodSchema(parameters) && !(typeof parameters === 'object')) { + throw new Error('Tool parameters must be a Zod object schema or a raw JSON schema'); } - throw new Error('Invalid tool'); + return { + type: 'function', + id: tool.name, + name: tool.name, + description: tool.description, + parameters, + execute: tool.execute, + flags: tool.flags ?? ToolFlag.NONE, + [TOOL_SYMBOL]: true, + [FUNCTION_TOOL_SYMBOL]: true, + }; } // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -517,10 +503,15 @@ export function isFunctionTool(tool: any): tool is FunctionTool { } // eslint-disable-next-line @typescript-eslint/no-explicit-any -export function isProviderDefinedTool(tool: any): tool is ProviderDefinedTool { +export function isProviderTool(tool: any): tool is ProviderTool { const isTool = tool && tool[TOOL_SYMBOL] === true; - const isProviderDefinedTool = tool[PROVIDER_DEFINED_TOOL_SYMBOL] === true; - return isTool && isProviderDefinedTool; + const isProviderTool = tool[PROVIDER_TOOL_SYMBOL] === true; + return isTool && isProviderTool; +} + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export function isToolset(value: any): value is Toolset { + return value && value[TOOLSET_SYMBOL] === true; } // eslint-disable-next-line @typescript-eslint/no-explicit-any diff --git a/agents/src/llm/tool_context.type.test.ts b/agents/src/llm/tool_context.type.test.ts index 187f95e7b..5d33124ad 100644 --- a/agents/src/llm/tool_context.type.test.ts +++ b/agents/src/llm/tool_context.type.test.ts @@ -3,7 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 import { describe, expect, expectTypeOf, it } from 'vitest'; import { z } from 'zod'; -import { type FunctionTool, type ProviderDefinedTool, type ToolOptions, tool } from './index.js'; +import { type FunctionTool, ProviderTool, type ToolOptions, tool } from './index.js'; describe('tool type inference', () => { it('should infer argument type from zod schema', () => { @@ -17,15 +17,15 @@ describe('tool type inference', () => { expectTypeOf(toolType).toEqualTypeOf>(); }); - it('should infer provider defined tool type', () => { - const toolType = tool({ - id: 'code-interpreter', - config: { - language: 'python', - }, - }); + it('rejects direct instantiation of the abstract ProviderTool base', () => { + // @ts-expect-error - ProviderTool is abstract; plugins must subclass it. + new ProviderTool({ id: 'code-interpreter' }); - expectTypeOf(toolType).toEqualTypeOf(); + class CodeInterpreter extends ProviderTool {} + const providerTool = new CodeInterpreter({ id: 'code-interpreter' }); + expectTypeOf(providerTool).toMatchTypeOf(); + expect(providerTool.id).toBe('code-interpreter'); + expect(providerTool.type).toBe('provider'); }); it('should infer run context type', () => { @@ -45,7 +45,6 @@ describe('tool type inference', () => { it('should not accept primitive zod schemas', () => { expect(() => { - // @ts-expect-error - Testing that non-object schemas are rejected tool({ name: 'test', description: 'test', @@ -57,7 +56,6 @@ describe('tool type inference', () => { it('should not accept array schemas', () => { expect(() => { - // @ts-expect-error - Testing that array schemas are rejected tool({ name: 'test', description: 'test', @@ -69,7 +67,6 @@ describe('tool type inference', () => { it('should not accept union schemas', () => { expect(() => { - // @ts-expect-error - Testing that union schemas are rejected tool({ name: 'test', description: 'test', diff --git a/agents/src/voice/agent_activity.ts b/agents/src/voice/agent_activity.ts index cf831350a..793809a3e 100644 --- a/agents/src/voice/agent_activity.ts +++ b/agents/src/voice/agent_activity.ts @@ -42,6 +42,7 @@ import { ToolFlag, Toolset, isFunctionTool, + isToolset, } from '../llm/index.js'; import type { LLMError } from '../llm/llm.js'; import { isSameToolChoice } from '../llm/tool_context.js'; @@ -1751,7 +1752,7 @@ export class AgentActivity implements RecognitionHooks { this.agent.toolCtx.tools.flatMap((t): ToolContextEntry[] => { const keepFn = (fn: Tool): boolean => !isFunctionTool(fn) || !(fn.flags & ToolFlag.IGNORE_ON_ENTER); - if (t instanceof Toolset) { + if (isToolset(t)) { return t.tools.filter(keepFn) as ToolContextEntry[]; } return keepFn(t) ? [t] : []; diff --git a/plugins/google/src/beta/realtime/realtime_api.ts b/plugins/google/src/beta/realtime/realtime_api.ts index 66bc6a7f9..1d8365584 100644 --- a/plugins/google/src/beta/realtime/realtime_api.ts +++ b/plugins/google/src/beta/realtime/realtime_api.ts @@ -33,7 +33,7 @@ import { import { Mutex } from '@livekit/mutex'; import { AudioFrame, AudioResampler, type VideoFrame } from '@livekit/rtc-node'; import { type LLMTools } from '../../tools.js'; -import { toFunctionDeclarations } from '../../utils.js'; +import { toToolsConfig } from '../../utils.js'; import type * as api_proto from './api_proto.js'; import type { LiveAPIModels, Voice } from './api_proto.js'; @@ -70,13 +70,6 @@ export interface InputTranscription { transcript: string; } -/** - * Helper function to check if two sets are equal - */ -function setsEqual(a: Set, b: Set): boolean { - return a.size === b.size && [...a].every((x) => b.has(x)); -} - /** * Internal realtime options for Google Realtime API */ @@ -455,7 +448,6 @@ export class RealtimeSession extends llm.RealtimeSession { private _chatCtx = llm.ChatContext.empty(); private options: RealtimeOptions; - private geminiDeclarations: types.FunctionDeclaration[] = []; private messageChannel = new Queue(); private inputResampler?: AudioResampler; private inputResamplerInputRate?: number; @@ -764,15 +756,12 @@ export class RealtimeSession extends llm.RealtimeSession { } async updateTools(tools: llm.ToolContext): Promise { - const newDeclarations = toFunctionDeclarations(tools); - const currentToolNames = new Set(this.geminiDeclarations.map((f) => f.name)); - const newToolNames = new Set(newDeclarations.map((f) => f.name)); - - if (!setsEqual(currentToolNames, newToolNames)) { - this.geminiDeclarations = newDeclarations; - this._tools = tools; - this.markRestartNeeded(); + if (this._tools.equals(tools)) { + return; } + + this._tools = tools; + this.markRestartNeeded(); } get chatCtx(): llm.ChatContext { @@ -1424,21 +1413,11 @@ export class RealtimeSession extends llm.RealtimeSession { }, languageCode: opts.language, }, - tools: - this.geminiDeclarations.length > 0 || this.options.geminiTools - ? [ - { - functionDeclarations: - this.options.toolBehavior !== undefined - ? this.geminiDeclarations.map((d) => ({ - ...d, - behavior: this.options.toolBehavior, - })) - : this.geminiDeclarations, - ...this.options.geminiTools, - }, - ] - : undefined, + tools: toToolsConfig({ + toolCtx: this._tools, + geminiTools: this.options.geminiTools, + toolBehavior: this.options.toolBehavior, + }), inputAudioTranscription: opts.inputAudioTranscription, outputAudioTranscription: opts.outputAudioTranscription, sessionResumption: this.sessionResumptionHandle diff --git a/plugins/google/src/index.ts b/plugins/google/src/index.ts index fbafc1d66..326ca6270 100644 --- a/plugins/google/src/index.ts +++ b/plugins/google/src/index.ts @@ -6,6 +6,7 @@ import { Plugin } from '@livekit/agents'; export * as beta from './beta/index.js'; export { LLM, LLMStream, type LLMOptions } from './llm.js'; export * from './models.js'; +export * from './tools.js'; class GooglePlugin extends Plugin { constructor() { diff --git a/plugins/google/src/llm.ts b/plugins/google/src/llm.ts index e452b70d2..4958ccc6c 100644 --- a/plugins/google/src/llm.ts +++ b/plugins/google/src/llm.ts @@ -13,7 +13,7 @@ import { } from '@livekit/agents'; import type { ChatModels } from './models.js'; import type { LLMTools } from './tools.js'; -import { toFunctionDeclarations } from './utils.js'; +import { toToolsConfig } from './utils.js'; interface GoogleFormatData { systemMessages: string[] | null; @@ -355,11 +355,11 @@ export class LLMStream extends llm.LLMStream { parts: turn.parts as types.Part[], })); - const functionDeclarations = this.toolCtx ? toFunctionDeclarations(this.toolCtx) : undefined; - const tools = - functionDeclarations && functionDeclarations.length > 0 - ? [{ functionDeclarations }] - : undefined; + const tools = toToolsConfig({ + toolCtx: this.toolCtx, + geminiTools: this.#geminiTools, + onlySingleType: true, + }); let systemInstruction: types.Content | undefined = undefined; if (extraData.systemMessages && extraData.systemMessages.length > 0) { diff --git a/plugins/google/src/tools.ts b/plugins/google/src/tools.ts index 90cd9cc7c..b864e98a6 100644 --- a/plugins/google/src/tools.ts +++ b/plugins/google/src/tools.ts @@ -1,6 +1,100 @@ // SPDX-FileCopyrightText: 2025 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import type { Tool } from '@google/genai'; +import type * as types from '@google/genai'; +import { llm } from '@livekit/agents'; -export type LLMTools = Omit; +export type LLMTools = Omit; + +export abstract class GeminiTool extends llm.ProviderTool { + abstract toToolConfig(): types.Tool; +} + +export class GoogleSearch extends GeminiTool { + constructor(public readonly options: types.GoogleSearch = {}) { + super({ id: 'gemini_google_search' }); + } + + toToolConfig(): types.Tool { + return { googleSearch: this.options }; + } +} + +export class GoogleMaps extends GeminiTool { + constructor(public readonly options: types.GoogleMaps = {}) { + super({ id: 'gemini_google_maps' }); + } + + toToolConfig(): types.Tool { + return { googleMaps: this.options }; + } +} + +export class URLContext extends GeminiTool { + constructor() { + super({ id: 'gemini_url_context' }); + } + + toToolConfig(): types.Tool { + return { urlContext: {} }; + } +} + +export interface FileSearchOptions extends types.FileSearch { + fileSearchStoreNames: string[]; +} + +export class FileSearch extends GeminiTool { + constructor(public readonly options: FileSearchOptions) { + super({ id: 'gemini_file_search' }); + } + + toToolConfig(): types.Tool { + return { fileSearch: this.options }; + } +} + +export class ToolCodeExecution extends GeminiTool { + constructor() { + super({ id: 'gemini_code_execution' }); + } + + toToolConfig(): types.Tool { + return { codeExecution: {} }; + } +} + +export interface VertexRAGRetrievalOptions { + ragResources: string[]; + similarityTopK?: number; + vectorDistanceThreshold?: number; +} + +export class VertexRAGRetrieval extends GeminiTool { + readonly ragResources: string[]; + readonly similarityTopK: number; + readonly vectorDistanceThreshold?: number; + + constructor({ + ragResources, + similarityTopK = 3, + vectorDistanceThreshold, + }: VertexRAGRetrievalOptions) { + super({ id: 'gemini_vertex_rag_retrieval' }); + this.ragResources = ragResources; + this.similarityTopK = similarityTopK; + this.vectorDistanceThreshold = vectorDistanceThreshold; + } + + toToolConfig(): types.Tool { + return { + retrieval: { + vertexRagStore: { + ragResources: this.ragResources.map((ragCorpus) => ({ ragCorpus })), + similarityTopK: this.similarityTopK, + vectorDistanceThreshold: this.vectorDistanceThreshold, + }, + }, + }; + } +} diff --git a/plugins/google/src/utils.ts b/plugins/google/src/utils.ts index 64a52a6c1..30b5fcc89 100644 --- a/plugins/google/src/utils.ts +++ b/plugins/google/src/utils.ts @@ -1,9 +1,11 @@ // SPDX-FileCopyrightText: 2025 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 +import type * as types from '@google/genai'; import type { FunctionDeclaration, Schema } from '@google/genai'; import { llm } from '@livekit/agents'; import type { JSONSchema7 } from 'json-schema'; +import { GeminiTool, type LLMTools } from './tools.js'; /** * JSON Schema v7 @@ -140,7 +142,7 @@ export function toFunctionDeclarations(toolCtx: llm.ToolContext): FunctionDeclar const functionDeclarations: FunctionDeclaration[] = []; for (const tool of toolCtx.flatten()) { - // TODO: support provider-defined tools in the Gemini schema. + // TODO: support provider tools in the Gemini schema. if (!llm.isFunctionTool(tool)) continue; const { name, description, parameters } = tool; const jsonSchema = llm.toJsonSchema(parameters, false); @@ -157,3 +159,57 @@ export function toFunctionDeclarations(toolCtx: llm.ToolContext): FunctionDeclar return functionDeclarations; } + +export function toToolsConfig({ + toolCtx, + geminiTools, + toolBehavior, + onlySingleType = false, +}: { + toolCtx?: llm.ToolContext; + geminiTools?: LLMTools; + toolBehavior?: types.Behavior; + onlySingleType?: boolean; +}): types.Tool[] | undefined { + const tools: types.Tool[] = []; + const providerTools: types.Tool[] = []; + + if (toolCtx) { + const functionDeclarations = toFunctionDeclarations(toolCtx); + if (functionDeclarations.length > 0) { + tools.push({ + functionDeclarations: + toolBehavior !== undefined + ? functionDeclarations.map((declaration) => ({ + ...declaration, + behavior: toolBehavior, + })) + : functionDeclarations, + }); + } + } + + if (geminiTools !== undefined) { + providerTools.push(geminiTools); + } + + if (toolCtx) { + for (const tool of toolCtx.providerTools) { + if (tool instanceof GeminiTool) { + providerTools.push(tool.toToolConfig()); + } + } + } + + if (tools.length > 0 && providerTools.length > 0) { + throw new Error('Gemini does not support mixing function tools and provider tools'); + } + + if (onlySingleType && tools.length > 0) { + return tools; + } + + tools.push(...providerTools); + + return tools.length > 0 ? tools : undefined; +} diff --git a/plugins/mistralai/src/llm.ts b/plugins/mistralai/src/llm.ts index f80bc8bcc..901325f0a 100644 --- a/plugins/mistralai/src/llm.ts +++ b/plugins/mistralai/src/llm.ts @@ -213,7 +213,7 @@ export class LLMStream extends llm.LLMStream { const toolsList: any[] = []; if (this.toolCtx) { for (const t of this.toolCtx.flatten()) { - // TODO: support provider-defined tools in the Mistral schema. + // TODO: support provider tools in the Mistral schema. if (!llm.isFunctionTool(t)) continue; toolsList.push({ type: 'function' as const, diff --git a/plugins/openai/src/index.ts b/plugins/openai/src/index.ts index ccffdcb3f..6a5d9cb7c 100644 --- a/plugins/openai/src/index.ts +++ b/plugins/openai/src/index.ts @@ -5,6 +5,7 @@ import { Plugin } from '@livekit/agents'; export { LLM, LLMStream, type LLMOptions } from './llm.js'; export * from './models.js'; +export * from './tools.js'; export * as realtime from './realtime/index.js'; export * as responses from './responses/index.js'; export { STT, type STTOptions } from './stt.js'; diff --git a/plugins/openai/src/realtime/realtime_model.ts b/plugins/openai/src/realtime/realtime_model.ts index 94f1e2988..49a0cb506 100644 --- a/plugins/openai/src/realtime/realtime_model.ts +++ b/plugins/openai/src/realtime/realtime_model.ts @@ -712,8 +712,9 @@ export class RealtimeSession extends llm.RealtimeSession { const oaiTools: api_proto.Tool[] = []; for (const t of _tools.flatten()) { - // TODO: support provider-defined tools in the Realtime session-update schema. + // TODO: support provider tools in the Realtime session-update schema. if (!llm.isFunctionTool(t)) continue; + try { const parameters = llm.toJsonSchema( t.parameters, diff --git a/plugins/openai/src/responses/llm.ts b/plugins/openai/src/responses/llm.ts index 20363a05f..4a1dc9d92 100644 --- a/plugins/openai/src/responses/llm.ts +++ b/plugins/openai/src/responses/llm.ts @@ -13,6 +13,7 @@ import { } from '@livekit/agents'; import OpenAI from 'openai'; import type { ChatModels } from '../models.js'; +import { toResponsesTools } from '../tool_utils.js'; import { WSLLM } from '../ws/llm.js'; export interface LLMOptions { @@ -186,27 +187,8 @@ class ResponsesHttpLLMStream extends llm.LLMStream { 'openai.responses', )) as OpenAI.Responses.ResponseInputItem[]; - // TODO: support provider-defined tools in the Responses schema. const tools = this.toolCtx - ? this.toolCtx - .flatten() - .filter(llm.isFunctionTool) - .map((t) => { - const oaiParams = { - type: 'function' as const, - name: t.name, - description: t.description, - parameters: llm.toJsonSchema( - t.parameters, - true, - this.strictToolSchema, - ) as unknown as OpenAI.Responses.FunctionTool['parameters'], - } as OpenAI.Responses.FunctionTool; - if (this.strictToolSchema) { - oaiParams.strict = true; - } - return oaiParams; - }) + ? toResponsesTools(this.toolCtx, this.strictToolSchema) : undefined; const requestOptions: Record = { ...this.modelOptions }; diff --git a/plugins/openai/src/tool_utils.test.ts b/plugins/openai/src/tool_utils.test.ts new file mode 100644 index 000000000..ce1922f52 --- /dev/null +++ b/plugins/openai/src/tool_utils.test.ts @@ -0,0 +1,84 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { llm } from '@livekit/agents'; +import { describe, expect, it } from 'vitest'; +import { z } from 'zod'; +import { toResponsesTools } from './tool_utils.js'; +import { CodeInterpreter, FileSearch, WebSearch } from './tools.js'; + +describe('toResponsesTools', () => { + it('serializes function tools', () => { + const fn = llm.tool({ + name: 'lookup_weather', + description: 'Look up weather', + parameters: z.object({ city: z.string() }), + execute: async () => 'sunny', + }); + + expect(toResponsesTools(new llm.ToolContext([fn]), true)).toEqual([ + { + type: 'function', + name: 'lookup_weather', + description: 'Look up weather', + parameters: { + $schema: 'http://json-schema.org/draft-07/schema#', + type: 'object', + properties: { city: { type: 'string' } }, + required: ['city'], + additionalProperties: false, + }, + strict: true, + }, + ]); + }); + + it('serializes OpenAI provider tools', () => { + const tools = toResponsesTools( + new llm.ToolContext([ + new WebSearch({ + filters: { allowed_domains: ['docs.livekit.io'] }, + searchContextSize: 'low', + userLocation: { type: 'approximate', country: 'US' }, + }), + new FileSearch({ + vectorStoreIds: ['vs_123'], + maxNumResults: 3, + rankingOptions: { ranker: 'auto' }, + }), + new CodeInterpreter({ container: { type: 'auto', file_ids: ['file_123'] } }), + ]), + false, + ); + + expect(tools).toEqual([ + { + type: 'web_search', + search_context_size: 'low', + filters: { allowed_domains: ['docs.livekit.io'] }, + user_location: { type: 'approximate', country: 'US' }, + }, + { + type: 'file_search', + vector_store_ids: ['vs_123'], + max_num_results: 3, + ranking_options: { ranker: 'auto' }, + }, + { type: 'code_interpreter', container: { type: 'auto', file_ids: ['file_123'] } }, + ]); + }); + + it('omits the code interpreter container when unset', () => { + expect(toResponsesTools(new llm.ToolContext([new CodeInterpreter()]), false)).toEqual([ + { type: 'code_interpreter' }, + ]); + }); + + it('ignores non-OpenAI provider tools', () => { + class OtherProviderTool extends llm.ProviderTool {} + + expect( + toResponsesTools(new llm.ToolContext([new OtherProviderTool({ id: 'other' })]), false), + ).toBeUndefined(); + }); +}); diff --git a/plugins/openai/src/tool_utils.ts b/plugins/openai/src/tool_utils.ts new file mode 100644 index 000000000..1e2e6c709 --- /dev/null +++ b/plugins/openai/src/tool_utils.ts @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { llm } from '@livekit/agents'; +import type OpenAI from 'openai'; +import { OpenAITool } from './tools.js'; + +export function toResponsesTools( + toolCtx: llm.ToolContext, + strictToolSchema: boolean, +): OpenAI.Responses.Tool[] | undefined { + const tools = toolCtx + .flatten() + .map((tool) => { + if (llm.isFunctionTool(tool)) { + const oaiParams = { + type: 'function' as const, + name: tool.name, + description: tool.description, + parameters: llm.toJsonSchema( + tool.parameters, + true, + strictToolSchema, + ) as unknown as OpenAI.Responses.FunctionTool['parameters'], + } as OpenAI.Responses.FunctionTool; + + if (strictToolSchema) { + oaiParams.strict = true; + } + + return oaiParams; + } + + if (tool instanceof OpenAITool) { + return tool.toToolConfig() as unknown as OpenAI.Responses.Tool; + } + + return undefined; + }) + .filter((tool): tool is OpenAI.Responses.Tool => tool !== undefined); + + return tools.length > 0 ? tools : undefined; +} diff --git a/plugins/openai/src/tools.ts b/plugins/openai/src/tools.ts new file mode 100644 index 000000000..1ce779376 --- /dev/null +++ b/plugins/openai/src/tools.ts @@ -0,0 +1,166 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { llm } from '@livekit/agents'; +import type OpenAI from 'openai'; + +/** Base class for OpenAI Responses API provider tools. */ +export abstract class OpenAITool extends llm.ProviderTool { + /** Convert this provider tool to the OpenAI Responses API tool configuration. */ + abstract toToolConfig(): Record; +} + +/** + * High-level guidance for the amount of context window space to use for web search. + * OpenAI defaults this to `medium`. + */ +export type WebSearchContextSize = 'low' | 'medium' | 'high'; + +/** Options for the OpenAI web search tool. */ +export interface WebSearchOptions { + /** + * Filters for the search, such as allowed domains. If not provided, all domains are allowed. + */ + filters?: OpenAI.Responses.WebSearchTool['filters']; + + /** + * Amount of context window space to use for the search. Defaults to `medium`. + */ + searchContextSize?: WebSearchContextSize | null; + + /** Approximate location of the user, such as city, region, country, or timezone. */ + userLocation?: OpenAI.Responses.WebSearchTool['user_location']; +} + +/** + * Search the Internet for sources related to the prompt. + * + * @see https://platform.openai.com/docs/guides/tools-web-search + */ +export class WebSearch extends OpenAITool { + /** Filters for the search, such as allowed domains. */ + readonly filters: OpenAI.Responses.WebSearchTool['filters'] | undefined; + + /** Amount of context window space to use for the search. */ + readonly searchContextSize: WebSearchContextSize | null; + + /** Approximate location of the user. */ + readonly userLocation: OpenAI.Responses.WebSearchTool['user_location'] | undefined; + + constructor({ filters, searchContextSize = 'medium', userLocation }: WebSearchOptions = {}) { + super({ id: 'openai_web_search' }); + this.filters = filters; + this.searchContextSize = searchContextSize; + this.userLocation = userLocation; + } + + toToolConfig(): Record { + const result: Record = { + type: 'web_search', + search_context_size: this.searchContextSize, + }; + if (this.userLocation !== undefined) { + result.user_location = this.userLocation; + } + if (this.filters !== undefined) { + result.filters = this.filters; + } + return result; + } +} + +/** Options for the OpenAI file search tool. */ +export interface FileSearchOptions { + /** IDs of the vector stores to search. */ + vectorStoreIds?: string[]; + + /** Filter to apply to file search results. */ + filters?: OpenAI.Responses.FileSearchTool['filters']; + + /** Maximum number of results to return. This should be between 1 and 50 inclusive. */ + maxNumResults?: number; + + /** Ranking options for search, including ranker and score threshold. */ + rankingOptions?: OpenAI.Responses.FileSearchTool.RankingOptions; +} + +/** + * Search for relevant content from uploaded files. + * + * @see https://platform.openai.com/docs/guides/tools-file-search + */ +export class FileSearch extends OpenAITool { + /** IDs of the vector stores to search. */ + readonly vectorStoreIds: string[]; + + /** Filter to apply to file search results. */ + readonly filters: OpenAI.Responses.FileSearchTool['filters'] | undefined; + + /** Maximum number of results to return. */ + readonly maxNumResults: number | undefined; + + /** Ranking options for search. */ + readonly rankingOptions: OpenAI.Responses.FileSearchTool.RankingOptions | undefined; + + constructor({ + vectorStoreIds = [], + filters, + maxNumResults, + rankingOptions, + }: FileSearchOptions = {}) { + super({ id: 'openai_file_search' }); + this.vectorStoreIds = [...vectorStoreIds]; + this.filters = filters; + this.maxNumResults = maxNumResults; + this.rankingOptions = rankingOptions; + } + + toToolConfig(): Record { + const result: Record = { + type: 'file_search', + vector_store_ids: this.vectorStoreIds, + }; + if (this.filters !== undefined) { + result.filters = this.filters; + } + if (this.maxNumResults !== undefined) { + result.max_num_results = this.maxNumResults; + } + if (this.rankingOptions !== undefined) { + result.ranking_options = this.rankingOptions; + } + return result; + } +} + +/** Options for the OpenAI code interpreter tool. */ +export interface CodeInterpreterOptions { + /** + * Code interpreter container. Can be a container ID or an object that specifies uploaded file IDs + * to make available to the code. + */ + container?: OpenAI.Responses.Tool.CodeInterpreter['container'] | null; +} + +/** + * Run Python code to help generate a response to a prompt. + * + * @see https://platform.openai.com/docs/guides/tools-code-interpreter + */ +export class CodeInterpreter extends OpenAITool { + /** Code interpreter container ID or configuration. */ + readonly container: OpenAI.Responses.Tool.CodeInterpreter['container'] | null; + + constructor({ container = null }: CodeInterpreterOptions = {}) { + super({ id: 'openai_code_interpreter' }); + this.container = container; + } + + toToolConfig(): Record { + const result: Record = { type: 'code_interpreter' }; + if (this.container !== null) { + result.container = this.container; + } + return result; + } +} diff --git a/plugins/openai/src/ws/llm.ts b/plugins/openai/src/ws/llm.ts index 64f2d641f..f75054387 100644 --- a/plugins/openai/src/ws/llm.ts +++ b/plugins/openai/src/ws/llm.ts @@ -15,6 +15,7 @@ import { import type OpenAI from 'openai'; import { WebSocket } from 'ws'; import type { ChatModels } from '../models.js'; +import { toResponsesTools } from '../tool_utils.js'; import type { WsOutputItemDoneEvent, WsOutputTextDeltaEvent, @@ -429,30 +430,7 @@ export class WSLLMStream extends llm.LLMStream { 'openai.responses', )) as OpenAI.Responses.ResponseInputItem[]; - // TODO: support provider-defined tools in the Responses schema. - const tools = this.toolCtx - ? this.toolCtx - .flatten() - .filter(llm.isFunctionTool) - .map((t) => { - const oaiParams = { - type: 'function' as const, - name: t.name, - description: t.description, - parameters: llm.toJsonSchema( - t.parameters, - true, - this.#strictToolSchema, - ) as unknown as OpenAI.Responses.FunctionTool['parameters'], - } as OpenAI.Responses.FunctionTool; - - if (this.#strictToolSchema) { - oaiParams.strict = true; - } - - return oaiParams; - }) - : undefined; + const tools = this.toolCtx ? toResponsesTools(this.toolCtx, this.#strictToolSchema) : undefined; const requestOptions: Record = { ...this.#modelOptions }; if (!tools) { diff --git a/plugins/phonic/src/realtime/realtime_model.ts b/plugins/phonic/src/realtime/realtime_model.ts index fd0e9baf9..bec3906dd 100644 --- a/plugins/phonic/src/realtime/realtime_model.ts +++ b/plugins/phonic/src/realtime/realtime_model.ts @@ -368,7 +368,7 @@ export class RealtimeSession extends llm.RealtimeSession { } this._tools = tools.copy(); - // TODO: support provider-defined tools in the Phonic schema. + // TODO: support provider tools in the Phonic schema. this.toolDefinitions = tools .flatten() .filter(llm.isFunctionTool) @@ -406,7 +406,7 @@ export class RealtimeSession extends llm.RealtimeSession { } if (tools !== undefined) { this._tools = tools.copy(); - // TODO: support provider-defined tools in the Phonic schema. + // TODO: support provider tools in the Phonic schema. this.toolDefinitions = tools .flatten() .filter(llm.isFunctionTool)