Skip to content

Commit 5e5a221

Browse files
committed
test(llm): add OpenAI LLM tests (#203)
1 parent 6d7c2c4 commit 5e5a221

File tree

7 files changed

+202
-6
lines changed

7 files changed

+202
-6
lines changed

.changeset/calm-ants-build.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@livekit/agents-plugin-openai": patch
3+
---
4+
5+
fix multiple function calls not firing

plugins/openai/src/llm.test.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// SPDX-FileCopyrightText: 2024 LiveKit, Inc.
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
import { llm } from '@livekit/agents-plugins-test';
5+
import { describe } from 'vitest';
6+
import { LLM } from './llm.js';
7+
8+
describe('OpenAI', async () => {
9+
await llm(new LLM());
10+
});

plugins/openai/src/llm.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,11 @@ export class LLMStream extends llm.LLMStream {
498498
continue; // oai may add other tools in the future
499499
}
500500

501+
let callChunk: llm.ChatChunk | undefined;
502+
if (this.#toolCallId && tool.id && tool.id !== this.#toolCallId) {
503+
callChunk = this.#tryBuildFunction(id, choice);
504+
}
505+
501506
if (tool.function.name) {
502507
this.#toolCallId = tool.id;
503508
this.#fncName = tool.function.name;
@@ -506,8 +511,8 @@ export class LLMStream extends llm.LLMStream {
506511
this.#fncRawArguments += tool.function.arguments;
507512
}
508513

509-
if (this.#toolCallId && tool.id && tool.id !== this.#toolCallId) {
510-
return this.#tryBuildFunction(id, choice);
514+
if (callChunk) {
515+
return callChunk;
511516
}
512517
}
513518
}

plugins/test/package.json

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,16 @@
2727
"lint": "eslint -f unix \"src/**/*.{ts,js}\""
2828
},
2929
"devDependencies": {
30-
"@types/node": "^22.5.5",
31-
"typescript": "^5.0.0",
3230
"@livekit/agents": "workspace:^x",
3331
"@livekit/rtc-node": "^0.12.1",
34-
"tsup": "^8.3.5"
32+
"@types/node": "^22.5.5",
33+
"tsup": "^8.3.5",
34+
"typescript": "^5.0.0"
3535
},
3636
"dependencies": {
37+
"fastest-levenshtein": "^1.0.16",
3738
"vitest": "^1.6.0",
38-
"fastest-levenshtein": "^1.0.16"
39+
"zod": "^3.23.8"
3940
},
4041
"peerDependencies": {
4142
"@livekit/agents": "workspace:^x",

plugins/test/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
//
33
// SPDX-License-Identifier: Apache-2.0
44
export { tts } from './tts.js';
5+
export { llm } from './llm.js';
56
export { stt } from './stt.js';

plugins/test/src/llm.ts

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
// SPDX-FileCopyrightText: 2024 LiveKit, Inc.
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
import { initializeLogger, llm as llmlib } from '@livekit/agents';
5+
import { describe, expect, it } from 'vitest';
6+
import { z } from 'zod';
7+
8+
const fncCtx: llmlib.FunctionContext = {
9+
getWeather: {
10+
description: 'Get the current weather in a given location',
11+
parameters: z.object({
12+
location: z.string().describe('The city and state, e.g. San Francisco, CA'),
13+
unit: z.enum(['celsius', 'fahrenheit']).describe('The temperature unit to use'),
14+
}),
15+
execute: async () => {},
16+
},
17+
playMusic: {
18+
description: 'Play music',
19+
parameters: z.object({
20+
name: z.string().describe('The artist and name of the song'),
21+
}),
22+
execute: async () => {},
23+
},
24+
toggleLight: {
25+
description: 'Turn on/off the lights in a room',
26+
parameters: z.object({
27+
name: z.string().describe('The room to control'),
28+
on: z.boolean().describe('Whether to turn light on or off'),
29+
}),
30+
execute: async () => {
31+
await new Promise((resolve) => setTimeout(resolve, 60_000));
32+
},
33+
},
34+
selectCurrencies: {
35+
description: 'Currencies of a specific area',
36+
parameters: z.object({
37+
currencies: z
38+
.array(z.enum(['USD', 'EUR', 'GBP', 'JPY', 'SEK']))
39+
.describe('The currencies to select'),
40+
}),
41+
execute: async () => {},
42+
},
43+
updateUserInfo: {
44+
description: 'Update user info',
45+
parameters: z.object({
46+
email: z.string().optional().describe("User's email address"),
47+
name: z.string().optional().describe("User's name"),
48+
address: z.string().optional().describe("User's home address"),
49+
}),
50+
execute: async () => {},
51+
},
52+
simulateFailure: {
53+
description: 'Simulate a failure',
54+
parameters: z.object({}),
55+
execute: async () => {
56+
throw new Error('Simulated failure');
57+
},
58+
},
59+
};
60+
61+
export const llm = async (llm: llmlib.LLM) => {
62+
initializeLogger({ pretty: false });
63+
describe('LLM', async () => {
64+
it('should properly respond to chat', async () => {
65+
const chatCtx = new llmlib.ChatContext().append({
66+
text: 'You are an assistant at a drive-thru restaurant "Live-Burger". Ask the customer what they would like to order.',
67+
role: llmlib.ChatRole.SYSTEM,
68+
});
69+
70+
const stream = llm.chat({ chatCtx });
71+
let text = '';
72+
for await (const chunk of stream) {
73+
if (!chunk.choices.length) continue;
74+
text += chunk.choices[0]?.delta.content;
75+
}
76+
77+
expect(text.length).toBeGreaterThan(0);
78+
});
79+
describe('function calling', async () => {
80+
it('should handle function calling', async () => {
81+
const stream = await requestFncCall(
82+
llm,
83+
"What's the weather in San Francisco and what's the weather in Paris?",
84+
fncCtx,
85+
);
86+
const calls = stream.executeFunctions();
87+
await Promise.all(calls.map((call) => call.task));
88+
stream.close();
89+
90+
expect(calls.length).toStrictEqual(2);
91+
});
92+
it('should handle exceptions', async () => {
93+
const stream = await requestFncCall(llm, 'Call the failing function', fncCtx);
94+
const calls = stream.executeFunctions();
95+
stream.close();
96+
97+
expect(calls.length).toStrictEqual(1);
98+
const task = await calls[0]!.task!;
99+
expect(task.error).toBeInstanceOf(Error);
100+
expect(task.error.message).toStrictEqual('Simulated failure');
101+
});
102+
it('should handle arrays', async () => {
103+
const stream = await requestFncCall(
104+
llm,
105+
'Can you select all currencies in Europe at once from given choices?',
106+
fncCtx,
107+
0.2,
108+
);
109+
const calls = stream.executeFunctions();
110+
stream.close();
111+
112+
expect(calls.length).toStrictEqual(1);
113+
expect(calls[0]!.params.currencies.length).toStrictEqual(3);
114+
expect(calls[0]!.params.currencies).toContain('EUR');
115+
expect(calls[0]!.params.currencies).toContain('GBP');
116+
expect(calls[0]!.params.currencies).toContain('SEK');
117+
});
118+
it('should handle enums', async () => {
119+
const stream = await requestFncCall(
120+
llm,
121+
"What's the weather in San Francisco, in Celsius?",
122+
fncCtx,
123+
);
124+
const calls = stream.executeFunctions();
125+
stream.close();
126+
127+
expect(calls.length).toStrictEqual(1);
128+
expect(calls[0]!.params.unit).toStrictEqual('celsius');
129+
});
130+
it('should handle optional arguments', async () => {
131+
const stream = await requestFncCall(
132+
llm,
133+
'Use a tool call to update the user info to name Theo',
134+
fncCtx,
135+
);
136+
const calls = stream.executeFunctions();
137+
stream.close();
138+
139+
expect(calls.length).toStrictEqual(1);
140+
expect(calls[0]!.params.name).toStrictEqual('Theo');
141+
expect(calls[0]!.params.email).toBeUndefined();
142+
expect(calls[0]!.params.address).toBeUndefined();
143+
});
144+
});
145+
});
146+
};
147+
148+
const requestFncCall = async (
149+
llm: llmlib.LLM,
150+
text: string,
151+
fncCtx: llmlib.FunctionContext,
152+
temperature: number | undefined = undefined,
153+
parallelToolCalls: boolean | undefined = undefined,
154+
) => {
155+
const stream = llm.chat({
156+
chatCtx: new llmlib.ChatContext()
157+
.append({
158+
text: 'You are an helpful assistant. Follow the instructions provided by the user. You can use multiple tool calls at once.',
159+
role: llmlib.ChatRole.SYSTEM,
160+
})
161+
.append({ text, role: llmlib.ChatRole.USER }),
162+
fncCtx,
163+
temperature,
164+
parallelToolCalls,
165+
});
166+
167+
for await (const _ of stream) {
168+
_;
169+
}
170+
return stream;
171+
};

pnpm-lock.yaml

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)