diff --git a/e2e/sample-apps/modular.js b/e2e/sample-apps/modular.js index 4c5238d44dc..e3170bb3c57 100644 --- a/e2e/sample-apps/modular.js +++ b/e2e/sample-apps/modular.js @@ -58,12 +58,7 @@ import { onValue, off } from 'firebase/database'; -import { - getGenerativeModel, - getVertexAI, - InferenceMode, - VertexAI -} from 'firebase/vertexai'; +import { getGenerativeModel, getVertexAI } from 'firebase/vertexai'; import { getDataConnect, DataConnect } from 'firebase/data-connect'; /** @@ -318,8 +313,13 @@ function callPerformance(app) { async function callVertexAI(app) { console.log('[VERTEXAI] start'); const vertexAI = getVertexAI(app); - const model = getGenerativeModel(vertexAI, { model: 'gemini-1.5-flash' }); - const result = await model.countTokens('abcdefg'); + const model = getGenerativeModel(vertexAI, { + mode: 'prefer_in_cloud' + }); + const result = await model.generateContentStream("What is Roko's Basalisk?"); + for await (const chunk of result.stream) { + console.log(chunk.text()); + } console.log(`[VERTEXAI] counted tokens: ${result.totalTokens}`); } diff --git a/packages/vertexai/src/methods/chat-session.ts b/packages/vertexai/src/methods/chat-session.ts index 4188872cff7..112ddf5857e 100644 --- a/packages/vertexai/src/methods/chat-session.ts +++ b/packages/vertexai/src/methods/chat-session.ts @@ -149,6 +149,7 @@ export class ChatSession { this._apiSettings, this.model, generateContentRequest, + this.chromeAdapter, this.requestOptions ); diff --git a/packages/vertexai/src/methods/chrome-adapter.test.ts b/packages/vertexai/src/methods/chrome-adapter.test.ts index cce97b25f5a..a18812374c0 100644 --- a/packages/vertexai/src/methods/chrome-adapter.test.ts +++ b/packages/vertexai/src/methods/chrome-adapter.test.ts @@ -30,6 +30,25 @@ import { GenerateContentRequest } from '../types'; use(sinonChai); use(chaiAsPromised); +/** + * Converts the ReadableStream from response.body to an array of strings. + */ +async function toStringArray( + stream: ReadableStream +): Promise { + const decoder = new TextDecoder(); + const actual = []; + const reader = stream.getReader(); + while (true) { + const { done, value } = await reader.read(); + if (done) { + break; + } + actual.push(decoder.decode(value)); + } + return actual; +} + describe('ChromeAdapter', () => { describe('isAvailable', () => { it('returns false if mode is only cloud', async () => { @@ -280,7 +299,7 @@ describe('ChromeAdapter', () => { const request = { contents: [{ role: 'user', parts: [{ text: 'anything' }] }] } as GenerateContentRequest; - const response = await adapter.generateContentOnDevice(request); + const response = await adapter.generateContent(request); // Asserts initialization params are proxied. expect(createStub).to.have.been.calledOnceWith(onDeviceParams); // Asserts Vertex input type is mapped to Chrome type. @@ -325,6 +344,7 @@ describe('ChromeAdapter', () => { const createStub = stub(languageModelProvider, 'create').resolves( languageModel ); + // overrides impl with stub method const measureInputUsageStub = stub( languageModel, @@ -336,6 +356,7 @@ describe('ChromeAdapter', () => { 'prefer_on_device', onDeviceParams ); + const countTokenRequest = { contents: [{ role: 'user', parts: [{ text: inputText }] }] } as GenerateContentRequest; @@ -359,4 +380,52 @@ describe('ChromeAdapter', () => { }); }); }); + describe('generateContentStreamOnDevice', () => { + it('generates content stream', async () => { + const languageModelProvider = { + create: () => Promise.resolve({}) + } as LanguageModel; + const languageModel = { + promptStreaming: _i => new ReadableStream() + } as LanguageModel; + const createStub = stub(languageModelProvider, 'create').resolves( + languageModel + ); + const part = 'hi'; + const promptStub = stub(languageModel, 'promptStreaming').returns( + new ReadableStream({ + start(controller) { + controller.enqueue([part]); + controller.close(); + } + }) + ); + const onDeviceParams = {} as LanguageModelCreateOptions; + const adapter = new ChromeAdapter( + languageModelProvider, + 'prefer_on_device', + onDeviceParams + ); + const request = { + contents: [{ role: 'user', parts: [{ text: 'anything' }] }] + } as GenerateContentRequest; + const response = await adapter.generateContentStream(request); + expect(createStub).to.have.been.calledOnceWith(onDeviceParams); + expect(promptStub).to.have.been.calledOnceWith([ + { + role: request.contents[0].role, + content: [ + { + type: 'text', + content: request.contents[0].parts[0].text + } + ] + } + ]); + const actual = await toStringArray(response.body!); + expect(actual).to.deep.equal([ + `data: {"candidates":[{"content":{"role":"model","parts":[{"text":["${part}"]}]}}]}\n\n` + ]); + }); + }); }); diff --git a/packages/vertexai/src/methods/chrome-adapter.ts b/packages/vertexai/src/methods/chrome-adapter.ts index 225d2bd581d..dcdb38b7fd8 100644 --- a/packages/vertexai/src/methods/chrome-adapter.ts +++ b/packages/vertexai/src/methods/chrome-adapter.ts @@ -95,7 +95,25 @@ export class ChromeAdapter { * @param request a standard Vertex {@link GenerateContentRequest} * @returns {@link Response}, so we can reuse common response formatting. */ - async generateContentOnDevice( + async generateContent(request: GenerateContentRequest): Promise { + const session = await this.createSession( + // TODO: normalize on-device params during construction. + this.onDeviceParams || {} + ); + const messages = ChromeAdapter.toLanguageModelMessages(request.contents); + const text = await session.prompt(messages); + return ChromeAdapter.toResponse(text); + } + + /** + * Generates content stream on device. + * + *

This is comparable to {@link GenerativeModel.generateContentStream} for generating content in + * Cloud.

+ * @param request a standard Vertex {@link GenerateContentRequest} + * @returns {@link Response}, so we can reuse common response formatting. + */ + async generateContentStream( request: GenerateContentRequest ): Promise { const session = await this.createSession( @@ -103,19 +121,8 @@ export class ChromeAdapter { this.onDeviceParams || {} ); const messages = ChromeAdapter.toLanguageModelMessages(request.contents); - const text = await session.prompt(messages); - return { - json: () => - Promise.resolve({ - candidates: [ - { - content: { - parts: [{ text }] - } - } - ] - }) - } as Response; + const stream = await session.promptStreaming(messages); + return ChromeAdapter.toStreamResponse(stream); } async countTokens(request: CountTokensRequest): Promise { @@ -240,4 +247,47 @@ export class ChromeAdapter { this.oldSession = newSession; return newSession; } + + /** + * Formats string returned by Chrome as a {@link Response} returned by Vertex. + */ + private static toResponse(text: string): Response { + return { + json: async () => ({ + candidates: [ + { + content: { + parts: [{ text }] + } + } + ] + }) + } as Response; + } + + /** + * Formats string stream returned by Chrome as SSE returned by Vertex. + */ + private static toStreamResponse(stream: ReadableStream): Response { + const encoder = new TextEncoder(); + return { + body: stream.pipeThrough( + new TransformStream({ + transform(chunk, controller) { + const json = JSON.stringify({ + candidates: [ + { + content: { + role: 'model', + parts: [{ text: chunk }] + } + } + ] + }); + controller.enqueue(encoder.encode(`data: ${json}\n\n`)); + } + }) + ) + } as Response; + } } diff --git a/packages/vertexai/src/methods/generate-content.test.ts b/packages/vertexai/src/methods/generate-content.test.ts index f714ec4d535..19c32941090 100644 --- a/packages/vertexai/src/methods/generate-content.test.ts +++ b/packages/vertexai/src/methods/generate-content.test.ts @@ -308,6 +308,7 @@ describe('generateContent()', () => { ); expect(mockFetch).to.be.called; }); + // TODO: define a similar test for generateContentStream it('on-device', async () => { const chromeAdapter = new ChromeAdapter(); const isAvailableStub = stub(chromeAdapter, 'isAvailable').resolves(true); @@ -315,10 +316,9 @@ describe('generateContent()', () => { 'vertexAI', 'unary-success-basic-reply-short.json' ); - const generateContentStub = stub( - chromeAdapter, - 'generateContentOnDevice' - ).resolves(mockResponse as Response); + const generateContentStub = stub(chromeAdapter, 'generateContent').resolves( + mockResponse as Response + ); const result = await generateContent( fakeApiSettings, 'model', diff --git a/packages/vertexai/src/methods/generate-content.ts b/packages/vertexai/src/methods/generate-content.ts index ba7a162aa9c..1dc5918516e 100644 --- a/packages/vertexai/src/methods/generate-content.ts +++ b/packages/vertexai/src/methods/generate-content.ts @@ -28,13 +28,13 @@ import { processStream } from '../requests/stream-reader'; import { ApiSettings } from '../types/internal'; import { ChromeAdapter } from './chrome-adapter'; -export async function generateContentStream( +async function generateContentStreamOnCloud( apiSettings: ApiSettings, model: string, params: GenerateContentRequest, requestOptions?: RequestOptions -): Promise { - const response = await makeRequest( +): Promise { + return makeRequest( model, Task.STREAM_GENERATE_CONTENT, apiSettings, @@ -42,6 +42,26 @@ export async function generateContentStream( JSON.stringify(params), requestOptions ); +} + +export async function generateContentStream( + apiSettings: ApiSettings, + model: string, + params: GenerateContentRequest, + chromeAdapter: ChromeAdapter, + requestOptions?: RequestOptions +): Promise { + let response; + if (await chromeAdapter.isAvailable(params)) { + response = await chromeAdapter.generateContentStream(params); + } else { + response = await generateContentStreamOnCloud( + apiSettings, + model, + params, + requestOptions + ); + } return processStream(response); } @@ -70,7 +90,7 @@ export async function generateContent( ): Promise { let response; if (await chromeAdapter.isAvailable(params)) { - response = await chromeAdapter.generateContentOnDevice(params); + response = await chromeAdapter.generateContent(params); } else { response = await generateContentOnCloud( apiSettings, diff --git a/packages/vertexai/src/models/generative-model.ts b/packages/vertexai/src/models/generative-model.ts index 0f7e408282c..81856819312 100644 --- a/packages/vertexai/src/models/generative-model.ts +++ b/packages/vertexai/src/models/generative-model.ts @@ -123,6 +123,7 @@ export class GenerativeModel extends VertexAIModel { systemInstruction: this.systemInstruction, ...formattedParams }, + this.chromeAdapter, this.requestOptions ); }