diff --git a/.changeset/calm-ants-build.md b/.changeset/calm-ants-build.md new file mode 100644 index 000000000..a8b40691e --- /dev/null +++ b/.changeset/calm-ants-build.md @@ -0,0 +1,5 @@ +--- +"@livekit/agents-plugin-openai": patch +--- + +fix multiple function calls not firing diff --git a/plugins/openai/src/llm.test.ts b/plugins/openai/src/llm.test.ts new file mode 100644 index 000000000..a4879d17f --- /dev/null +++ b/plugins/openai/src/llm.test.ts @@ -0,0 +1,10 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { llm } from '@livekit/agents-plugins-test'; +import { describe } from 'vitest'; +import { LLM } from './llm.js'; + +describe('OpenAI', async () => { + await llm(new LLM()); +}); diff --git a/plugins/openai/src/llm.ts b/plugins/openai/src/llm.ts index 56a0cbbb5..19386bb61 100644 --- a/plugins/openai/src/llm.ts +++ b/plugins/openai/src/llm.ts @@ -501,6 +501,11 @@ export class LLMStream extends llm.LLMStream { continue; // oai may add other tools in the future } + let callChunk: llm.ChatChunk | undefined; + if (this.#toolCallId && tool.id && tool.id !== this.#toolCallId) { + callChunk = this.#tryBuildFunction(id, choice); + } + if (tool.function.name) { this.#toolCallId = tool.id; this.#fncName = tool.function.name; @@ -509,8 +514,8 @@ export class LLMStream extends llm.LLMStream { this.#fncRawArguments += tool.function.arguments; } - if (this.#toolCallId && tool.id && tool.id !== this.#toolCallId) { - return this.#tryBuildFunction(id, choice); + if (callChunk) { + return callChunk; } } } diff --git a/plugins/test/package.json b/plugins/test/package.json index ee24fd9b1..e3c999f1e 100644 --- a/plugins/test/package.json +++ b/plugins/test/package.json @@ -27,15 +27,16 @@ "lint": "eslint -f unix \"src/**/*.{ts,js}\"" }, "devDependencies": { - "@types/node": "^22.5.5", - "typescript": "^5.0.0", "@livekit/agents": "workspace:^x", "@livekit/rtc-node": "^0.12.1", - "tsup": "^8.3.5" + "@types/node": "^22.5.5", + "tsup": "^8.3.5", + "typescript": "^5.0.0" }, "dependencies": { + "fastest-levenshtein": "^1.0.16", "vitest": "^1.6.0", - "fastest-levenshtein": "^1.0.16" + "zod": "^3.23.8" }, "peerDependencies": { "@livekit/agents": "workspace:^x", diff --git a/plugins/test/src/index.ts b/plugins/test/src/index.ts index 94d6457d2..497d2b33a 100644 --- a/plugins/test/src/index.ts +++ b/plugins/test/src/index.ts @@ -2,4 +2,5 @@ // // SPDX-License-Identifier: Apache-2.0 export { tts } from './tts.js'; +export { llm } from './llm.js'; export { stt } from './stt.js'; diff --git a/plugins/test/src/llm.ts b/plugins/test/src/llm.ts new file mode 100644 index 000000000..bd4ca78f8 --- /dev/null +++ b/plugins/test/src/llm.ts @@ -0,0 +1,171 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { initializeLogger, llm as llmlib } from '@livekit/agents'; +import { describe, expect, it } from 'vitest'; +import { z } from 'zod'; + +const fncCtx: llmlib.FunctionContext = { + getWeather: { + description: 'Get the current weather in a given location', + parameters: z.object({ + location: z.string().describe('The city and state, e.g. San Francisco, CA'), + unit: z.enum(['celsius', 'fahrenheit']).describe('The temperature unit to use'), + }), + execute: async () => {}, + }, + playMusic: { + description: 'Play music', + parameters: z.object({ + name: z.string().describe('The artist and name of the song'), + }), + execute: async () => {}, + }, + toggleLight: { + description: 'Turn on/off the lights in a room', + parameters: z.object({ + name: z.string().describe('The room to control'), + on: z.boolean().describe('Whether to turn light on or off'), + }), + execute: async () => { + await new Promise((resolve) => setTimeout(resolve, 60_000)); + }, + }, + selectCurrencies: { + description: 'Currencies of a specific area', + parameters: z.object({ + currencies: z + .array(z.enum(['USD', 'EUR', 'GBP', 'JPY', 'SEK'])) + .describe('The currencies to select'), + }), + execute: async () => {}, + }, + updateUserInfo: { + description: 'Update user info', + parameters: z.object({ + email: z.string().optional().describe("User's email address"), + name: z.string().optional().describe("User's name"), + address: z.string().optional().describe("User's home address"), + }), + execute: async () => {}, + }, + simulateFailure: { + description: 'Simulate a failure', + parameters: z.object({}), + execute: async () => { + throw new Error('Simulated failure'); + }, + }, +}; + +export const llm = async (llm: llmlib.LLM) => { + initializeLogger({ pretty: false }); + describe('LLM', async () => { + it('should properly respond to chat', async () => { + const chatCtx = new llmlib.ChatContext().append({ + text: 'You are an assistant at a drive-thru restaurant "Live-Burger". Ask the customer what they would like to order.', + role: llmlib.ChatRole.SYSTEM, + }); + + const stream = llm.chat({ chatCtx }); + let text = ''; + for await (const chunk of stream) { + if (!chunk.choices.length) continue; + text += chunk.choices[0]?.delta.content; + } + + expect(text.length).toBeGreaterThan(0); + }); + describe('function calling', async () => { + it('should handle function calling', async () => { + const stream = await requestFncCall( + llm, + "What's the weather in San Francisco and what's the weather in Paris?", + fncCtx, + ); + const calls = stream.executeFunctions(); + await Promise.all(calls.map((call) => call.task)); + stream.close(); + + expect(calls.length).toStrictEqual(2); + }); + it('should handle exceptions', async () => { + const stream = await requestFncCall(llm, 'Call the failing function', fncCtx); + const calls = stream.executeFunctions(); + stream.close(); + + expect(calls.length).toStrictEqual(1); + const task = await calls[0]!.task!; + expect(task.error).toBeInstanceOf(Error); + expect(task.error.message).toStrictEqual('Simulated failure'); + }); + it('should handle arrays', async () => { + const stream = await requestFncCall( + llm, + 'Can you select all currencies in Europe at once from given choices?', + fncCtx, + 0.2, + ); + const calls = stream.executeFunctions(); + stream.close(); + + expect(calls.length).toStrictEqual(1); + expect(calls[0]!.params.currencies.length).toStrictEqual(3); + expect(calls[0]!.params.currencies).toContain('EUR'); + expect(calls[0]!.params.currencies).toContain('GBP'); + expect(calls[0]!.params.currencies).toContain('SEK'); + }); + it('should handle enums', async () => { + const stream = await requestFncCall( + llm, + "What's the weather in San Francisco, in Celsius?", + fncCtx, + ); + const calls = stream.executeFunctions(); + stream.close(); + + expect(calls.length).toStrictEqual(1); + expect(calls[0]!.params.unit).toStrictEqual('celsius'); + }); + it('should handle optional arguments', async () => { + const stream = await requestFncCall( + llm, + 'Use a tool call to update the user info to name Theo', + fncCtx, + ); + const calls = stream.executeFunctions(); + stream.close(); + + expect(calls.length).toStrictEqual(1); + expect(calls[0]!.params.name).toStrictEqual('Theo'); + expect(calls[0]!.params.email).toBeUndefined(); + expect(calls[0]!.params.address).toBeUndefined(); + }); + }); + }); +}; + +const requestFncCall = async ( + llm: llmlib.LLM, + text: string, + fncCtx: llmlib.FunctionContext, + temperature: number | undefined = undefined, + parallelToolCalls: boolean | undefined = undefined, +) => { + const stream = llm.chat({ + chatCtx: new llmlib.ChatContext() + .append({ + text: 'You are an helpful assistant. Follow the instructions provided by the user. You can use multiple tool calls at once.', + role: llmlib.ChatRole.SYSTEM, + }) + .append({ text, role: llmlib.ChatRole.USER }), + fncCtx, + temperature, + parallelToolCalls, + }); + + for await (const _ of stream) { + _; + } + return stream; +}; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index fea9bed69..8989c3088 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -296,6 +296,9 @@ importers: vitest: specifier: ^1.6.0 version: 1.6.0(@types/node@22.5.5) + zod: + specifier: ^3.23.8 + version: 3.23.8 devDependencies: '@livekit/agents': specifier: workspace:^x