Skip to content

Initial setup for export benchmarking suite and some models #2596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions torchbenchmark/models/hf_Qwen2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
from torchbenchmark.tasks import NLP
from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceModel
from transformers import AutoTokenizer, DynamicCache, AutoModelForCausalLM


class Model(HuggingFaceModel):
task = NLP.LANGUAGE_MODELING
DEFAULT_EVAL_BSIZE = 1
DEFAULT_EVAL_CUDA_PRECISION = "fp16"

def __init__(self, test="inference", device="cuda", batch_size=None, extra_args=[]):
# self.device = "cuda"
super().__init__(
name="hf_Qwen2",
test=test,
device=device,
batch_size=batch_size,
extra_args=extra_args,
)

prompt = "What is the best way to debug python script?"
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B")
inputs = tokenizer(prompt, return_tensors="pt")

input_ids = inputs.input_ids.cuda()
attention_mask = inputs.attention_mask.cuda()

self.example_inputs = ((), {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": DynamicCache(),
"use_cache": True
})
self.model.to(self.device)

def train(self):
raise NotImplementedError("Training is not implemented.")

def get_module(self):
return self.model, self.example_inputs

def eval(self):
example_inputs_args, example_inputs_kwargs = self.example_inputs
example_inputs_kwargs["past_key_values"] = DynamicCache()
self.model.eval()
self.model(*example_inputs_args, **example_inputs_kwargs)
11 changes: 11 additions & 0 deletions torchbenchmark/models/hf_Qwen2/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import os

from torchbenchmark.util.framework.huggingface.patch_hf import (
cache_model,
patch_transformers,
)

if __name__ == "__main__":
patch_transformers()
model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
cache_model(model_name)
13 changes: 13 additions & 0 deletions torchbenchmark/models/hf_Qwen2/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
devices:
NVIDIA A100-SXM4-40GB:
eval_batch_size: 1
eval_benchmark: false
eval_deterministic: false
eval_nograd: true
not_implemented:
- device: cpu
- test: train
- device: cuda
- test: train
train_benchmark: false
train_deterministic: false
1 change: 1 addition & 0 deletions torchbenchmark/models/hf_Qwen2/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
transformers
82 changes: 82 additions & 0 deletions torchbenchmark/models/hf_minicpm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
from torchbenchmark.tasks import NLP
from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceModel
from transformers import AutoTokenizer, DynamicCache, AutoModelForCausalLM
import librosa
from contextlib import contextmanager
from pathlib import Path
import torch.utils._pytree as pytree

def copy_tensors(inputs):
return pytree.tree_map_only(torch.Tensor, torch.clone, inputs)

def add_sampling_hook(module, samples, hooks):
def _(module, args, kwargs):
print("INSIDE HOOK")
samples.append(copy_tensors((args, kwargs)))

hook = module.register_forward_pre_hook(_, prepend=True, with_kwargs=True)
hooks.append(hook)
return hook


class Model(HuggingFaceModel):
task = NLP.LANGUAGE_MODELING
DEFAULT_EVAL_BSIZE = 1
DEFAULT_EVAL_CUDA_PRECISION = "fp16"

def __init__(self, test="inference", device="cuda", batch_size=None, extra_args=[]):
# self.device = "cuda"
super().__init__(
name="hf_minicpm",
test=test,
device=device,
batch_size=batch_size,
extra_args=extra_args,
)


prompt = "What is the best way to debug python script?"
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True)
inputs = tokenizer(prompt, return_tensors="pt")

self.model.init_tts()
self.model.tts.float()

class WrapperModule(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, *args, **kwargs):
return self.model.generate(*args, **kwargs)

self.model = WrapperModule(self.model.tts)

