Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/hot-otters-pay.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@livekit/agents": minor
---

support nested speech handles in pipeline agent
2 changes: 1 addition & 1 deletion agents/src/pipeline/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ export {
type BeforeTTSCallback,
type BeforeLLMCallback,
type VPACallbacks,
type AgentCallContext,
type AgentTranscriptionOptions,
type VPAOptions,
VPAEvent,
VoicePipelineAgent,
AgentCallContext,
} from './pipeline_agent.js';
201 changes: 139 additions & 62 deletions agents/src/pipeline/pipeline_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ export class AgentCallContext {
#agent: VoicePipelineAgent;
#llmStream: LLMStream;
#metadata = new Map<string, any>();
#extraChatMessages: ChatMessage[] = [];
static #current: AgentCallContext;

constructor(agent: VoicePipelineAgent, llmStream: LLMStream) {
Expand Down Expand Up @@ -109,6 +110,14 @@ export class AgentCallContext {
get llmStream(): LLMStream {
return this.#llmStream;
}

get extraChatMessages() {
return this.#extraChatMessages;
}

addExtraChatMessage(message: ChatMessage) {
this.#extraChatMessages.push(message);
}
}

const defaultBeforeLLMCallback: BeforeLLMCallback = (
Expand Down Expand Up @@ -175,7 +184,7 @@ export interface VPAOptions {
interruptMinWords: number;
/** Delay to wait before considering the user speech done. */
minEndpointingDelay: number;
maxRecursiveFncCalls: number;
maxNestedFncCalls: number;
/* Whether to preemptively synthesize responses. */
preemptiveSynthesis: boolean;
/*
Expand Down Expand Up @@ -205,7 +214,7 @@ const defaultVPAOptions: VPAOptions = {
interruptSpeechDuration: 50,
interruptMinWords: 0,
minEndpointingDelay: 500,
maxRecursiveFncCalls: 1,
maxNestedFncCalls: 1,
preemptiveSynthesis: false,
beforeLLMCallback: defaultBeforeLLMCallback,
beforeTTSCallback: defaultBeforeTTSCallback,
Expand Down Expand Up @@ -368,12 +377,51 @@ export class VoicePipelineAgent extends (EventEmitter as new () => TypedEmitter<
source: string | LLMStream | AsyncIterable<string>,
allowInterruptions = true,
addToChatCtx = true,
) {
): Promise<SpeechHandle> {
await this.#trackPublishedFut.await;

let callContext: AgentCallContext | undefined;
let fncSource: string | AsyncIterable<string> | undefined;
if (addToChatCtx) {
callContext = AgentCallContext.getCurrent();
if (source instanceof LLMStream) {
this.#logger.warn('LLMStream will be ignored for function call chat context');
} else if (typeof source === 'string') {
fncSource = source;
} else {
fncSource = source;
source = new AsyncIterableQueue<string>();
}
}

const newHandle = SpeechHandle.createAssistantSpeech(allowInterruptions, addToChatCtx);
const synthesisHandle = this.#synthesizeAgentSpeech(newHandle.id, source);
newHandle.initialize(source, synthesisHandle);
this.#addSpeechForPlayout(newHandle);

if (this.#playingSpeech && !this.#playingSpeech.nestedSpeechFinished) {
this.#playingSpeech.addNestedSpeech(newHandle);
} else {
this.#addSpeechForPlayout(newHandle);
}

if (callContext && fncSource) {
let text: string;
if (typeof source === 'string') {
text = fncSource as string;
} else {
text = '';
for await (const chunk of fncSource) {
(source as AsyncIterableQueue<string>).put(chunk);
text += chunk;
}
(source as AsyncIterableQueue<string>).close();
}

callContext.addExtraChatMessage(ChatMessage.create({ text, role: ChatRole.ASSISTANT }));
this.#logger.child({ text }).debug('added speech to function call chat context');
}

return newHandle;
}

#updateState(state: AgentState, delay = 0) {
Expand Down Expand Up @@ -646,82 +694,109 @@ export class VoicePipelineAgent extends (EventEmitter as new () => TypedEmitter<
commitUserQuestionIfNeeded();

// TODO(nbsp): what goes here
let collectedText = '';
const collectedText = '';
const isUsingTools = handle.source instanceof LLMStream && !!handle.source.functionCalls.length;
const extraToolsMessages = []; // additional messages from the functions to add to the context
let interrupted = handle.interrupted;
const interrupted = handle.interrupted;

const executeFunctionCalls = async () => {
// if the answer is using tools, execute the functions and automatically generate
// a response to the user question from the returned values
if (!isUsingTools || interrupted) return;

if (handle.fncNestedDepth >= this.#opts.maxNestedFncCalls) {
this.#logger
.child({ speechId: handle.id, fncNestedDepth: handle.fncNestedDepth })
.warn('max function calls nested depth reached');
return;
}

// if the answer is using tools, execute the functions and automatically generate
// a response to the user question from the returned values
if (isUsingTools && !interrupted) {
if (!userQuestion || !handle.userCommitted) {
throw new Error('user speech should have been committed before using tools');
}
const llmStream = handle.source;
let newFunctionCalls = llmStream.functionCalls;

for (let i = 0; i < this.#opts.maxRecursiveFncCalls; i++) {
this.emit(VPAEvent.FUNCTION_CALLS_COLLECTED, newFunctionCalls);
const calledFuncs: FunctionCallInfo[] = [];
for (const func of newFunctionCalls) {
const task = func.func.execute(func.params).then(
(result) => ({ name: func.name, toolCallId: func.toolCallId, result }),
(error) => ({ name: func.name, toolCallId: func.toolCallId, error }),
);
calledFuncs.push({ ...func, task });
const newFunctionCalls = llmStream.functionCalls;

new AgentCallContext(this, llmStream);

this.emit(VPAEvent.FUNCTION_CALLS_COLLECTED, newFunctionCalls);
const calledFuncs: FunctionCallInfo[] = [];
for (const func of newFunctionCalls) {
const task = func.func.execute(func.params).then(
(result) => ({ name: func.name, toolCallId: func.toolCallId, result }),
(error) => ({ name: func.name, toolCallId: func.toolCallId, error }),
);
calledFuncs.push({ ...func, task });
this.#logger
.child({ function: func.name, speechId: handle.id })
.debug('executing AI function');
try {
await task;
} catch {
this.#logger
.child({ function: func.name, speechId: handle.id })
.debug('executing AI function');
try {
await task;
} catch {
this.#logger
.child({ function: func.name, speechId: handle.id })
.error('error executing AI function');
}
.error('error executing AI function');
}
}

const toolCallsInfo = [];
const toolCallsResults = [];
for (const fnc of calledFuncs) {
// ignore the function calls that return void
const task = await fnc.task;
if (!task || task.result === undefined) continue;
toolCallsInfo.push(fnc);
toolCallsResults.push(ChatMessage.createToolFromFunctionResult(task));
}
const toolCallsInfo = [];
const toolCallsResults = [];
for (const fnc of calledFuncs) {
// ignore the function calls that return void
const task = await fnc.task;
if (!task || task.result === undefined) continue;
toolCallsInfo.push(fnc);
toolCallsResults.push(ChatMessage.createToolFromFunctionResult(task));
}

if (!toolCallsInfo.length) break;
if (!toolCallsInfo.length) return;

// generate an answer from the tool calls
extraToolsMessages.push(ChatMessage.createToolCalls(toolCallsInfo, collectedText));
extraToolsMessages.push(...toolCallsResults);
// generate an answer from the tool calls
const extraToolsMessages = [ChatMessage.createToolCalls(toolCallsInfo, collectedText)];
extraToolsMessages.push(...toolCallsResults);

const chatCtx = handle.source.chatCtx.copy();
chatCtx.messages.push(...extraToolsMessages);
// create a nested speech handle
const newSpeechHandle = SpeechHandle.createToolSpeech(
handle.allowInterruptions,
handle.addToChatCtx,
handle.fncNestedDepth + 1,
extraToolsMessages,
);

const answerLLMStream = this.llm.chat({
chatCtx,
fncCtx: this.fncCtx,
});
const answerSynthesis = this.#synthesizeAgentSpeech(handle.id, answerLLMStream);
// replace the synthesis handle with the new one to allow interruption
handle.synthesisHandle = answerSynthesis;
const playHandle = answerSynthesis.play();
await playHandle.join().await;

// TODO(nbsp): what text goes here
collectedText = '';
interrupted = answerSynthesis.interrupted;
newFunctionCalls = answerLLMStream.functionCalls;

this.emit(VPAEvent.FUNCTION_CALLS_FINISHED, calledFuncs);
if (!newFunctionCalls) break;
// synthesize the tool speech with the chat ctx from llmStream
const chatCtx = handle.source.chatCtx.copy();
chatCtx.messages.push(...extraToolsMessages);
chatCtx.messages.push(...AgentCallContext.getCurrent().extraChatMessages);

const answerLLMStream = this.llm.chat({
chatCtx,
fncCtx: this.fncCtx,
});
const answerSynthesis = this.#synthesizeAgentSpeech(newSpeechHandle.id, answerLLMStream);
newSpeechHandle.initialize(answerLLMStream, answerSynthesis);
handle.addNestedSpeech(newSpeechHandle);

this.emit(VPAEvent.FUNCTION_CALLS_FINISHED, calledFuncs);
};

const task = executeFunctionCalls().then(() => {
handle.markNestedSpeechFinished();
});
while (!handle.nestedSpeechFinished) {
const changed = handle.nestedSpeechChanged();
await Promise.race([changed, task]);
while (handle.nestedSpeechHandles.length) {
const speech = handle.nestedSpeechHandles[0]!;
this.#playingSpeech = speech;
await this.#playSpeech(speech);
handle.nestedSpeechHandles.shift();
this.#playingSpeech = handle;
}
}

if (handle.addToChatCtx && (!userQuestion || handle.userCommitted)) {
this.chatCtx.messages.push(...extraToolsMessages);
if (handle.extraToolsMessages) {
this.chatCtx.messages.push(...handle.extraToolsMessages);
}
if (interrupted) {
collectedText + '…';
}
Expand All @@ -743,6 +818,8 @@ export class VoicePipelineAgent extends (EventEmitter as new () => TypedEmitter<
speechId: handle.id,
})
.debug('committed agent speech');

handle.setDone();
}
}

Expand Down
Loading
Loading