Skip to content

Float8 inference roofline benchmarks #1126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions benchmarks/float8/dummy_float8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from transformers import AutoConfig
from diffusers import DiffusionPipeline
import warnings
import torch

def get_llm_mm_shapes(model_name, seq_len=512):
"""Extracts matrix shapes for matrix multiplications in attention and feed-forward layers for an LLM model."""
config = AutoConfig.from_pretrained(model_name)

hidden_size = config.hidden_size
num_attention_heads = config.num_attention_heads
intermediate_size = getattr(config, "intermediate_size", hidden_size * 4) # Typically 4x hidden size

d_head = hidden_size // num_attention_heads

matrix_shapes = {
"Attention mm": (seq_len, seq_len, d_head), # Attention score matrix per head
"Input -> Intermediate": (seq_len, hidden_size, intermediate_size), # Feed-forward layer matrix multiplication shapes
"Intermediate -> Output": (seq_len, intermediate_size, hidden_size), # Feed-forward layer matrix multiplication shapes
}

return matrix_shapes.items()

def get_diffusion_mm_shapes(model_name, spatial_dim=64):
"""Extracts matrix shapes for attention and convolution operations in a diffusion model."""
try:
model = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cuda")
except Exception as e:
warnings.warn(f"Could not load {model_name}: {e}")
return None
print(f"Loaded {model_name}, model: {model}")
# Example for attention layers (Q, K, V)
hidden_size = model.config.block_out_channels[-1] # Highest channel dimension in the U-Net
num_attention_heads = getattr(model.config, "num_heads", 8) # Typically defined or defaulted
d_head = hidden_size // num_attention_heads

# Attention shapes
attention_shapes = {
"QK^T": (spatial_dim * spatial_dim, spatial_dim * spatial_dim), # Flattened spatial dimensions
"Attention * V": (spatial_dim * spatial_dim, d_head)
}

# Convolution shapes
conv_shapes = [
(in_ch, out_ch, kernel, kernel)
for in_ch, out_ch in zip(model.config.in_channels, model.config.block_out_channels)
for kernel in [3] # Assuming 3x3 convolutions commonly used in diffusion models
]

return {
"attention_shapes": attention_shapes,
"conv_shapes": conv_shapes
}

# Example usage for extracting and printing matrix shapes
llm_model_names = ["bert-base-uncased", "gpt2", "t5-small", "meta-llama/Llama-3.2-3B", "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF"]
diffusion_model_names = ["stabilityai/stable-diffusion-3.5-large"] # Example diffusion model

# Get LLM matrix shapes
for model_name in llm_model_names:
print(f"\nMatrix shapes for LLM '{model_name}':")
shapes = get_llm_mm_shapes(model_name)
print("shapes:", shapes)

# Get diffusion model matrix shapes
for model_name in diffusion_model_names:
print(f"\nMatrix shapes for Diffusion model '{model_name}':")
shapes = get_diffusion_mm_shapes(model_name)
print("Attention Shapes:", shapes["attention_shapes"])
print("Convolution Shapes:", shapes["conv_shapes"])
# diffusion_shapes = get_diffusion_mm_shapes(diffusion_model_name)
# if diffusion_shapes:
# print("Attention Shapes:", diffusion_shapes["attention_shapes"])
# print("Convolution Shapes:", diffusion_shapes["conv_shapes"])
204 changes: 204 additions & 0 deletions benchmarks/float8/float8_inference_roofline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""
This is a script to estimate the benefit from converting a `torch.nn.Linear`
layer to float8, by estimating the difference in e2e GPU kernel time between:
1. bf16 gemms in fwd and bwd, and
2. float8 gemms in fwd and bwd, and float8 overhead

The gemm times are estimated either from direct measurements via benchmarks,
or with a roofline estimation based on TOPS and peak compute bandwidth of an
NVIDIA H100.

The float8 overhead times are estimated by counting memory reads and writes
based on the specified float8 scaling, and estimating that we can achieve
a certain % of machine peak memory bandwidth when performing these reads and writes.

