Skip to content

Commit 34c0e27

Browse files
authored
Make Decoding Functions Graph-compatible (with XLA Support!) (#271)
* Make text gen functions graph compat * Minor edit * Address review comments - I * Changes for XLA support * Address review comments - II * Some polishing up * Format code * Address review comments - III * Minor edit * Add TODO comment * Fix format
1 parent bf2110e commit 34c0e27

File tree

2 files changed

+421
-72
lines changed

2 files changed

+421
-72
lines changed

keras_nlp/utils/text_generation.py

Lines changed: 143 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from absl import logging
1919
from tensorflow import keras
2020

21+
# TODO (@chenmoneygithub): Refactor code to reuse snippets.
22+
2123

2224
def validate_prompt(prompt):
2325
"""Helper function to validate input to text_generation utils."""
@@ -52,7 +54,11 @@ def mask_tokens_after_end_token(prompt, max_length, end_token_id, pad_token_id):
5254
# Find index of first end_token_id.
5355
end_indices = tf.math.argmax(prompt == end_token_id, -1)
5456
# Use max_length if no `end_token_id` is found.
55-
end_indices = tf.where(end_indices == 0, max_length, end_indices)
57+
end_indices = tf.where(
58+
end_indices == 0,
59+
tf.cast(max_length, dtype=end_indices.dtype),
60+
end_indices,
61+
)
5662
# Build a mask including end_token and replace tokens after end_token
5763
# with `pad_token_id`.
5864
valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length)
@@ -128,29 +134,48 @@ def token_probability_fn(inputs):
128134
```
129135
130136
"""
131-
if not tf.executing_eagerly():
132-
raise RuntimeError(
133-
"`keras_nlp.utils.greedy_search` currently requires an eager "
134-
"execution context. Please call `greedy_search` outside "
135-
"tf.function or run `tf.config.run_functions_eagerly(True)` to run "
136-
"tf.function in eager mode."
137-
)
138-
139137
prompt = validate_prompt(prompt)
140138

141139
input_is_1d = prompt.shape.rank == 1
142140
if input_is_1d:
143141
prompt = prompt[tf.newaxis, :]
144142
validate_token_probability_fn(token_probability_fn, prompt)
145143

146-
i = prompt.shape[1]
147-
while i < max_length:
148-
# If the prompt has reached our desired length, exit while loop.
149-
pred = token_probability_fn(prompt)
144+
shape = tf.shape(prompt)
145+
batch_size = shape[0]
146+
length = shape[1]
147+
148+
# Pad the prompt with `pad_token_id` to `max_length`.
149+
padding = tf.fill((batch_size, max_length - length), pad_token_id)
150+
prompt = tf.concat((prompt, padding), axis=1)
151+
152+
def one_step(length, prompt):
153+
pred = token_probability_fn(prompt[:, :length])
150154
next_token = tf.cast(tf.argmax(pred, axis=-1), dtype=prompt.dtype)
155+
151156
# Append the next token to current sequence.
152-
prompt = tf.concat([prompt, next_token[:, tf.newaxis]], axis=-1)
153-
i += 1
157+
def add_token(args):
158+
sequence, token = args
159+
return tf.tensor_scatter_nd_update(
160+
tensor=sequence, indices=[[length]], updates=[token]
161+
)
162+
163+
prompt = tf.map_fn(
164+
fn=add_token,
165+
elems=(prompt, next_token),
166+
fn_output_signature=tf.TensorSpec(
167+
shape=(max_length), dtype=prompt.dtype
168+
),
169+
)
170+
length += 1
171+
return (length, prompt)
172+
173+
# Run a while loop till text of length `max_length` has been generated.
174+
length, prompt = tf.while_loop(
175+
cond=lambda length, _: tf.less(length, max_length),
176+
body=one_step,
177+
loop_vars=(length, prompt),
178+
)
154179

155180
if end_token_id is not None:
156181
prompt = mask_tokens_after_end_token(
@@ -382,33 +407,55 @@ def token_probability_fn(inputs):
382407
```
383408
384409
"""
385-
if not tf.executing_eagerly():
386-
raise RuntimeError(
387-
"`keras_nlp.utils.random_sampling` currently requires an eager "
388-
"execution context. Please call `random_sampling` outside "
389-
"tf.function or run `tf.config.run_functions_eagerly(True)` to run "
390-
"tf.function in eager mode."
391-
)
392-
393410
prompt = validate_prompt(prompt)
394411
input_is_1d = prompt.shape.rank == 1
395412
if input_is_1d:
396413
prompt = prompt[tf.newaxis, :]
397414
validate_token_probability_fn(token_probability_fn, prompt)
398415

399-
i = prompt.shape[1]
400-
while i < max_length:
401-
# If the prompt has reached our desired length, exit while loop.
416+
shape = tf.shape(prompt)
417+
batch_size = shape[0]
418+
length = shape[1]
419+
420+
# Pad the prompt with `pad_token_id` to `max_length`.
421+
padding = tf.fill((batch_size, max_length - length), pad_token_id)
422+
prompt = tf.concat((prompt, padding), axis=1)
423+
424+
def one_step(length, prompt):
402425
pred = token_probability_fn(prompt)
403426
if from_logits:
404427
pred = keras.activations.softmax(pred, axis=-1)
405-
next_token = tf.cast(
406-
tf.random.categorical(tf.math.log(pred), 1, seed=seed),
407-
dtype=prompt.dtype,
428+
next_token = tf.squeeze(
429+
tf.cast(
430+
tf.random.categorical(tf.math.log(pred), 1, seed=seed),
431+
dtype=prompt.dtype,
432+
),
433+
axis=1,
408434
)
435+
409436
# Append the next token to current sequence.
410-
prompt = tf.concat([prompt, next_token], axis=-1)
411-
i += 1
437+
def add_token(args):
438+
sequence, token = args
439+
return tf.tensor_scatter_nd_update(
440+
tensor=sequence, indices=[[length]], updates=[token]
441+
)
442+
443+
prompt = tf.map_fn(
444+
fn=add_token,
445+
elems=(prompt, next_token),
446+
fn_output_signature=tf.TensorSpec(
447+
shape=(max_length), dtype=prompt.dtype
448+
),
449+
)
450+
length += 1
451+
return (length, prompt)
452+
453+
# Run a while loop till text of length `max_length` has been generated.
454+
length, prompt = tf.while_loop(
455+
cond=lambda length, _: tf.less(length, max_length),
456+
body=one_step,
457+
loop_vars=(length, prompt),
458+
)
412459

413460
if end_token_id is not None:
414461
prompt = mask_tokens_after_end_token(
@@ -497,14 +544,6 @@ def token_probability_fn(inputs):
497544
```
498545
499546
"""
500-
if not tf.executing_eagerly():
501-
raise RuntimeError(
502-
"`keras_nlp.utils.top_k_search` currently requires an eager "
503-
"execution context. Please call `top_k_search` outside "
504-
"tf.function or run `tf.config.run_functions_eagerly(True)` to run "
505-
"tf.function in eager mode."
506-
)
507-
508547
prompt = validate_prompt(prompt)
509548
input_is_1d = prompt.shape.rank == 1
510549
if input_is_1d:
@@ -522,8 +561,15 @@ def token_probability_fn(inputs):
522561
)
523562
k = pred.shape[1]
524563

