Skip to content

VinF Hybrid Inference #4: ChromeAdapter in stream methods (rebased) #8949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Apr 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a3d869b
Define HybridParams
erikeldridge Mar 24, 2025
14eee16
Copy over most types from @types package
erikeldridge Mar 24, 2025
b242749
Trim unused AI types
erikeldridge Mar 25, 2025
5e97457
Assert HybridParams sets the model name
erikeldridge Mar 25, 2025
1fe8a08
Use dom-chromium-ai package directly
erikeldridge Mar 26, 2025
869fee7
Define ChromeAdapter class
erikeldridge Mar 25, 2025
ff31b42
Implement ChromeAdapter class
erikeldridge Apr 2, 2025
1e487d5
Integrate with e2e test app
erikeldridge Apr 2, 2025
8307fe5
Parameterize default in-cloud model name
erikeldridge Mar 28, 2025
291c53b
Use type for inference mode and update docs
erikeldridge Apr 2, 2025
fe2bebc
Remove stray ai.ts
erikeldridge Apr 3, 2025
d4286d6
Run yarn format
erikeldridge Apr 3, 2025
b898cd0
Test request-based availability checks
erikeldridge Apr 4, 2025
2fb2795
Remove request.systemInstruction validation
erikeldridge Apr 4, 2025
ef893c9
Integrate chrome adapter into stream methods
erikeldridge Apr 7, 2025
4c37859
Refactor to emulate Vertex response
erikeldridge Apr 8, 2025
eb25fec
Group response formatting methods together
erikeldridge Apr 8, 2025
b8d849c
Run docgen
erikeldridge Apr 18, 2025
1b9c98d
Re-remove isChrome
erikeldridge Apr 18, 2025
5092bd8
Re-remove dom-chromium-ai
erikeldridge Apr 18, 2025
025b786
Unit test stream method
erikeldridge Apr 18, 2025
34c658e
Remove redundant ondevice suffix
erikeldridge Apr 18, 2025
7af0f8d
Merge remote-tracking branch 'public/vaihi-exp' into erikeldridge-ver…
erikeldridge Apr 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions e2e/sample-apps/modular.js
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,7 @@ import {
onValue,
off
} from 'firebase/database';
import {
getGenerativeModel,
getVertexAI,
InferenceMode,
VertexAI
} from 'firebase/vertexai';
import { getGenerativeModel, getVertexAI } from 'firebase/vertexai';
import { getDataConnect, DataConnect } from 'firebase/data-connect';

/**
Expand Down Expand Up @@ -318,8 +313,13 @@ function callPerformance(app) {
async function callVertexAI(app) {
console.log('[VERTEXAI] start');
const vertexAI = getVertexAI(app);
const model = getGenerativeModel(vertexAI, { model: 'gemini-1.5-flash' });
const result = await model.countTokens('abcdefg');
const model = getGenerativeModel(vertexAI, {
mode: 'prefer_in_cloud'
});
const result = await model.generateContentStream("What is Roko's Basalisk?");
for await (const chunk of result.stream) {
console.log(chunk.text());
}
console.log(`[VERTEXAI] counted tokens: ${result.totalTokens}`);
}

Expand Down
1 change: 1 addition & 0 deletions packages/vertexai/src/methods/chat-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ export class ChatSession {
this._apiSettings,
this.model,
generateContentRequest,
this.chromeAdapter,
this.requestOptions
);

Expand Down
71 changes: 70 additions & 1 deletion packages/vertexai/src/methods/chrome-adapter.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,25 @@ import { GenerateContentRequest } from '../types';
use(sinonChai);
use(chaiAsPromised);

/**
* Converts the ReadableStream from response.body to an array of strings.
*/
async function toStringArray(
stream: ReadableStream<Uint8Array>
): Promise<string[]> {
const decoder = new TextDecoder();
const actual = [];
const reader = stream.getReader();
while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}
actual.push(decoder.decode(value));
}
return actual;
}

