18
18
from absl import logging
19
19
from tensorflow import keras
20
20
21
+ # TODO (@chenmoneygithub): Refactor code to reuse snippets.
22
+
21
23
22
24
def validate_prompt (prompt ):
23
25
"""Helper function to validate input to text_generation utils."""
@@ -52,7 +54,11 @@ def mask_tokens_after_end_token(prompt, max_length, end_token_id, pad_token_id):
52
54
# Find index of first end_token_id.
53
55
end_indices = tf .math .argmax (prompt == end_token_id , - 1 )
54
56
# Use max_length if no `end_token_id` is found.
55
- end_indices = tf .where (end_indices == 0 , max_length , end_indices )
57
+ end_indices = tf .where (
58
+ end_indices == 0 ,
59
+ tf .cast (max_length , dtype = end_indices .dtype ),
60
+ end_indices ,
61
+ )
56
62
# Build a mask including end_token and replace tokens after end_token
57
63
# with `pad_token_id`.
58
64
valid_indices = tf .sequence_mask (end_indices + 1 , maxlen = max_length )
@@ -128,29 +134,48 @@ def token_probability_fn(inputs):
128
134
```
129
135
130
136
"""
131
- if not tf .executing_eagerly ():
132
- raise RuntimeError (
133
- "`keras_nlp.utils.greedy_search` currently requires an eager "
134
- "execution context. Please call `greedy_search` outside "
135
- "tf.function or run `tf.config.run_functions_eagerly(True)` to run "
136
- "tf.function in eager mode."
137
- )
138
-
139
137
prompt = validate_prompt (prompt )
140
138
141
139
input_is_1d = prompt .shape .rank == 1
142
140
if input_is_1d :
143
141
prompt = prompt [tf .newaxis , :]
144
142
validate_token_probability_fn (token_probability_fn , prompt )
145
143
146
- i = prompt .shape [1 ]
147
- while i < max_length :
148
- # If the prompt has reached our desired length, exit while loop.
149
- pred = token_probability_fn (prompt )
144
+ shape = tf .shape (prompt )
145
+ batch_size = shape [0 ]
146
+ length = shape [1 ]
147
+
148
+ # Pad the prompt with `pad_token_id` to `max_length`.
149
+ padding = tf .fill ((batch_size , max_length - length ), pad_token_id )
150
+ prompt = tf .concat ((prompt , padding ), axis = 1 )
151
+
152
+ def one_step (length , prompt ):
153
+ pred = token_probability_fn (prompt [:, :length ])
150
154
next_token = tf .cast (tf .argmax (pred , axis = - 1 ), dtype = prompt .dtype )
155
+
151
156
# Append the next token to current sequence.
152
- prompt = tf .concat ([prompt , next_token [:, tf .newaxis ]], axis = - 1 )
153
- i += 1
157
+ def add_token (args ):
158
+ sequence , token = args
159
+ return tf .tensor_scatter_nd_update (
160
+ tensor = sequence , indices = [[length ]], updates = [token ]
161
+ )
162
+
163
+ prompt = tf .map_fn (
164
+ fn = add_token ,
165
+ elems = (prompt , next_token ),
166
+ fn_output_signature = tf .TensorSpec (
167
+ shape = (max_length ), dtype = prompt .dtype
168
+ ),
169
+ )
170
+ length += 1
171
+ return (length , prompt )
172
+
173
+ # Run a while loop till text of length `max_length` has been generated.
174
+ length , prompt = tf .while_loop (
175
+ cond = lambda length , _ : tf .less (length , max_length ),
176
+ body = one_step ,
177
+ loop_vars = (length , prompt ),
178
+ )
154
179
155
180
if end_token_id is not None :
156
181
prompt = mask_tokens_after_end_token (
@@ -382,33 +407,55 @@ def token_probability_fn(inputs):
382
407
```
383
408
384
409
"""
385
- if not tf .executing_eagerly ():
386
- raise RuntimeError (
387
- "`keras_nlp.utils.random_sampling` currently requires an eager "
388
- "execution context. Please call `random_sampling` outside "
389
- "tf.function or run `tf.config.run_functions_eagerly(True)` to run "
390
- "tf.function in eager mode."
391
- )
392
-
393
410
prompt = validate_prompt (prompt )
394
411
input_is_1d = prompt .shape .rank == 1
395
412
if input_is_1d :
396
413
prompt = prompt [tf .newaxis , :]
397
414
validate_token_probability_fn (token_probability_fn , prompt )
398
415
399
- i = prompt .shape [1 ]
400
- while i < max_length :
401
- # If the prompt has reached our desired length, exit while loop.
416
+ shape = tf .shape (prompt )
417
+ batch_size = shape [0 ]
418
+ length = shape [1 ]
419
+
420
+ # Pad the prompt with `pad_token_id` to `max_length`.
421
+ padding = tf .fill ((batch_size , max_length - length ), pad_token_id )
422
+ prompt = tf .concat ((prompt , padding ), axis = 1 )
423
+
424
+ def one_step (length , prompt ):
402
425
pred = token_probability_fn (prompt )
403
426
if from_logits :
404
427
pred = keras .activations .softmax (pred , axis = - 1 )
405
- next_token = tf .cast (
406
- tf .random .categorical (tf .math .log (pred ), 1 , seed = seed ),
407
- dtype = prompt .dtype ,
428
+ next_token = tf .squeeze (
429
+ tf .cast (
430
+ tf .random .categorical (tf .math .log (pred ), 1 , seed = seed ),
431
+ dtype = prompt .dtype ,
432
+ ),
433
+ axis = 1 ,
408
434
)
435
+
409
436
# Append the next token to current sequence.
410
- prompt = tf .concat ([prompt , next_token ], axis = - 1 )
411
- i += 1
437
+ def add_token (args ):
438
+ sequence , token = args
439
+ return tf .tensor_scatter_nd_update (
440
+ tensor = sequence , indices = [[length ]], updates = [token ]
441
+ )
442
+
443
+ prompt = tf .map_fn (
444
+ fn = add_token ,
445
+ elems = (prompt , next_token ),
446
+ fn_output_signature = tf .TensorSpec (
447
+ shape = (max_length ), dtype = prompt .dtype
448
+ ),
449
+ )
450
+ length += 1
451
+ return (length , prompt )
452
+
453
+ # Run a while loop till text of length `max_length` has been generated.
454
+ length , prompt = tf .while_loop (
455
+ cond = lambda length , _ : tf .less (length , max_length ),
456
+ body = one_step ,
457
+ loop_vars = (length , prompt ),
458
+ )
412
459
413
460
if end_token_id is not None :
414
461
prompt = mask_tokens_after_end_token (
@@ -497,14 +544,6 @@ def token_probability_fn(inputs):
497
544
```
498
545
499
546
"""
500
- if not tf .executing_eagerly ():
501
- raise RuntimeError (
502
- "`keras_nlp.utils.top_k_search` currently requires an eager "
503
- "execution context. Please call `top_k_search` outside "
504
- "tf.function or run `tf.config.run_functions_eagerly(True)` to run "
505
- "tf.function in eager mode."
506
- )
507
-
508
547
prompt = validate_prompt (prompt )
509
548
input_is_1d = prompt .shape .rank == 1
510
549
if input_is_1d :
@@ -522,8 +561,15 @@ def token_probability_fn(inputs):
522
561
)
523
562
k = pred .shape [1 ]
524
563
525
- i = prompt .shape [1 ]
526
- while i < max_length :
564
+ shape = tf .shape (prompt )
565
+ batch_size = shape [0 ]
566
+ length = shape [1 ]
567
+
568
+ # Pad the prompt with `pad_token_id` to `max_length`.
569
+ padding = tf .fill ((batch_size , max_length - length ), pad_token_id )
570
+ prompt = tf .concat ((prompt , padding ), axis = 1 )
571
+
572
+ def one_step (length , prompt ):
527
573
pred = token_probability_fn (prompt )
528
574
if from_logits :
529
575
pred = keras .activations .softmax (pred , axis = - 1 )
@@ -534,12 +580,34 @@ def token_probability_fn(inputs):
534
580
next_token = tf .random .categorical (
535
581
tf .math .log (top_k_pred ), 1 , seed = seed
536
582
)
583
+
537
584
# Rearrange to get the next token idx from the original order.
538
585
next_token = tf .gather_nd (top_k_indices , next_token , batch_dims = 1 )
539
586
next_token = tf .cast (next_token , dtype = prompt .dtype )
587
+
540
588
# Append the next token to current sequence.
541
- prompt = tf .concat ([prompt , next_token [:, tf .newaxis ]], axis = - 1 )
542
- i += 1
589
+ def add_token (args ):
590
+ sequence , token = args
591
+ return tf .tensor_scatter_nd_update (
592
+ tensor = sequence , indices = [[length ]], updates = [token ]
593
+ )
594
+
595
+ prompt = tf .map_fn (
596
+ fn = add_token ,
597
+ elems = (prompt , next_token ),
598
+ fn_output_signature = tf .TensorSpec (
599
+ shape = (max_length ), dtype = prompt .dtype
600
+ ),
601
+ )
602
+ length += 1
603
+ return (length , prompt )
604
+
605
+ # Run a while loop till text of length `max_length` has been generated.
606
+ length , prompt = tf .while_loop (
607
+ cond = lambda length , _ : tf .less (length , max_length ),
608
+ body = one_step ,
609
+ loop_vars = (length , prompt ),
610
+ )
543
611
544
612
if end_token_id is not None :
545
613
prompt = mask_tokens_after_end_token (
@@ -630,13 +698,6 @@ def token_probability_fn(inputs):
630
698
```
631
699
632
700
"""
633
- if not tf .executing_eagerly ():
634
- raise RuntimeError (
635
- "`keras_nlp.utils.top_p_search` currently requires an eager "
636
- "execution context. Please call `top_p_search` outside "
637
- "tf.function or run `tf.config.run_functions_eagerly(True)` to run "
638
- "tf.function in eager mode."
639
- )
640
701
if p <= 0 or p >= 1 :
641
702
raise ValueError (
642
703
f"`p` should be in the range (0, 1). Received: `p={ p } `."
@@ -648,9 +709,15 @@ def token_probability_fn(inputs):
648
709
prompt = prompt [tf .newaxis , :]
649
710
validate_token_probability_fn (token_probability_fn , prompt )
650
711
651
- i = prompt .shape [1 ]
652
- while i < max_length :
653
- # If the prompt has reached our desired length, exit while loop.
712
+ shape = tf .shape (prompt )
713
+ batch_size = shape [0 ]
714
+ length = shape [1 ]
715
+
716
+ # Pad the prompt with `pad_token_id` to `max_length`.
717
+ padding = tf .fill ((batch_size , max_length - length ), pad_token_id )
718
+ prompt = tf .concat ((prompt , padding ), axis = 1 )
719
+
720
+ def one_step (length , prompt ):
654
721
pred = token_probability_fn (prompt )
655
722
if from_logits :
656
723
pred = keras .activations .softmax (pred , axis = - 1 )
@@ -679,9 +746,30 @@ def token_probability_fn(inputs):
679
746
sorted_indices , sorted_next_token , batch_dims = 1
680
747
)
681
748
next_token = tf .cast (next_token , dtype = prompt .dtype )
749
+
682
750
# Append the next token to current sequence.
683
- prompt = tf .concat ([prompt , next_token [:, tf .newaxis ]], axis = - 1 )
684
- i += 1
751
+ def add_token (args ):
752
+ sequence , token = args
753
+ return tf .tensor_scatter_nd_update (
754
+ tensor = sequence , indices = [[length ]], updates = [token ]
755
+ )
756
+
757
+ prompt = tf .map_fn (
758
+ fn = add_token ,
759
+ elems = (prompt , next_token ),
760
+ fn_output_signature = tf .TensorSpec (
761
+ shape = (max_length ), dtype = prompt .dtype
762
+ ),
763
+ )
764
+ length += 1
765
+ return (length , prompt )
766
+
767
+ # Run a while loop till text of length `max_length` has been generated.
768
+ length , prompt = tf .while_loop (
769
+ cond = lambda length , _ : tf .less (length , max_length ),
770
+ body = one_step ,
771
+ loop_vars = (length , prompt ),
772
+ )
685
773
686
774
if end_token_id is not None :
687
775
prompt = mask_tokens_after_end_token (
0 commit comments