Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 46 additions & 13 deletions Extension/src/LanguageServer/copilotCompletionContextProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright (c) Microsoft Corporation. All Rights Reserved.
* See 'LICENSE' in the project root for license information.
* ------------------------------------------------------------------------------------------ */
import { ContextResolver, ResolveRequest, SupportedContextItem } from '@github/copilot-language-server';
import { ContextResolver, ResolveRequest, SupportedContextItem, type ContextProvider } from '@github/copilot-language-server';
import { randomUUID } from 'crypto';
import * as vscode from 'vscode';
import { DocumentSelector } from 'vscode-languageserver-protocol';
Expand All @@ -11,7 +11,7 @@ import { getOutputChannelLogger, Logger } from '../logger';
import * as telemetry from '../telemetry';
import { CopilotCompletionContextResult } from './client';
import { CopilotCompletionContextTelemetry } from './copilotCompletionContextTelemetry';
import { getCopilotApi } from './copilotProviders';
import { getCopilotChatApi, getCopilotClientApi, type CopilotContextProviderAPI } from './copilotProviders';
import { clients } from './extension';
import { CppSettings } from './settings';

Expand Down Expand Up @@ -83,7 +83,7 @@ export class CopilotCompletionContextProvider implements ContextResolver<Support
private static readonly defaultMaxSnippetLength = 3 * 1024;
private static readonly defaultDoAggregateSnippets = true;
private completionContextCancellation = new vscode.CancellationTokenSource();
private contextProviderDisposable: vscode.Disposable | undefined;
private contextProviderDisposables: vscode.Disposable[] | undefined;
static readonly CppContextProviderEnabledFeatures = 'enabledFeatures';
static readonly CppContextProviderTimeBudgetMs = 'timeBudgetMs';
static readonly CppContextProviderMaxSnippetCount = 'maxSnippetCount';
Expand Down Expand Up @@ -312,7 +312,12 @@ export class CopilotCompletionContextProvider implements ContextResolver<Support

public dispose(): void {
this.completionContextCancellation.cancel();
this.contextProviderDisposable?.dispose();
if (this.contextProviderDisposables) {
for (const disposable of this.contextProviderDisposables) {
disposable.dispose();
}
this.contextProviderDisposables = undefined;
}
}

public removeFile(fileUri: string): void {
Expand Down Expand Up @@ -444,18 +449,32 @@ ${copilotCompletionContext?.areSnippetsMissing ? "(missing code snippets)" : ""}
const properties: Record<string, string> = {};
const registerCopilotContextProvider = 'registerCopilotContextProvider';
try {
const copilotApi = await getCopilotApi();
if (!copilotApi) { throw new CopilotContextProviderException("getCopilotApi() returned null, Copilot is missing or inactive."); }
const hasGetContextProviderAPI = "getContextProviderAPI" in copilotApi;
if (!hasGetContextProviderAPI) { throw new CopilotContextProviderException("getContextProviderAPI() is not available."); }
const contextAPI = await copilotApi.getContextProviderAPI("v1");
if (!contextAPI) { throw new CopilotContextProviderException("getContextProviderAPI(v1) returned null."); }
this.contextProviderDisposable = contextAPI.registerContextProvider({
const copilotApi = await getCopilotClientApi();
const copilotChatApi = await getCopilotChatApi();
if (!copilotApi && !copilotChatApi) { throw new CopilotContextProviderException("getCopilotApi() returned null, Copilot is missing or inactive."); }
const contextProvider = {
id: CopilotCompletionContextProvider.providerId,
selector: CopilotCompletionContextProvider.defaultCppDocumentSelector,
resolver: this
});
properties["cppCodeSnippetsProviderRegistered"] = "true";
};
type InstallSummary = { hasGetContextProviderAPI: boolean; hasAPI: boolean };
const installSummary: { client?: InstallSummary; chat?: InstallSummary } = {};
if (copilotApi) {
installSummary.client = await this.installContextProvider(copilotApi, contextProvider);
}
if (copilotChatApi) {
installSummary.chat = await this.installContextProvider(copilotChatApi, contextProvider);
}
if (installSummary.client?.hasAPI || installSummary.chat?.hasAPI) {
properties["cppCodeSnippetsProviderRegistered"] = "true";
} else {
if (installSummary.client?.hasGetContextProviderAPI === false &&
installSummary.chat?.hasGetContextProviderAPI === false) {
throw new CopilotContextProviderException("getContextProviderAPI() is not available.");
} else {
throw new CopilotContextProviderException("getContextProviderAPI(v1) returned null.");
}
}
} catch (e) {
console.debug("Failed to register the Copilot Context Provider.");
properties["error"] = "Failed to register the Copilot Context Provider";
Expand All @@ -466,4 +485,18 @@ ${copilotCompletionContext?.areSnippetsMissing ? "(missing code snippets)" : ""}
telemetry.logCopilotEvent(registerCopilotContextProvider, { ...properties });
}
}

