Skip to content

JIT compile inference func in LLM pretrain tutorial #197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
702 changes: 375 additions & 327 deletions docs/source/JAX_for_LLM_pretraining.ipynb

Large diffs are not rendered by default.

118 changes: 68 additions & 50 deletions docs/source/JAX_for_LLM_pretraining.md
Original file line number Diff line number Diff line change
@@ -40,7 +40,7 @@ JAX installation is covered in [this guide](https://jax.readthedocs.io/en/latest
colab:
base_uri: https://localhost:8080/
id: 6zMsOIc7ouCO
outputId: 037d56a9-b18f-4504-f80a-3a4fa2945068
outputId: ad486e3b-dd63-405f-d786-79b0b6d60cbd
---
!pip install -Uq tiktoken grain matplotlib
```
@@ -56,7 +56,7 @@ Check the available JAX devices, or [`jax.Device`](https://jax.readthedocs.io/en
colab:
base_uri: https://localhost:8080/
id: LS9sQEY3n0mB
outputId: 9ffcf3a6-20ef-4f80-b006-f5d3c5644a15
outputId: c18a63d0-696e-4c93-f8d5-8c7942045005
---
import jax
jax.devices()
@@ -71,7 +71,7 @@ Get the [TinyStories dataset from Hugging Face](https://huggingface.co/datasets/
colab:
base_uri: https://localhost:8080/
id: wUjQsgQEmI1N
outputId: e6eff24e-5578-4277-a0f9-24e27bd91ee0
outputId: 8a75571a-8339-4f1a-958b-5805aa285bb9
---
!wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true -O TinyStories-train.txt
```
@@ -318,48 +318,31 @@ class MiniGPT(nnx.Module):
outputs = self.output_layer(x)
return outputs
# Text generation.
def generate_text(self, max_tokens: int, start_tokens: [int], top_k=10):
# Sample the next token from a probability distribution based on
# `logits` and `tok_k` (top-k) sampling strategy.
def sample_from(logits):
logits, indices = jax.lax.top_k(logits, k=top_k)
# Convert logits to probabilities (using `flax.nnx.softmax`).
logits = nnx.softmax(logits)
return jax.random.choice(jax.random.PRNGKey(0), indices, p=logits)
# Generate text one token at a time until the maximum token limit is reached (`maxlen`).
def generate_step(start_tokens):
pad_len = maxlen - len(start_tokens)
# Index of the last token in the current sequence.
sample_index = len(start_tokens) - 1
# If the input is longer than `maxlen`, then truncate it.
if pad_len < 0:
x = jnp.array(start_tokens[:maxlen])
sample_index = maxlen - 1
# If the input is shorter than `maxlen`, then pad it (`pad_len`).
elif pad_len > 0:
x = jnp.array(start_tokens + [0] * pad_len)
else:
x = jnp.array(start_tokens)
# Add a batch dimension.
x = x[None, :]
logits = self(x)
next_token = sample_from(logits[0][sample_index])
return next_token
# Store generated tokens.
@nnx.jit
def sample_from(self, logits):
logits, indices = jax.lax.top_k(logits, k=top_k)
logits = nnx.softmax(logits)
return jax.random.choice(jax.random.PRNGKey(0), indices, p=logits)
@nnx.jit
def generate_step(self, padded_tokens, sample_index):
logits = self(padded_tokens)
next_token = self.sample_from(logits[0][sample_index])
return next_token
def generate_text(self, max_tokens, start_tokens):
generated = []
# Generate tokens until the end-of-text token is encountered or the maximum token limit is reached.
for _ in range(max_tokens):
next_token = generate_step(start_tokens + generated)
# Truncate whatever is after '<|endoftext|>' (stop word)
print(tokenizer.decode(start_tokens), flush=True, end='')
for i in range(max_tokens):
sample_index = len(start_tokens) + len(generated) - 1
padded_tokens = jnp.array((start_tokens + generated + [0] * (maxlen - len(start_tokens) - len(generated))))[None, :]
next_token = int(self.generate_step(padded_tokens, sample_index))
if next_token == tokenizer.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]:
# Stop text generation if the end-of-text token is encountered.
break
generated.append(int(next_token))
# Decode the generated token IDs into text.
generated.append(next_token)
# decode and print next_token
print(tokenizer.decode([next_token]), flush=True, end='')
return tokenizer.decode(start_tokens + generated)
# Creates the miniGPT model with 4 transformer blocks.
@@ -382,6 +365,7 @@ num_heads = 8
feed_forward_dim = 256
batch_size = 256 # You can set a bigger batch size if you use Kaggle's Cloud TPU.
num_epochs = 1
top_k = 10
```

+++ {"id": "mI1ci-HyMspJ"}
@@ -473,7 +457,7 @@ We are also using the `jax.vmap` transformation to produce the target sequences
colab:
base_uri: https://localhost:8080/
id: Ysl6CsfENeJN
outputId: 5dd06dca-f030-4927-a9b6-35d412da535c
outputId: b713cc63-a4ff-4fea-96ea-c234348e0ea4
---
model = create_model(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
@@ -484,11 +468,10 @@ rng = jax.random.PRNGKey(0)
start_prompt = "Once upon a time"
start_tokens = tokenizer.encode(start_prompt)[:maxlen]
print(f"Initial generated text:")
generated_text = model.generate_text(
maxlen, start_tokens
)
print(f"Initial generated text:\n{generated_text}\n")
metrics_history = {
'train_loss': [],
@@ -512,20 +495,21 @@ for epoch in range(num_epochs):
metrics.reset()
elapsed_time = time.time() - start_time
print(f"Step {step + 1}, Loss: {metrics_history['train_loss'][-1]}, Elapsed Time: {elapsed_time:.2f} seconds")
print(f"\n\nStep {step + 1}, Loss: {metrics_history['train_loss'][-1]}, Elapsed Time: {elapsed_time:.2f} seconds")
start_time = time.time()
print(f"Generated text:")
generated_text = model.generate_text(
maxlen, start_tokens
)
print(f"Generated text:\n{generated_text}\n")
step += 1
# Final text generation
print(f"Final generated text:")
generated_text = model.generate_text(
maxlen, start_tokens
)
print(f"Final generated text:\n{generated_text}")
```

