Skip to content
37 changes: 37 additions & 0 deletions scripts/generate/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Model Generation Check

The `test_generate` script provides a straightforward way to validate models, tokenizers, checkpoints, and device compatibility by running a single forward pass. This script functions as a sanity check to ensure everything is set up correctly.

While **torchtitan** focuses on advanced features for distributed pre-training, this script acts as a lightweight integration test to verify runtime setup. For more extensive inference and generation capabilities, consider tools like [pytorch/torchchat](https://github.com/pytorch/torchchat/).

## Purpose and Use Case

This script is ideal for users who need to:

- **Run Sanity Checks**: Confirm that models, tokenizers, and checkpoints load without errors.
- **Test Compatibility**: Execute a forward pass to assess model response and memory usage.
- **Evaluate Device Scaling**: Optionally test distributed generation using tensor parallel (TP) to confirm multi-device functionality.

## Usage Instructions

#### Run on a single GPU.

```bash
NGPU=1 CONFIG_FILE=./train_configs/llama3_8b.toml CHECKPOINT_DIR=./outputs/checkpoint/ \
PROMPT="What is the meaning of life?" \
./scripts/generate/run_llama_generate.sh --max_new_tokens=32 --temperature=0.8 --seed=3
```

#### Run on 4 GPUs and pipe results to a json file.

```bash
NGPU=4 CONFIG_FILE=./train_configs/llama3_8b.toml CHECKPOINT_DIR=./outputs/checkpoint/ \
PROMPT="What is the meaning of life?" \
./scripts/generate/run_llama_generate.sh --max_new_tokens=32 --temperature=0.8 --seed=3 --out > output.json
```

#### View Available Arguments

```bash
> python ./scripts/generate/test_generate.py --help
```
80 changes: 80 additions & 0 deletions scripts/generate/_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch


def multinomial_sample_one(
probs: torch.Tensor, rng: Optional[torch.Generator] = None
) -> torch.Tensor:
q = torch.empty_like(probs).exponential_(1, generator=rng)
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.long)


def logits_to_probs(
logits: torch.Tensor,
temperature: float = 1.0,
top_k: Optional[int] = None,
) -> torch.Tensor:
logits = logits / max(temperature, 1e-5)

if top_k is not None:
v, _ = torch.topk(logits, k=min(top_k, logits.size(-1)))
pivot = v.select(dim=-1, index=-1).unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)

probs = torch.nn.functional.softmax(logits, dim=-1)
return probs


def generate_next_token(
model,
x: torch.Tensor,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
rng: Optional[torch.Generator] = None,
) -> torch.Tensor:
logits = model(x) # (B, T, vocab_size)
probs = logits_to_probs(logits[:, -1, :], temperature, top_k)
next_token = multinomial_sample_one(probs, rng=rng)
return next_token


@torch.no_grad()
def generate(
model,
input_ids: torch.Tensor,
*,
max_new_tokens: int,
temperature: float = 1.0,
top_k: Optional[int] = None,
seed: Optional[int] = None,
) -> torch.Tensor:
# ensure batch dimension (T,) --> (B, T)
if input_ids.ndim == 1:
input_ids = input_ids.unsqueeze(0)

rng = None
if seed is not None:
rng = torch.Generator(input_ids.device).manual_seed(seed)

generated_tokens = input_ids.clone()

for _ in range(max_new_tokens):
next_token = generate_next_token(
model,
x=generated_tokens,
temperature=temperature,
top_k=top_k,
rng=rng,
)

generated_tokens = torch.cat([generated_tokens, next_token], dim=1)

return generated_tokens
44 changes: 44 additions & 0 deletions scripts/generate/run_llama_generate.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

set -e

# use envs as local overrides for convenience
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_llama_generate.sh
NGPU=${NGPU:-"1"}
LOG_RANK=${LOG_RANK:-0}
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}
CHECKPOINT_DIR=${CHECKPOINT_DIR:-"./outputs/checkpoint/"}
PROMPT=${PROMPT:-""}

overrides=()
if [ $# -ne 0 ]; then
for arg in "$@"; do
# special case to handle prompt in quotes
if [[ "$arg" == --prompt=* ]]; then
PROMPT="${arg#--prompt=}"
# check if file
if [[ -f "$PROMPT" ]]; then
PROMPT=$(<"$PROMPT")
fi
else
# handle other args
overrides+=("$arg")
fi
done
fi

set -x
torchrun --standalone \
--nproc_per_node="${NGPU}" \
--local-ranks-filter="${LOG_RANK}" \
scripts/generate/test_generate.py \
--config="${CONFIG_FILE}" \
--checkpoint="${CHECKPOINT_DIR}" \
--prompt="${PROMPT}" \
"${overrides[@]}"
Loading