Skip to content

VinF Hybrid Inference: narrow Chrome input type #8953

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions e2e/sample-apps/modular.js
Original file line number Diff line number Diff line change
Expand Up @@ -314,13 +314,14 @@ async function callVertexAI(app) {
console.log('[VERTEXAI] start');
const vertexAI = getVertexAI(app);
const model = getGenerativeModel(vertexAI, {
mode: 'prefer_in_cloud'
mode: 'prefer_on_device'
});
const result = await model.generateContentStream("What is Roko's Basalisk?");
for await (const chunk of result.stream) {
console.log(chunk.text());
}
console.log(`[VERTEXAI] counted tokens: ${result.totalTokens}`);
const singleResult = await model.generateContent([
{ text: 'describe the following:' },
{ text: 'the mojave desert' }
]);
console.log(`Generated text: ${singleResult.response.text()}`);
console.log(`[VERTEXAI] end`);
}

/**
Expand All @@ -346,18 +347,18 @@ async function main() {
const app = initializeApp(config);
setLogLevel('warn');

callAppCheck(app);
await authLogin(app);
await callStorage(app);
await callFirestore(app);
await callDatabase(app);
await callMessaging(app);
callAnalytics(app);
callPerformance(app);
await callFunctions(app);
// callAppCheck(app);
// await authLogin(app);
// await callStorage(app);
// await callFirestore(app);
// await callDatabase(app);
// await callMessaging(app);
// callAnalytics(app);
// callPerformance(app);
// await callFunctions(app);
await callVertexAI(app);
callDataConnect(app);
await authLogout(app);
// callDataConnect(app);
// await authLogout(app);
console.log('DONE');
}

Expand Down
51 changes: 11 additions & 40 deletions packages/vertexai/src/methods/chrome-adapter.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import { ChromeAdapter } from './chrome-adapter';
import {
Availability,
LanguageModel,
LanguageModelCreateOptions
LanguageModelCreateOptions,
LanguageModelMessageContent
} from '../types/language-model';
import { stub } from 'sinon';
import { GenerateContentRequest } from '../types';
Expand Down Expand Up @@ -105,22 +106,6 @@ describe('ChromeAdapter', () => {
})
).to.be.false;
});
it('returns false if request content has multiple parts', async () => {
const adapter = new ChromeAdapter(
{} as LanguageModel,
'prefer_on_device'
);
expect(
await adapter.isAvailable({
contents: [
{
role: 'user',
parts: [{ text: 'a' }, { text: 'b' }]
}
]
})
).to.be.false;
});
it('returns false if request content has non-text part', async () => {
const adapter = new ChromeAdapter(
{} as LanguageModel,
Expand Down Expand Up @@ -281,7 +266,8 @@ describe('ChromeAdapter', () => {
create: () => Promise.resolve({})
} as LanguageModel;
const languageModel = {
prompt: i => Promise.resolve(i)
// eslint-disable-next-line @typescript-eslint/no-unused-vars
prompt: (p: LanguageModelMessageContent[]) => Promise.resolve('')
} as LanguageModel;
const createStub = stub(languageModelProvider, 'create').resolves(
languageModel
Expand All @@ -305,13 +291,8 @@ describe('ChromeAdapter', () => {
// Asserts Vertex input type is mapped to Chrome type.
expect(promptStub).to.have.been.calledOnceWith([
{
role: request.contents[0].role,
content: [
{
type: 'text',
content: request.contents[0].parts[0].text
}
]
type: 'text',
content: request.contents[0].parts[0].text
}
]);
// Asserts expected output.
Expand Down Expand Up @@ -366,21 +347,16 @@ describe('ChromeAdapter', () => {
// Asserts Vertex input type is mapped to Chrome type.
expect(measureInputUsageStub).to.have.been.calledOnceWith([
{
role: 'user',
content: [
{
type: 'text',
content: inputText
}
]
type: 'text',
content: inputText
}
]);
expect(await response.json()).to.deep.equal({
totalTokens: expectedCount
});
});
});
describe('generateContentStreamOnDevice', () => {
describe('generateContentStream', () => {
it('generates content stream', async () => {
const languageModelProvider = {
create: () => Promise.resolve({})
Expand Down Expand Up @@ -413,13 +389,8 @@ describe('ChromeAdapter', () => {
expect(createStub).to.have.been.calledOnceWith(onDeviceParams);
expect(promptStub).to.have.been.calledOnceWith([
{
role: request.contents[0].role,
content: [
{
type: 'text',
content: request.contents[0].parts[0].text
}
]
type: 'text',
content: request.contents[0].parts[0].text
}
]);
const actual = await toStringArray(response.body!);
Expand Down
53 changes: 19 additions & 34 deletions packages/vertexai/src/methods/chrome-adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,15 @@
*/

import {
Content,
CountTokensRequest,
GenerateContentRequest,
InferenceMode,
Part,
Role
Part
} from '../types';
import {
Availability,
LanguageModel,
LanguageModelCreateOptions,
LanguageModelMessage,
LanguageModelMessageRole,
LanguageModelMessageContent
} from '../types/language-model';

Expand Down Expand Up @@ -100,8 +96,12 @@ export class ChromeAdapter {
// TODO: normalize on-device params during construction.
this.onDeviceParams || {}
);
const messages = ChromeAdapter.toLanguageModelMessages(request.contents);
const text = await session.prompt(messages);
// TODO: support multiple content objects when Chrome supports
// sequence<LanguageModelMessage>
const contents = request.contents[0].parts.map(
ChromeAdapter.toLanguageModelMessageContent
);
const text = await session.prompt(contents);
return ChromeAdapter.toResponse(text);
}

Expand All @@ -120,8 +120,12 @@ export class ChromeAdapter {
// TODO: normalize on-device params during construction.
this.onDeviceParams || {}
);
const messages = ChromeAdapter.toLanguageModelMessages(request.contents);
const stream = await session.promptStreaming(messages);
// TODO: support multiple content objects when Chrome supports
// sequence<LanguageModelMessage>
const contents = request.contents[0].parts.map(
ChromeAdapter.toLanguageModelMessageContent
);
const stream = await session.promptStreaming(contents);
return ChromeAdapter.toStreamResponse(stream);
}

Expand All @@ -131,8 +135,12 @@ export class ChromeAdapter {
// TODO: normalize on-device params during construction.
this.onDeviceParams || {}
);
const messages = ChromeAdapter.toLanguageModelMessages(request.contents);
const tokenCount = await session.measureInputUsage(messages);
// TODO: support multiple content objects when Chrome supports
// sequence<LanguageModelMessage>
const contents = request.contents[0].parts.map(
ChromeAdapter.toLanguageModelMessageContent
);
const tokenCount = await session.measureInputUsage(contents);
return {
json: async () => ({
totalTokens: tokenCount
Expand All @@ -155,10 +163,6 @@ export class ChromeAdapter {
return false;
}

if (content.parts.length > 1) {
return false;
}

if (!content.parts[0].text) {
return false;
}
Expand Down Expand Up @@ -188,25 +192,6 @@ export class ChromeAdapter {
});
}

/**
* Converts a Vertex role string to a Chrome role string.
*/
private static toOnDeviceRole(role: Role): LanguageModelMessageRole {
return role === 'model' ? 'assistant' : 'user';
}

/**
* Converts a Vertex Content object to a Chrome LanguageModelMessage object.
*/
private static toLanguageModelMessages(
contents: Content[]
): LanguageModelMessage[] {
return contents.map(c => ({
role: ChromeAdapter.toOnDeviceRole(c.role),
content: c.parts.map(ChromeAdapter.toLanguageModelMessageContent)
}));
}

/**
* Converts a Vertex Part object to a Chrome LanguageModelMessageContent object.
*/
Expand Down
10 changes: 4 additions & 6 deletions packages/vertexai/src/types/language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,12 @@ interface LanguageModelExpectedInput {
type: LanguageModelMessageType;
languages?: string[];
}
export type LanguageModelPrompt =
| LanguageModelMessage[]
| LanguageModelMessageShorthand[]
| string;
// TODO: revert to type from Prompt API explainer once it's supported.
export type LanguageModelPrompt = LanguageModelMessageContent[];
type LanguageModelInitialPrompts =
| LanguageModelMessage[]
| LanguageModelMessageShorthand[];
export interface LanguageModelMessage {
interface LanguageModelMessage {
role: LanguageModelMessageRole;
content: LanguageModelMessageContent[];
}
Expand All @@ -75,7 +73,7 @@ export interface LanguageModelMessageContent {
type: LanguageModelMessageType;
content: LanguageModelMessageContentValue;
}
export type LanguageModelMessageRole = 'system' | 'user' | 'assistant';
type LanguageModelMessageRole = 'system' | 'user' | 'assistant';
type LanguageModelMessageType = 'text' | 'image' | 'audio';
type LanguageModelMessageContentValue =
| ImageBitmapSource
Expand Down
Loading