self.example_inputs = ((), {
"input_ids": torch.zeros((1, 303, 4), device=self.device),
"past_key_values": [
(
torch.randn((1, self.model.model.config.num_attention_heads, 302, 64), device=self.device),
torch.randn((1, self.model.model.config.num_attention_heads, 302, 64), device=self.device),
) for _ in range(self.model.model.config.num_hidden_layers)
],
"temperature": torch.tensor([0.1000, 0.3000, 0.1000, 0.3000], device=self.device),
"eos_token": torch.tensor([625], device=self.device),
"streaming_tts_text_mask": torch.ones([303], dtype=torch.int8, device=self.device),
"max_new_token": 25,
})

self.model.to(self.device)

def train(self):
raise NotImplementedError("Training is not implemented.")

def get_module(self):
return self.model, self.example_inputs

def eval(self):
example_inputs_args, example_inputs_kwargs = self.example_inputs
example_inputs_kwargs["past_key_values"] = DynamicCache()
self.model.eval()
self.model(*example_inputs_args, **example_inputs_kwargs)
Binary file not shown.
15 changes: 15 additions & 0 deletions torchbenchmark/models/hf_minicpm/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os
import sys
import subprocess

from torchbenchmark.util.framework.huggingface.patch_hf import (
cache_model,
patch_transformers,
)
from utils.python_utils import pip_install_requirements

if __name__ == "__main__":
patch_transformers()
pip_install_requirements()
model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
cache_model(model_name)
13 changes: 13 additions & 0 deletions torchbenchmark/models/hf_minicpm/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
devices:
NVIDIA A100-SXM4-40GB:
eval_batch_size: 1
eval_benchmark: false
eval_deterministic: false
eval_nograd: true
not_implemented:
- device: cpu
- test: train
- device: cuda
- test: train
train_benchmark: false
train_deterministic: false
3 changes: 3 additions & 0 deletions torchbenchmark/models/hf_minicpm/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
flash_attn
vector_quantize_pytorch
vocos
56 changes: 56 additions & 0 deletions torchbenchmark/models/hf_simplescaling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
from torchbenchmark.tasks import NLP
from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceModel
from transformers import AutoTokenizer, DynamicCache, AutoModelForCausalLM


class Model(HuggingFaceModel):
task = NLP.LANGUAGE_MODELING
DEFAULT_EVAL_BSIZE = 1
DEFAULT_EVAL_CUDA_PRECISION = "fp16"

def __init__(self, test="inference", device="cuda", batch_size=None, extra_args=[]):
# self.device = "cuda"
super().__init__(
name="hf_simplescaling",
test=test,
device=device,
batch_size=batch_size,
extra_args=extra_args,
)

tokenizer = AutoTokenizer.from_pretrained("simplescaling/s1.1-32B")