"""

import torch
import torchao
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.profiler import profile, record_function, ProfilerActivity
from torchao.quantization.quant_api import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only
import copy
from utils import (
get_name_to_shapes_iter,
get_llm_mm_shapes,
get_diffusion_mm_shapes
)
import tqdm
from tabulate import tabulate

# Set the device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64, dtype=torch.bfloat16):
super().__init__()
self.dtype = dtype
self.linear1 = torch.nn.Linear(k, n, bias=False).to(dtype)

def example_inputs(self, m=1, device="cuda"):
return (torch.randn(m, self.linear1.in_features, dtype=self.dtype, device=device),)

def forward(self, x):
x = self.linear1(x)
return x

# Function to benchmark model evaluation with profiling
def benchmark_model_with_profiling(model, input_data, dtype):
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
# with record_function("model_inference"):
for _ in range(1): # Run the model multiple times to warm up the cache
with torch.no_grad():
_ = model(*input_data)
torch.cuda.synchronize()

# Return the profiler output
return prof


def get_gpu_kernel_times(profiler_chrome_trace, gpu_op_name):
# Filter CUDA events
event_data = [(event.key, event.device_time)
for event in profiler_chrome_trace.key_averages()
if event.device_type == torch.autograd.DeviceType.CUDA]

# Calculate overhead time and op time
gpu_op_time, gpu_overhead_time = 0, 0
for event in event_data:
if gpu_op_name in event[0]:
gpu_op_time += event[1]
else:
gpu_overhead_time += event[1]
return gpu_op_time, gpu_overhead_time

def run_gemm_benchmarks(performance_data, name_to_shapes, float8_dtype=torch.float8_e4m3fn, other_dtype=torch.bfloat16, quantization_technique=float8_weight_only, batch_size=1):
# Run benchmarks for each input size
for idx, (name, (m, k, n)) in enumerate(tqdm.tqdm(name_to_shapes)):
print(f"Profiling model with input size: {batch_size, m, k, n} for quantization technique: {quantization_technique}, dtype: {float8_dtype} vs {other_dtype}")

# Initialize the model with the specified dimensions
model = ToyLinearModel(batch_size*m, k, n).eval().to(device)
example_inputs = model.example_inputs(batch_size*m)
model_bf16 = copy.deepcopy(model).to(device) # Copy the model to bf
model_bf16 = torch.compile(model_bf16) # Compile the model

# Quantize the model
model_ref = copy.deepcopy(model).to(device) # Copy the model for quantization
quantize_(model_ref, quantization_technique()) # Quantize model to float8
model_ref = torch.compile(model_ref) # Compile the model

# Profile float8 model evaluation
prof_float8 = benchmark_model_with_profiling(model_ref, example_inputs, float8_dtype)
prof_float8.export_chrome_trace(f"fp8_model_{example_inputs[0].size()[0]}.json") # Save profiling details

# Profile bf16 model evaluation
prof_bf16 = benchmark_model_with_profiling(model_bf16, example_inputs, other_dtype)
prof_bf16.export_chrome_trace(f"bf16_model_{example_inputs[0].size()[0]}.json") # Save profiling details

# Calculate and store GPU kernel times -> op time, overhead time
float8_gpu_op_time, float8_gpu_overhead_time = get_gpu_kernel_times(prof_float8, 'gemm')
bf16_gpu_op_time, bf16_gpu_overhead_time = get_gpu_kernel_times(prof_bf16, 'gemm')

# Add the performance data to the dictionary
# time/1000 -> Convert from microseconds to milliseconds
performance_data['Input Size'].append(f"{(m, k, n)}")
performance_data['float8 Total Kernel Times (ms)'].append((float8_gpu_op_time + float8_gpu_overhead_time) / 1000)
performance_data['bf16 Total Kernel Times (ms)'].append((bf16_gpu_op_time + bf16_gpu_overhead_time) / 1000)
performance_data['float8 Op Kernel Times (ms)'].append(float8_gpu_op_time / 1000)
performance_data['bf16 Op Kernel Times (ms)'].append(bf16_gpu_op_time / 1000)
performance_data['float8 Overhead Kernel Times (ms)'].append(float8_gpu_overhead_time / 1000)
performance_data['bf16 Overhead Kernel Times (ms)'].append(bf16_gpu_overhead_time / 1000)
performance_data['Batch Size'].append(batch_size)

return performance_data


def plot_performance_data(performance_data, x_col, plot_name = 'model_evaluation_gpu_kernel_performance'):
# Plotting the results
plt.figure(figsize=(10, 6))
plt.plot(performance_data[x_col], performance_data['float8 Total Kernel Times (ms)'], marker='o', label='float8')
plt.plot(performance_data[x_col], performance_data['bf16 Total Kernel Times (ms)'], marker='s', label='bf16')
plt.xlabel(x_col)
plt.ylabel('Kernel Time (ms)')
plt.title(plot_name+' performance: float8 vs bf16')
plt.legend()
plt.grid(True)
plt.savefig(plot_name+'.png')


def plot_llm_performance_data_hf_model(model_name, quantization_dtype=torch.float8_e4m3fn, quantization_technique=float8_weight_only, baseline_dtype=torch.bfloat16, batch_sizes=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096,]):
# Dictionary to store performance data
performance_data = {
'Input Size': [],
'float8 Op Kernel Times (ms)': [],
'bf16 Op Kernel Times (ms)': [],
'float8 Overhead Kernel Times (ms)': [],
'bf16 Overhead Kernel Times (ms)': [],
'float8 Total Kernel Times (ms)': [],
'bf16 Total Kernel Times (ms)': [],
'Batch Size': []
}
name_to_shapes = get_llm_mm_shapes(model_name, seq_len=128)
print(f'For Model: {model_name}, and Shapes: {name_to_shapes}')
quantization_technique = float8_weight_only # Change to the quantization technique you want to use
for batch_size in batch_sizes:
performance_data = run_gemm_benchmarks(
performance_data=performance_data,
name_to_shapes=name_to_shapes,
float8_dtype=quantization_dtype,
other_dtype=baseline_dtype,
quantization_technique=quantization_technique,
batch_size=batch_size,
)
df_performance_data = pd.DataFrame(performance_data)
df_grouped = df_performance_data.groupby('Input Size')
for name, group in df_grouped:
print(f"Group: {name}")
# print(group)
plot_performance_data(group, 'Batch Size', plot_name=f'{model_name.split("/")[-1]}_input_size_{name}_quant_{quantization_technique}')


if __name__ == '__main__':

# Run benchmarks for LLMs
llm_model_names = ["bert-base-uncased", "gpt2", "t5-small", "meta-llama/Llama-3.2-3B", "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF"]
for model_name in llm_model_names:
plot_llm_performance_data_hf_model(
model_name,
quantization_dtype=torch.float8_e4m3fn,
quantization_technique=float8_weight_only,
baseline_dtype=torch.bfloat16,
)

# Run benchmarks for different_matrix_shapes (m, k, n)
name_to_shapes = get_name_to_shapes_iter("square", None, None, None)
# Dictionary to store performance data
performance_data = {
'Input Size': [],
'float8 Op Kernel Times (ms)': [],
'bf16 Op Kernel Times (ms)': [],
'float8 Overhead Kernel Times (ms)': [],
'bf16 Overhead Kernel Times (ms)': [],
'float8 Total Kernel Times (ms)': [],
'bf16 Total Kernel Times (ms)': [],
'Batch Size': []
}
performance_data = run_gemm_benchmarks(
performance_data=performance_data,
name_to_shapes=name_to_shapes,
float8_dtype=torch.float8_e4m3fn,
other_dtype=torch.bfloat16,
quantization_technique=float8_weight_only,
)
plot_performance_data(performance_data, 'Input Size', plot_name='different_matrix_shapes')
# print('Performance data: \n', tabulate(performance_data, headers=performance_data.keys()))
20 changes: 20 additions & 0 deletions benchmarks/float8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import Optional

from torch.profiler import profile, ProfilerActivity, record_function
from transformers import AutoConfig
from diffusers import DiffusionPipeline

def profiler_output_to_filtered_time_by_kernel_name(
prof,
Expand Down Expand Up @@ -202,6 +204,7 @@ def get_name_to_shapes_iter(

raise AssertionError(f'unknown shape_gen_name {shape_gen_name}')


# copy-pasta from https://github.com/vkuzo/pytorch_scripts/blob/main/add_inductor_metadata_to_perf_trace.py
def update_triton_kernels_in_prof_chome_trace_with_torch_logs(
perf_trace_file: str,
Expand Down Expand Up @@ -350,3 +353,20 @@ def get_gpu_kernel_gemm_time_s(f, *args, **kwargs):
raise AssertionError("unexpected format of data")


def get_llm_mm_shapes(model_name, seq_len=512):
"""Extracts matrix shapes for matrix multiplications in attention and feed-forward layers for an LLM model."""
config = AutoConfig.from_pretrained(model_name)

hidden_size = config.hidden_size
num_attention_heads = config.num_attention_heads
intermediate_size = getattr(config, "intermediate_size", hidden_size * 4) # Typically 4x hidden size

d_head = hidden_size // num_attention_heads

matrix_shapes = {
"Attention mm": (seq_len, seq_len, d_head), # Attention score matrix per head
"Input -> Intermediate": (seq_len, hidden_size, intermediate_size), # Feed-forward layer matrix multiplication shapes
"Intermediate -> Output": (seq_len, intermediate_size, hidden_size), # Feed-forward layer matrix multiplication shapes
}

return matrix_shapes.items()
Loading