Skip to content

Commit d6cad7c

Browse files
committed
fixed issue with chat screens sometimes not refreshing while the LLM was responding
problem came from state.transcript changes not busting the hashCode cache
1 parent b1005dc commit d6cad7c

File tree

9 files changed

+215
-436
lines changed

9 files changed

+215
-436
lines changed

packages/mediapipe-task-genai/example/lib/bloc.dart

+71-69
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ class TranscriptBloc extends Bloc<TranscriptEvent, TranscriptState> {
3030
downloadModel: (e) => _downloadModel(e, emit),
3131
deleteModel: (e) => _deleteModel(e, emit),
3232
setPercentDownloaded: (e) => _setPercentDownloaded(e, emit),
33-
setTranscript: (e) => _setTranscript(e, emit),
3433
updateTemperature: (e) => _updateTemperature(e, emit),
3534
updateTopK: (e) => _updateTopK(e, emit),
3635
updateMaxTokens: (e) => _updateMaxTokens(e, emit),
@@ -39,9 +38,6 @@ class TranscriptBloc extends Bloc<TranscriptEvent, TranscriptState> {
3938
);
4039
},
4140
);
42-
for (final model in LlmModel.values) {
43-
transcript[model] = <ChatMessage>[];
44-
}
4541
final cacheDirFuture = path_provider.getApplicationCacheDirectory();
4642
modelProvider.ready.then((_) async {
4743
cacheDir = (await cacheDirFuture).absolute.path;
@@ -54,13 +50,6 @@ class TranscriptBloc extends Bloc<TranscriptEvent, TranscriptState> {
5450
/// Utility which knows how to download and store the model file.
5551
ModelLocationProvider modelProvider;
5652

57-
/// Container for one message log per [LlmModel]. Mutable data type.
58-
/// Initialized to contain empty lists for each [LlmModel]. At any given time,
59-
/// one of these lists should be mirrored in [state.transcript]. Which one
60-
/// depends on the [model] value attached to the most recent model-specific
61-
/// event.
62-
final Map<LlmModel, List<ChatMessage>> transcript = {};
63-
6453
/// Constructor for the inference engine.
6554
final LlmInferenceEngine Function(LlmInferenceOptions) engineBuilder;
6655

@@ -88,10 +77,6 @@ class TranscriptBloc extends Bloc<TranscriptEvent, TranscriptState> {
8877
}
8978
}
9079

91-
void _setTranscript(SetTranscript event, Emit emit) {
92-
emit(state.copyWith(transcript: transcript[event.model]!));
93-
}
94-
9580
ModelInfo _getInfo(LlmModel model) {
9681
final existingPath = modelProvider.pathFor(model);
9782
return existingPath != null
@@ -125,9 +110,6 @@ class TranscriptBloc extends Bloc<TranscriptEvent, TranscriptState> {
125110
/// However, if you have already made other modifications to a state object
126111
/// (via copying), consider passing that copied value to the [overrideState]
127112
/// parameter.
128-
///
129-
/// See also:
130-
/// * [_emitTranscript] for the transcript version of this method.
131113
void _emitNewModelInfo(
132114
LlmModel model,
133115
ModelInfo info,
@@ -220,6 +202,7 @@ class TranscriptBloc extends Bloc<TranscriptEvent, TranscriptState> {
220202
);
221203
await Future.delayed(const Duration(milliseconds: 100));
222204
emit(state.copyWith(error: null));
205+
return;
223206
}
224207

225208
if (downloadStream != null) {
@@ -239,18 +222,12 @@ class TranscriptBloc extends Bloc<TranscriptEvent, TranscriptState> {
239222
add(CheckForModel(modelToDelete));
240223
}
241224

242-
void _completeResponse(CompleteResponse e, emit) {
243-
transcript[e.model]!.last = transcript[e.model]!.last.complete();
244-
final newState = state.copyWith(
245-
isLlmTyping: false,
246-
transcript: transcript[e.model]!,
247-
);
248-
_emitTranscript(emit, e.model, newState);
249-
}
225+
void _completeResponse(CompleteResponse e, emit) =>
226+
emit(state.copyWith(isLlmTyping: false).completeMessage(e.model));
250227

251228
Future<void> _addMessage(AddMessage event, Emit emit) async {
252229
// Value equality will detect that this state is "new"
253-
addMessageToTranscript(event, emit);
230+
emit(state.addMessage(event.message, event.model));
254231
if (state.engine == null) {
255232
await _queueMessageForLlm(event, emit);
256233
} else {
@@ -269,16 +246,20 @@ class TranscriptBloc extends Bloc<TranscriptEvent, TranscriptState> {
269246
AddMessage event,
270247
Emit emit,
271248
) async {
272-
final formattedChatHistory = _formatChatHistoryForLlm(state.transcript);
249+
final formattedChatHistory =
250+
_formatChatHistoryForLlm(state.transcript[event.model]!);
273251
final responseStream = state.engine!.generateResponse(formattedChatHistory);
274252

275253
// Add a blank response for the LLM into which we can write its answer.
276254
// Create a synthetic event just to pass to this helper method, but don't
277255
// route it through the `add` method.
278-
addMessageToTranscript(AddMessage(ChatMessage.llm(''), event.model), emit);
279-
emit(state.copyWith(isLlmTyping: true));
256+
emit(
257+
state
258+
.copyWith(isLlmTyping: true)
259+
.addMessage(ChatMessage.llm(''), event.model),
260+
);
280261

281-
final messageIndex = transcript[event.model]!.length - 1;
262+
final messageIndex = state.transcript[event.model]!.length - 1;
282263
bool first = true;
283264
await for (final String chunk in responseStream) {
284265
add(
@@ -297,46 +278,22 @@ class TranscriptBloc extends Bloc<TranscriptEvent, TranscriptState> {
297278
chunk: '',
298279
model: event.model,
299280
index: messageIndex,
300-
first: false,
281+
first: first,
301282
last: true,
302283
),
303284
);
304285
add(CompleteResponse(event.model));
305286
}
306287

307-
void _extendMessage(ExtendMessage event, Emit emit) {
308-
final message = transcript[event.model]![event.index];
309-
transcript[event.model]![event.index] = message.copyWith(
310-
body: '${message.body}${event.chunk}'.sanitize(event.first, event.last),
311-
);
312-
_emitTranscript(emit, event.model);
313-
}
314-
315-
void addMessageToTranscript(AddMessage event, Emit emit) {
316-
transcript[event.model]!.add(event.message);
317-
_emitTranscript(emit, event.model);
318-
}
319-
320-
/// Re-emits the current state with an updated copy of the transcript from the
321-
/// Bloc's own storage. However, if you have already made other modifications
322-
/// to a state object (via copying), consider passing that copied value to
323-
/// the [overrideState] parameter.
324-
///
325-
/// See also:
326-
/// * [_emitNewModelInfo] for the [ModelInfo] version of this method.
327-
void _emitTranscript(
328-
Emit emit,
329-
LlmModel model, [
330-
TranscriptState? overrideState,
331-
]) {
332-
// It would be nice to not copy the whole list every time.
333-
final newState = (overrideState ?? state).copyWith(
334-
transcript: List<ChatMessage>.from(
335-
transcript[model]!,
336-
),
337-
);
338-
emit(newState);
339-
}
288+
void _extendMessage(ExtendMessage event, Emit emit) => emit(
289+
state.extendMessage(
290+
event.chunk,
291+
model: event.model,
292+
index: event.index,
293+
first: event.first,
294+
last: event.last,
295+
),
296+
);
340297

341298
static const _begin = '<begin_transmission>';
342299
static const _end = '<end_transmission>';
@@ -368,13 +325,13 @@ class TranscriptBloc extends Bloc<TranscriptEvent, TranscriptState> {
368325
}
369326
}
370327

371-
@Freezed()
328+
@Freezed(makeCollectionsUnmodifiable: false)
372329
class TranscriptState with _$TranscriptState {
373330
TranscriptState._();
374331
factory TranscriptState({
375332
/// Log of messages for the [selectedModel]. Other models may have other
376333
/// message logs found on the [TranscriptBloc].
377-
@Default(<ChatMessage>[]) List<ChatMessage> transcript,
334+
required Map<LlmModel, List<ChatMessage>> transcript,
378335

379336
/// True only after the [ModelLocationProvider] has sorted out the initial
380337
/// state.
@@ -407,17 +364,63 @@ class TranscriptState with _$TranscriptState {
407364

408365
factory TranscriptState.initial([int? seed]) {
409366
final modelInfoMap = <LlmModel, ModelInfo>{};
367+
final transcript = <LlmModel, List<ChatMessage>>{};
410368
for (final model in LlmModel.values) {
411369
modelInfoMap[model] = const ModelInfo();
370+
transcript[model] = <ChatMessage>[];
412371
}
413372
return TranscriptState(
414-
transcript: <ChatMessage>[],
373+
transcript: transcript,
415374
randomSeed: seed ?? Random().nextInt(1 << 32),
416375
modelInfoMap: modelInfoMap,
417376
);
418377
}
419378

420379
int get sequenceBatchSize => 20;
380+
381+
Map<LlmModel, List<ChatMessage>> _copyTranscript() {
382+
final newTranscript = <LlmModel, List<ChatMessage>>{};
383+
for (final key in transcript.keys) {
384+
newTranscript[key] = List<ChatMessage>.from(transcript[key]!);
385+
}
386+
return newTranscript;
387+
}
388+
389+
TranscriptState addMessage(ChatMessage message, LlmModel model) {
390+
final newTranscript = _copyTranscript();
391+
newTranscript[model]!.add(message);
392+
return copyWith(transcript: newTranscript);
393+
}
394+
395+
TranscriptState extendMessage(
396+
String chunk, {
397+
required int index,
398+
required LlmModel model,
399+
required bool first,
400+
required bool last,
401+
}) {
402+
final newTranscript = _copyTranscript();
403+
assert(() {
404+
if (newTranscript[model]!.length < index + 1) {
405+
throw Exception('Tried to add to index $index, but length is '
406+
'only ${newTranscript[model]!.length} for $model');
407+
}
408+
return true;
409+
}());
410+
final oldMessage = newTranscript[model]![index];
411+
final newMessage = oldMessage.copyWith(
412+
body: '${oldMessage.body}$chunk'.sanitize(first, last),
413+
);
414+
newTranscript[model]![index] = newMessage;
415+
return copyWith(transcript: newTranscript);
416+
}
417+
418+
TranscriptState completeMessage(LlmModel model) {
419+
final newTranscript = _copyTranscript();
420+
newTranscript[model]!.last =
421+
newTranscript[model]!.last.copyWith(isComplete: true);
422+
return copyWith(transcript: newTranscript);
423+
}
421424
}
422425

423426
@Freezed()
@@ -435,7 +438,6 @@ class TranscriptEvent with _$TranscriptEvent {
435438
UpdateTemperature;
436439
const factory TranscriptEvent.updateTopK(int value) = UpdateTopK;
437440
const factory TranscriptEvent.updateMaxTokens(int value) = UpdateMaxTokens;
438-
const factory TranscriptEvent.setTranscript(LlmModel model) = SetTranscript;
439441
const factory TranscriptEvent.addMessage(
440442
ChatMessage message, LlmModel model) = AddMessage;
441443
const factory TranscriptEvent.extendMessage({

0 commit comments

Comments
 (0)