Skip to content

Commit 9af29a7

Browse files
committed
added simple generation
1 parent 0edd2fb commit 9af29a7

File tree

2 files changed

+242
-0
lines changed

2 files changed

+242
-0
lines changed

test/generate/generation.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Callable, Optional, Tuple
8+
9+
import torch
10+
11+
12+
def sample(
13+
logits: torch.Tensor,
14+
*,
15+
temperature: float = 1.0,
16+
top_k: Optional[int] = None,
17+
) -> torch.Tensor:
18+
"""Sample from a probability distribution
19+
20+
Args:
21+
logits (torch.Tensor): logits from which to sample (vocab_size,)
22+
temperature (float): value to scale logits by, default 1.0.
23+
top_k (Optional[int]): if specified, prune sampling to only tokens within the top_k probs.
24+
25+
Returns:
26+
torch.Tensor: sampled token id
27+
"""
28+
29+
# scale
30+
logits = logits / max(temperature, 1e-5)
31+
32+
# top-k
33+
if top_k is not None:
34+
v, _ = torch.topk(logits, k=min(top_k, logits.size(-1))) # (k,)
35+
# select last value from top_k above as the pivot
36+
pivot = v.select(dim=-1, index=-1).unsqueeze(-1) # (1,)
37+
# mask values smaller than pivot to -inf since these should be pruned
38+
logits = torch.where(logits < pivot, -float("Inf"), logits) # (vocab_size, )
39+
40+
# normalize
41+
probs = torch.nn.functional.softmax(logits, dim=-1)
42+
43+
return torch.argmax(probs, dim=-1, keepdim=True).to(dtype=torch.int)
44+
45+
46+
def generate_next_token(
47+
model,
48+
x: torch.Tensor,
49+
*,
50+
temperature: float = 1.0,
51+
top_k: Optional[int] = None,
52+
) -> Tuple[torch.Tensor, torch.Tensor]:
53+
logits = model(x) # (B, T, vocab_size)
54+
return (
55+
sample(
56+
logits[0, -1, :].clone(), temperature=temperature, top_k=top_k
57+
).unsqueeze(-1),
58+
logits,
59+
)
60+
61+
62+
@torch.inference_mode()
63+
def generate(
64+
model,
65+
prompt: torch.Tensor,
66+
*,
67+
max_generated_tokens: int,
68+
temperature: float = 1.0,
69+
top_k: Optional[int] = None,
70+
custom_generate_next_token: Optional[Callable] = None,
71+
) -> Tuple[torch.Tensor, torch.Tensor]:
72+
""" """
73+
74+
prompt = prompt.view(1, -1) if prompt.ndim == 1 else prompt
75+
76+
if custom_generate_next_token is None:
77+
_generate_next_token = generate_next_token
78+
else:
79+
_generate_next_token = custom_generate_next_token
80+
81+
generated_tokens = prompt.clone()
82+
83+
tokens, generated_logits = generate_next_token(
84+
model,
85+
x=prompt,
86+
temperature=temperature,
87+
top_k=top_k,
88+
)
89+
90+
generated_tokens = torch.cat([generated_tokens, tokens], dim=-1)
91+
92+
for _ in range(max_generated_tokens - 1):
93+
tokens = generated_tokens.clone()
94+
tokens, logits = _generate_next_token(
95+
model,
96+
x=tokens.clone(),
97+
temperature=temperature,
98+
top_k=top_k,
99+
)
100+
101+
generated_tokens = torch.cat([generated_tokens, tokens], dim=-1)
102+
generated_logits = logits
103+
104+
return generated_tokens, generated_logits