private async installContextProvider(copilotAPI: CopilotContextProviderAPI, contextProvider: ContextProvider<SupportedContextItem>): Promise<{ hasGetContextProviderAPI: boolean; hasAPI: boolean }> {
const hasGetContextProviderAPI = typeof copilotAPI.getContextProviderAPI === 'function';
if (hasGetContextProviderAPI) {
const contextAPI = await copilotAPI.getContextProviderAPI("v1");
if (contextAPI) {
this.contextProviderDisposables = this.contextProviderDisposables ?? [];
this.contextProviderDisposables.push(contextAPI.registerContextProvider(contextProvider));
}
return { hasGetContextProviderAPI, hasAPI: contextAPI !== undefined };
} else {
return { hasGetContextProviderAPI: false, hasAPI: false };
}
}
}
39 changes: 35 additions & 4 deletions Extension/src/LanguageServer/copilotProviders.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ export interface CopilotTrait {
promptTextOverride?: string;
}

export interface CopilotApi {
export interface CopilotContextProviderAPI {
getContextProviderAPI(version: string): Promise<ContextProviderApiV1 | undefined>;
}

export interface CopilotApi extends CopilotContextProviderAPI {
registerRelatedFilesProvider(
providerId: { extensionId: string; languageId: string },
callback: (
Expand All @@ -33,11 +37,10 @@ export interface CopilotApi {
cancellationToken: vscode.CancellationToken
) => Promise<{ entries: vscode.Uri[]; traits?: CopilotTrait[] } | undefined>
): Disposable;
getContextProviderAPI(version: string): Promise<ContextProviderApiV1 | undefined>;
}

export async function registerRelatedFilesProvider(): Promise<void> {
const api = await getCopilotApi();
const api = await getCopilotClientApi();
if (util.extensionContext && api) {
try {
for (const languageId of ['c', 'cpp', 'cuda-cpp']) {
Expand Down Expand Up @@ -129,7 +132,7 @@ async function getIncludes(uri: vscode.Uri, maxDepth: number): Promise<GetInclud
return includes;
}

export async function getCopilotApi(): Promise<CopilotApi | undefined> {
export async function getCopilotClientApi(): Promise<CopilotApi | undefined> {
const copilotExtension = vscode.extensions.getExtension<CopilotApi>('github.copilot');
if (!copilotExtension) {
return undefined;
Expand All @@ -145,3 +148,31 @@ export async function getCopilotApi(): Promise<CopilotApi | undefined> {
return copilotExtension.exports;
}
}

export async function getCopilotChatApi(): Promise<CopilotContextProviderAPI | undefined> {
type CopilotChatApi = { getAPI?(version: number): CopilotContextProviderAPI | undefined };
const copilotExtension = vscode.extensions.getExtension<CopilotChatApi>('github.copilot-chat');
if (!copilotExtension) {
return undefined;
}

let exports: CopilotChatApi | undefined;
if (!copilotExtension.isActive) {
try {
exports = await copilotExtension.activate();
} catch {
return undefined;
}
} else {
exports = copilotExtension.exports;
}
if (!exports || typeof exports.getAPI !== 'function') {
return undefined;
}
const result = exports.getAPI(1);
return result;
}

interface Disposable {
dispose(): void;
}