diff --git a/torchbenchmark/models/hf_Qwen2/__init__.py b/torchbenchmark/models/hf_Qwen2/__init__.py new file mode 100644 index 0000000000..bd843ad144 --- /dev/null +++ b/torchbenchmark/models/hf_Qwen2/__init__.py @@ -0,0 +1,47 @@ +import torch +from torchbenchmark.tasks import NLP +from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceModel +from transformers import AutoTokenizer, DynamicCache, AutoModelForCausalLM + + +class Model(HuggingFaceModel): + task = NLP.LANGUAGE_MODELING + DEFAULT_EVAL_BSIZE = 1 + DEFAULT_EVAL_CUDA_PRECISION = "fp16" + + def __init__(self, test="inference", device="cuda", batch_size=None, extra_args=[]): + # self.device = "cuda" + super().__init__( + name="hf_Qwen2", + test=test, + device=device, + batch_size=batch_size, + extra_args=extra_args, + ) + + prompt = "What is the best way to debug python script?" + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B") + inputs = tokenizer(prompt, return_tensors="pt") + + input_ids = inputs.input_ids.cuda() + attention_mask = inputs.attention_mask.cuda() + + self.example_inputs = ((), { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": DynamicCache(), + "use_cache": True + }) + self.model.to(self.device) + + def train(self): + raise NotImplementedError("Training is not implemented.") + + def get_module(self): + return self.model, self.example_inputs + + def eval(self): + example_inputs_args, example_inputs_kwargs = self.example_inputs + example_inputs_kwargs["past_key_values"] = DynamicCache() + self.model.eval() + self.model(*example_inputs_args, **example_inputs_kwargs) diff --git a/torchbenchmark/models/hf_Qwen2/install.py b/torchbenchmark/models/hf_Qwen2/install.py new file mode 100644 index 0000000000..4297eb3df8 --- /dev/null +++ b/torchbenchmark/models/hf_Qwen2/install.py @@ -0,0 +1,11 @@ +import os + +from torchbenchmark.util.framework.huggingface.patch_hf import ( + cache_model, + patch_transformers, +) + +if __name__ == "__main__": + patch_transformers() + model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__))) + cache_model(model_name) diff --git a/torchbenchmark/models/hf_Qwen2/metadata.yaml b/torchbenchmark/models/hf_Qwen2/metadata.yaml new file mode 100644 index 0000000000..b7ae6697a0 --- /dev/null +++ b/torchbenchmark/models/hf_Qwen2/metadata.yaml @@ -0,0 +1,13 @@ +devices: + NVIDIA A100-SXM4-40GB: + eval_batch_size: 1 +eval_benchmark: false +eval_deterministic: false +eval_nograd: true +not_implemented: +- device: cpu +- test: train +- device: cuda +- test: train +train_benchmark: false +train_deterministic: false diff --git a/torchbenchmark/models/hf_Qwen2/requirements.txt b/torchbenchmark/models/hf_Qwen2/requirements.txt new file mode 100644 index 0000000000..976a2b1f39 --- /dev/null +++ b/torchbenchmark/models/hf_Qwen2/requirements.txt @@ -0,0 +1 @@ +transformers diff --git a/torchbenchmark/models/hf_minicpm/__init__.py b/torchbenchmark/models/hf_minicpm/__init__.py new file mode 100644 index 0000000000..be30f42655 --- /dev/null +++ b/torchbenchmark/models/hf_minicpm/__init__.py @@ -0,0 +1,82 @@ +import torch +from torchbenchmark.tasks import NLP +from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceModel +from transformers import AutoTokenizer, DynamicCache, AutoModelForCausalLM +import librosa +from contextlib import contextmanager +from pathlib import Path +import torch.utils._pytree as pytree + +def copy_tensors(inputs): + return pytree.tree_map_only(torch.Tensor, torch.clone, inputs) + +def add_sampling_hook(module, samples, hooks): + def _(module, args, kwargs): + print("INSIDE HOOK") + samples.append(copy_tensors((args, kwargs))) + + hook = module.register_forward_pre_hook(_, prepend=True, with_kwargs=True) + hooks.append(hook) + return hook + + +class Model(HuggingFaceModel): + task = NLP.LANGUAGE_MODELING + DEFAULT_EVAL_BSIZE = 1 + DEFAULT_EVAL_CUDA_PRECISION = "fp16" + + def __init__(self, test="inference", device="cuda", batch_size=None, extra_args=[]): + # self.device = "cuda" + super().__init__( + name="hf_minicpm", + test=test, + device=device, + batch_size=batch_size, + extra_args=extra_args, + ) + + + prompt = "What is the best way to debug python script?" + tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True) + inputs = tokenizer(prompt, return_tensors="pt") + + self.model.init_tts() + self.model.tts.float() + + class WrapperModule(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, *args, **kwargs): + return self.model.generate(*args, **kwargs) + + self.model = WrapperModule(self.model.tts) + + self.example_inputs = ((), { + "input_ids": torch.zeros((1, 303, 4), device=self.device), + "past_key_values": [ + ( + torch.randn((1, self.model.model.config.num_attention_heads, 302, 64), device=self.device), + torch.randn((1, self.model.model.config.num_attention_heads, 302, 64), device=self.device), + ) for _ in range(self.model.model.config.num_hidden_layers) + ], + "temperature": torch.tensor([0.1000, 0.3000, 0.1000, 0.3000], device=self.device), + "eos_token": torch.tensor([625], device=self.device), + "streaming_tts_text_mask": torch.ones([303], dtype=torch.int8, device=self.device), + "max_new_token": 25, + }) + + self.model.to(self.device) + + def train(self): + raise NotImplementedError("Training is not implemented.") + + def get_module(self): + return self.model, self.example_inputs + + def eval(self): + example_inputs_args, example_inputs_kwargs = self.example_inputs + example_inputs_kwargs["past_key_values"] = DynamicCache() + self.model.eval() + self.model(*example_inputs_args, **example_inputs_kwargs) diff --git a/torchbenchmark/models/hf_minicpm/audio_understanding.mp3 b/torchbenchmark/models/hf_minicpm/audio_understanding.mp3 new file mode 100644 index 0000000000..7a98f4c056 Binary files /dev/null and b/torchbenchmark/models/hf_minicpm/audio_understanding.mp3 differ diff --git a/torchbenchmark/models/hf_minicpm/install.py b/torchbenchmark/models/hf_minicpm/install.py new file mode 100644 index 0000000000..02c8abf2aa --- /dev/null +++ b/torchbenchmark/models/hf_minicpm/install.py @@ -0,0 +1,15 @@ +import os +import sys +import subprocess + +from torchbenchmark.util.framework.huggingface.patch_hf import ( + cache_model, + patch_transformers, +) +from utils.python_utils import pip_install_requirements + +if __name__ == "__main__": + patch_transformers() + pip_install_requirements() + model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__))) + cache_model(model_name) diff --git a/torchbenchmark/models/hf_minicpm/metadata.yaml b/torchbenchmark/models/hf_minicpm/metadata.yaml new file mode 100644 index 0000000000..b7ae6697a0 --- /dev/null +++ b/torchbenchmark/models/hf_minicpm/metadata.yaml @@ -0,0 +1,13 @@ +devices: + NVIDIA A100-SXM4-40GB: + eval_batch_size: 1 +eval_benchmark: false +eval_deterministic: false +eval_nograd: true +not_implemented: +- device: cpu +- test: train +- device: cuda +- test: train +train_benchmark: false +train_deterministic: false diff --git a/torchbenchmark/models/hf_minicpm/requirements.txt b/torchbenchmark/models/hf_minicpm/requirements.txt new file mode 100644 index 0000000000..cf614ff519 --- /dev/null +++ b/torchbenchmark/models/hf_minicpm/requirements.txt @@ -0,0 +1,3 @@ +flash_attn +vector_quantize_pytorch +vocos diff --git a/torchbenchmark/models/hf_simplescaling/__init__.py b/torchbenchmark/models/hf_simplescaling/__init__.py new file mode 100644 index 0000000000..55915d3f00 --- /dev/null +++ b/torchbenchmark/models/hf_simplescaling/__init__.py @@ -0,0 +1,56 @@ +import torch +from torchbenchmark.tasks import NLP +from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceModel +from transformers import AutoTokenizer, DynamicCache, AutoModelForCausalLM + + +class Model(HuggingFaceModel): + task = NLP.LANGUAGE_MODELING + DEFAULT_EVAL_BSIZE = 1 + DEFAULT_EVAL_CUDA_PRECISION = "fp16" + + def __init__(self, test="inference", device="cuda", batch_size=None, extra_args=[]): + # self.device = "cuda" + super().__init__( + name="hf_simplescaling", + test=test, + device=device, + batch_size=batch_size, + extra_args=extra_args, + ) + + tokenizer = AutoTokenizer.from_pretrained("simplescaling/s1.1-32B") + + prompt = "How many r in raspberry" + messages = [ + {"role": "system", "content": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step."}, + {"role": "user", "content": prompt} + ] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + model_inputs = tokenizer([text], return_tensors="pt").to(self.device) + self.example_inputs = {**model_inputs, "max_new_tokens":512} + + class WrapperModel(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, *args, **kwargs): + return self.model.generate(*args, **kwargs) + + self.model = WrapperModel(self.model) + + def train(self): + raise NotImplementedError("Training is not implemented.") + + def get_module(self): + return self.model, self.example_inputs + + def eval(self): + example_inputs = self.example_inputs + self.model.eval() + self.model(**example_inputs) diff --git a/torchbenchmark/models/hf_simplescaling/install.py b/torchbenchmark/models/hf_simplescaling/install.py new file mode 100644 index 0000000000..4297eb3df8 --- /dev/null +++ b/torchbenchmark/models/hf_simplescaling/install.py @@ -0,0 +1,11 @@ +import os + +from torchbenchmark.util.framework.huggingface.patch_hf import ( + cache_model, + patch_transformers, +) + +if __name__ == "__main__": + patch_transformers() + model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__))) + cache_model(model_name) diff --git a/torchbenchmark/models/hf_simplescaling/metadata.yaml b/torchbenchmark/models/hf_simplescaling/metadata.yaml new file mode 100644 index 0000000000..e606ef84f0 --- /dev/null +++ b/torchbenchmark/models/hf_simplescaling/metadata.yaml @@ -0,0 +1,15 @@ +devices: + NVIDIA A100-SXM4-40GB: + eval_batch_size: 1 +eval_benchmark: false +eval_deterministic: false +eval_nograd: true +not_implemented: +- device: cpu +- test: train +- test: eval +- device: cuda +- test: train +- test: eval +train_benchmark: false +train_deterministic: false diff --git a/torchbenchmark/models/hf_simplescaling/requirements.txt b/torchbenchmark/models/hf_simplescaling/requirements.txt new file mode 100644 index 0000000000..976a2b1f39 --- /dev/null +++ b/torchbenchmark/models/hf_simplescaling/requirements.txt @@ -0,0 +1 @@ +transformers diff --git a/torchbenchmark/models/kokoro/__init__.py b/torchbenchmark/models/kokoro/__init__.py new file mode 100644 index 0000000000..b26d35c29f --- /dev/null +++ b/torchbenchmark/models/kokoro/__init__.py @@ -0,0 +1,77 @@ +from argparse import Namespace +from pathlib import Path +from typing import Tuple + +import torch +import re +from torchbenchmark.tasks import OTHER + +from ...util.model import BenchmarkModel + +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +from kokoro import KPipeline +from huggingface_hub import hf_hub_download + + +def load_single_voice(pipeline, voice: str): + if voice in pipeline.voices: + return pipeline.voices[voice] + if voice.endswith('.pt'): + f = voice + else: + f = hf_hub_download(repo_id="hexgrad/Kokoro-82M", filename=f'voices/{voice}.pt') + assert voice.startswith(pipeline.lang_code) + pack = torch.load(f, weights_only=True) + pipeline.voices[voice] = pack + return pack + + +class Model(BenchmarkModel): + task = OTHER.OTHER_TASKS + DEFAULT_EVAL_BSIZE = 1 + + def __init__(self, test, device, batch_size=None, extra_args=[]): + super().__init__( + test=test, device=device, batch_size=batch_size, extra_args=extra_args + ) + + self.pipeline = KPipeline(lang_code='a') + self.model = self.pipeline.model + + text = ''' +The sky above the port was the color of television, tuned to a dead channel. +"It's not like I'm using," Case heard someone say, as he shouldered his way through the crowd around the door of the Chat. "It's like my body's developed this massive drug deficiency." +It was a Sprawl voice and a Sprawl joke. The Chatsubo was a bar for professional expatriates; you could drink there for a week and never hear two words in Japanese. + +These were to have an enormous impact, not only because they were associated with Constantine, but also because, as in so many other areas, the decisions taken by Constantine (or in his name) were to have great significance for centuries to come. One of the main issues was the shape that Christian churches were to take, since there was not, apparently, a tradition of monumental church buildings when Constantine decided to help the Christian church build a series of truly spectacular structures. The main form that these churches took was that of the basilica, a multipurpose rectangular structure, based ultimately on the earlier Greek stoa, which could be found in most of the great cities of the empire. Christianity, unlike classical polytheism, needed a large interior space for the celebration of its religious services, and the basilica aptly filled that need. We naturally do not know the degree to which the emperor was involved in the design of new churches, but it is tempting to connect this with the secular basilica that Constantine completed in the Roman forum (the so-called Basilica of Maxentius) and the one he probably built in Trier, in connection with his residence in the city at a time when he was still caesar. + +[Kokoro](/kˈOkəɹO/) is an open-weight TTS model with 82 million parameters. Despite its lightweight architecture, it delivers comparable quality to larger models while being significantly faster and more cost-efficient. With Apache-licensed weights, [Kokoro](/kˈOkəɹO/) can be deployed anywhere from production environments to personal projects. +''' + + pack = load_single_voice(self.pipeline, "af_heart").to(self.device) + text = re.split(r'\n+', text.strip()) + + for graphemes in text: + _, tokens = self.pipeline.g2p(graphemes) + for gs, ps in self.pipeline.en_tokenize(tokens): + if not ps: + continue + elif len(ps) > 510: + print(f"TODO: Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'") + continue + input_ids = self.pipeline.p2ii(ps) + self.example_inputs = ((input_ids, pack[len(input_ids)-1], 1.0), {}) + break + assert self.example_inputs is not None + + + def get_module(self): + return self.model, self.example_inputs + + def eval(self) -> Tuple[torch.Tensor]: + out = self.model(*self.example_inputs[0], **self.example_inputs[1]) + return (out,) + + def train(self): + raise NotImplementedError("MAML model doesn't support train.") diff --git a/torchbenchmark/models/kokoro/install.py b/torchbenchmark/models/kokoro/install.py new file mode 100644 index 0000000000..df4fdfae39 --- /dev/null +++ b/torchbenchmark/models/kokoro/install.py @@ -0,0 +1,4 @@ +from utils.python_utils import pip_install_requirements + +if __name__ == "__main__": + pip_install_requirements() diff --git a/torchbenchmark/models/kokoro/metadata.yaml b/torchbenchmark/models/kokoro/metadata.yaml new file mode 100644 index 0000000000..b7ae6697a0 --- /dev/null +++ b/torchbenchmark/models/kokoro/metadata.yaml @@ -0,0 +1,13 @@ +devices: + NVIDIA A100-SXM4-40GB: + eval_batch_size: 1 +eval_benchmark: false +eval_deterministic: false +eval_nograd: true +not_implemented: +- device: cpu +- test: train +- device: cuda +- test: train +train_benchmark: false +train_deterministic: false diff --git a/torchbenchmark/models/kokoro/requirements.txt b/torchbenchmark/models/kokoro/requirements.txt new file mode 100644 index 0000000000..2d52a55cdb --- /dev/null +++ b/torchbenchmark/models/kokoro/requirements.txt @@ -0,0 +1,2 @@ +kokoro +soundfile diff --git a/torchbenchmark/util/framework/huggingface/basic_configs.py b/torchbenchmark/util/framework/huggingface/basic_configs.py index f0941f7df2..235035d3ab 100644 --- a/torchbenchmark/util/framework/huggingface/basic_configs.py +++ b/torchbenchmark/util/framework/huggingface/basic_configs.py @@ -13,6 +13,24 @@ 'AutoConfig.from_pretrained("gpt2")', "AutoModelForCausalLM", ), + "hf_Qwen2": ( + 512, + 32768, + 'AutoConfig.from_pretrained("Qwen/Qwen2-7B")', + "AutoModelForCausalLM" + ), + "hf_minicpm": ( + 512, + 32768, + 'AutoConfig.from_pretrained("openbmb/MiniCPM-o-2_6", trust_remote_code=True)', + "AutoModelForCausalLM" + ), + "hf_simplescaling": ( + 512, + 1024, + 'AutoConfig.from_pretrained("simplescaling/s1-32B")', + 'AutoModelForCausalLM' + ), "hf_GPT2_large": ( 512, 1024, @@ -199,6 +217,7 @@ HUGGINGFACE_MODELS_REQUIRING_TRUST_REMOTE_CODE = [ "hf_Falcon_7b", "hf_MPT_7b_instruct", + "hf_minicpm", "phi_1_5", "phi_2", "hf_Yi", diff --git a/userbenchmark/export_new_models/__init__.py b/userbenchmark/export_new_models/__init__.py new file mode 100644 index 0000000000..81d4e6361f --- /dev/null +++ b/userbenchmark/export_new_models/__init__.py @@ -0,0 +1 @@ +BM_NAME = "export_new_models" diff --git a/userbenchmark/export_new_models/result_audio_understanding.wav b/userbenchmark/export_new_models/result_audio_understanding.wav new file mode 100644 index 0000000000..40b359175a Binary files /dev/null and b/userbenchmark/export_new_models/result_audio_understanding.wav differ diff --git a/userbenchmark/export_new_models/run.py b/userbenchmark/export_new_models/run.py new file mode 100644 index 0000000000..f2c6a433ff --- /dev/null +++ b/userbenchmark/export_new_models/run.py @@ -0,0 +1,92 @@ +import torch +import importlib +import sys +import pprint +from pathlib import Path +import torch.utils._pytree as pytree + +# Makes sure we setup torchbenchmark +repo = Path(__file__).parent.parent.parent +sys.path.append(str(repo)) + +from userbenchmark.utils import dump_output +from userbenchmark.export_new_models import BM_NAME + +models = [ + "hf_Qwen2", + #"hf_simplescaling", + "hf_minicpm", + "kokoro", +] + + +def assert_equal(a, b): + if a != b: + raise AssertionError("not equal") + + +def compare_output(eager, export): + flat_orig_outputs = pytree.tree_leaves(eager) + flat_loaded_outputs = pytree.tree_leaves(export) + + for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs): + assert_equal(type(orig), type(loaded)) + + # torch.allclose doesn't work for float8 + if isinstance(orig, torch.Tensor) and orig.dtype not in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + ]: + if orig.is_meta: + assert_equal(orig, loaded) + else: + if not torch.allclose(orig, loaded): + raise AssertionError("not equal") + else: + assert_equal(orig, loaded) + + +def get_model(name): + model_module_ = importlib.import_module(f"torchbenchmark.models.{name}") + model_cls = getattr(model_module_, "Model") + model = model_cls(device="cuda", test="eval") + return model + +def run(): + metrics = {} + errors = {} + count_success = 0 + for model_name in models: + print(f"Testing {model_name}") + model = get_model(model_name) + model, example_inputs = model.get_module() + try: + with torch.inference_mode(): + ep = torch.export.export(model, example_inputs[0], example_inputs[1], strict=False).module() + except Exception as e: + errors[model_name] = str(e) + continue + + try: + with torch.inference_mode(): + compare_output(model(*example_inputs[0], **example_inputs[1]), ep.module()(*example_inputs[0], **example_inputs[1])) + except Exception as e: + errors[model_name] = str(e) + continue + count_success += 1 + + metrics["success_rate"] = count_success / len(models) + metrics["errors"] = errors + + result = { + "name": BM_NAME, + "environ": { + "pytorch_git_version": torch.version.git_version, + }, + "metrics": metrics, + } + pprint.pprint(result) + dump_output(BM_NAME, result) + +if __name__ == "__main__": + run()