From e5e55febaee1381ec7f14ab63719dfb5ded8c7f3 Mon Sep 17 00:00:00 2001 From: Jayson Francis Date: Wed, 30 Oct 2024 03:07:20 +0000 Subject: [PATCH 1/8] simple inference, wip --- test/generate/generation.py | 74 +++++++++ test/generate/run_llama_pred.sh | 36 +++++ test/generate/test_generate.py | 159 +++++++++++++++++++ test/generate/test_generate_dist.py | 233 ++++++++++++++++++++++++++++ 4 files changed, 502 insertions(+) create mode 100644 test/generate/generation.py create mode 100644 test/generate/run_llama_pred.sh create mode 100644 test/generate/test_generate.py create mode 100644 test/generate/test_generate_dist.py diff --git a/test/generate/generation.py b/test/generate/generation.py new file mode 100644 index 0000000000..5437ad1c66 --- /dev/null +++ b/test/generate/generation.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch + + +def multinomial_sample_one(probs: torch.Tensor) -> torch.Tensor: + q = torch.empty_like(probs).exponential_(1) + return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.long) + + +def logits_to_probs( + logits: torch.Tensor, + temperature: float = 1.0, + top_k: Optional[int] = None, +) -> torch.Tensor: + + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, k=min(top_k, logits.size(-1))) # (k,) + pivot = v.select(dim=-1, index=-1).unsqueeze(-1) # (1,) + logits = torch.where(logits < pivot, -float("Inf"), logits) # (vocab_size, ) + + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def generate_next_token( + model, + x: torch.Tensor, + *, + temperature: float = 1.0, + top_k: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + + logits = model(x) # (B, T, vocab_size) + probs = logits_to_probs(logits[:, -1, :], temperature, top_k) + next_token = multinomial_sample_one(probs) + return next_token, probs + + +@torch.no_grad() +def generate( + model, + prompt: torch.Tensor, + *, + max_new_tokens: int, + temperature: float = 1.0, + top_k: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + + if prompt.ndim == 1: + prompt = prompt.unsqueeze(0) + + generated_tokens = prompt.clone() + + for i in range(max_new_tokens): + + tokens, logits = generate_next_token( + model, + x=generated_tokens.clone(), + temperature=temperature, + top_k=top_k, + ) + + generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) + + return generated_tokens, logits diff --git a/test/generate/run_llama_pred.sh b/test/generate/run_llama_pred.sh new file mode 100644 index 0000000000..e52454f12c --- /dev/null +++ b/test/generate/run_llama_pred.sh @@ -0,0 +1,36 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +# use envs as local overrides for convenience +# e.g. +# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh +NGPU=${NGPU:-"2"} +LOG_RANK=${LOG_RANK:-0,1} +CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} +CHECKPOINT_DIR=${CHECKPOINT_DIR:-"./outputs/checkpoint/"} +PROMPT=${PROMPT:-"Hello!"} + +overrides="" +if [ $# -ne 0 ]; then + overrides="$*" +fi + +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=ALL +# export NCCL_BLOCKING_WAIT=1 +# export NCCL_ASYNC_ERROR_HANDLING=1 + +torchrun --standalone \ + --nproc_per_node="${NGPU}" \ + --local-ranks-filter="${LOG_RANK}" \ + test/generate/test_generate_dist.py \ + --config="${CONFIG_FILE}" \ + --checkpoint="${CHECKPOINT_DIR}" \ + --prompt="${PROMPT}" \ + ${overrides} diff --git a/test/generate/test_generate.py b/test/generate/test_generate.py new file mode 100644 index 0000000000..793f552a70 --- /dev/null +++ b/test/generate/test_generate.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import time +from typing import Optional + +import torch +import torch.distributed.checkpoint as dcp + +from generation import generate +from torchtitan import utils + +from torchtitan.config_manager import JobConfig +from torchtitan.datasets import build_tokenizer +from torchtitan.logging import init_logger, logger +from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config + + +def example_generate( + config_path: str, + checkpoint_path: str, + prompt: str, + *, + device: str = "cuda", + temperature: float = 1.0, + max_new_tokens: int = 32, + top_k: Optional[int] = None, + seed: Optional[int] = None, +): + init_logger() + color = utils.Color + + # Load configuration from toml file + config = JobConfig() + config.parse_args([f"--job.config_file={config_path}"]) + config._validate_config() + + # Load tokenizer and model configuration + tokenizer = build_tokenizer( + model_name_to_tokenizer[config.model.name], config.model.tokenizer_path + ) + model_cls = model_name_to_cls[config.model.name] + model_config = models_config[config.model.name][config.model.flavor] + model_config.vocab_size = tokenizer.n_words + + # Load model and checkpoint + with torch.device(device): + model = model_cls.from_model_args(model_config) + + model_param_count = utils.get_num_params(model) + logger.info(f"Model Params: {model_param_count:,}") + + state_dict = {"model": model.state_dict()} + + precompute = False + if "freqs_cis" in state_dict["model"]: + del state_dict["model"]["freqs_cis"] + precompute = True + + begin = time.monotonic() + logger.info(f"Loading checkpoint at: {checkpoint_path}") + dcp.load(state_dict, checkpoint_id=checkpoint_path) + logger.info( + f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." + ) + + # Precompute frequency if required + if precompute: + model.freqs_cis = model._precompute_freqs_cis().to(args.device) + + # Encode input prompt and generate response + input_ids = torch.tensor( + tokenizer.encode(prompt, bos=False, eos=False), dtype=torch.long + ).to(device) + + begin = time.monotonic() + responses = generate( + model, + input_ids, + temperature=temperature, + max_new_tokens=max_new_tokens, + top_k=top_k, + ) + + logger.info(f"Generation completed in {time.monotonic() - begin:.2f} seconds.") + logger.info(f"{color.red}Input tokens: {len(input_ids)}{color.reset}") + logger.info( + f"{color.blue}Output tokens: {len(responses[0])-len(input_ids)}{color.reset}" + ) + + response = tokenizer.decode( + [token.item() for token in responses[0][len(input_ids) :]] + ) + + logger.info(f"{color.red}{prompt}{color.blue}{response}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test generation") + parser.add_argument( + "--config", type=str, required=True, help="TOML config file path (required)" + ) + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Checkpoint path to load (required)", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to load model on. Default is 'cuda'", + ) + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="Sampling temperature. Default is 1.0", + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=32, + help="Max number of tokens to generate. Default is 32", + ) + parser.add_argument( + "--batch_size", type=int, default=1, help="Number of samples to run in batch" + ) + parser.add_argument( + "--top_k", type=int, help="Prune to select from top_k probabilities. Optional" + ) + parser.add_argument( + "--seed", type=int, default=42, help="Random seed for reproducibility" + ) + + parser.add_argument( + "--prompt", + type=str, + default="Hello! How are", + help="Input prompt for generation", + ) + + args = parser.parse_args() + + example_generate( + config_path=args.config, + checkpoint_path=args.checkpoint, + prompt=args.prompt, + device=args.device, + temperature=args.temperature, + max_new_tokens=args.max_new_tokens, + top_k=args.top_k, + seed=args.seed, + ) diff --git a/test/generate/test_generate_dist.py b/test/generate/test_generate_dist.py new file mode 100644 index 0000000000..8db93dd690 --- /dev/null +++ b/test/generate/test_generate_dist.py @@ -0,0 +1,233 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import json +import os +import sys +import time +from pathlib import Path + +from typing import Optional + +import torch +import torch.distributed.checkpoint as dcp + +from torchtitan import utils + +from torchtitan.config_manager import JobConfig +from torchtitan.datasets import build_tokenizer +from torchtitan.logging import init_logger, logger +from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config +from torchtitan.parallelisms import models_parallelize_fns, ParallelDims + +# support running w/o installing as package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from generate.generation import generate + + +def example_generate( + config_path: str, + checkpoint_path: str, + prompt: str, + *, + device: str = "cuda", + temperature: float = 1.0, + max_new_tokens: int = 32, + batch_size: int = 1, + top_k: Optional[int] = None, + seed: Optional[int] = None, +): + init_logger() + color = utils.Color + + # Load configuration from toml file + config = JobConfig() + config.parse_args([f"--job.config_file={config_path}"]) + config._validate_config() + + utils.set_determinism(seed) + + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + + model_name = config.model.name + + # Init distributed env + if world_size > 1: + utils.init_distributed(config) + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=-1, + cp=1, + tp=world_size, + pp=1, + world_size=world_size, + enable_loss_parallel=False, + ) + # Build world mesh for parallelism + world_mesh = parallel_dims.build_mesh(device_type="cuda") + + logger.info(f"World Size: {world_size}, Local Rank: {local_rank} on {device}") + + # Tokenizer setup + tokenizer = build_tokenizer( + model_name_to_tokenizer[model_name], config.model.tokenizer_path + ) + + model_config = models_config[model_name][config.model.flavor] + model_config.norm_type = config.model.norm_type + model_config.max_seq_len = config.training.seq_len + model_config.vocab_size = tokenizer.n_words + + model_cls = model_name_to_cls[model_name] + with torch.device("meta"): + model = model_cls.from_model_args(model_config) + + if world_size > 1: + models_parallelize_fns[model_name](model, world_mesh, parallel_dims, config) + + # materalize model + model.to_empty(device="cuda") + model.eval() + + state_dict = {"model": model.state_dict()} + + precompute = False + if "freqs_cis" in state_dict["model"]: + del state_dict["model"]["freqs_cis"] + precompute = True + + # Checkpoint Loading + begin = time.monotonic() + logger.info(f"Loading chkpt at: {checkpoint_path}") + dcp.load(state_dict, checkpoint_id=checkpoint_path) + logger.info(f"Finished loading chkpt in {time.monotonic() - begin:.2f} seconds.") + + if precompute: + model.freqs_cis = model._precompute_freqs_cis().to("cuda") + + # Tokenize prompt and repeat batch_size times + input_ids = ( + ( + torch.tensor( + tokenizer.encode(prompt, bos=True, eos=False), dtype=torch.long + ) + .view(1, -1) + .repeat(batch_size, 1) + ) + .cuda() + .detach() + ) + + # Inference + begin = time.monotonic() + responses, _ = generate( + model, + input_ids, + temperature=temperature, + max_new_tokens=max_new_tokens, + top_k=top_k, + ) + end = time.monotonic() + + prompt_len = input_ids.size(1) # num tokens + + if local_rank == 0: + logger.info(f"Generation completed in {end-begin:.2f} seconds.") + + r, b = color.red, color.blue + + output_data = [] + + for i, response in enumerate(responses): + + inp_tok = response[:prompt_len].tolist() + out_tok = response[prompt_len:].tolist() + + input_text = tokenizer.decode(inp_tok) + output_text = tokenizer.decode(out_tok) + + response_data = { + "response_idx": i, + "input_n_tokens": len(inp_tok), + "output_n_tokens": len(out_tok), + "input_text": input_text, + "output_text": output_text, + } + output_data.append(response_data) + + logger.info(f"{r}\n{input_text}{b}{output_text}\n{color.reset}") + + print(json.dumps(output_data, indent=4)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test generation") + parser.add_argument( + "--config", type=str, required=True, help="TOML config file path (required)" + ) + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Checkpoint path to load (required)", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to load model on. Default is 'cuda'", + ) + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="Sampling temperature. Default is 1.0", + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=32, + help="Max number of tokens to generate. Default is 32", + ) + parser.add_argument( + "--batch_size", type=int, default=1, help="Number of samples to run in batch" + ) + parser.add_argument( + "--top_k", type=int, help="Prune to select from top_k probabilities. Optional" + ) + parser.add_argument( + "--seed", type=int, default=42, help="Random seed for reproducibility" + ) + + parser.add_argument( + "--prompt", + type=str, + default="Hello! How are", + help="Input prompt for generation", + ) + + args = parser.parse_args() + + example_generate( + config_path=args.config, + checkpoint_path=args.checkpoint, + prompt=args.prompt, + device=args.device, + temperature=args.temperature, + max_new_tokens=args.max_new_tokens, + batch_size=args.batch_size, + top_k=args.top_k, + seed=args.seed, + ) + + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() From 0a9dd6d97949b7d7bf79de682fca4a2fbdb20537 Mon Sep 17 00:00:00 2001 From: Jayson Francis Date: Wed, 30 Oct 2024 17:51:15 +0000 Subject: [PATCH 2/8] [wip] consolidated, removed freq_cis precompute --- test/generate/run_llama_pred.sh | 2 +- test/generate/test_generate.py | 144 ++++++++++++----- test/generate/test_generate_dist.py | 233 ---------------------------- 3 files changed, 106 insertions(+), 273 deletions(-) mode change 100644 => 100755 test/generate/run_llama_pred.sh delete mode 100644 test/generate/test_generate_dist.py diff --git a/test/generate/run_llama_pred.sh b/test/generate/run_llama_pred.sh old mode 100644 new mode 100755 index e52454f12c..f797c78795 --- a/test/generate/run_llama_pred.sh +++ b/test/generate/run_llama_pred.sh @@ -29,7 +29,7 @@ fi torchrun --standalone \ --nproc_per_node="${NGPU}" \ --local-ranks-filter="${LOG_RANK}" \ - test/generate/test_generate_dist.py \ + test/generate/test_generate.py \ --config="${CONFIG_FILE}" \ --checkpoint="${CHECKPOINT_DIR}" \ --prompt="${PROMPT}" \ diff --git a/test/generate/test_generate.py b/test/generate/test_generate.py index 793f552a70..2433c5e83a 100644 --- a/test/generate/test_generate.py +++ b/test/generate/test_generate.py @@ -5,19 +5,30 @@ # LICENSE file in the root directory of this source tree. import argparse +import json +import os +import sys import time +from pathlib import Path + from typing import Optional import torch import torch.distributed.checkpoint as dcp -from generation import generate from torchtitan import utils from torchtitan.config_manager import JobConfig from torchtitan.datasets import build_tokenizer from torchtitan.logging import init_logger, logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config +from torchtitan.parallelisms import models_parallelize_fns, ParallelDims + +# support running w/o installing as package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from generate.generation import generate def example_generate( @@ -28,6 +39,7 @@ def example_generate( device: str = "cuda", temperature: float = 1.0, max_new_tokens: int = 32, + batch_size: int = 1, top_k: Optional[int] = None, seed: Optional[int] = None, ): @@ -39,64 +51,116 @@ def example_generate( config.parse_args([f"--job.config_file={config_path}"]) config._validate_config() - # Load tokenizer and model configuration + utils.set_determinism(seed) + + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + + model_name = config.model.name + + # Init distributed env + if world_size > 1: + utils.init_distributed(config) + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=-1, + cp=1, + tp=world_size, + pp=1, + world_size=world_size, + enable_loss_parallel=False, + ) + # Build world mesh for parallelism + world_mesh = parallel_dims.build_mesh(device_type="cuda") + + logger.info(f"World Size: {world_size}, Local Rank: {local_rank} on {device}") + + # Tokenizer setup tokenizer = build_tokenizer( - model_name_to_tokenizer[config.model.name], config.model.tokenizer_path + model_name_to_tokenizer[model_name], config.model.tokenizer_path ) - model_cls = model_name_to_cls[config.model.name] - model_config = models_config[config.model.name][config.model.flavor] + + model_config = models_config[model_name][config.model.flavor] + model_config.norm_type = config.model.norm_type + model_config.max_seq_len = config.training.seq_len model_config.vocab_size = tokenizer.n_words - # Load model and checkpoint - with torch.device(device): + model_cls = model_name_to_cls[model_name] + init_device = "meta" if world_size > 1 else device + with torch.device(init_device): + logger.info(f"Init model on init_device: {init_device}") model = model_cls.from_model_args(model_config) - model_param_count = utils.get_num_params(model) - logger.info(f"Model Params: {model_param_count:,}") + if world_size > 1: + models_parallelize_fns[model_name](model, world_mesh, parallel_dims, config) - state_dict = {"model": model.state_dict()} + # materalize model + model.to_empty(device="cuda") + model.eval() - precompute = False - if "freqs_cis" in state_dict["model"]: - del state_dict["model"]["freqs_cis"] - precompute = True + state_dict = {"model": model.state_dict()} + # Checkpoint Loading begin = time.monotonic() - logger.info(f"Loading checkpoint at: {checkpoint_path}") + logger.info(f"Loading chkpt at: {checkpoint_path}") dcp.load(state_dict, checkpoint_id=checkpoint_path) - logger.info( - f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." + logger.info(f"Finished loading chkpt in {time.monotonic() - begin:.2f} seconds.") + + # Tokenize prompt and repeat batch_size times + input_ids = ( + ( + torch.tensor( + tokenizer.encode(prompt, bos=True, eos=False), dtype=torch.long + ) + .view(1, -1) + .repeat(batch_size, 1) + ) + .cuda() + .detach() ) - # Precompute frequency if required - if precompute: - model.freqs_cis = model._precompute_freqs_cis().to(args.device) - - # Encode input prompt and generate response - input_ids = torch.tensor( - tokenizer.encode(prompt, bos=False, eos=False), dtype=torch.long - ).to(device) - + # Inference begin = time.monotonic() - responses = generate( + responses, _ = generate( model, input_ids, temperature=temperature, max_new_tokens=max_new_tokens, top_k=top_k, ) + end = time.monotonic() - logger.info(f"Generation completed in {time.monotonic() - begin:.2f} seconds.") - logger.info(f"{color.red}Input tokens: {len(input_ids)}{color.reset}") - logger.info( - f"{color.blue}Output tokens: {len(responses[0])-len(input_ids)}{color.reset}" - ) + prompt_len = input_ids.size(1) # num tokens - response = tokenizer.decode( - [token.item() for token in responses[0][len(input_ids) :]] - ) + if local_rank == 0: + logger.info(f"Generation completed in {end-begin:.2f} seconds.") - logger.info(f"{color.red}{prompt}{color.blue}{response}") + r, b = color.red, color.blue + + output_data = [] + + for i, response in enumerate(responses): + + inp_tok = response[:prompt_len].tolist() + out_tok = response[prompt_len:].tolist() + + input_text = tokenizer.decode(inp_tok) + output_text = tokenizer.decode(out_tok) + + response_data = { + "response_idx": i, + "input_n_tokens": len(inp_tok), + "output_n_tokens": len(out_tok), + "input_text": input_text, + "output_text": output_text, + } + output_data.append(response_data) + + logger.info(f"{r}\n{input_text}{b}{output_text}\n{color.reset}") + + print(json.dumps(output_data, indent=4)) if __name__ == "__main__": @@ -134,9 +198,7 @@ def example_generate( parser.add_argument( "--top_k", type=int, help="Prune to select from top_k probabilities. Optional" ) - parser.add_argument( - "--seed", type=int, default=42, help="Random seed for reproducibility" - ) + parser.add_argument("--seed", type=int, help="Random seed for reproducibility") parser.add_argument( "--prompt", @@ -154,6 +216,10 @@ def example_generate( device=args.device, temperature=args.temperature, max_new_tokens=args.max_new_tokens, + batch_size=args.batch_size, top_k=args.top_k, seed=args.seed, ) + + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() diff --git a/test/generate/test_generate_dist.py b/test/generate/test_generate_dist.py deleted file mode 100644 index 8db93dd690..0000000000 --- a/test/generate/test_generate_dist.py +++ /dev/null @@ -1,233 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import argparse -import json -import os -import sys -import time -from pathlib import Path - -from typing import Optional - -import torch -import torch.distributed.checkpoint as dcp - -from torchtitan import utils - -from torchtitan.config_manager import JobConfig -from torchtitan.datasets import build_tokenizer -from torchtitan.logging import init_logger, logger -from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config -from torchtitan.parallelisms import models_parallelize_fns, ParallelDims - -# support running w/o installing as package -wd = Path(__file__).parent.parent.resolve() -sys.path.append(str(wd)) - -from generate.generation import generate - - -def example_generate( - config_path: str, - checkpoint_path: str, - prompt: str, - *, - device: str = "cuda", - temperature: float = 1.0, - max_new_tokens: int = 32, - batch_size: int = 1, - top_k: Optional[int] = None, - seed: Optional[int] = None, -): - init_logger() - color = utils.Color - - # Load configuration from toml file - config = JobConfig() - config.parse_args([f"--job.config_file={config_path}"]) - config._validate_config() - - utils.set_determinism(seed) - - world_size = int(os.environ.get("WORLD_SIZE", 1)) - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - device = torch.device(f"cuda:{local_rank}") - torch.cuda.set_device(device) - - model_name = config.model.name - - # Init distributed env - if world_size > 1: - utils.init_distributed(config) - parallel_dims = ParallelDims( - dp_replicate=1, - dp_shard=-1, - cp=1, - tp=world_size, - pp=1, - world_size=world_size, - enable_loss_parallel=False, - ) - # Build world mesh for parallelism - world_mesh = parallel_dims.build_mesh(device_type="cuda") - - logger.info(f"World Size: {world_size}, Local Rank: {local_rank} on {device}") - - # Tokenizer setup - tokenizer = build_tokenizer( - model_name_to_tokenizer[model_name], config.model.tokenizer_path - ) - - model_config = models_config[model_name][config.model.flavor] - model_config.norm_type = config.model.norm_type - model_config.max_seq_len = config.training.seq_len - model_config.vocab_size = tokenizer.n_words - - model_cls = model_name_to_cls[model_name] - with torch.device("meta"): - model = model_cls.from_model_args(model_config) - - if world_size > 1: - models_parallelize_fns[model_name](model, world_mesh, parallel_dims, config) - - # materalize model - model.to_empty(device="cuda") - model.eval() - - state_dict = {"model": model.state_dict()} - - precompute = False - if "freqs_cis" in state_dict["model"]: - del state_dict["model"]["freqs_cis"] - precompute = True - - # Checkpoint Loading - begin = time.monotonic() - logger.info(f"Loading chkpt at: {checkpoint_path}") - dcp.load(state_dict, checkpoint_id=checkpoint_path) - logger.info(f"Finished loading chkpt in {time.monotonic() - begin:.2f} seconds.") - - if precompute: - model.freqs_cis = model._precompute_freqs_cis().to("cuda") - - # Tokenize prompt and repeat batch_size times - input_ids = ( - ( - torch.tensor( - tokenizer.encode(prompt, bos=True, eos=False), dtype=torch.long - ) - .view(1, -1) - .repeat(batch_size, 1) - ) - .cuda() - .detach() - ) - - # Inference - begin = time.monotonic() - responses, _ = generate( - model, - input_ids, - temperature=temperature, - max_new_tokens=max_new_tokens, - top_k=top_k, - ) - end = time.monotonic() - - prompt_len = input_ids.size(1) # num tokens - - if local_rank == 0: - logger.info(f"Generation completed in {end-begin:.2f} seconds.") - - r, b = color.red, color.blue - - output_data = [] - - for i, response in enumerate(responses): - - inp_tok = response[:prompt_len].tolist() - out_tok = response[prompt_len:].tolist() - - input_text = tokenizer.decode(inp_tok) - output_text = tokenizer.decode(out_tok) - - response_data = { - "response_idx": i, - "input_n_tokens": len(inp_tok), - "output_n_tokens": len(out_tok), - "input_text": input_text, - "output_text": output_text, - } - output_data.append(response_data) - - logger.info(f"{r}\n{input_text}{b}{output_text}\n{color.reset}") - - print(json.dumps(output_data, indent=4)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Test generation") - parser.add_argument( - "--config", type=str, required=True, help="TOML config file path (required)" - ) - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Checkpoint path to load (required)", - ) - parser.add_argument( - "--device", - type=str, - default="cuda", - help="Device to load model on. Default is 'cuda'", - ) - parser.add_argument( - "--temperature", - type=float, - default=1.0, - help="Sampling temperature. Default is 1.0", - ) - parser.add_argument( - "--max_new_tokens", - type=int, - default=32, - help="Max number of tokens to generate. Default is 32", - ) - parser.add_argument( - "--batch_size", type=int, default=1, help="Number of samples to run in batch" - ) - parser.add_argument( - "--top_k", type=int, help="Prune to select from top_k probabilities. Optional" - ) - parser.add_argument( - "--seed", type=int, default=42, help="Random seed for reproducibility" - ) - - parser.add_argument( - "--prompt", - type=str, - default="Hello! How are", - help="Input prompt for generation", - ) - - args = parser.parse_args() - - example_generate( - config_path=args.config, - checkpoint_path=args.checkpoint, - prompt=args.prompt, - device=args.device, - temperature=args.temperature, - max_new_tokens=args.max_new_tokens, - batch_size=args.batch_size, - top_k=args.top_k, - seed=args.seed, - ) - - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() From b88eefe1ea91f81cfae56acbc8df741652280614 Mon Sep 17 00:00:00 2001 From: Jayson Francis Date: Wed, 30 Oct 2024 20:20:47 +0000 Subject: [PATCH 3/8] [wip] tp/output updates --- test/generate/generation.py | 22 ++++++------ test/generate/run_llama_pred.sh | 6 ++-- test/generate/test_generate.py | 59 +++++++++++++++++++++++++-------- 3 files changed, 59 insertions(+), 28 deletions(-) diff --git a/test/generate/generation.py b/test/generate/generation.py index 5437ad1c66..1d8913a506 100644 --- a/test/generate/generation.py +++ b/test/generate/generation.py @@ -23,9 +23,9 @@ def logits_to_probs( logits = logits / max(temperature, 1e-5) if top_k is not None: - v, _ = torch.topk(logits, k=min(top_k, logits.size(-1))) # (k,) - pivot = v.select(dim=-1, index=-1).unsqueeze(-1) # (1,) - logits = torch.where(logits < pivot, -float("Inf"), logits) # (vocab_size, ) + v, _ = torch.topk(logits, k=min(top_k, logits.size(-1))) + pivot = v.select(dim=-1, index=-1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) probs = torch.nn.functional.softmax(logits, dim=-1) return probs @@ -48,27 +48,27 @@ def generate_next_token( @torch.no_grad() def generate( model, - prompt: torch.Tensor, + input_ids: torch.Tensor, *, max_new_tokens: int, temperature: float = 1.0, top_k: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - if prompt.ndim == 1: - prompt = prompt.unsqueeze(0) + # ensure batch dimension (T,) --> (B, T) + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) - generated_tokens = prompt.clone() + generated_tokens = input_ids.clone() for i in range(max_new_tokens): - - tokens, logits = generate_next_token( + next_token, logits = generate_next_token( model, - x=generated_tokens.clone(), + x=generated_tokens, temperature=temperature, top_k=top_k, ) - generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) + generated_tokens = torch.cat([generated_tokens, next_token], dim=1) return generated_tokens, logits diff --git a/test/generate/run_llama_pred.sh b/test/generate/run_llama_pred.sh index f797c78795..a5160c415a 100755 --- a/test/generate/run_llama_pred.sh +++ b/test/generate/run_llama_pred.sh @@ -21,10 +21,10 @@ if [ $# -ne 0 ]; then overrides="$*" fi -# export NCCL_DEBUG=INFO +# export NCCL_DEBUG=INFO # TRACE # export NCCL_DEBUG_SUBSYS=ALL -# export NCCL_BLOCKING_WAIT=1 -# export NCCL_ASYNC_ERROR_HANDLING=1 +# export TORCH_NCCL_BLOCKING_WAIT=1 +# export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 torchrun --standalone \ --nproc_per_node="${NGPU}" \ diff --git a/test/generate/test_generate.py b/test/generate/test_generate.py index 2433c5e83a..8370589943 100644 --- a/test/generate/test_generate.py +++ b/test/generate/test_generate.py @@ -21,6 +21,7 @@ from torchtitan.config_manager import JobConfig from torchtitan.datasets import build_tokenizer from torchtitan.logging import init_logger, logger +from torchtitan.metrics import build_gpu_memory_monitor from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config from torchtitan.parallelisms import models_parallelize_fns, ParallelDims @@ -57,6 +58,7 @@ def example_generate( local_rank = int(os.environ.get("LOCAL_RANK", 0)) device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) + gpu_memory_monitor = build_gpu_memory_monitor() model_name = config.model.name @@ -108,6 +110,13 @@ def example_generate( dcp.load(state_dict, checkpoint_id=checkpoint_path) logger.info(f"Finished loading chkpt in {time.monotonic() - begin:.2f} seconds.") + gpu_mem_stats = gpu_memory_monitor.get_peak_stats() + logger.info( + f"GPU memory usage for model: " + f"{gpu_mem_stats.max_reserved_gib:.2f}GiB" + f"({gpu_mem_stats.max_reserved_pct:.2f}%)" + ) + # Tokenize prompt and repeat batch_size times input_ids = ( ( @@ -120,9 +129,10 @@ def example_generate( .cuda() .detach() ) + gpu_memory_monitor.reset_peak_stats() - # Inference - begin = time.monotonic() + # Run generation + t0 = time.monotonic() responses, _ = generate( model, input_ids, @@ -130,36 +140,57 @@ def example_generate( max_new_tokens=max_new_tokens, top_k=top_k, ) - end = time.monotonic() + t1 = time.monotonic() + elapsed_sec = t1 - t0 - prompt_len = input_ids.size(1) # num tokens + # Post process + B, T = responses.size() # B: batch_size, T: total seq length + input_n_tokens = input_ids.size(1) + generated_n_tokens = T - input_n_tokens # == max_new_tokens if local_rank == 0: - logger.info(f"Generation completed in {end-begin:.2f} seconds.") + logger.info(f"Generation completed in {elapsed_sec:.2f} seconds.") r, b = color.red, color.blue - output_data = [] - - for i, response in enumerate(responses): + output_data = { + "metadata": {}, + "responses": [], + } - inp_tok = response[:prompt_len].tolist() - out_tok = response[prompt_len:].tolist() + for i, tokens in enumerate(responses): + inp_tok = tokens[:input_n_tokens].tolist() + out_tok = tokens[input_n_tokens:].tolist() input_text = tokenizer.decode(inp_tok) output_text = tokenizer.decode(out_tok) - response_data = { + _data = { "response_idx": i, - "input_n_tokens": len(inp_tok), - "output_n_tokens": len(out_tok), "input_text": input_text, "output_text": output_text, } - output_data.append(response_data) + output_data["responses"].append(_data) logger.info(f"{r}\n{input_text}{b}{output_text}\n{color.reset}") + gpu_mem_stats = gpu_memory_monitor.get_peak_stats() + output_data["metadata"] = { + "generated_n_tokens": generated_n_tokens, + "input_n_tokens": input_n_tokens, + "generation_time_sec": f"{elapsed_sec:.2f}", + "tokens_per_sec": (B * T) / elapsed_sec, + "batch_size": B, + "seed": seed, + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), + "memory/max_active(GiB)": gpu_mem_stats.max_active_gib, + "memory/max_active(%)": gpu_mem_stats.max_active_pct, + "memory/max_reserved(GiB)": gpu_mem_stats.max_reserved_gib, + "memory/max_reserved(%)": gpu_mem_stats.max_reserved_pct, + "memory/num_alloc_retries": gpu_mem_stats.num_alloc_retries, + "memory/num_ooms": gpu_mem_stats.num_ooms, + "torch_version": torch.__version__, + } print(json.dumps(output_data, indent=4)) From 76ce752d254e2ba502009c4033d10747b2166aa4 Mon Sep 17 00:00:00 2001 From: Jayson Francis Date: Fri, 1 Nov 2024 20:36:35 +0000 Subject: [PATCH 4/8] linting, small clean up --- test/generate/generation.py | 14 +++++++------- test/generate/run_llama_pred.sh | 4 ++-- test/generate/test_generate.py | 13 +++++++++++-- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/test/generate/generation.py b/test/generate/generation.py index 1d8913a506..969551479e 100644 --- a/test/generate/generation.py +++ b/test/generate/generation.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional, Tuple +from typing import Optional import torch @@ -37,12 +37,12 @@ def generate_next_token( *, temperature: float = 1.0, top_k: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> torch.Tensor: logits = model(x) # (B, T, vocab_size) probs = logits_to_probs(logits[:, -1, :], temperature, top_k) next_token = multinomial_sample_one(probs) - return next_token, probs + return next_token @torch.no_grad() @@ -53,7 +53,7 @@ def generate( max_new_tokens: int, temperature: float = 1.0, top_k: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> torch.Tensor: # ensure batch dimension (T,) --> (B, T) if input_ids.ndim == 1: @@ -61,8 +61,8 @@ def generate( generated_tokens = input_ids.clone() - for i in range(max_new_tokens): - next_token, logits = generate_next_token( + for _ in range(max_new_tokens): + next_token = generate_next_token( model, x=generated_tokens, temperature=temperature, @@ -71,4 +71,4 @@ def generate( generated_tokens = torch.cat([generated_tokens, next_token], dim=1) - return generated_tokens, logits + return generated_tokens diff --git a/test/generate/run_llama_pred.sh b/test/generate/run_llama_pred.sh index a5160c415a..7e5a03e6e7 100755 --- a/test/generate/run_llama_pred.sh +++ b/test/generate/run_llama_pred.sh @@ -9,7 +9,7 @@ set -ex # use envs as local overrides for convenience # e.g. -# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh +# LOG_RANK=0,1 NGPU=4 ./run_llama_pred.sh NGPU=${NGPU:-"2"} LOG_RANK=${LOG_RANK:-0,1} CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} @@ -21,7 +21,7 @@ if [ $# -ne 0 ]; then overrides="$*" fi -# export NCCL_DEBUG=INFO # TRACE +# export NCCL_DEBUG=TRACE # INFO # export NCCL_DEBUG_SUBSYS=ALL # export TORCH_NCCL_BLOCKING_WAIT=1 # export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 diff --git a/test/generate/test_generate.py b/test/generate/test_generate.py index 8370589943..631beea9c1 100644 --- a/test/generate/test_generate.py +++ b/test/generate/test_generate.py @@ -15,6 +15,7 @@ import torch import torch.distributed.checkpoint as dcp +from torch.distributed.elastic.multiprocessing.errors import record from torchtitan import utils @@ -32,6 +33,7 @@ from generate.generation import generate +@record def example_generate( config_path: str, checkpoint_path: str, @@ -54,6 +56,11 @@ def example_generate( utils.set_determinism(seed) + if seed is None: + logger.info("Deterministic off") + else: + logger.info(f"Deterministic on. Using seed: {seed}") + world_size = int(os.environ.get("WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) device = torch.device(f"cuda:{local_rank}") @@ -129,11 +136,12 @@ def example_generate( .cuda() .detach() ) + gpu_memory_monitor.reset_peak_stats() # Run generation t0 = time.monotonic() - responses, _ = generate( + responses = generate( model, input_ids, temperature=temperature, @@ -178,7 +186,7 @@ def example_generate( output_data["metadata"] = { "generated_n_tokens": generated_n_tokens, "input_n_tokens": input_n_tokens, - "generation_time_sec": f"{elapsed_sec:.2f}", + "generation_time_sec": elapsed_sec, "tokens_per_sec": (B * T) / elapsed_sec, "batch_size": B, "seed": seed, @@ -189,6 +197,7 @@ def example_generate( "memory/max_reserved(%)": gpu_mem_stats.max_reserved_pct, "memory/num_alloc_retries": gpu_mem_stats.num_alloc_retries, "memory/num_ooms": gpu_mem_stats.num_ooms, + "world_size": world_size, "torch_version": torch.__version__, } print(json.dumps(output_data, indent=4)) From febda824c92b8f4f9bc5f3b66641b06e517c8b7e Mon Sep 17 00:00:00 2001 From: Jayson Francis Date: Sat, 2 Nov 2024 02:56:08 +0000 Subject: [PATCH 5/8] Adding apply_torchchat_tp --- test/generate/generation.py | 57 ++++++++++++++++++++++++++++++---- test/generate/test_generate.py | 14 ++++++--- 2 files changed, 61 insertions(+), 10 deletions(-) diff --git a/test/generate/generation.py b/test/generate/generation.py index 969551479e..571ec30227 100644 --- a/test/generate/generation.py +++ b/test/generate/generation.py @@ -7,10 +7,51 @@ from typing import Optional import torch +import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed._tensor import Replicate +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + RowwiseParallel, +) + + +def apply_torchchat_tp(model: nn.Module, tp_mesh: DeviceMesh): + # As implemented in torchchat + # https://github.com/pytorch/torchchat/blob/main/torchchat/model.py#L679 + + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel(input_layouts=Replicate()), + "output": ColwiseParallel(output_layouts=Replicate()), + }, + ) + + for layer_id, transformer_block in model.layers.items(): + layer_plan = { + "attention.wq": ColwiseParallel(), + "attention.wk": ColwiseParallel(), + "attention.wv": ColwiseParallel(), + "attention.wo": RowwiseParallel(), + "feed_forward.w1": ColwiseParallel(), + "feed_forward.w2": RowwiseParallel(), + "feed_forward.w3": ColwiseParallel(), + } + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) -def multinomial_sample_one(probs: torch.Tensor) -> torch.Tensor: - q = torch.empty_like(probs).exponential_(1) +def multinomial_sample_one( + probs: torch.Tensor, rng: Optional[torch.Generator] = None +) -> torch.Tensor: + q = torch.empty_like(probs).exponential_(1, generator=rng) return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.long) @@ -19,7 +60,6 @@ def logits_to_probs( temperature: float = 1.0, top_k: Optional[int] = None, ) -> torch.Tensor: - logits = logits / max(temperature, 1e-5) if top_k is not None: @@ -37,11 +77,11 @@ def generate_next_token( *, temperature: float = 1.0, top_k: Optional[int] = None, + rng: Optional[torch.Generator] = None, ) -> torch.Tensor: - logits = model(x) # (B, T, vocab_size) probs = logits_to_probs(logits[:, -1, :], temperature, top_k) - next_token = multinomial_sample_one(probs) + next_token = multinomial_sample_one(probs, rng=rng) return next_token @@ -53,12 +93,16 @@ def generate( max_new_tokens: int, temperature: float = 1.0, top_k: Optional[int] = None, + seed: Optional[int] = None, ) -> torch.Tensor: - # ensure batch dimension (T,) --> (B, T) if input_ids.ndim == 1: input_ids = input_ids.unsqueeze(0) + rng = None + if seed is not None: + rng = torch.Generator(input_ids.device).manual_seed(seed) + generated_tokens = input_ids.clone() for _ in range(max_new_tokens): @@ -67,6 +111,7 @@ def generate( x=generated_tokens, temperature=temperature, top_k=top_k, + rng=rng, ) generated_tokens = torch.cat([generated_tokens, next_token], dim=1) diff --git a/test/generate/test_generate.py b/test/generate/test_generate.py index 631beea9c1..6d0c74860a 100644 --- a/test/generate/test_generate.py +++ b/test/generate/test_generate.py @@ -30,7 +30,7 @@ wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -from generate.generation import generate +from generate.generation import apply_torchchat_tp, generate @record @@ -57,9 +57,9 @@ def example_generate( utils.set_determinism(seed) if seed is None: - logger.info("Deterministic off") + logger.info("Deterministic sampling off") else: - logger.info(f"Deterministic on. Using seed: {seed}") + logger.info(f"Deterministic sampling on. Using seed: {seed}") world_size = int(os.environ.get("WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) @@ -103,7 +103,12 @@ def example_generate( model = model_cls.from_model_args(model_config) if world_size > 1: - models_parallelize_fns[model_name](model, world_mesh, parallel_dims, config) + + use_torchchat_tp = False + if use_torchchat_tp: + apply_torchchat_tp(model, world_mesh["tp"]) # Working + else: + models_parallelize_fns[model_name](model, world_mesh, parallel_dims, config) # materalize model model.to_empty(device="cuda") @@ -147,6 +152,7 @@ def example_generate( temperature=temperature, max_new_tokens=max_new_tokens, top_k=top_k, + seed=seed, ) t1 = time.monotonic() elapsed_sec = t1 - t0 From d89e0ea56bd9815ba721ff74b77e0b813343fca3 Mon Sep 17 00:00:00 2001 From: Jayson Francis Date: Tue, 12 Nov 2024 22:08:34 +0000 Subject: [PATCH 6/8] added generate.md --- docs/generate.md | 37 ++++++++++++++++++++++++++++++++++ test/generate/test_generate.py | 36 +++++++++++++++++++++++---------- 2 files changed, 62 insertions(+), 11 deletions(-) create mode 100644 docs/generate.md diff --git a/docs/generate.md b/docs/generate.md new file mode 100644 index 0000000000..8a08084004 --- /dev/null +++ b/docs/generate.md @@ -0,0 +1,37 @@ +# Model Generation Check + +The `test_generate` script provides a straightforward way to validate models, tokenizers, checkpoints, and device compatibility by running a single forward pass. This script functions as a sanity check to ensure everything is set up correctly. + +While **torchtitan** focuses on advanced features for distributed pre-training, this tool acts as a lightweight integration test to verify runtime setup. For more extensive inference and generation capabilities, consider tools like [pytorch/torchchat](https://github.com/pytorch/torchchat/). + +## Purpose and Use Case + +This script is ideal for users who need to: + +- **Run Sanity Checks**: Confirm that models, tokenizers, and checkpoints load without errors. +- **Test Compatibility**: Execute a forward pass to assess model response and memory usage. +- **Evaluate Device Scaling**: Optionally test distributed generation using tensor parallel (TP) to confirm multi-device functionality. + +## Usage Instructions + +#### Run on a single GPU. + +```bash +NGPU=1 CONFIG_FILE=./train_configs/llama3_8b.toml CHECKPOINT_DIR=./outputs/checkpoint/ \ +PROMPT="What is the meaning of life?" \ +./test/generate/run_llama_pred.sh --max_new_tokens=32 --temperature=0.8 --seed=3 +``` + +#### Run on 4 GPUs + +```bash +NGPU=4 CONFIG_FILE=./train_configs/llama3_8b.toml CHECKPOINT_DIR=./outputs/checkpoint/ \ +PROMPT="What is the meaning of life?" \ +./test/generate/run_llama_pred.sh --max_new_tokens=32 --temperature=0.8 --seed=3 +``` + +#### View Available Arguments + +```bash +> python ./test/generate/test_generate.py --help +``` diff --git a/test/generate/test_generate.py b/test/generate/test_generate.py index 6d0c74860a..3e50d85018 100644 --- a/test/generate/test_generate.py +++ b/test/generate/test_generate.py @@ -39,7 +39,6 @@ def example_generate( checkpoint_path: str, prompt: str, *, - device: str = "cuda", temperature: float = 1.0, max_new_tokens: int = 32, batch_size: int = 1, @@ -103,7 +102,6 @@ def example_generate( model = model_cls.from_model_args(model_config) if world_size > 1: - use_torchchat_tp = False if use_torchchat_tp: apply_torchchat_tp(model, world_mesh["tp"]) # Working @@ -209,6 +207,28 @@ def example_generate( print(json.dumps(output_data, indent=4)) +def load_prompt(prompt): + prompt_path = Path(prompt) + + if prompt_path.exists(): + if prompt_path.is_file(): + try: + content = prompt_path.read_text() + if content: # Ensure the file is not empty + return content + print("Error: Prompt file is empty.") + except IOError as e: + print(f"Error: Unable to read file '{prompt_path}'. {e}") + else: + print(f"Error: Path '{prompt}' is not a file.") + # If not empty, streat as a string + elif prompt: + return prompt + + print("Error: Provided prompt is empty or file does not exist") + sys.exit(1) + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Test generation") parser.add_argument( @@ -220,12 +240,6 @@ def example_generate( required=True, help="Checkpoint path to load (required)", ) - parser.add_argument( - "--device", - type=str, - default="cuda", - help="Device to load model on. Default is 'cuda'", - ) parser.add_argument( "--temperature", type=float, @@ -250,16 +264,16 @@ def example_generate( "--prompt", type=str, default="Hello! How are", - help="Input prompt for generation", + help="Input prompt for generation, either as a string or a path to a .txt file", ) args = parser.parse_args() + prompt_text = load_prompt(args.prompt) example_generate( config_path=args.config, checkpoint_path=args.checkpoint, - prompt=args.prompt, - device=args.device, + prompt=prompt_text, temperature=args.temperature, max_new_tokens=args.max_new_tokens, batch_size=args.batch_size, From 5c3d2f93733e837b92a51ab5db554fb2b5f94dd5 Mon Sep 17 00:00:00 2001 From: Jayson Francis Date: Thu, 21 Nov 2024 01:11:43 +0000 Subject: [PATCH 7/8] refactor to scripts/ --- .../generate.md => scripts/generate/README.md | 10 +- .../generate/_generation.py | 39 ------ scripts/generate/run_llama_generate.sh | 44 +++++++ {test => scripts}/generate/test_generate.py | 118 +++++++++++------- test/generate/run_llama_pred.sh | 36 ------ 5 files changed, 120 insertions(+), 127 deletions(-) rename docs/generate.md => scripts/generate/README.md (67%) rename test/generate/generation.py => scripts/generate/_generation.py (64%) create mode 100755 scripts/generate/run_llama_generate.sh rename {test => scripts}/generate/test_generate.py (81%) delete mode 100755 test/generate/run_llama_pred.sh diff --git a/docs/generate.md b/scripts/generate/README.md similarity index 67% rename from docs/generate.md rename to scripts/generate/README.md index 8a08084004..df252dfc5e 100644 --- a/docs/generate.md +++ b/scripts/generate/README.md @@ -2,7 +2,7 @@ The `test_generate` script provides a straightforward way to validate models, tokenizers, checkpoints, and device compatibility by running a single forward pass. This script functions as a sanity check to ensure everything is set up correctly. -While **torchtitan** focuses on advanced features for distributed pre-training, this tool acts as a lightweight integration test to verify runtime setup. For more extensive inference and generation capabilities, consider tools like [pytorch/torchchat](https://github.com/pytorch/torchchat/). +While **torchtitan** focuses on advanced features for distributed pre-training, this script acts as a lightweight integration test to verify runtime setup. For more extensive inference and generation capabilities, consider tools like [pytorch/torchchat](https://github.com/pytorch/torchchat/). ## Purpose and Use Case @@ -19,19 +19,19 @@ This script is ideal for users who need to: ```bash NGPU=1 CONFIG_FILE=./train_configs/llama3_8b.toml CHECKPOINT_DIR=./outputs/checkpoint/ \ PROMPT="What is the meaning of life?" \ -./test/generate/run_llama_pred.sh --max_new_tokens=32 --temperature=0.8 --seed=3 +./scripts/generate/run_llama_generate.sh --max_new_tokens=32 --temperature=0.8 --seed=3 ``` -#### Run on 4 GPUs +#### Run on 4 GPUs and pipe results to a json file. ```bash NGPU=4 CONFIG_FILE=./train_configs/llama3_8b.toml CHECKPOINT_DIR=./outputs/checkpoint/ \ PROMPT="What is the meaning of life?" \ -./test/generate/run_llama_pred.sh --max_new_tokens=32 --temperature=0.8 --seed=3 +./scripts/generate/run_llama_generate.sh --max_new_tokens=32 --temperature=0.8 --seed=3 --out > output.json ``` #### View Available Arguments ```bash -> python ./test/generate/test_generate.py --help +> python ./scripts/generate/test_generate.py --help ``` diff --git a/test/generate/generation.py b/scripts/generate/_generation.py similarity index 64% rename from test/generate/generation.py rename to scripts/generate/_generation.py index 571ec30227..6cd3a844d7 100644 --- a/test/generate/generation.py +++ b/scripts/generate/_generation.py @@ -7,45 +7,6 @@ from typing import Optional import torch -import torch.nn as nn -from torch.distributed import DeviceMesh -from torch.distributed._tensor import Replicate -from torch.distributed.tensor.parallel import ( - ColwiseParallel, - parallelize_module, - RowwiseParallel, -) - - -def apply_torchchat_tp(model: nn.Module, tp_mesh: DeviceMesh): - # As implemented in torchchat - # https://github.com/pytorch/torchchat/blob/main/torchchat/model.py#L679 - - parallelize_module( - model, - tp_mesh, - { - "tok_embeddings": RowwiseParallel(input_layouts=Replicate()), - "output": ColwiseParallel(output_layouts=Replicate()), - }, - ) - - for layer_id, transformer_block in model.layers.items(): - layer_plan = { - "attention.wq": ColwiseParallel(), - "attention.wk": ColwiseParallel(), - "attention.wv": ColwiseParallel(), - "attention.wo": RowwiseParallel(), - "feed_forward.w1": ColwiseParallel(), - "feed_forward.w2": RowwiseParallel(), - "feed_forward.w3": ColwiseParallel(), - } - - parallelize_module( - module=transformer_block, - device_mesh=tp_mesh, - parallelize_plan=layer_plan, - ) def multinomial_sample_one( diff --git a/scripts/generate/run_llama_generate.sh b/scripts/generate/run_llama_generate.sh new file mode 100755 index 0000000000..ad2946dbcd --- /dev/null +++ b/scripts/generate/run_llama_generate.sh @@ -0,0 +1,44 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -e + +# use envs as local overrides for convenience +# e.g. +# LOG_RANK=0,1 NGPU=4 ./run_llama_generate.sh +NGPU=${NGPU:-"1"} +LOG_RANK=${LOG_RANK:-0} +CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} +CHECKPOINT_DIR=${CHECKPOINT_DIR:-"./outputs/checkpoint/"} +PROMPT=${PROMPT:-""} + +overrides=() +if [ $# -ne 0 ]; then + for arg in "$@"; do + # special case to handle prompt in quotes + if [[ "$arg" == --prompt=* ]]; then + PROMPT="${arg#--prompt=}" + # check if file + if [[ -f "$PROMPT" ]]; then + PROMPT=$(<"$PROMPT") + fi + else + # handle other args + overrides+=("$arg") + fi + done +fi + +set -x +torchrun --standalone \ + --nproc_per_node="${NGPU}" \ + --local-ranks-filter="${LOG_RANK}" \ + scripts/generate/test_generate.py \ + --config="${CONFIG_FILE}" \ + --checkpoint="${CHECKPOINT_DIR}" \ + --prompt="${PROMPT}" \ + "${overrides[@]}" diff --git a/test/generate/test_generate.py b/scripts/generate/test_generate.py similarity index 81% rename from test/generate/test_generate.py rename to scripts/generate/test_generate.py index 3e50d85018..25bf70f3a7 100644 --- a/test/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -15,7 +15,15 @@ import torch import torch.distributed.checkpoint as dcp +import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed._tensor import Replicate from torch.distributed.elastic.multiprocessing.errors import record +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + RowwiseParallel, +) from torchtitan import utils @@ -30,11 +38,42 @@ wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -from generate.generation import apply_torchchat_tp, generate +from generate._generation import generate + + +def apply_torchchat_tp(model: nn.Module, tp_mesh: DeviceMesh): + # As implemented in torchchat + # https://github.com/pytorch/torchchat/blob/main/torchchat/model.py#L679 + + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel(input_layouts=Replicate()), + "output": ColwiseParallel(output_layouts=Replicate()), + }, + ) + + for layer_id, transformer_block in model.layers.items(): + layer_plan = { + "attention.wq": ColwiseParallel(), + "attention.wk": ColwiseParallel(), + "attention.wv": ColwiseParallel(), + "attention.wo": RowwiseParallel(), + "feed_forward.w1": ColwiseParallel(), + "feed_forward.w2": RowwiseParallel(), + "feed_forward.w3": ColwiseParallel(), + } + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) @record -def example_generate( +def test_generate( config_path: str, checkpoint_path: str, prompt: str, @@ -53,6 +92,11 @@ def example_generate( config.parse_args([f"--job.config_file={config_path}"]) config._validate_config() + if len(args.prompt) == 0: + logger.warning( + "The input prompt is empty, model will respond from a empty sequence." + ) + utils.set_determinism(seed) if seed is None: @@ -68,21 +112,6 @@ def example_generate( model_name = config.model.name - # Init distributed env - if world_size > 1: - utils.init_distributed(config) - parallel_dims = ParallelDims( - dp_replicate=1, - dp_shard=-1, - cp=1, - tp=world_size, - pp=1, - world_size=world_size, - enable_loss_parallel=False, - ) - # Build world mesh for parallelism - world_mesh = parallel_dims.build_mesh(device_type="cuda") - logger.info(f"World Size: {world_size}, Local Rank: {local_rank} on {device}") # Tokenizer setup @@ -101,8 +130,22 @@ def example_generate( logger.info(f"Init model on init_device: {init_device}") model = model_cls.from_model_args(model_config) + # Init distributed env if world_size > 1: - use_torchchat_tp = False + utils.init_distributed(config) + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=-1, + cp=1, + tp=world_size, + pp=1, + world_size=world_size, + enable_loss_parallel=False, + ) + # Build world mesh for parallelism + world_mesh = parallel_dims.build_mesh(device_type="cuda") + + use_torchchat_tp = True if use_torchchat_tp: apply_torchchat_tp(model, world_mesh["tp"]) # Working else: @@ -204,29 +247,9 @@ def example_generate( "world_size": world_size, "torch_version": torch.__version__, } - print(json.dumps(output_data, indent=4)) - -def load_prompt(prompt): - prompt_path = Path(prompt) - - if prompt_path.exists(): - if prompt_path.is_file(): - try: - content = prompt_path.read_text() - if content: # Ensure the file is not empty - return content - print("Error: Prompt file is empty.") - except IOError as e: - print(f"Error: Unable to read file '{prompt_path}'. {e}") - else: - print(f"Error: Path '{prompt}' is not a file.") - # If not empty, streat as a string - elif prompt: - return prompt - - print("Error: Provided prompt is empty or file does not exist") - sys.exit(1) + if args.out: + print(json.dumps(output_data, indent=4)) if __name__ == "__main__": @@ -260,20 +283,21 @@ def load_prompt(prompt): ) parser.add_argument("--seed", type=int, help="Random seed for reproducibility") + parser.add_argument("--prompt", type=str, help="Input prompt") + parser.add_argument( - "--prompt", - type=str, - default="Hello! How are", - help="Input prompt for generation, either as a string or a path to a .txt file", + "--out", + action="store_true", + default=False, + help="If specified, prints the report to stdout. Defaults to no output.", ) args = parser.parse_args() - prompt_text = load_prompt(args.prompt) - example_generate( + test_generate( config_path=args.config, checkpoint_path=args.checkpoint, - prompt=prompt_text, + prompt=args.prompt, temperature=args.temperature, max_new_tokens=args.max_new_tokens, batch_size=args.batch_size, diff --git a/test/generate/run_llama_pred.sh b/test/generate/run_llama_pred.sh deleted file mode 100755 index 7e5a03e6e7..0000000000 --- a/test/generate/run_llama_pred.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/usr/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -set -ex - -# use envs as local overrides for convenience -# e.g. -# LOG_RANK=0,1 NGPU=4 ./run_llama_pred.sh -NGPU=${NGPU:-"2"} -LOG_RANK=${LOG_RANK:-0,1} -CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} -CHECKPOINT_DIR=${CHECKPOINT_DIR:-"./outputs/checkpoint/"} -PROMPT=${PROMPT:-"Hello!"} - -overrides="" -if [ $# -ne 0 ]; then - overrides="$*" -fi - -# export NCCL_DEBUG=TRACE # INFO -# export NCCL_DEBUG_SUBSYS=ALL -# export TORCH_NCCL_BLOCKING_WAIT=1 -# export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 - -torchrun --standalone \ - --nproc_per_node="${NGPU}" \ - --local-ranks-filter="${LOG_RANK}" \ - test/generate/test_generate.py \ - --config="${CONFIG_FILE}" \ - --checkpoint="${CHECKPOINT_DIR}" \ - --prompt="${PROMPT}" \ - ${overrides} From ef1a60ec9dbed917b45e58ff23cbdbd3b6032e38 Mon Sep 17 00:00:00 2001 From: Jayson Francis Date: Thu, 21 Nov 2024 19:08:20 +0000 Subject: [PATCH 8/8] rebase, update devices --- scripts/generate/test_generate.py | 55 ++++++++++++++----------------- 1 file changed, 24 insertions(+), 31 deletions(-) diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index 25bf70f3a7..b069e8bdd8 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -26,13 +26,14 @@ ) from torchtitan import utils +from torchtitan.utils import device_module, device_type from torchtitan.config_manager import JobConfig from torchtitan.datasets import build_tokenizer from torchtitan.logging import init_logger, logger -from torchtitan.metrics import build_gpu_memory_monitor +from torchtitan.metrics import build_device_memory_monitor, build_metric_logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config -from torchtitan.parallelisms import models_parallelize_fns, ParallelDims +from torchtitan.parallelisms import ParallelDims # support running w/o installing as package wd = Path(__file__).parent.parent.resolve() @@ -41,10 +42,7 @@ from generate._generation import generate -def apply_torchchat_tp(model: nn.Module, tp_mesh: DeviceMesh): - # As implemented in torchchat - # https://github.com/pytorch/torchchat/blob/main/torchchat/model.py#L679 - +def apply_tp_minus_sp(model: nn.Module, tp_mesh: DeviceMesh): parallelize_module( model, tp_mesh, @@ -106,9 +104,9 @@ def test_generate( world_size = int(os.environ.get("WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) - device = torch.device(f"cuda:{local_rank}") - torch.cuda.set_device(device) - gpu_memory_monitor = build_gpu_memory_monitor() + device = torch.device(f"{device_type}:{local_rank}") + device_module.set_device(device) + device_memory_monitor = build_device_memory_monitor() model_name = config.model.name @@ -143,16 +141,13 @@ def test_generate( enable_loss_parallel=False, ) # Build world mesh for parallelism - world_mesh = parallel_dims.build_mesh(device_type="cuda") + world_mesh = parallel_dims.build_mesh(device_type=device_type) - use_torchchat_tp = True - if use_torchchat_tp: - apply_torchchat_tp(model, world_mesh["tp"]) # Working - else: - models_parallelize_fns[model_name](model, world_mesh, parallel_dims, config) + # apply_tp (with Sequence Parallel) on unevenly sharded sequences would require https://github.com/pytorch/torchtitan/pull/686 + apply_tp_minus_sp(model, world_mesh["tp"]) # materalize model - model.to_empty(device="cuda") + model.to_empty(device=device_type) model.eval() state_dict = {"model": model.state_dict()} @@ -163,11 +158,11 @@ def test_generate( dcp.load(state_dict, checkpoint_id=checkpoint_path) logger.info(f"Finished loading chkpt in {time.monotonic() - begin:.2f} seconds.") - gpu_mem_stats = gpu_memory_monitor.get_peak_stats() + device_mem_stats = device_memory_monitor.get_peak_stats() logger.info( - f"GPU memory usage for model: " - f"{gpu_mem_stats.max_reserved_gib:.2f}GiB" - f"({gpu_mem_stats.max_reserved_pct:.2f}%)" + f"{device_type.upper()} memory usage for model: " + f"{device_mem_stats.max_reserved_gib:.2f}GiB" + f"({device_mem_stats.max_reserved_pct:.2f}%)" ) # Tokenize prompt and repeat batch_size times @@ -179,11 +174,9 @@ def test_generate( .view(1, -1) .repeat(batch_size, 1) ) - .cuda() - .detach() - ) + ).to(device_type) - gpu_memory_monitor.reset_peak_stats() + device_memory_monitor.reset_peak_stats() # Run generation t0 = time.monotonic() @@ -229,7 +222,7 @@ def test_generate( logger.info(f"{r}\n{input_text}{b}{output_text}\n{color.reset}") - gpu_mem_stats = gpu_memory_monitor.get_peak_stats() + device_mem_stats = device_memory_monitor.get_peak_stats() output_data["metadata"] = { "generated_n_tokens": generated_n_tokens, "input_n_tokens": input_n_tokens, @@ -238,12 +231,12 @@ def test_generate( "batch_size": B, "seed": seed, "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), - "memory/max_active(GiB)": gpu_mem_stats.max_active_gib, - "memory/max_active(%)": gpu_mem_stats.max_active_pct, - "memory/max_reserved(GiB)": gpu_mem_stats.max_reserved_gib, - "memory/max_reserved(%)": gpu_mem_stats.max_reserved_pct, - "memory/num_alloc_retries": gpu_mem_stats.num_alloc_retries, - "memory/num_ooms": gpu_mem_stats.num_ooms, + "memory/max_active(GiB)": device_mem_stats.max_active_gib, + "memory/max_active(%)": device_mem_stats.max_active_pct, + "memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib, + "memory/max_reserved(%)": device_mem_stats.max_reserved_pct, + "memory/num_alloc_retries": device_mem_stats.num_alloc_retries, + "memory/num_ooms": device_mem_stats.num_ooms, "world_size": world_size, "torch_version": torch.__version__, }