Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/calm-ants-build.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@livekit/agents-plugin-openai": patch
---

fix multiple function calls not firing
10 changes: 10 additions & 0 deletions plugins/openai/src/llm.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// SPDX-FileCopyrightText: 2024 LiveKit, Inc.
//
// SPDX-License-Identifier: Apache-2.0
import { llm } from '@livekit/agents-plugins-test';
import { describe } from 'vitest';
import { LLM } from './llm.js';

describe('OpenAI', async () => {
await llm(new LLM());
});
9 changes: 7 additions & 2 deletions plugins/openai/src/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,11 @@ export class LLMStream extends llm.LLMStream {
continue; // oai may add other tools in the future
}

let callChunk: llm.ChatChunk | undefined;
if (this.#toolCallId && tool.id && tool.id !== this.#toolCallId) {
callChunk = this.#tryBuildFunction(id, choice);
}

if (tool.function.name) {
this.#toolCallId = tool.id;
this.#fncName = tool.function.name;
Expand All @@ -509,8 +514,8 @@ export class LLMStream extends llm.LLMStream {
this.#fncRawArguments += tool.function.arguments;
}

if (this.#toolCallId && tool.id && tool.id !== this.#toolCallId) {
return this.#tryBuildFunction(id, choice);
if (callChunk) {
return callChunk;
}
}
}
Expand Down
9 changes: 5 additions & 4 deletions plugins/test/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@
"lint": "eslint -f unix \"src/**/*.{ts,js}\""
},
"devDependencies": {
"@types/node": "^22.5.5",
"typescript": "^5.0.0",
"@livekit/agents": "workspace:^x",
"@livekit/rtc-node": "^0.12.1",
"tsup": "^8.3.5"
"@types/node": "^22.5.5",
"tsup": "^8.3.5",
"typescript": "^5.0.0"
},
"dependencies": {
"fastest-levenshtein": "^1.0.16",
"vitest": "^1.6.0",
"fastest-levenshtein": "^1.0.16"
"zod": "^3.23.8"
},
"peerDependencies": {
"@livekit/agents": "workspace:^x",
Expand Down
1 change: 1 addition & 0 deletions plugins/test/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
//
// SPDX-License-Identifier: Apache-2.0
export { tts } from './tts.js';
export { llm } from './llm.js';
export { stt } from './stt.js';
171 changes: 171 additions & 0 deletions plugins/test/src/llm.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// SPDX-FileCopyrightText: 2024 LiveKit, Inc.
//
// SPDX-License-Identifier: Apache-2.0
import { initializeLogger, llm as llmlib } from '@livekit/agents';
import { describe, expect, it } from 'vitest';
import { z } from 'zod';

const fncCtx: llmlib.FunctionContext = {
getWeather: {
description: 'Get the current weather in a given location',
parameters: z.object({
location: z.string().describe('The city and state, e.g. San Francisco, CA'),
unit: z.enum(['celsius', 'fahrenheit']).describe('The temperature unit to use'),
}),
execute: async () => {},
},
playMusic: {
description: 'Play music',
parameters: z.object({
name: z.string().describe('The artist and name of the song'),
}),
execute: async () => {},
},
toggleLight: {
description: 'Turn on/off the lights in a room',
parameters: z.object({
name: z.string().describe('The room to control'),
on: z.boolean().describe('Whether to turn light on or off'),
}),
execute: async () => {
await new Promise((resolve) => setTimeout(resolve, 60_000));
},
},
selectCurrencies: {
description: 'Currencies of a specific area',
parameters: z.object({
currencies: z
.array(z.enum(['USD', 'EUR', 'GBP', 'JPY', 'SEK']))
.describe('The currencies to select'),
}),
execute: async () => {},
},
updateUserInfo: {
description: 'Update user info',
parameters: z.object({
email: z.string().optional().describe("User's email address"),
name: z.string().optional().describe("User's name"),
address: z.string().optional().describe("User's home address"),
}),
execute: async () => {},
},
simulateFailure: {
description: 'Simulate a failure',
parameters: z.object({}),
execute: async () => {
throw new Error('Simulated failure');
},
},
};

export const llm = async (llm: llmlib.LLM) => {
initializeLogger({ pretty: false });
describe('LLM', async () => {
it('should properly respond to chat', async () => {
const chatCtx = new llmlib.ChatContext().append({
text: 'You are an assistant at a drive-thru restaurant "Live-Burger". Ask the customer what they would like to order.',
role: llmlib.ChatRole.SYSTEM,
});

const stream = llm.chat({ chatCtx });
let text = '';
for await (const chunk of stream) {
if (!chunk.choices.length) continue;
text += chunk.choices[0]?.delta.content;
}

expect(text.length).toBeGreaterThan(0);
});
describe('function calling', async () => {
it('should handle function calling', async () => {
const stream = await requestFncCall(
llm,
"What's the weather in San Francisco and what's the weather in Paris?",
fncCtx,
);
const calls = stream.executeFunctions();
await Promise.all(calls.map((call) => call.task));
stream.close();

expect(calls.length).toStrictEqual(2);
});
it('should handle exceptions', async () => {
const stream = await requestFncCall(llm, 'Call the failing function', fncCtx);
const calls = stream.executeFunctions();
stream.close();

expect(calls.length).toStrictEqual(1);
const task = await calls[0]!.task!;
expect(task.error).toBeInstanceOf(Error);
expect(task.error.message).toStrictEqual('Simulated failure');
});
it('should handle arrays', async () => {
const stream = await requestFncCall(
llm,
'Can you select all currencies in Europe at once from given choices?',
fncCtx,
0.2,
);
const calls = stream.executeFunctions();
stream.close();

expect(calls.length).toStrictEqual(1);
expect(calls[0]!.params.currencies.length).toStrictEqual(3);
expect(calls[0]!.params.currencies).toContain('EUR');
expect(calls[0]!.params.currencies).toContain('GBP');
expect(calls[0]!.params.currencies).toContain('SEK');
});
it('should handle enums', async () => {
const stream = await requestFncCall(
llm,
"What's the weather in San Francisco, in Celsius?",
fncCtx,
);
const calls = stream.executeFunctions();
stream.close();

expect(calls.length).toStrictEqual(1);
expect(calls[0]!.params.unit).toStrictEqual('celsius');
});
it('should handle optional arguments', async () => {
const stream = await requestFncCall(
llm,
'Use a tool call to update the user info to name Theo',
fncCtx,
);
const calls = stream.executeFunctions();
stream.close();

expect(calls.length).toStrictEqual(1);
expect(calls[0]!.params.name).toStrictEqual('Theo');
expect(calls[0]!.params.email).toBeUndefined();
expect(calls[0]!.params.address).toBeUndefined();
});
});
});
};

const requestFncCall = async (
llm: llmlib.LLM,
text: string,
fncCtx: llmlib.FunctionContext,
temperature: number | undefined = undefined,
parallelToolCalls: boolean | undefined = undefined,
) => {
const stream = llm.chat({
chatCtx: new llmlib.ChatContext()
.append({
text: 'You are an helpful assistant. Follow the instructions provided by the user. You can use multiple tool calls at once.',
role: llmlib.ChatRole.SYSTEM,
})
.append({ text, role: llmlib.ChatRole.USER }),
fncCtx,
temperature,
parallelToolCalls,
});

for await (const _ of stream) {
_;
}
return stream;
};
3 changes: 3 additions & 0 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading