Skip to content

Commit ac117e5

Browse files
erikeldridgegsiddh
authored andcommitted
Test content generation
1 parent f62a98f commit ac117e5

File tree

3 files changed

+89
-49
lines changed

3 files changed

+89
-49
lines changed

packages/vertexai/src/methods/chrome-adapter.test.ts

+61-3
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,14 @@ import { expect, use } from 'chai';
1919
import sinonChai from 'sinon-chai';
2020
import chaiAsPromised from 'chai-as-promised';
2121
import { ChromeAdapter } from './chrome-adapter';
22-
import { Availability, LanguageModel } from '../types/language-model';
22+
import {
23+
Availability,
24+
LanguageModel,
25+
LanguageModelCreateOptions
26+
} from '../types/language-model';
2327
import { stub } from 'sinon';
2428
import * as util from '@firebase/util';
29+
import { GenerateContentRequest } from '../types';
2530

2631
use(sinonChai);
2732
use(chaiAsPromised);
@@ -191,16 +196,18 @@ describe('ChromeAdapter', () => {
191196
const createStub = stub(languageModelProvider, 'create').resolves(
192197
{} as LanguageModel
193198
);
199+
const onDeviceParams = {} as LanguageModelCreateOptions;
194200
const adapter = new ChromeAdapter(
195201
languageModelProvider,
196-
'prefer_on_device'
202+
'prefer_on_device',
203+
onDeviceParams
197204
);
198205
expect(
199206
await adapter.isAvailable({
200207
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
201208
})
202209
).to.be.false;
203-
expect(createStub).to.have.been.calledOnce;
210+
expect(createStub).to.have.been.calledOnceWith(onDeviceParams);
204211
});
205212
it('avoids redundant downloads', async () => {
206213
const languageModelProvider = {
@@ -260,4 +267,55 @@ describe('ChromeAdapter', () => {
260267
).to.be.false;
261268
});
262269
});
270+
describe('generateContentOnDevice', () => {
271+
it('generates content', async () => {
272+
const languageModelProvider = {
273+
create: () => Promise.resolve({})
274+
} as LanguageModel;
275+
const languageModel = {
276+
prompt: i => Promise.resolve(i)
277+
} as LanguageModel;
278+
const createStub = stub(languageModelProvider, 'create').resolves(
279+
languageModel
280+
);
281+
const promptOutput = 'hi';
282+
const promptStub = stub(languageModel, 'prompt').resolves(promptOutput);
283+
const onDeviceParams = {
284+
systemPrompt: 'be yourself'
285+
} as LanguageModelCreateOptions;
286+
const adapter = new ChromeAdapter(
287+
languageModelProvider,
288+
'prefer_on_device',
289+
onDeviceParams
290+
);
291+
const request = {
292+
contents: [{ role: 'user', parts: [{ text: 'anything' }] }]
293+
} as GenerateContentRequest;
294+
const response = await adapter.generateContentOnDevice(request);
295+
// Asserts initialization params are proxied.
296+
expect(createStub).to.have.been.calledOnceWith(onDeviceParams);
297+
// Asserts Vertex input type is mapped to Chrome type.
298+
expect(promptStub).to.have.been.calledOnceWith([
299+
{
300+
role: request.contents[0].role,
301+
content: [
302+
{
303+
type: 'text',
304+
content: request.contents[0].parts[0].text
305+
}
306+
]
307+
}
308+
]);
309+
// Asserts expected output.
310+
expect(await response.json()).to.deep.equal({
311+
candidates: [
312+
{
313+
content: {
314+
parts: [{ text: promptOutput }]
315+
}
316+
}
317+
]
318+
});
319+
});
320+
});
263321
});

packages/vertexai/src/methods/chrome-adapter.ts

+24-42
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ import {
2727
Availability,
2828
LanguageModel,
2929
LanguageModelCreateOptions,
30+
LanguageModelMessage,
3031
LanguageModelMessageRole,
31-
LanguageModelMessageShorthand
32+
LanguageModelMessageContent
3233
} from '../types/language-model';
3334
import { isChrome } from '@firebase/util';
3435

