Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,16 @@ public int MaximumConsecutiveErrorsPerRequest
set => _maximumConsecutiveErrorsPerRequest = Throw.IfLessThan(value, 0);
}

/// <summary>Gets or sets a collection of additional tools the client is able to invoke.</summary>
/// <remarks>
/// These will not impact the requests sent by the <see cref="FunctionInvokingChatClient"/>, which will pass through the
/// <see cref="ChatOptions.Tools" /> unmodified. However, if the inner client requests the invocation of a tool
/// that was not in <see cref="ChatOptions.Tools" />, this <see cref="AdditionalTools"/> collection will also be consulted
/// to look for a corresponding tool to invoke. This is useful when the service may have been pre-configured to be aware
/// of certain tools that aren't also sent on each individual request.
/// </remarks>
public IList<AITool>? AdditionalTools { get; set; }

/// <summary>Gets or sets a delegate used to invoke <see cref="AIFunction"/> instances.</summary>
/// <remarks>
/// By default, the protected <see cref="InvokeFunctionAsync"/> method is called for each <see cref="AIFunction"/> to be invoked,
Expand Down Expand Up @@ -250,7 +260,7 @@ public override async Task<ChatResponse> GetResponseAsync(

// Any function call work to do? If yes, ensure we're tracking that work in functionCallContents.
bool requiresFunctionInvocation =
options?.Tools is { Count: > 0 } &&
(options?.Tools is { Count: > 0 } || AdditionalTools is { Count: > 0 }) &&
iteration < MaximumIterationsPerRequest &&
CopyFunctionCalls(response.Messages, ref functionCallContents);

Expand Down Expand Up @@ -288,7 +298,7 @@ public override async Task<ChatResponse> GetResponseAsync(

// Add the responses from the function calls into the augmented history and also into the tracked
// list of response messages.
var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, consecutiveErrorCount, isStreaming: false, cancellationToken);
var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents!, iteration, consecutiveErrorCount, isStreaming: false, cancellationToken);
responseMessages.AddRange(modeAndMessages.MessagesAdded);
consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount;

Expand All @@ -297,7 +307,7 @@ public override async Task<ChatResponse> GetResponseAsync(
break;
}

UpdateOptionsForNextIteration(ref options!, response.ConversationId);
UpdateOptionsForNextIteration(ref options, response.ConversationId);
}

Debug.Assert(responseMessages is not null, "Expected to only be here if we have response messages.");
Expand Down Expand Up @@ -367,7 +377,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA

// If there are no tools to call, or for any other reason we should stop, return the response.
if (functionCallContents is not { Count: > 0 } ||
options?.Tools is not { Count: > 0 } ||
(options?.Tools is not { Count: > 0 } && AdditionalTools is not { Count: > 0 }) ||
iteration >= _maximumIterationsPerRequest)
{
break;
Expand Down Expand Up @@ -535,9 +545,16 @@ private static bool CopyFunctionCalls(
return any;
}

private static void UpdateOptionsForNextIteration(ref ChatOptions options, string? conversationId)
private static void UpdateOptionsForNextIteration(ref ChatOptions? options, string? conversationId)
{
if (options.ToolMode is RequiredChatToolMode)
if (options is null)
{
if (conversationId is not null)
{
options = new() { ConversationId = conversationId };
}
}
else if (options.ToolMode is RequiredChatToolMode)
{
// We have to reset the tool mode to be non-required after the first iteration,
// as otherwise we'll be in an infinite loop.
Expand Down Expand Up @@ -566,7 +583,7 @@ private static void UpdateOptionsForNextIteration(ref ChatOptions options, strin
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests.</param>
/// <returns>A value indicating how the caller should proceed.</returns>
private async Task<(bool ShouldTerminate, int NewConsecutiveErrorCount, IList<ChatMessage> MessagesAdded)> ProcessFunctionCallsAsync(
List<ChatMessage> messages, ChatOptions options, List<FunctionCallContent> functionCallContents, int iteration, int consecutiveErrorCount,
List<ChatMessage> messages, ChatOptions? options, List<FunctionCallContent> functionCallContents, int iteration, int consecutiveErrorCount,
bool isStreaming, CancellationToken cancellationToken)
{
// We must add a response for every tool call, regardless of whether we successfully executed it or not.
Expand Down Expand Up @@ -695,13 +712,13 @@ private void ThrowIfNoFunctionResultsAdded(IList<ChatMessage>? messages)
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests.</param>
/// <returns>A value indicating how the caller should proceed.</returns>
private async Task<FunctionInvocationResult> ProcessFunctionCallAsync(
List<ChatMessage> messages, ChatOptions options, List<FunctionCallContent> callContents,
List<ChatMessage> messages, ChatOptions? options, List<FunctionCallContent> callContents,
int iteration, int functionCallIndex, bool captureExceptions, bool isStreaming, CancellationToken cancellationToken)
{
var callContent = callContents[functionCallIndex];

// Look up the AIFunction for the function call. If the requested function isn't available, send back an error.
AIFunction? aiFunction = options.Tools!.OfType<AIFunction>().FirstOrDefault(t => t.Name == callContent.Name);
AIFunction? aiFunction = FindAIFunction(options?.Tools, callContent.Name) ?? FindAIFunction(AdditionalTools, callContent.Name);
if (aiFunction is null)
{
return new(terminate: false, FunctionInvocationStatus.NotFound, callContent, result: null, exception: null);
Expand Down Expand Up @@ -746,6 +763,23 @@ private async Task<FunctionInvocationResult> ProcessFunctionCallAsync(
callContent,
result,
exception: null);

static AIFunction? FindAIFunction(IList<AITool>? tools, string functionName)
{
if (tools is not null)
{
int count = tools.Count;
for (int i = 0; i < count; i++)
{
if (tools[i] is AIFunction function && function.Name == functionName)
{
return function;
}
}
}

return null;
}
}

/// <summary>Creates one or more response messages for function invocation results.</summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,10 @@
}
],
"Properties": [
{
"Member": "System.Collections.Generic.IList<Microsoft.Extensions.AI.AITool>? Microsoft.Extensions.AI.FunctionInvokingChatClient.AdditionalTools { get; set; }",
"Stage": "Stable"
},
{
"Member": "bool Microsoft.Extensions.AI.FunctionInvokingChatClient.AllowConcurrentInvocation { get; set; }",
"Stage": "Stable"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public void Ctor_HasExpectedDefaults()
Assert.Equal(40, client.MaximumIterationsPerRequest);
Assert.Equal(3, client.MaximumConsecutiveErrorsPerRequest);
Assert.Null(client.FunctionInvoker);
Assert.Null(client.AdditionalTools);
}

[Fact]
Expand Down Expand Up @@ -67,6 +68,11 @@ public void Properties_Roundtrip()
Func<FunctionInvocationContext, CancellationToken, ValueTask<object?>> invoker = (ctx, ct) => new ValueTask<object?>("test");
client.FunctionInvoker = invoker;
Assert.Same(invoker, client.FunctionInvoker);

Assert.Null(client.AdditionalTools);
IList<AITool> additionalTools = [AIFunctionFactory.Create(() => "Additional Tool")];
client.AdditionalTools = additionalTools;
Assert.Same(additionalTools, client.AdditionalTools);
}

[Fact]
Expand Down Expand Up @@ -99,6 +105,73 @@ public async Task SupportsSingleFunctionCallPerRequestAsync()
await InvokeAndAssertStreamingAsync(options, plan);
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task SupportsToolsProvidedByAdditionalTools(bool provideOptions)
{
ChatOptions? options = provideOptions ?
new() { Tools = [AIFunctionFactory.Create(() => "Shouldn't be invoked", "ChatOptionsFunc")] } :
null;

Func<ChatClientBuilder, ChatClientBuilder> configure = builder =>
builder.UseFunctionInvocation(configure: c => c.AdditionalTools =
[
AIFunctionFactory.Create(() => "Result 1", "Func1"),
AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"),
AIFunctionFactory.Create((int i) => { }, "VoidReturn"),
]);

List<ChatMessage> plan =
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary<string, object?> { { "i", 42 } })]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42")]),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary<string, object?> { { "i", 43 } })]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", result: "Success: Function completed.")]),
new ChatMessage(ChatRole.Assistant, "world"),
];

await InvokeAndAssertAsync(options, plan, configurePipeline: configure);

await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure);
}

[Fact]
public async Task PrefersToolsProvidedByChatOptions()
{
ChatOptions options = new()
{
Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")]
};

Func<ChatClientBuilder, ChatClientBuilder> configure = builder =>
builder.UseFunctionInvocation(configure: c => c.AdditionalTools =
[
AIFunctionFactory.Create(() => "Should never be invoked", "Func1"),
AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"),
AIFunctionFactory.Create((int i) => { }, "VoidReturn"),
]);

List<ChatMessage> plan =
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary<string, object?> { { "i", 42 } })]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42")]),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary<string, object?> { { "i", 43 } })]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", result: "Success: Function completed.")]),
new ChatMessage(ChatRole.Assistant, "world"),
];

await InvokeAndAssertAsync(options, plan, configurePipeline: configure);

await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure);
}

[Theory]
[InlineData(false)]
[InlineData(true)]
Expand Down Expand Up @@ -1002,7 +1075,7 @@ public override void Post(SendOrPostCallback d, object? state)
}

private static async Task<List<ChatMessage>> InvokeAndAssertAsync(
ChatOptions options,
ChatOptions? options,
List<ChatMessage> plan,
List<ChatMessage>? expected = null,
Func<ChatClientBuilder, ChatClientBuilder>? configurePipeline = null,
Expand Down Expand Up @@ -1102,7 +1175,7 @@ private static UsageDetails CreateRandomUsage()
}

private static async Task<List<ChatMessage>> InvokeAndAssertStreamingAsync(
ChatOptions options,
ChatOptions? options,
List<ChatMessage> plan,
List<ChatMessage>? expected = null,
Func<ChatClientBuilder, ChatClientBuilder>? configurePipeline = null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
</PropertyGroup>

<PropertyGroup>
<NoWarn>$(NoWarn);CA1063;CA1861;SA1130;VSTHRD003</NoWarn>
<NoWarn>$(NoWarn);CA1063;CA1861;S104;SA1130;VSTHRD003</NoWarn>
<NoWarn>$(NoWarn);MEAI001</NoWarn>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
</PropertyGroup>
Expand Down
Loading