525-
i = prompt.shape[1]
526-
while i < max_length:
564+
shape = tf.shape(prompt)
565+
batch_size = shape[0]
566+
length = shape[1]
567+
568+
# Pad the prompt with `pad_token_id` to `max_length`.
569+
padding = tf.fill((batch_size, max_length - length), pad_token_id)
570+
prompt = tf.concat((prompt, padding), axis=1)
571+
572+
def one_step(length, prompt):
527573
pred = token_probability_fn(prompt)
528574
if from_logits:
529575
pred = keras.activations.softmax(pred, axis=-1)
@@ -534,12 +580,34 @@ def token_probability_fn(inputs):
534580
next_token = tf.random.categorical(
535581
tf.math.log(top_k_pred), 1, seed=seed
536582
)
583+
537584
# Rearrange to get the next token idx from the original order.
538585
next_token = tf.gather_nd(top_k_indices, next_token, batch_dims=1)
539586
next_token = tf.cast(next_token, dtype=prompt.dtype)
587+
540588
# Append the next token to current sequence.
541-
prompt = tf.concat([prompt, next_token[:, tf.newaxis]], axis=-1)
542-
i += 1
589+
def add_token(args):
590+
sequence, token = args
591+
return tf.tensor_scatter_nd_update(
592+
tensor=sequence, indices=[[length]], updates=[token]
593+
)
594+
595+
prompt = tf.map_fn(
596+
fn=add_token,
597+
elems=(prompt, next_token),
598+
fn_output_signature=tf.TensorSpec(
599+
shape=(max_length), dtype=prompt.dtype
600+
),
601+
)
602+
length += 1
603+
return (length, prompt)
604+
605+
# Run a while loop till text of length `max_length` has been generated.
606+
length, prompt = tf.while_loop(
607+
cond=lambda length, _: tf.less(length, max_length),
608+
body=one_step,
609+
loop_vars=(length, prompt),
610+
)
543611

