diff --git a/src/client/telemetry/importTracker.ts b/src/client/telemetry/importTracker.ts index 4a06fbb04e54..edf188684857 100644 --- a/src/client/telemetry/importTracker.ts +++ b/src/client/telemetry/importTracker.ts @@ -67,19 +67,6 @@ export class ImportTracker implements IExtensionSingleActivationService { this.documentManager.textDocuments.forEach((d) => this.onOpenedOrSavedDocument(d)); } - private getDocumentLines(document: TextDocument): (string | undefined)[] { - const array = Array<string>(Math.min(document.lineCount, MAX_DOCUMENT_LINES)).fill(''); - return array - .map((_a: string, i: number) => { - const line = document.lineAt(i); - if (line && !line.isEmptyOrWhitespace) { - return line.text; - } - return undefined; - }) - .filter((f: string | undefined) => f); - } - private onOpenedOrSavedDocument(document: TextDocument) { // Make sure this is a Python file. if (path.extname(document.fileName) === '.py') { @@ -112,7 +99,7 @@ export class ImportTracker implements IExtensionSingleActivationService { @captureTelemetry(EventName.HASHED_PACKAGE_PERF) private checkDocument(document: TextDocument) { this.pendingChecks.delete(document.fileName); - const lines = this.getDocumentLines(document); + const lines = getDocumentLines(document); this.lookForImports(lines); } @@ -152,3 +139,16 @@ export class ImportTracker implements IExtensionSingleActivationService { } } } + +export function getDocumentLines(document: TextDocument): (string | undefined)[] { + const array = Array<string>(Math.min(document.lineCount, MAX_DOCUMENT_LINES)).fill(''); + return array + .map((_a: string, i: number) => { + const line = document.lineAt(i); + if (line && !line.isEmptyOrWhitespace) { + return line.text; + } + return undefined; + }) + .filter((f: string | undefined) => f); +} diff --git a/src/client/tensorBoard/serviceRegistry.ts b/src/client/tensorBoard/serviceRegistry.ts index baee4a03c7b4..ab4c1155501b 100644 --- a/src/client/tensorBoard/serviceRegistry.ts +++ b/src/client/tensorBoard/serviceRegistry.ts @@ -4,8 +4,10 @@ import { IExtensionSingleActivationService } from '../activation/types'; import { IServiceManager } from '../ioc/types'; import { TensorBoardFileWatcher } from './tensorBoardFileWatcher'; +import { TensorBoardImportTracker } from './tensorBoardImportTracker'; import { TensorBoardPrompt } from './tensorBoardPrompt'; import { TensorBoardSessionProvider } from './tensorBoardSessionProvider'; +import { ITensorBoardImportTracker } from './types'; export function registerTypes(serviceManager: IServiceManager) { serviceManager.addSingleton<IExtensionSingleActivationService>( @@ -17,4 +19,6 @@ export function registerTypes(serviceManager: IServiceManager) { TensorBoardFileWatcher ); serviceManager.addSingleton<TensorBoardPrompt>(TensorBoardPrompt, TensorBoardPrompt); + serviceManager.addSingleton<ITensorBoardImportTracker>(ITensorBoardImportTracker, TensorBoardImportTracker); + serviceManager.addBinding(ITensorBoardImportTracker, IExtensionSingleActivationService); } diff --git a/src/client/tensorBoard/tensorBoardImportTracker.ts b/src/client/tensorBoard/tensorBoardImportTracker.ts new file mode 100644 index 000000000000..24d56840d88d --- /dev/null +++ b/src/client/tensorBoard/tensorBoardImportTracker.ts @@ -0,0 +1,91 @@ +import { inject, injectable } from 'inversify'; +import { noop } from 'lodash'; +import * as path from 'path'; +import { Event, EventEmitter, TextEditor, window } from 'vscode'; +import { IExtensionSingleActivationService } from '../activation/types'; +import { IDocumentManager } from '../common/application/types'; +import { IDisposableRegistry } from '../common/types'; +import { getDocumentLines } from '../telemetry/importTracker'; +import { ITensorBoardImportTracker } from './types'; + +// While it is uncommon for users to `import tensorboard`, TensorBoard is frequently +// included as a submodule of other packages, e.g. torch.utils.tensorboard. +// This is a modified version of the regex from src/client/telemetry/importTracker.ts +// in order to match on imported submodules as well, since the original regex only +// matches the 'main' module. +const ImportRegEx = /^\s*from (?<fromImport>\w+(?:\.\w+)*) import (?<fromImportTarget>\w+(?:, \w+)*)(?: as \w+)?|import (?<importImport>\w+(?:, \w+)*)(?: as \w+)?$/; + +@injectable() +export class TensorBoardImportTracker implements ITensorBoardImportTracker, IExtensionSingleActivationService { + private pendingChecks = new Map<string, NodeJS.Timer | number>(); + private _onDidImportTensorBoard = new EventEmitter<void>(); + + constructor( + @inject(IDocumentManager) private documentManager: IDocumentManager, + @inject(IDisposableRegistry) private disposables: IDisposableRegistry + ) { + this.documentManager.onDidChangeActiveTextEditor( + (e) => this.onChangedActiveTextEditor(e), + this, + this.disposables + ); + } + + // Fires when the active text editor contains a tensorboard import. + public get onDidImportTensorBoard(): Event<void> { + return this._onDidImportTensorBoard.event; + } + + public dispose() { + this.pendingChecks.clear(); + } + + public async activate(): Promise<void> { + // Process active text editor with a timeout delay + this.onChangedActiveTextEditor(window.activeTextEditor); + } + + private onChangedActiveTextEditor(editor: TextEditor | undefined) { + if (!editor || !editor.document) { + return; + } + const document = editor.document; + if ( + (path.extname(document.fileName) === '.ipynb' && document.languageId === 'python') || + path.extname(document.fileName) === '.py' + ) { + const lines = getDocumentLines(document); + this.lookForImports(lines); + } + } + + private lookForImports(lines: (string | undefined)[]) { + try { + for (const s of lines) { + const matches = s ? ImportRegEx.exec(s) : null; + if (matches === null || matches.groups === undefined) { + continue; + } + let componentsToCheck: string[] = []; + if (matches.groups.fromImport && matches.groups.fromImportTarget) { + // from x.y.z import u, v, w + componentsToCheck = matches.groups.fromImport + .split('.') + .concat(matches.groups.fromImportTarget.split(',')); + } else if (matches.groups.importImport) { + // import package1, package2, ... + componentsToCheck = matches.groups.importImport.split(','); + } + for (const component of componentsToCheck) { + if (component && component.trim() === 'tensorboard') { + this._onDidImportTensorBoard.fire(); + return; + } + } + } + } catch { + // Don't care about failures. + noop(); + } + } +} diff --git a/src/client/tensorBoard/tensorBoardPrompt.ts b/src/client/tensorBoard/tensorBoardPrompt.ts index 060c2bc2598b..8ddc18b4200a 100644 --- a/src/client/tensorBoard/tensorBoardPrompt.ts +++ b/src/client/tensorBoard/tensorBoardPrompt.ts @@ -4,8 +4,9 @@ import { inject, injectable } from 'inversify'; import { IApplicationShell, ICommandManager } from '../common/application/types'; import { Commands } from '../common/constants'; -import { IPersistentState, IPersistentStateFactory } from '../common/types'; +import { IDisposableRegistry, IPersistentState, IPersistentStateFactory } from '../common/types'; import { Common, TensorBoard } from '../common/utils/localize'; +import { ITensorBoardImportTracker } from './types'; enum TensorBoardPromptStateKeys { ShowNativeTensorBoardPrompt = 'showNativeTensorBoardPrompt' @@ -15,11 +16,14 @@ enum TensorBoardPromptStateKeys { export class TensorBoardPrompt { private state: IPersistentState<boolean>; private enabled: Promise<boolean> | undefined; + private enabledInCurrentSession: boolean = true; private waitingForUserSelection: boolean = false; constructor( @inject(IApplicationShell) private applicationShell: IApplicationShell, @inject(ICommandManager) private commandManager: ICommandManager, + @inject(ITensorBoardImportTracker) private importTracker: ITensorBoardImportTracker, + @inject(IDisposableRegistry) private disposableRegistry: IDisposableRegistry, @inject(IPersistentStateFactory) private persistentStateFactory: IPersistentStateFactory ) { this.state = this.persistentStateFactory.createWorkspacePersistentState<boolean>( @@ -27,10 +31,11 @@ export class TensorBoardPrompt { true ); this.enabled = this.isPromptEnabled(); + this.importTracker.onDidImportTensorBoard(this.showNativeTensorBoardPrompt, this, this.disposableRegistry); } public async showNativeTensorBoardPrompt() { - if ((await this.enabled) && !this.waitingForUserSelection) { + if ((await this.enabled) && this.enabledInCurrentSession && !this.waitingForUserSelection) { const yes = Common.bannerLabelYes(); const no = Common.bannerLabelNo(); const doNotAskAgain = Common.doNotShowAgain(); @@ -41,10 +46,10 @@ export class TensorBoardPrompt { ...options ); this.waitingForUserSelection = false; + this.enabledInCurrentSession = false; switch (selection) { case yes: await this.commandManager.executeCommand(Commands.LaunchTensorBoard); - await this.disablePrompt(); break; case doNotAskAgain: await this.disablePrompt(); diff --git a/src/client/tensorBoard/types.ts b/src/client/tensorBoard/types.ts new file mode 100644 index 000000000000..6e2c274d63f4 --- /dev/null +++ b/src/client/tensorBoard/types.ts @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { Event } from 'vscode'; + +export const ITensorBoardImportTracker = Symbol('ITensorBoardImportTracker'); +export interface ITensorBoardImportTracker { + onDidImportTensorBoard: Event<void>; +} diff --git a/src/test/startPage/mockDocument.ts b/src/test/startPage/mockDocument.ts index eacf8a549940..0daa23dacf7e 100644 --- a/src/test/startPage/mockDocument.ts +++ b/src/test/startPage/mockDocument.ts @@ -51,13 +51,20 @@ export class MockDocument implements TextDocument { private _contents: string = ''; private _isUntitled = false; private _isDirty = false; + private _language = 'python'; private _onSave: (doc: TextDocument) => Promise<boolean>; - constructor(contents: string, fileName: string, onSave: (doc: TextDocument) => Promise<boolean>) { + constructor( + contents: string, + fileName: string, + onSave: (doc: TextDocument) => Promise<boolean>, + language?: string + ) { this._uri = Uri.file(fileName); this._contents = contents; this._lines = this.createLines(); this._onSave = onSave; + this._language = language ?? this._language; } public setContent(contents: string) { @@ -85,7 +92,7 @@ export class MockDocument implements TextDocument { return this._isUntitled; } public get languageId(): string { - return 'python'; + return this._language; } public get version(): number { return this._version; diff --git a/src/test/startPage/mockDocumentManager.ts b/src/test/startPage/mockDocumentManager.ts index 4601fcd91fd2..f19c83f406cc 100644 --- a/src/test/startPage/mockDocumentManager.ts +++ b/src/test/startPage/mockDocumentManager.ts @@ -94,12 +94,12 @@ export class MockDocumentManager implements IDocumentManager { throw new Error('Method not implemented.'); } - public addDocument(code: string, file: string) { + public addDocument(code: string, file: string, language?: string) { let existing = this.textDocuments.find((d) => d.uri.fsPath === file) as MockDocument; if (existing) { existing.setContent(code); } else { - existing = new MockDocument(code, file, this.saveDocument); + existing = new MockDocument(code, file, this.saveDocument, language); this.textDocuments.push(existing); } return existing; diff --git a/src/test/tensorBoard/tensorBoardImportTracker.unit.test.ts b/src/test/tensorBoard/tensorBoardImportTracker.unit.test.ts new file mode 100644 index 000000000000..b7726a4f9854 --- /dev/null +++ b/src/test/tensorBoard/tensorBoardImportTracker.unit.test.ts @@ -0,0 +1,86 @@ +import * as sinon from 'sinon'; +import { TensorBoardImportTracker } from '../../client/tensorBoard/tensorBoardImportTracker'; +import { MockDocumentManager } from '../startPage/mockDocumentManager'; + +suite('TensorBoard import tracker', () => { + let documentManager: MockDocumentManager; + let tensorBoardImportTracker: TensorBoardImportTracker; + let onDidImportTensorBoardListener: sinon.SinonExpectation; + + setup(() => { + documentManager = new MockDocumentManager(); + tensorBoardImportTracker = new TensorBoardImportTracker(documentManager, []); + onDidImportTensorBoardListener = sinon.expectation.create('onDidImportTensorBoardListener'); + tensorBoardImportTracker.onDidImportTensorBoard(onDidImportTensorBoardListener); + }); + + test('Simple tensorboard import in Python file', async () => { + const document = documentManager.addDocument('import tensorboard', 'foo.py'); + await documentManager.showTextDocument(document); + await tensorBoardImportTracker.activate(); + onDidImportTensorBoardListener.once().verify(); + }); + test('Simple tensorboard import in Python ipynb', async () => { + const document = documentManager.addDocument('import tensorboard', 'foo.ipynb'); + await documentManager.showTextDocument(document); + await tensorBoardImportTracker.activate(); + onDidImportTensorBoardListener.once().verify(); + }); + test('`from x.y.tensorboard import z` import', async () => { + const document = documentManager.addDocument('from torch.utils.tensorboard import SummaryWriter', 'foo.py'); + await documentManager.showTextDocument(document); + await tensorBoardImportTracker.activate(); + onDidImportTensorBoardListener.once().verify(); + }); + test('`from x.y import tensorboard` import', async () => { + const document = documentManager.addDocument('from torch.utils import tensorboard', 'foo.py'); + await documentManager.showTextDocument(document); + await tensorBoardImportTracker.activate(); + onDidImportTensorBoardListener.once().verify(); + }); + test('`import x, y` import', async () => { + const document = documentManager.addDocument('import tensorboard, tensorflow', 'foo.py'); + await documentManager.showTextDocument(document); + await tensorBoardImportTracker.activate(); + onDidImportTensorBoardListener.once().verify(); + }); + test('`import pkg as _` import', async () => { + const document = documentManager.addDocument('import tensorboard as tb', 'foo.py'); + await documentManager.showTextDocument(document); + await tensorBoardImportTracker.activate(); + onDidImportTensorBoardListener.once().verify(); + }); + test('Fire on changed text editor', async () => { + await tensorBoardImportTracker.activate(); + const document = documentManager.addDocument('import tensorboard as tb', 'foo.py'); + await documentManager.showTextDocument(document); + onDidImportTensorBoardListener.once().verify(); + }); + test('Do not fire event if no tensorboard import', async () => { + const document = documentManager.addDocument('import tensorflow as tf\nfrom torch.utils import foo', 'foo.py'); + await documentManager.showTextDocument(document); + await tensorBoardImportTracker.activate(); + onDidImportTensorBoardListener.never().verify(); + }); + test('Do not fire event if language is not Python', async () => { + const document = documentManager.addDocument( + 'import tensorflow as tf\nfrom torch.utils import foo', + 'foo.cpp', + 'cpp' + ); + await documentManager.showTextDocument(document); + await tensorBoardImportTracker.activate(); + onDidImportTensorBoardListener.never().verify(); + }); + test('Ignore docstrings', async () => { + const document = documentManager.addDocument( + `""" +import tensorboard +"""`, + 'foo.py' + ); + await documentManager.showTextDocument(document); + await tensorBoardImportTracker.activate(); + onDidImportTensorBoardListener.never().verify(); + }); +});