Skip to content
198 changes: 143 additions & 55 deletions keras_nlp/utils/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from absl import logging
from tensorflow import keras

# TODO (@chenmoneygithub): Refactor code to reuse snippets.


def validate_prompt(prompt):
"""Helper function to validate input to text_generation utils."""
Expand Down Expand Up @@ -52,7 +54,11 @@ def mask_tokens_after_end_token(prompt, max_length, end_token_id, pad_token_id):
# Find index of first end_token_id.
end_indices = tf.math.argmax(prompt == end_token_id, -1)
# Use max_length if no `end_token_id` is found.
end_indices = tf.where(end_indices == 0, max_length, end_indices)
end_indices = tf.where(
end_indices == 0,
tf.cast(max_length, dtype=end_indices.dtype),
end_indices,
)
# Build a mask including end_token and replace tokens after end_token
# with `pad_token_id`.
valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length)
Expand Down Expand Up @@ -128,29 +134,48 @@ def token_probability_fn(inputs):
```

"""
if not tf.executing_eagerly():
raise RuntimeError(
"`keras_nlp.utils.greedy_search` currently requires an eager "
"execution context. Please call `greedy_search` outside "
"tf.function or run `tf.config.run_functions_eagerly(True)` to run "
"tf.function in eager mode."
)

prompt = validate_prompt(prompt)

input_is_1d = prompt.shape.rank == 1
if input_is_1d:
prompt = prompt[tf.newaxis, :]
validate_token_probability_fn(token_probability_fn, prompt)

i = prompt.shape[1]
while i < max_length:
# If the prompt has reached our desired length, exit while loop.
pred = token_probability_fn(prompt)
shape = tf.shape(prompt)
batch_size = shape[0]
length = shape[1]

# Pad the prompt with `pad_token_id` to `max_length`.
padding = tf.fill((batch_size, max_length - length), pad_token_id)
prompt = tf.concat((prompt, padding), axis=1)

def one_step(length, prompt):
pred = token_probability_fn(prompt[:, :length])
next_token = tf.cast(tf.argmax(pred, axis=-1), dtype=prompt.dtype)

# Append the next token to current sequence.
prompt = tf.concat([prompt, next_token[:, tf.newaxis]], axis=-1)
i += 1
def add_token(args):
sequence, token = args
return tf.tensor_scatter_nd_update(
tensor=sequence, indices=[[length]], updates=[token]
)

prompt = tf.map_fn(
fn=add_token,
elems=(prompt, next_token),
fn_output_signature=tf.TensorSpec(
shape=(max_length), dtype=prompt.dtype
),
)
length += 1
return (length, prompt)

# Run a while loop till text of length `max_length` has been generated.
length, prompt = tf.while_loop(
cond=lambda length, _: tf.less(length, max_length),
body=one_step,
loop_vars=(length, prompt),
)

if end_token_id is not None:
prompt = mask_tokens_after_end_token(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we even need this function anymore, if we are just starting with the correct sized tensor filled with pad_token_id?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I guess we do to avoid random tokens after the end_token_id

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep

Expand Down Expand Up @@ -382,33 +407,55 @@ def token_probability_fn(inputs):
```

"""
if not tf.executing_eagerly():
raise RuntimeError(
"`keras_nlp.utils.random_sampling` currently requires an eager "
"execution context. Please call `random_sampling` outside "
"tf.function or run `tf.config.run_functions_eagerly(True)` to run "
"tf.function in eager mode."
)

prompt = validate_prompt(prompt)
input_is_1d = prompt.shape.rank == 1
if input_is_1d:
prompt = prompt[tf.newaxis, :]
validate_token_probability_fn(token_probability_fn, prompt)

i = prompt.shape[1]
while i < max_length:
# If the prompt has reached our desired length, exit while loop.
shape = tf.shape(prompt)
batch_size = shape[0]
length = shape[1]

# Pad the prompt with `pad_token_id` to `max_length`.
padding = tf.fill((batch_size, max_length - length), pad_token_id)
prompt = tf.concat((prompt, padding), axis=1)

def one_step(length, prompt):
pred = token_probability_fn(prompt)
if from_logits:
pred = keras.activations.softmax(pred, axis=-1)
next_token = tf.cast(
tf.random.categorical(tf.math.log(pred), 1, seed=seed),
dtype=prompt.dtype,
next_token = tf.squeeze(
tf.cast(
tf.random.categorical(tf.math.log(pred), 1, seed=seed),
dtype=prompt.dtype,
),
axis=1,
)

# Append the next token to current sequence.
prompt = tf.concat([prompt, next_token], axis=-1)
i += 1
def add_token(args):
sequence, token = args
return tf.tensor_scatter_nd_update(
tensor=sequence, indices=[[length]], updates=[token]
)

prompt = tf.map_fn(
fn=add_token,
elems=(prompt, next_token),
fn_output_signature=tf.TensorSpec(
shape=(max_length), dtype=prompt.dtype
),
)
length += 1
return (length, prompt)

# Run a while loop till text of length `max_length` has been generated.
length, prompt = tf.while_loop(
cond=lambda length, _: tf.less(length, max_length),
body=one_step,
loop_vars=(length, prompt),
)

if end_token_id is not None:
prompt = mask_tokens_after_end_token(
Expand Down Expand Up @@ -497,14 +544,6 @@ def token_probability_fn(inputs):
```

"""
if not tf.executing_eagerly():
raise RuntimeError(
"`keras_nlp.utils.top_k_search` currently requires an eager "
"execution context. Please call `top_k_search` outside "
"tf.function or run `tf.config.run_functions_eagerly(True)` to run "
"tf.function in eager mode."
)