prompt = "How many r in raspberry"
messages = [
{"role": "system", "content": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(self.device)
self.example_inputs = {**model_inputs, "max_new_tokens":512}

class WrapperModel(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, *args, **kwargs):
return self.model.generate(*args, **kwargs)

self.model = WrapperModel(self.model)

def train(self):
raise NotImplementedError("Training is not implemented.")

def get_module(self):
return self.model, self.example_inputs

def eval(self):
example_inputs = self.example_inputs
self.model.eval()
self.model(**example_inputs)
11 changes: 11 additions & 0 deletions torchbenchmark/models/hf_simplescaling/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import os

from torchbenchmark.util.framework.huggingface.patch_hf import (
cache_model,
patch_transformers,
)

if __name__ == "__main__":
patch_transformers()
model_name = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
cache_model(model_name)
15 changes: 15 additions & 0 deletions torchbenchmark/models/hf_simplescaling/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
devices:
NVIDIA A100-SXM4-40GB:
eval_batch_size: 1
eval_benchmark: false
eval_deterministic: false
eval_nograd: true
not_implemented:
- device: cpu
- test: train
- test: eval
- device: cuda
- test: train
- test: eval
train_benchmark: false
train_deterministic: false
1 change: 1 addition & 0 deletions torchbenchmark/models/hf_simplescaling/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
transformers
77 changes: 77 additions & 0 deletions torchbenchmark/models/kokoro/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from argparse import Namespace
from pathlib import Path
from typing import Tuple

import torch
import re
from torchbenchmark.tasks import OTHER

from ...util.model import BenchmarkModel

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
from kokoro import KPipeline
from huggingface_hub import hf_hub_download


def load_single_voice(pipeline, voice: str):
if voice in pipeline.voices:
return pipeline.voices[voice]
if voice.endswith('.pt'):
f = voice
else:
f = hf_hub_download(repo_id="hexgrad/Kokoro-82M", filename=f'voices/{voice}.pt')
assert voice.startswith(pipeline.lang_code)
pack = torch.load(f, weights_only=True)
pipeline.voices[voice] = pack
return pack


class Model(BenchmarkModel):
task = OTHER.OTHER_TASKS
DEFAULT_EVAL_BSIZE = 1

def __init__(self, test, device, batch_size=None, extra_args=[]):
super().__init__(
test=test, device=device, batch_size=batch_size, extra_args=extra_args
)

self.pipeline = KPipeline(lang_code='a')
self.model = self.pipeline.model

text = '''
The sky above the port was the color of television, tuned to a dead channel.
"It's not like I'm using," Case heard someone say, as he shouldered his way through the crowd around the door of the Chat. "It's like my body's developed this massive drug deficiency."
It was a Sprawl voice and a Sprawl joke. The Chatsubo was a bar for professional expatriates; you could drink there for a week and never hear two words in Japanese.

These were to have an enormous impact, not only because they were associated with Constantine, but also because, as in so many other areas, the decisions taken by Constantine (or in his name) were to have great significance for centuries to come. One of the main issues was the shape that Christian churches were to take, since there was not, apparently, a tradition of monumental church buildings when Constantine decided to help the Christian church build a series of truly spectacular structures. The main form that these churches took was that of the basilica, a multipurpose rectangular structure, based ultimately on the earlier Greek stoa, which could be found in most of the great cities of the empire. Christianity, unlike classical polytheism, needed a large interior space for the celebration of its religious services, and the basilica aptly filled that need. We naturally do not know the degree to which the emperor was involved in the design of new churches, but it is tempting to connect this with the secular basilica that Constantine completed in the Roman forum (the so-called Basilica of Maxentius) and the one he probably built in Trier, in connection with his residence in the city at a time when he was still caesar.

[Kokoro](/kˈOkəɹO/) is an open-weight TTS model with 82 million parameters. Despite its lightweight architecture, it delivers comparable quality to larger models while being significantly faster and more cost-efficient. With Apache-licensed weights, [Kokoro](/kˈOkəɹO/) can be deployed anywhere from production environments to personal projects.
'''

pack = load_single_voice(self.pipeline, "af_heart").to(self.device)
text = re.split(r'\n+', text.strip())

for graphemes in text:
_, tokens = self.pipeline.g2p(graphemes)
for gs, ps in self.pipeline.en_tokenize(tokens):
if not ps:
continue
elif len(ps) > 510:
print(f"TODO: Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
continue
input_ids = self.pipeline.p2ii(ps)
self.example_inputs = ((input_ids, pack[len(input_ids)-1], 1.0), {})
break
assert self.example_inputs is not None


def get_module(self):
return self.model, self.example_inputs

def eval(self) -> Tuple[torch.Tensor]:
out = self.model(*self.example_inputs[0], **self.example_inputs[1])
return (out,)

def train(self):
raise NotImplementedError("MAML model doesn't support train.")
4 changes: 4 additions & 0 deletions torchbenchmark/models/kokoro/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from utils.python_utils import pip_install_requirements

if __name__ == "__main__":
pip_install_requirements()
13 changes: 13 additions & 0 deletions torchbenchmark/models/kokoro/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
devices:
NVIDIA A100-SXM4-40GB:
eval_batch_size: 1
eval_benchmark: false
eval_deterministic: false
eval_nograd: true
not_implemented:
- device: cpu
- test: train
- device: cuda
- test: train
train_benchmark: false
train_deterministic: false
2 changes: 2 additions & 0 deletions torchbenchmark/models/kokoro/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
kokoro
soundfile
Loading
Loading