544612
if end_token_id is not None:
545613
prompt = mask_tokens_after_end_token(
@@ -630,13 +698,6 @@ def token_probability_fn(inputs):
630698
```
631699
632700
"""
633-
if not tf.executing_eagerly():
634-
raise RuntimeError(
635-
"`keras_nlp.utils.top_p_search` currently requires an eager "
636-
"execution context. Please call `top_p_search` outside "
637-
"tf.function or run `tf.config.run_functions_eagerly(True)` to run "
638-
"tf.function in eager mode."
639-
)
640701
if p <= 0 or p >= 1:
641702
raise ValueError(
642703
f"`p` should be in the range (0, 1). Received: `p={p}`."
@@ -648,9 +709,15 @@ def token_probability_fn(inputs):
648709
prompt = prompt[tf.newaxis, :]
649710
validate_token_probability_fn(token_probability_fn, prompt)
650711

651-
i = prompt.shape[1]
652-
while i < max_length:
653-
# If the prompt has reached our desired length, exit while loop.
712+
shape = tf.shape(prompt)
713+
batch_size = shape[0]
714+
length = shape[1]
715+
716+
# Pad the prompt with `pad_token_id` to `max_length`.
717+
padding = tf.fill((batch_size, max_length - length), pad_token_id)
718+
prompt = tf.concat((prompt, padding), axis=1)
719+
720+
def one_step(length, prompt):
654721
pred = token_probability_fn(prompt)
655722
if from_logits:
656723
pred = keras.activations.softmax(pred, axis=-1)
@@ -679,9 +746,30 @@ def token_probability_fn(inputs):
679746
sorted_indices, sorted_next_token, batch_dims=1
680747
)
681748
next_token = tf.cast(next_token, dtype=prompt.dtype)
749+
682750
# Append the next token to current sequence.
683-
prompt = tf.concat([prompt, next_token[:, tf.newaxis]], axis=-1)
684-
i += 1
751+
def add_token(args):
752+
sequence, token = args
753+
return tf.tensor_scatter_nd_update(
754+
tensor=sequence, indices=[[length]], updates=[token]
755+
)
756+
757+
prompt = tf.map_fn(
758+
fn=add_token,
759+
elems=(prompt, next_token),
760+
fn_output_signature=tf.TensorSpec(
761+
shape=(max_length), dtype=prompt.dtype
762+
),
763+
)
764+
length += 1
765+
return (length, prompt)
766+
767+
# Run a while loop till text of length `max_length` has been generated.
768+
length, prompt = tf.while_loop(
769+
cond=lambda length, _: tf.less(length, max_length),
770+
body=one_step,
771+
loop_vars=(length, prompt),
772+
)
685773

686774
if end_token_id is not None:
687775
prompt = mask_tokens_after_end_token(

0 commit comments

Comments
 (0)