test/generate/test_generate.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
from typing import Optional
9+
10+
import torch
11+
import torch.distributed.checkpoint as dcp
12+
13+
from generation import generate
14+
from torchtitan import utils
15+
16+
from torchtitan.config_manager import JobConfig
17+
from torchtitan.datasets import build_tokenizer
18+
from torchtitan.logging import init_logger, logger
19+
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
20+
21+
22+
def example_generate(
23+
config_path: str,
24+
checkpoint_path: str,
25+
prompt: str,
26+
*,
27+
device: str = "cuda",
28+
temperature: float = 1.0,
29+
max_generated_tokens: int = 32,
30+
top_k: Optional[int] = None,
31+
):
32+
init_logger()
33+
color = utils.Color
34+
35+
# Load configuration from toml file
36+
config = JobConfig()
37+
config.parse_args([f"--job.config_file={config_path}"])
38+
config._validate_config()
39+
40+
# Load tokenizer and model configuration
41+
tokenizer = build_tokenizer(
42+
model_name_to_tokenizer[config.model.name], config.model.tokenizer_path
43+
)
44+
model_cls = model_name_to_cls[config.model.name]
45+
model_config = models_config[config.model.name][config.model.flavor]
46+
model_config.vocab_size = tokenizer.n_words
47+
48+
# Load model and checkpoint
49+
with torch.device(device):
50+
model = model_cls.from_model_args(model_config)
51+
state_dict = model.state_dict()
52+
53+
precompute = False
54+
if "freqs_cis" in state_dict:
55+
del state_dict["freqs_cis"]
56+
precompute = True
57+
58+
logger.info(f"Loading checkpoint at: {checkpoint_path}")
59+
dcp.load(state_dict, checkpoint_id=checkpoint_path)
60+
61+
# Precompute frequency if required
62+
if precompute:
63+
model.freqs_cis = model._precompute_freqs_cis().to(args.device)
64+
65+
# Encode input prompt and generate response
66+
input_ids = torch.tensor(
67+
tokenizer.encode(prompt, bos=False, eos=False), dtype=torch.long
68+
).to(device)
69+
logger.info(f"{color.red}Input tokens: {len(input_ids)}{color.reset}")
70+
71+
responses, _ = generate(
72+
model,
73+
input_ids,
74+
temperature=temperature,
75+
max_generated_tokens=max_generated_tokens,
76+
top_k=top_k,
77+
)
78+
logger.info(
79+
f"{color.blue}Output tokens: {len(responses[0]) - len(input_ids)}{color.reset}"
80+
)
81+
82+
response = tokenizer.decode(
83+
[token.item() for token in responses[0][len(input_ids) :]]
84+
)
85+
logger.info(f"{color.red}{prompt}{color.blue}{response}")
86+
87+
88+
if __name__ == "__main__":
89+
parser = argparse.ArgumentParser(description="Test generation")
90+
parser.add_argument(
91+
"--config", type=str, required=True, help="TOML config file path (required)"
92+
)
93+
parser.add_argument(
94+
"--checkpoint",
95+
type=str,
96+
required=True,
97+
help="Checkpoint path to load (required)",
98+
)
99+
parser.add_argument(
100+
"--device",
101+
type=str,
102+
default="cuda",
103+
choices=["cpu", "cuda"],
104+
help="Device to load model on. Default is 'cuda'",
105+
)
106+
parser.add_argument(
107+
"--temperature",
108+
type=float,
109+
default=1.0,
110+
help="Sampling temperature. Default is 1.0",
111+
)
112+
parser.add_argument(
113+
"--max_generated_tokens",
114+
type=int,
115+
default=32,
116+
help="Max number of tokens to generate. Default is 32",
117+
)
118+
parser.add_argument(
119+
"--top_k", type=int, help="Prune to select from top_k probabilities. Optional"
120+
)
121+
parser.add_argument(
122+
"--prompt",
123+
type=str,
124+
default="Hello! How are you?",
125+
help="Input prompt for generation",
126+
)
127+
128+
args = parser.parse_args()
129+
130+
example_generate(
131+
config_path=args.config,
132+
checkpoint_path=args.checkpoint,
133+
prompt=args.prompt,
134+
device=args.device,
135+
temperature=args.temperature,
136+
max_generated_tokens=args.max_generated_tokens,
137+
top_k=args.top_k,
138+
)

0 commit comments

Comments
 (0)