diff --git a/e2e/sample-apps/modular.js b/e2e/sample-apps/modular.js index 292a11535a3..9df125de841 100644 --- a/e2e/sample-apps/modular.js +++ b/e2e/sample-apps/modular.js @@ -58,12 +58,7 @@ import { onValue, off } from 'firebase/database'; -import { - getGenerativeModel, - getVertexAI, - InferenceMode, - VertexAI -} from 'firebase/vertexai'; +import { getGenerativeModel, getVertexAI, VertexAI } from 'firebase/vertexai'; import { getDataConnect, DataConnect } from 'firebase/data-connect'; /** @@ -318,8 +313,13 @@ function callPerformance(app) { async function callVertexAI(app) { console.log('[VERTEXAI] start'); const vertexAI = getVertexAI(app); - const model = getGenerativeModel(vertexAI, { model: 'gemini-1.5-flash' }); - const result = await model.countTokens('abcdefg'); + const model = getGenerativeModel(vertexAI, { + mode: 'prefer_in_cloud' + }); + const result = await model.generateContentStream("What is Roko's Basalisk?"); + for await (const chunk of result.stream) { + console.log(chunk.text()); + } console.log(`[VERTEXAI] counted tokens: ${result.totalTokens}`); } @@ -337,17 +337,6 @@ function callDataConnect(app) { console.log('[DATACONNECT] initialized'); } -async function callVertex(app) { - console.log('[VERTEX] start'); - const vertex = getVertexAI(app); - const model = getGenerativeModel(vertex, { - mode: InferenceMode.PREFER_ON_DEVICE - }); - const result = await model.generateContent("What is Roko's Basalisk?"); - console.log(result.response.text()); - console.log('[VERTEX] initialized'); -} - /** * Run smoke tests for all products. * Comment out any products you want to ignore. @@ -357,19 +346,18 @@ async function main() { const app = initializeApp(config); setLogLevel('warn'); - callAppCheck(app); - await authLogin(app); - await callStorage(app); - await callFirestore(app); - await callDatabase(app); - await callMessaging(app); - callAnalytics(app); - callPerformance(app); - await callFunctions(app); + // callAppCheck(app); + // await authLogin(app); + // await callStorage(app); + // await callFirestore(app); + // await callDatabase(app); + // await callMessaging(app); + // callAnalytics(app); + // callPerformance(app); + // await callFunctions(app); await callVertexAI(app); - callDataConnect(app); - await authLogout(app); - await callVertex(app); + // callDataConnect(app); + // await authLogout(app); console.log('DONE'); } diff --git a/packages/vertexai/src/methods/chat-session.ts b/packages/vertexai/src/methods/chat-session.ts index 55c7700c156..981ff3b7bf6 100644 --- a/packages/vertexai/src/methods/chat-session.ts +++ b/packages/vertexai/src/methods/chat-session.ts @@ -149,6 +149,7 @@ export class ChatSession { this._apiSettings, this.model, generateContentRequest, + this.chromeAdapter, this.requestOptions ); diff --git a/packages/vertexai/src/methods/chrome-adapter.ts b/packages/vertexai/src/methods/chrome-adapter.ts index ebf6b67dbb0..bbc14f47f70 100644 --- a/packages/vertexai/src/methods/chrome-adapter.ts +++ b/packages/vertexai/src/methods/chrome-adapter.ts @@ -82,18 +82,20 @@ export class ChromeAdapter { const result = await session.prompt(prompt.content); return ChromeAdapter.toResponse(result); } - private static toResponse(text: string): Response { - return { - json: async () => ({ - candidates: [ - { - content: { - parts: [{ text }] - } - } - ] - }) - } as Response; + async generateContentStreamOnDevice( + request: GenerateContentRequest + ): Promise { + const createOptions = this.onDeviceParams || {}; + createOptions.initialPrompts ??= []; + const extractedInitialPrompts = ChromeAdapter.toInitialPrompts( + request.contents + ); + // Assumes validation asserted there is at least one initial prompt. + const prompt = extractedInitialPrompts.pop()!; + createOptions.initialPrompts.push(...extractedInitialPrompts); + const session = await this.session(createOptions); + const stream = await session.promptStreaming(prompt.content); + return ChromeAdapter.toStreamResponse(stream); } private static isOnDeviceRequest(request: GenerateContentRequest): boolean { // Returns false if the prompt is empty. @@ -157,4 +159,41 @@ export class ChromeAdapter { this.oldSession = newSession; return newSession; } + private static toResponse(text: string): Response { + return { + json: async () => ({ + candidates: [ + { + content: { + parts: [{ text }] + } + } + ] + }) + } as Response; + } + private static toStreamResponse( + stream: ReadableStream + ): Response { + const encoder = new TextEncoder(); + return { + body: stream.pipeThrough( + new TransformStream({ + transform(chunk, controller) { + const json = JSON.stringify({ + candidates: [ + { + content: { + role: 'model', + parts: [{ text: chunk }] + } + } + ] + }); + controller.enqueue(encoder.encode(`data: ${json}\n\n`)); + } + }) + ) + } as Response; + } } diff --git a/packages/vertexai/src/methods/generate-content.ts b/packages/vertexai/src/methods/generate-content.ts index ba7a162aa9c..a6343bcc3a8 100644 --- a/packages/vertexai/src/methods/generate-content.ts +++ b/packages/vertexai/src/methods/generate-content.ts @@ -28,13 +28,13 @@ import { processStream } from '../requests/stream-reader'; import { ApiSettings } from '../types/internal'; import { ChromeAdapter } from './chrome-adapter'; -export async function generateContentStream( +async function generateContentStreamOnCloud( apiSettings: ApiSettings, model: string, params: GenerateContentRequest, requestOptions?: RequestOptions -): Promise { - const response = await makeRequest( +): Promise { + return makeRequest( model, Task.STREAM_GENERATE_CONTENT, apiSettings, @@ -42,6 +42,26 @@ export async function generateContentStream( JSON.stringify(params), requestOptions ); +} + +export async function generateContentStream( + apiSettings: ApiSettings, + model: string, + params: GenerateContentRequest, + chromeAdapter: ChromeAdapter, + requestOptions?: RequestOptions +): Promise { + let response; + if (await chromeAdapter.isAvailable(params)) { + response = await chromeAdapter.generateContentStreamOnDevice(params); + } else { + response = await generateContentStreamOnCloud( + apiSettings, + model, + params, + requestOptions + ); + } return processStream(response); } diff --git a/packages/vertexai/src/models/generative-model.ts b/packages/vertexai/src/models/generative-model.ts index f8f699e43eb..bf72ae0be9f 100644 --- a/packages/vertexai/src/models/generative-model.ts +++ b/packages/vertexai/src/models/generative-model.ts @@ -123,6 +123,7 @@ export class GenerativeModel extends VertexAIModel { systemInstruction: this.systemInstruction, ...formattedParams }, + this.chromeAdapter, this.requestOptions ); }