diff --git a/packages/vertexai/src/methods/chrome-adapter.ts b/packages/vertexai/src/methods/chrome-adapter.ts index 63e1db83e89..32dec23035d 100644 --- a/packages/vertexai/src/methods/chrome-adapter.ts +++ b/packages/vertexai/src/methods/chrome-adapter.ts @@ -41,8 +41,10 @@ export class ChromeAdapter { constructor( private languageModelProvider?: LanguageModel, private mode?: InferenceMode, - private onDeviceParams?: LanguageModelCreateOptions - ) {} + private onDeviceParams: LanguageModelCreateOptions = {} + ) { + this.addImageTypeAsExpectedInput(); + } /** * Checks if a given request can be made on-device. @@ -64,12 +66,8 @@ export class ChromeAdapter { return false; } - const availability = await this.languageModelProvider?.availability(); - - // Triggers async model download so it'll be available next time. - if (availability === Availability.downloadable) { - this.download(); - } + // Triggers out-of-band download so model will eventually become available. + const availability = await this.downloadIfAvailable(); if (this.mode === 'only_on_device') { return true; @@ -91,10 +89,7 @@ export class ChromeAdapter { * @returns {@link Response}, so we can reuse common response formatting. */ async generateContent(request: GenerateContentRequest): Promise { - const session = await this.createSession( - // TODO: normalize on-device params during construction. - this.onDeviceParams || {} - ); + const session = await this.createSession(); // TODO: support multiple content objects when Chrome supports // sequence const contents = await Promise.all( @@ -115,10 +110,7 @@ export class ChromeAdapter { async generateContentStream( request: GenerateContentRequest ): Promise { - const session = await this.createSession( - // TODO: normalize on-device params during construction. - this.onDeviceParams || {} - ); + const session = await this.createSession(); // TODO: support multiple content objects when Chrome supports // sequence const contents = await Promise.all( @@ -155,7 +147,22 @@ export class ChromeAdapter { } /** - * Triggers the download of an on-device model. + * Encapsulates logic to get availability and download a model if one is downloadable. + */ + private async downloadIfAvailable(): Promise { + const availability = await this.languageModelProvider?.availability( + this.onDeviceParams + ); + + if (availability === Availability.downloadable) { + this.download(); + } + + return availability; + } + + /** + * Triggers out-of-band download of an on-device model. * *

Chrome only downloads models as needed. Chrome knows a model is needed when code calls * LanguageModel.create.

@@ -168,10 +175,8 @@ export class ChromeAdapter { return; } this.isDownloading = true; - const options = this.onDeviceParams || {}; - ChromeAdapter.addImageTypeAsExpectedInput(options); this.downloadPromise = this.languageModelProvider - ?.create(options) + ?.create(this.onDeviceParams) .then(() => { this.isDownloading = false; }); @@ -214,19 +219,16 @@ export class ChromeAdapter { *

Chrome will remove a model from memory if it's no longer in use, so this method ensures a * new session is created before an old session is destroyed.

*/ - private async createSession( - // TODO: define a default value, since these are optional. - options: LanguageModelCreateOptions - ): Promise { + private async createSession(): Promise { if (!this.languageModelProvider) { throw new AIError( AIErrorCode.REQUEST_ERROR, 'Chrome AI requested for unsupported browser version.' ); } - // TODO: could we use this.onDeviceParams instead of passing in options? - ChromeAdapter.addImageTypeAsExpectedInput(options); - const newSession = await this.languageModelProvider!.create(options); + const newSession = await this.languageModelProvider.create( + this.onDeviceParams + ); if (this.oldSession) { this.oldSession.destroy(); } @@ -235,11 +237,9 @@ export class ChromeAdapter { return newSession; } - private static addImageTypeAsExpectedInput( - options: LanguageModelCreateOptions - ): void { - options.expectedInputs = options.expectedInputs || []; - options.expectedInputs.push({ type: 'image' }); + private addImageTypeAsExpectedInput(): void { + // Defaults to support image inputs for convenience. + this.onDeviceParams.expectedInputs ??= [{ type: 'image' }]; } /**