prompt = validate_prompt(prompt)
input_is_1d = prompt.shape.rank == 1
if input_is_1d:
Expand All @@ -522,8 +561,15 @@ def token_probability_fn(inputs):
)
k = pred.shape[1]

i = prompt.shape[1]
while i < max_length:
shape = tf.shape(prompt)
batch_size = shape[0]
length = shape[1]

# Pad the prompt with `pad_token_id` to `max_length`.
padding = tf.fill((batch_size, max_length - length), pad_token_id)
prompt = tf.concat((prompt, padding), axis=1)

def one_step(length, prompt):
pred = token_probability_fn(prompt)
if from_logits:
pred = keras.activations.softmax(pred, axis=-1)
Expand All @@ -534,12 +580,34 @@ def token_probability_fn(inputs):
next_token = tf.random.categorical(
tf.math.log(top_k_pred), 1, seed=seed
)

# Rearrange to get the next token idx from the original order.
next_token = tf.gather_nd(top_k_indices, next_token, batch_dims=1)
next_token = tf.cast(next_token, dtype=prompt.dtype)

# Append the next token to current sequence.
prompt = tf.concat([prompt, next_token[:, tf.newaxis]], axis=-1)
i += 1
def add_token(args):
sequence, token = args
return tf.tensor_scatter_nd_update(
tensor=sequence, indices=[[length]], updates=[token]
)

prompt = tf.map_fn(
fn=add_token,
elems=(prompt, next_token),
fn_output_signature=tf.TensorSpec(
shape=(max_length), dtype=prompt.dtype
),
)
length += 1
return (length, prompt)

# Run a while loop till text of length `max_length` has been generated.
length, prompt = tf.while_loop(
cond=lambda length, _: tf.less(length, max_length),
body=one_step,
loop_vars=(length, prompt),
)

if end_token_id is not None:
prompt = mask_tokens_after_end_token(
Expand Down Expand Up @@ -630,13 +698,6 @@ def token_probability_fn(inputs):
```

"""
if not tf.executing_eagerly():
raise RuntimeError(
"`keras_nlp.utils.top_p_search` currently requires an eager "
"execution context. Please call `top_p_search` outside "
"tf.function or run `tf.config.run_functions_eagerly(True)` to run "
"tf.function in eager mode."
)
if p <= 0 or p >= 1:
raise ValueError(
f"`p` should be in the range (0, 1). Received: `p={p}`."
Expand All @@ -648,9 +709,15 @@ def token_probability_fn(inputs):
prompt = prompt[tf.newaxis, :]
validate_token_probability_fn(token_probability_fn, prompt)

i = prompt.shape[1]
while i < max_length:
# If the prompt has reached our desired length, exit while loop.
shape = tf.shape(prompt)
batch_size = shape[0]
length = shape[1]

# Pad the prompt with `pad_token_id` to `max_length`.
padding = tf.fill((batch_size, max_length - length), pad_token_id)
prompt = tf.concat((prompt, padding), axis=1)

def one_step(length, prompt):
pred = token_probability_fn(prompt)
if from_logits:
pred = keras.activations.softmax(pred, axis=-1)
Expand Down Expand Up @@ -679,9 +746,30 @@ def token_probability_fn(inputs):
sorted_indices, sorted_next_token, batch_dims=1
)
next_token = tf.cast(next_token, dtype=prompt.dtype)

# Append the next token to current sequence.
prompt = tf.concat([prompt, next_token[:, tf.newaxis]], axis=-1)
i += 1
def add_token(args):
sequence, token = args
return tf.tensor_scatter_nd_update(
tensor=sequence, indices=[[length]], updates=[token]
)

prompt = tf.map_fn(
fn=add_token,
elems=(prompt, next_token),
fn_output_signature=tf.TensorSpec(
shape=(max_length), dtype=prompt.dtype
),
)
length += 1
return (length, prompt)

# Run a while loop till text of length `max_length` has been generated.
length, prompt = tf.while_loop(
cond=lambda length, _: tf.less(length, max_length),
body=one_step,
loop_vars=(length, prompt),
)

if end_token_id is not None:
prompt = mask_tokens_after_end_token(
Expand Down
Loading