diff --git a/Extension/src/LanguageServer/copilotCompletionContextProvider.ts b/Extension/src/LanguageServer/copilotCompletionContextProvider.ts index cdf75308f..a456ab73c 100644 --- a/Extension/src/LanguageServer/copilotCompletionContextProvider.ts +++ b/Extension/src/LanguageServer/copilotCompletionContextProvider.ts @@ -2,7 +2,7 @@ * Copyright (c) Microsoft Corporation. All Rights Reserved. * See 'LICENSE' in the project root for license information. * ------------------------------------------------------------------------------------------ */ -import { ContextResolver, ResolveRequest, SupportedContextItem } from '@github/copilot-language-server'; +import { ContextResolver, ResolveRequest, SupportedContextItem, type ContextProvider } from '@github/copilot-language-server'; import { randomUUID } from 'crypto'; import * as vscode from 'vscode'; import { DocumentSelector } from 'vscode-languageserver-protocol'; @@ -11,7 +11,7 @@ import { getOutputChannelLogger, Logger } from '../logger'; import * as telemetry from '../telemetry'; import { CopilotCompletionContextResult } from './client'; import { CopilotCompletionContextTelemetry } from './copilotCompletionContextTelemetry'; -import { getCopilotApi } from './copilotProviders'; +import { getCopilotChatApi, getCopilotClientApi, type CopilotContextProviderAPI } from './copilotProviders'; import { clients } from './extension'; import { CppSettings } from './settings'; @@ -83,7 +83,7 @@ export class CopilotCompletionContextProvider implements ContextResolver = {}; const registerCopilotContextProvider = 'registerCopilotContextProvider'; try { - const copilotApi = await getCopilotApi(); - if (!copilotApi) { throw new CopilotContextProviderException("getCopilotApi() returned null, Copilot is missing or inactive."); } - const hasGetContextProviderAPI = "getContextProviderAPI" in copilotApi; - if (!hasGetContextProviderAPI) { throw new CopilotContextProviderException("getContextProviderAPI() is not available."); } - const contextAPI = await copilotApi.getContextProviderAPI("v1"); - if (!contextAPI) { throw new CopilotContextProviderException("getContextProviderAPI(v1) returned null."); } - this.contextProviderDisposable = contextAPI.registerContextProvider({ + const copilotApi = await getCopilotClientApi(); + const copilotChatApi = await getCopilotChatApi(); + if (!copilotApi && !copilotChatApi) { throw new CopilotContextProviderException("getCopilotApi() returned null, Copilot is missing or inactive."); } + const contextProvider = { id: CopilotCompletionContextProvider.providerId, selector: CopilotCompletionContextProvider.defaultCppDocumentSelector, resolver: this - }); - properties["cppCodeSnippetsProviderRegistered"] = "true"; + }; + type InstallSummary = { hasGetContextProviderAPI: boolean; hasAPI: boolean }; + const installSummary: { client?: InstallSummary; chat?: InstallSummary } = {}; + if (copilotApi) { + installSummary.client = await this.installContextProvider(copilotApi, contextProvider); + } + if (copilotChatApi) { + installSummary.chat = await this.installContextProvider(copilotChatApi, contextProvider); + } + if (installSummary.client?.hasAPI || installSummary.chat?.hasAPI) { + properties["cppCodeSnippetsProviderRegistered"] = "true"; + } else { + if (installSummary.client?.hasGetContextProviderAPI === false && + installSummary.chat?.hasGetContextProviderAPI === false) { + throw new CopilotContextProviderException("getContextProviderAPI() is not available."); + } else { + throw new CopilotContextProviderException("getContextProviderAPI(v1) returned null."); + } + } } catch (e) { console.debug("Failed to register the Copilot Context Provider."); properties["error"] = "Failed to register the Copilot Context Provider"; @@ -466,4 +485,18 @@ ${copilotCompletionContext?.areSnippetsMissing ? "(missing code snippets)" : ""} telemetry.logCopilotEvent(registerCopilotContextProvider, { ...properties }); } } + + private async installContextProvider(copilotAPI: CopilotContextProviderAPI, contextProvider: ContextProvider): Promise<{ hasGetContextProviderAPI: boolean; hasAPI: boolean }> { + const hasGetContextProviderAPI = typeof copilotAPI.getContextProviderAPI === 'function'; + if (hasGetContextProviderAPI) { + const contextAPI = await copilotAPI.getContextProviderAPI("v1"); + if (contextAPI) { + this.contextProviderDisposables = this.contextProviderDisposables ?? []; + this.contextProviderDisposables.push(contextAPI.registerContextProvider(contextProvider)); + } + return { hasGetContextProviderAPI, hasAPI: contextAPI !== undefined }; + } else { + return { hasGetContextProviderAPI: false, hasAPI: false }; + } + } } diff --git a/Extension/src/LanguageServer/copilotProviders.ts b/Extension/src/LanguageServer/copilotProviders.ts index e0551edcb..31cf21f3e 100644 --- a/Extension/src/LanguageServer/copilotProviders.ts +++ b/Extension/src/LanguageServer/copilotProviders.ts @@ -24,7 +24,11 @@ export interface CopilotTrait { promptTextOverride?: string; } -export interface CopilotApi { +export interface CopilotContextProviderAPI { + getContextProviderAPI(version: string): Promise; +} + +export interface CopilotApi extends CopilotContextProviderAPI { registerRelatedFilesProvider( providerId: { extensionId: string; languageId: string }, callback: ( @@ -33,11 +37,10 @@ export interface CopilotApi { cancellationToken: vscode.CancellationToken ) => Promise<{ entries: vscode.Uri[]; traits?: CopilotTrait[] } | undefined> ): Disposable; - getContextProviderAPI(version: string): Promise; } export async function registerRelatedFilesProvider(): Promise { - const api = await getCopilotApi(); + const api = await getCopilotClientApi(); if (util.extensionContext && api) { try { for (const languageId of ['c', 'cpp', 'cuda-cpp']) { @@ -129,7 +132,7 @@ async function getIncludes(uri: vscode.Uri, maxDepth: number): Promise { +export async function getCopilotClientApi(): Promise { const copilotExtension = vscode.extensions.getExtension('github.copilot'); if (!copilotExtension) { return undefined; @@ -145,3 +148,31 @@ export async function getCopilotApi(): Promise { return copilotExtension.exports; } } + +export async function getCopilotChatApi(): Promise { + type CopilotChatApi = { getAPI?(version: number): CopilotContextProviderAPI | undefined }; + const copilotExtension = vscode.extensions.getExtension('github.copilot-chat'); + if (!copilotExtension) { + return undefined; + } + + let exports: CopilotChatApi | undefined; + if (!copilotExtension.isActive) { + try { + exports = await copilotExtension.activate(); + } catch { + return undefined; + } + } else { + exports = copilotExtension.exports; + } + if (!exports || typeof exports.getAPI !== 'function') { + return undefined; + } + const result = exports.getAPI(1); + return result; +} + +interface Disposable { + dispose(): void; +}