describe('ChromeAdapter', () => {
describe('isAvailable', () => {
it('returns false if mode is only cloud', async () => {
Expand Down Expand Up @@ -280,7 +299,7 @@ describe('ChromeAdapter', () => {
const request = {
contents: [{ role: 'user', parts: [{ text: 'anything' }] }]
} as GenerateContentRequest;
const response = await adapter.generateContentOnDevice(request);
const response = await adapter.generateContent(request);
// Asserts initialization params are proxied.
expect(createStub).to.have.been.calledOnceWith(onDeviceParams);
// Asserts Vertex input type is mapped to Chrome type.
Expand Down Expand Up @@ -325,6 +344,7 @@ describe('ChromeAdapter', () => {
const createStub = stub(languageModelProvider, 'create').resolves(
languageModel
);

// overrides impl with stub method
const measureInputUsageStub = stub(
languageModel,
Expand All @@ -336,6 +356,7 @@ describe('ChromeAdapter', () => {
'prefer_on_device',
onDeviceParams
);

const countTokenRequest = {
contents: [{ role: 'user', parts: [{ text: inputText }] }]
} as GenerateContentRequest;
Expand All @@ -359,4 +380,52 @@ describe('ChromeAdapter', () => {
});
});
});
describe('generateContentStreamOnDevice', () => {
it('generates content stream', async () => {
const languageModelProvider = {
create: () => Promise.resolve({})
} as LanguageModel;
const languageModel = {
promptStreaming: _i => new ReadableStream()
} as LanguageModel;
const createStub = stub(languageModelProvider, 'create').resolves(
languageModel
);
const part = 'hi';
const promptStub = stub(languageModel, 'promptStreaming').returns(
new ReadableStream({
start(controller) {
controller.enqueue([part]);
controller.close();
}
})
);
const onDeviceParams = {} as LanguageModelCreateOptions;
const adapter = new ChromeAdapter(
languageModelProvider,
'prefer_on_device',
onDeviceParams
);
const request = {
contents: [{ role: 'user', parts: [{ text: 'anything' }] }]
} as GenerateContentRequest;
const response = await adapter.generateContentStream(request);
expect(createStub).to.have.been.calledOnceWith(onDeviceParams);
expect(promptStub).to.have.been.calledOnceWith([
{
role: request.contents[0].role,
content: [
{
type: 'text',
content: request.contents[0].parts[0].text
}
]
}
]);
const actual = await toStringArray(response.body!);
expect(actual).to.deep.equal([
`data: {"candidates":[{"content":{"role":"model","parts":[{"text":["${part}"]}]}}]}\n\n`
]);
});
});
});
78 changes: 64 additions & 14 deletions packages/vertexai/src/methods/chrome-adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,27 +95,34 @@ export class ChromeAdapter {
* @param request a standard Vertex {@link GenerateContentRequest}
* @returns {@link Response}, so we can reuse common response formatting.
*/
async generateContentOnDevice(
async generateContent(request: GenerateContentRequest): Promise<Response> {
const session = await this.createSession(
// TODO: normalize on-device params during construction.
this.onDeviceParams || {}
);
const messages = ChromeAdapter.toLanguageModelMessages(request.contents);
const text = await session.prompt(messages);
return ChromeAdapter.toResponse(text);
}

/**
* Generates content stream on device.
*
* <p>This is comparable to {@link GenerativeModel.generateContentStream} for generating content in
* Cloud.</p>
* @param request a standard Vertex {@link GenerateContentRequest}
* @returns {@link Response}, so we can reuse common response formatting.
*/
async generateContentStream(
request: GenerateContentRequest
): Promise<Response> {
const session = await this.createSession(
// TODO: normalize on-device params during construction.
this.onDeviceParams || {}
);
const messages = ChromeAdapter.toLanguageModelMessages(request.contents);
const text = await session.prompt(messages);
return {
json: () =>
Promise.resolve({
candidates: [
{
content: {
parts: [{ text }]
}
}
]
})
} as Response;
const stream = await session.promptStreaming(messages);
return ChromeAdapter.toStreamResponse(stream);
}

async countTokens(request: CountTokensRequest): Promise<Response> {
Expand Down Expand Up @@ -240,4 +247,47 @@ export class ChromeAdapter {
this.oldSession = newSession;
return newSession;
}

/**
* Formats string returned by Chrome as a {@link Response} returned by Vertex.
*/
private static toResponse(text: string): Response {
return {
json: async () => ({
candidates: [
{
content: {
parts: [{ text }]
}
}
]
})
} as Response;
}

/**
* Formats string stream returned by Chrome as SSE returned by Vertex.
*/
private static toStreamResponse(stream: ReadableStream<string>): Response {
const encoder = new TextEncoder();
return {
body: stream.pipeThrough(
new TransformStream({
transform(chunk, controller) {
const json = JSON.stringify({
candidates: [
{
content: {
role: 'model',
parts: [{ text: chunk }]
}
}
]
});
controller.enqueue(encoder.encode(`data: ${json}\n\n`));
}
})
)
} as Response;
}
}
8 changes: 4 additions & 4 deletions packages/vertexai/src/methods/generate-content.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -308,17 +308,17 @@ describe('generateContent()', () => {
);
expect(mockFetch).to.be.called;
});
// TODO: define a similar test for generateContentStream
it('on-device', async () => {
const chromeAdapter = new ChromeAdapter();
const isAvailableStub = stub(chromeAdapter, 'isAvailable').resolves(true);
const mockResponse = getMockResponse(
'vertexAI',
'unary-success-basic-reply-short.json'
);
const generateContentStub = stub(
chromeAdapter,
'generateContentOnDevice'
).resolves(mockResponse as Response);
const generateContentStub = stub(chromeAdapter, 'generateContent').resolves(
mockResponse as Response
);
const result = await generateContent(
fakeApiSettings,
'model',
Expand Down
28 changes: 24 additions & 4 deletions packages/vertexai/src/methods/generate-content.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,40 @@ import { processStream } from '../requests/stream-reader';
import { ApiSettings } from '../types/internal';
import { ChromeAdapter } from './chrome-adapter';

export async function generateContentStream(
async function generateContentStreamOnCloud(
apiSettings: ApiSettings,
model: string,
params: GenerateContentRequest,
requestOptions?: RequestOptions
): Promise<GenerateContentStreamResult> {
const response = await makeRequest(
): Promise<Response> {
return makeRequest(
model,
Task.STREAM_GENERATE_CONTENT,
apiSettings,
/* stream */ true,
JSON.stringify(params),
requestOptions
);
}

export async function generateContentStream(
apiSettings: ApiSettings,
model: string,
params: GenerateContentRequest,
chromeAdapter: ChromeAdapter,
requestOptions?: RequestOptions
): Promise<GenerateContentStreamResult> {
let response;
if (await chromeAdapter.isAvailable(params)) {
response = await chromeAdapter.generateContentStream(params);
} else {
response = await generateContentStreamOnCloud(
apiSettings,
model,
params,
requestOptions
);
}
return processStream(response);
}

Expand Down Expand Up @@ -70,7 +90,7 @@ export async function generateContent(
): Promise<GenerateContentResult> {
let response;
if (await chromeAdapter.isAvailable(params)) {
response = await chromeAdapter.generateContentOnDevice(params);
response = await chromeAdapter.generateContent(params);
} else {
response = await generateContentOnCloud(
apiSettings,
Expand Down
1 change: 1 addition & 0 deletions packages/vertexai/src/models/generative-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ export class GenerativeModel extends VertexAIModel {
systemInstruction: this.systemInstruction,
...formattedParams
},
this.chromeAdapter,
this.requestOptions
);
}
Expand Down
Loading