+++ {"id": "thaLs6TD0lt5"}
@@ -538,7 +522,7 @@ colab:
base_uri: https://localhost:8080/
height: 472
id: B6Eg1Cz2y_iP
outputId: 7cafe711-1ae4-4eb9-fd37-e1bde54cbfc5
outputId: fb96b456-23a8-448f-ebc6-23d807a626d2
---
import matplotlib.pyplot as plt
plt.plot(metrics_history['train_loss'])
@@ -563,7 +547,7 @@ Save the model checkpoint.
colab:
base_uri: https://localhost:8080/
id: EkoFGCgSZ1yz
outputId: 3467b8ba-ce05-42f0-fb89-75922cc91e31
outputId: 13434f77-2d73-4176-db24-7c0b2e232fe6
---
import orbax.checkpoint as orbax
@@ -576,21 +560,37 @@ checkpointer.save('/content/save', state)
!ls /content/save/
```

+++ {"id": "3813cbf2"}

## Profiling for hyperparameter tuning

```{code-cell}
---
colab:
base_uri: https://localhost:8080/
id: b5d933c6
outputId: b1233c24-0832-49f8-8568-e4ccbd73ca83
---
!pip install -Uq tensorboard-plugin-profile tensorflow tensorboard
```

+++ {"id": "2ac5fc4d"}

Load the tensorboard colab extension.

```{code-cell}
:id: 74f0c212
%load_ext tensorboard
```

+++ {"id": "17c6131f"}

As we're going to be running this model a number of times, we need some scaffolding to more easily compare our work. For a baseline, we'll need to perform some warmup to guarantee that our code is JIT'd and that our TPUs are warm. For improved comparability, we'll only start tracing after we've finished warmup.

```{code-cell}
:id: ddfd576e
trace_dir = "/tmp/jax-trace/"
def loop_step(batch, step):
@@ -611,9 +611,13 @@ def generate_trace():
jax.profiler.stop_trace()
```

+++ {"id": "de70f5b7"}

Now we'll perform some traces to compare results of different batch sizes. This will take several minutes as we need to reprocess our input data to prepare new batches each time.

```{code-cell}
:id: bc9452a6
trace_dir = "/tmp/jax-trace-batch-comparison/"
batch_size = 64
@@ -625,16 +629,22 @@ text_dl = iter(load_and_preprocess_data('TinyStories-train.txt', batch_size, max
generate_trace()
```

+++ {"id": "ea379965"}

Run Tensorboard with the Profiler Plugin to compare our runs. Runs are listed in order from newest to oldest, so the top run in the list will be have `batch_size = 256`.

The key metrics to focus on here for this hyperparameter are FLOPS Utilization and Average Step Time.

In general, we want to maximize FLOPS Utilization while minimizing the step time per training example. In this case, we can see that increasing the batch size from 64 -> 256 achieves both of those. FLOPS increases from 16% to 27%. Average Step Time increase from 100ms to 260ms, however we increased our batch size by 300%. This means we move from 1.5ms per training example to 1.02ms per training example.

```{code-cell}
:id: b86c565a
%tensorboard --logdir=$trace_dir
```

+++ {"id": "657967a5"}

Next, we can explore alternative parallelism methods. In cell #4, we used 4-way data parallel and 2-way tensor parallel. 8-way data parallel is another popular way. Let's compare results between them. To switch to 8-way data parallel, we'll replace the `Mesh` definition with:

`mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))`
@@ -644,6 +654,8 @@ JAX will automatically figure out how to shard the model and data to use the new
How simple and powerful is this! And that's the beauty of JAX automatic parallelism.

```{code-cell}
:id: 80daa8dc
trace_dir = "/tmp/jax-trace-parallelism-comparison/"
mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))
@@ -653,14 +665,20 @@ mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))
generate_trace()
```

+++ {"id": "ad96e72b"}

Once again we'll run tensorboard.

Looking at the results, we see that the step times are nearly the same, however the FLOPS Utilization is at 13% for 8-way data parallelism compared to 27% or 4-way data parallelism.

By looking at the Trace Viewer tool and looking under each TPU's ops, we can see that the TPUs spend a large amount of time idle while waiting for the host, as well as spending a good amount of time in `reduce_sum` operations.

```{code-cell}
:id: 780e9c72
%tensorboard --logdir=$trace_dir
```

+++ {"id": "deca486e"}

By changing hyperparameters and comparing profiles, we're able to gain significant insights into our bottlenecks and limitations. These are just two examples of hyperparameters to tune, but plenty more of them will have significant effects on training speed and resource utilization.