@@ -85,17 +86,13 @@ export class ChromeAdapter {
8586
async generateContentOnDevice(
8687
request: GenerateContentRequest
8788
): Promise<Response> {
88-
const initialPrompts = ChromeAdapter.toInitialPrompts(request.contents);
89-
// Assumes validation asserted there is at least one initial prompt.
90-
const prompt = initialPrompts.pop()!;
91-
const systemPrompt = ChromeAdapter.toSystemPrompt(
92-
request.systemInstruction
89+
const session = await this.session(
90+
// TODO: normalize on-device params during construction.
91+
this.onDeviceParams || {}
9392
);
94-
const session = await this.session({
95-
initialPrompts,
96-
systemPrompt
97-
});
98-
const text = await session.prompt(prompt.content);
93+
const messages = ChromeAdapter.toLanguageModelMessages(request.contents);
94+
const text = await session.prompt(messages);
95+
console.log(text);
9996
return {
10097
json: () =>
10198
Promise.resolve({
@@ -162,45 +159,30 @@ export class ChromeAdapter {
162159
this.isDownloading = false;
163160
});
164161
}
165-
private static toSystemPrompt(
166-
prompt: string | Content | Part | undefined
167-
): string | undefined {
168-
if (!prompt) {
169-
return undefined;
170-
}
171-
172-
if (typeof prompt === 'string') {
173-
return prompt;
174-
}
175-
176-
const systemContent = prompt as Content;
177-
if (
178-
systemContent.parts &&
179-
systemContent.parts[0] &&
180-
systemContent.parts[0].text
181-
) {
182-
return systemContent.parts[0].text;
183-
}
184-
185-
const systemPart = prompt as Part;
186-
if (systemPart.text) {
187-
return systemPart.text;
188-
}
189-
190-
return undefined;
191-
}
192162
private static toOnDeviceRole(role: Role): LanguageModelMessageRole {
193163
return role === 'model' ? 'assistant' : 'user';
194164
}
195-
private static toInitialPrompts(
165+
private static toLanguageModelMessages(
196166
contents: Content[]
197-
): LanguageModelMessageShorthand[] {
167+
): LanguageModelMessage[] {
198168
return contents.map(c => ({
199169
role: ChromeAdapter.toOnDeviceRole(c.role),
200-
// Assumes contents have been verified to contain only a single TextPart.
201-
content: c.parts[0].text!
170+
content: c.parts.map(ChromeAdapter.toLanguageModelMessageContent)
202171
}));
203172
}
173+
private static toLanguageModelMessageContent(
174+
part: Part
175+
): LanguageModelMessageContent {
176+
if (part.text) {
177+
return {
178+
type: 'text',
179+
content: part.text
180+
};
181+
}
182+
// Assumes contents have been verified to contain only a single TextPart.
183+
// TODO: support other input types
184+
throw new Error('Not yet implemented');
185+
}
204186
private async session(
205187
opts: LanguageModelCreateOptions
206188
): Promise<LanguageModel> {

packages/vertexai/src/types/language-model.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -56,22 +56,22 @@ interface LanguageModelExpectedInput {
5656
type: LanguageModelMessageType;
5757
languages?: string[];
5858
}
59-
type LanguageModelPrompt =
59+
export type LanguageModelPrompt =
6060
| LanguageModelMessage[]
6161
| LanguageModelMessageShorthand[]
6262
| string;
6363
type LanguageModelInitialPrompts =
6464
| LanguageModelMessage[]
6565
| LanguageModelMessageShorthand[];
66-
interface LanguageModelMessage {
66+
export interface LanguageModelMessage {
6767
role: LanguageModelMessageRole;
6868
content: LanguageModelMessageContent[];
6969
}
70-
export interface LanguageModelMessageShorthand {
70+
interface LanguageModelMessageShorthand {
7171
role: LanguageModelMessageRole;
7272
content: string;
7373
}
74-
interface LanguageModelMessageContent {
74+
export interface LanguageModelMessageContent {
7575
type: LanguageModelMessageType;
7676
content: LanguageModelMessageContentValue;
7777
}

0 commit comments

Comments
 (0)