From 7a07885f4ed841bd8fc11d8938fe9f84ca650312 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 21 Feb 2025 15:44:04 -0800 Subject: [PATCH 01/14] Add files --- .../microbenchmarks/benchmark_config.yml | 19 +++++++++++++++++++ .../microbenchmarks/benchmark_inference.py | 6 ++++++ .../microbenchmarks/benchmark_runner.py | 15 +++++++++++++++ benchmarks/microbenchmarks/utils.py | 0 4 files changed, 40 insertions(+) create mode 100644 benchmarks/microbenchmarks/benchmark_config.yml create mode 100644 benchmarks/microbenchmarks/benchmark_inference.py create mode 100644 benchmarks/microbenchmarks/benchmark_runner.py create mode 100644 benchmarks/microbenchmarks/utils.py diff --git a/benchmarks/microbenchmarks/benchmark_config.yml b/benchmarks/microbenchmarks/benchmark_config.yml new file mode 100644 index 0000000000..e6a9064004 --- /dev/null +++ b/benchmarks/microbenchmarks/benchmark_config.yml @@ -0,0 +1,19 @@ +# Sample configuration for inference kernel benchmarks +quantizations: + - "baseline" + - "int8wo" + - "int4wo-128" + - "int4wo-128-hqq" + +model_params: + matrix_shapes: + - name: "custom" + shapes: [ + [1024, 1024, 1024], # [m, k, n] + [2048, 4096, 1024], + [4096, 4096, 1024] + ] + precision: "torch.bfloat16" + compile: false + device: "cuda" + model_type: "linear" diff --git a/benchmarks/microbenchmarks/benchmark_inference.py b/benchmarks/microbenchmarks/benchmark_inference.py new file mode 100644 index 0000000000..b86647b8dd --- /dev/null +++ b/benchmarks/microbenchmarks/benchmark_inference.py @@ -0,0 +1,6 @@ +""" +Inference benchmark runner + +This script runs inference benchmarks and generates a micro-benchmarking report for it. +- run() function is the main entry point for running inference benchmarks. + """ diff --git a/benchmarks/microbenchmarks/benchmark_runner.py b/benchmarks/microbenchmarks/benchmark_runner.py new file mode 100644 index 0000000000..1b9872075e --- /dev/null +++ b/benchmarks/microbenchmarks/benchmark_runner.py @@ -0,0 +1,15 @@ +""" +Benchmark Runner + +This is the main entry point for the benchmarking application. It reads the YAML configuration +file and orchestrates the entire benchmarking process by: +- Loading and validating benchmark configurations +- Executing benchmark scenarios +- Collecting and processing results +- Generating reports + +Usage: + python benchmark_runner.py [config.yaml] + +The YAML file should contain all necessary configuration parameters for the benchmarks. +""" diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py new file mode 100644 index 0000000000..e69de29bb2 From c8ddfdbf5a337dea7bb82f2cf3a516affa6eb276 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Mon, 24 Feb 2025 21:30:19 -0800 Subject: [PATCH 02/14] Add basic benchmarks for inference --- .../microbenchmarks/benchmark_config.yml | 4 +- .../microbenchmarks/benchmark_inference.py | 66 ++- .../microbenchmarks/benchmark_runner.py | 70 +++ .../microbenchmarks/benchmark_training.py | 13 + ...line_custom_m1024_k1024_n1024_results.json | 3 + ...line_custom_m2048_k4096_n1024_results.json | 3 + ...line_custom_m4096_k4096_n1024_results.json | 3 + ...line_linear_m1024_k1024_n1024_results.json | 3 + ...line_linear_m2048_k4096_n1024_results.json | 3 + ...line_linear_m4096_k4096_n1024_results.json | 3 + ...-hqq_custom_m1024_k1024_n1024_results.json | 3 + ...-hqq_custom_m2048_k4096_n1024_results.json | 3 + ...-hqq_custom_m4096_k4096_n1024_results.json | 3 + ...-hqq_linear_m1024_k1024_n1024_results.json | 3 + ...-hqq_linear_m2048_k4096_n1024_results.json | 3 + ...-hqq_linear_m4096_k4096_n1024_results.json | 3 + ...-128_custom_m1024_k1024_n1024_results.json | 3 + ...-128_custom_m2048_k4096_n1024_results.json | 3 + ...-128_custom_m4096_k4096_n1024_results.json | 3 + ...-128_linear_m1024_k1024_n1024_results.json | 3 + ...-128_linear_m2048_k4096_n1024_results.json | 3 + ...-128_linear_m4096_k4096_n1024_results.json | 3 + ...t8wo_custom_m1024_k1024_n1024_results.json | 3 + ...t8wo_custom_m2048_k4096_n1024_results.json | 3 + ...t8wo_custom_m4096_k4096_n1024_results.json | 3 + ...t8wo_linear_m1024_k1024_n1024_results.json | 3 + ...t8wo_linear_m2048_k4096_n1024_results.json | 3 + ...t8wo_linear_m4096_k4096_n1024_results.json | 3 + benchmarks/microbenchmarks/utils.py | 520 ++++++++++++++++++ 29 files changed, 742 insertions(+), 3 deletions(-) create mode 100644 benchmarks/microbenchmarks/benchmark_training.py create mode 100644 benchmarks/microbenchmarks/results/benchmark_baseline_custom_m1024_k1024_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_baseline_custom_m2048_k4096_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_baseline_custom_m4096_k4096_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_baseline_linear_m1024_k1024_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_baseline_linear_m2048_k4096_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_baseline_linear_m4096_k4096_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m1024_k1024_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m2048_k4096_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m4096_k4096_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m1024_k1024_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m2048_k4096_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m4096_k4096_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m1024_k1024_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m2048_k4096_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m4096_k4096_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m1024_k1024_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m2048_k4096_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m4096_k4096_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m1024_k1024_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m2048_k4096_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m4096_k4096_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m1024_k1024_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m2048_k4096_n1024_results.json create mode 100644 benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m4096_k4096_n1024_results.json diff --git a/benchmarks/microbenchmarks/benchmark_config.yml b/benchmarks/microbenchmarks/benchmark_config.yml index e6a9064004..003c43da0a 100644 --- a/benchmarks/microbenchmarks/benchmark_config.yml +++ b/benchmarks/microbenchmarks/benchmark_config.yml @@ -4,7 +4,7 @@ quantizations: - "int8wo" - "int4wo-128" - "int4wo-128-hqq" - +output_dir: "benchmarks/microbenchmarks/results" # Directory for results and plots model_params: matrix_shapes: - name: "custom" @@ -15,5 +15,5 @@ model_params: ] precision: "torch.bfloat16" compile: false - device: "cuda" + device: "cpu" # Change this to "cuda", "mps", "xpu", or "cpu" as needed model_type: "linear" diff --git a/benchmarks/microbenchmarks/benchmark_inference.py b/benchmarks/microbenchmarks/benchmark_inference.py index b86647b8dd..d57dadcac6 100644 --- a/benchmarks/microbenchmarks/benchmark_inference.py +++ b/benchmarks/microbenchmarks/benchmark_inference.py @@ -3,4 +3,68 @@ This script runs inference benchmarks and generates a micro-benchmarking report for it. - run() function is the main entry point for running inference benchmarks. - """ +""" +from copy import deepcopy +import json +from pathlib import Path + +import torch +from utils import ( + benchmark_model_inference_in_seconds, + clean_caches, + create_model_and_input, + quantize_model, + BenchmarkConfig, +) + +def run(config: BenchmarkConfig) -> None: + """Run inference benchmarks""" + clean_caches() # Clean caches + + # Create output directory if it doesn't exist + Path(config.output_dir).mkdir(parents=True, exist_ok=True) + + base_model, input_data = create_model_and_input( + config.model_type, + config.m, + config.k, + config.n, + dtype=config.precision, + device=config.device, + ) + print( + f"Starting benchmarking for model: {base_model.__class__.__name__} for quantization: {config.quantization}" + ) + + # Use quantize_ to apply each quantization function to the model + m_copy = deepcopy(base_model).eval().to(config.device) + m_copy = quantize_model(m_copy, config.quantization) + + if config.compile: + print("Compiling model....") + m_copy = torch.compile(m_copy, mode=config.compile, fullgraph=True) + + # Run benchmarks + results = {} + + # Benchmark time to run an inference call for quantized model + model_time = benchmark_model_inference_in_seconds( + model=m_copy, input_data=input_data + ) + results[f"benchmark_model_inference_in_seconds"] = model_time + print( + f"Time to run a {base_model.__class__.__name__}: {model_time:.2f} seconds quantized with {config.quantization}" + ) + + # 2. Benchmark time using profiler + # Profile dtype model evaluation + # prof_dtype = benchmark_model_op_with_profiler_in_microseconds(m_copy, input_data, quantized_dtype) + # prof_dtype.export_chrome_trace(f"{quantization}_model_{input_data[0].size()[0]}.json") # Save profiling details + + # 3. Benchmark gemm time using cuda graph + # gemm_time = benchmark_torch_function_in_microseconds(gemm_op, *args, **kwargs) + + # 4. Benchmark op with cuda graph + # time = benchmark_op_with_cuda_graph(op, args) + + return results diff --git a/benchmarks/microbenchmarks/benchmark_runner.py b/benchmarks/microbenchmarks/benchmark_runner.py index 1b9872075e..0303dc0ba3 100644 --- a/benchmarks/microbenchmarks/benchmark_runner.py +++ b/benchmarks/microbenchmarks/benchmark_runner.py @@ -13,3 +13,73 @@ The YAML file should contain all necessary configuration parameters for the benchmarks. """ +from itertools import product +from typing import Any, Dict, List, Tuple + +import torch +import yaml + +from utils import BenchmarkConfig + +def get_shapes_for_config(shape_config: Dict[str, Any]) -> List[Tuple[str, List[int]]]: + """Get shapes for a given configuration""" + name = shape_config["name"] + if name == "custom": + return [(name, shape) for shape in shape_config["shapes"]] + + +def load_benchmark_configs(config_path: str) -> List[BenchmarkConfig]: + """Load benchmark configurations from YAML file""" + with open(config_path, "r") as f: + config_data = yaml.safe_load(f) + + quantizations = config_data["quantizations"] + params = config_data["model_params"] + output_dir = config_data.get("output_dir", "benchmarks/microbenchmarks/results") + + configs = [] + # Process each shape configuration + for shape_config in params["matrix_shapes"]: + shapes = get_shapes_for_config(shape_config) + # Generate combinations for each shape + for quant, (shape_name, shape) in product(quantizations, shapes): + configs.append(BenchmarkConfig( + quantization=quant, + params=params, + shape_name=shape_name, + shape=shape, + output_dir=output_dir, + )) + + return configs + + +def run_benchmarks_from_config(config_path: str) -> None: + """Run benchmarks using configurations from YAML file""" + from benchmark_inference import run as run_inference + + configs = load_benchmark_configs(config_path) + results = [] + print(f"Benchmarking Inference ......") + for config in configs: + print(f"Running: {config.name}") + result = run_inference(config) # Pass the config object directly + results.append(result) + + # TODO: Convert results to csv + # 1. For different shapes for same model + # 2. For different quantizations for same model and shape + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Run benchmarks from config file") + parser.add_argument( + "--config", + type=str, + required=True, + help="Path to benchmark configuration file", + ) + args = parser.parse_args() + run_benchmarks_from_config(args.config) \ No newline at end of file diff --git a/benchmarks/microbenchmarks/benchmark_training.py b/benchmarks/microbenchmarks/benchmark_training.py new file mode 100644 index 0000000000..604f44f489 --- /dev/null +++ b/benchmarks/microbenchmarks/benchmark_training.py @@ -0,0 +1,13 @@ +""" +Training benchmark runner + +This script runs training benchmarks and generates a micro-benchmarking report for it. + - run() function is the main entry point for running training benchmarks. +""" + +from utils import BenchmarkConfig + + +def run(config: BenchmarkConfig) -> None: + """Run training benchmarks""" + raise NotImplementedError("Training benchmarks are not implemented yet. This is a placeholder function.") diff --git a/benchmarks/microbenchmarks/results/benchmark_baseline_custom_m1024_k1024_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_baseline_custom_m1024_k1024_n1024_results.json new file mode 100644 index 0000000000..1abee3d60b --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_baseline_custom_m1024_k1024_n1024_results.json @@ -0,0 +1,3 @@ +{ + "inference_time_us": 15562.51660000001 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_baseline_custom_m2048_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_baseline_custom_m2048_k4096_n1024_results.json new file mode 100644 index 0000000000..14c88528fa --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_baseline_custom_m2048_k4096_n1024_results.json @@ -0,0 +1,3 @@ +{ + "inference_time_us": 110424.1749999999 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_baseline_custom_m4096_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_baseline_custom_m4096_k4096_n1024_results.json new file mode 100644 index 0000000000..d1ce48c394 --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_baseline_custom_m4096_k4096_n1024_results.json @@ -0,0 +1,3 @@ +{ + "inference_time_us": 242173.30819999994 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_baseline_linear_m1024_k1024_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_baseline_linear_m1024_k1024_n1024_results.json new file mode 100644 index 0000000000..24ce509670 --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_baseline_linear_m1024_k1024_n1024_results.json @@ -0,0 +1,3 @@ +{ + "benchmark_model_inference_in_seconds": 0.013309416799999951 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_baseline_linear_m2048_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_baseline_linear_m2048_k4096_n1024_results.json new file mode 100644 index 0000000000..736461b641 --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_baseline_linear_m2048_k4096_n1024_results.json @@ -0,0 +1,3 @@ +{ + "benchmark_model_inference_in_seconds": 0.11298482500000002 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_baseline_linear_m4096_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_baseline_linear_m4096_k4096_n1024_results.json new file mode 100644 index 0000000000..67d64673c7 --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_baseline_linear_m4096_k4096_n1024_results.json @@ -0,0 +1,3 @@ +{ + "benchmark_model_inference_in_seconds": 0.24748350000000005 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m1024_k1024_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m1024_k1024_n1024_results.json new file mode 100644 index 0000000000..c1f3d71203 --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m1024_k1024_n1024_results.json @@ -0,0 +1,3 @@ +{ + "inference_time_us": 14940.825000000046 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m2048_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m2048_k4096_n1024_results.json new file mode 100644 index 0000000000..f1769826ef --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m2048_k4096_n1024_results.json @@ -0,0 +1,3 @@ +{ + "inference_time_us": 122228.56680000013 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m4096_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m4096_k4096_n1024_results.json new file mode 100644 index 0000000000..b0bc0a8a45 --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m4096_k4096_n1024_results.json @@ -0,0 +1,3 @@ +{ + "inference_time_us": 260059.6167999999 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m1024_k1024_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m1024_k1024_n1024_results.json new file mode 100644 index 0000000000..8f16ac0420 --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m1024_k1024_n1024_results.json @@ -0,0 +1,3 @@ +{ + "benchmark_model_inference_in_seconds": 0.018115591599999804 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m2048_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m2048_k4096_n1024_results.json new file mode 100644 index 0000000000..a595902e73 --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m2048_k4096_n1024_results.json @@ -0,0 +1,3 @@ +{ + "benchmark_model_inference_in_seconds": 0.11708445820000009 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m4096_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m4096_k4096_n1024_results.json new file mode 100644 index 0000000000..cf3671aae2 --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m4096_k4096_n1024_results.json @@ -0,0 +1,3 @@ +{ + "benchmark_model_inference_in_seconds": 0.24722661660000006 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m1024_k1024_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m1024_k1024_n1024_results.json new file mode 100644 index 0000000000..16521f36b6 --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m1024_k1024_n1024_results.json @@ -0,0 +1,3 @@ +{ + "inference_time_us": 14625.07500000001 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m2048_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m2048_k4096_n1024_results.json new file mode 100644 index 0000000000..163625d2c9 --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m2048_k4096_n1024_results.json @@ -0,0 +1,3 @@ +{ + "inference_time_us": 115082.14999999992 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m4096_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m4096_k4096_n1024_results.json new file mode 100644 index 0000000000..19379b2b00 --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m4096_k4096_n1024_results.json @@ -0,0 +1,3 @@ +{ + "inference_time_us": 275312.9167999998 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m1024_k1024_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m1024_k1024_n1024_results.json new file mode 100644 index 0000000000..c69854bb60 --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m1024_k1024_n1024_results.json @@ -0,0 +1,3 @@ +{ + "benchmark_model_inference_in_seconds": 0.013688116599999845 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m2048_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m2048_k4096_n1024_results.json new file mode 100644 index 0000000000..658fe43027 --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m2048_k4096_n1024_results.json @@ -0,0 +1,3 @@ +{ + "benchmark_model_inference_in_seconds": 0.1140080165999997 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m4096_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m4096_k4096_n1024_results.json new file mode 100644 index 0000000000..ae6d7ae357 --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m4096_k4096_n1024_results.json @@ -0,0 +1,3 @@ +{ + "benchmark_model_inference_in_seconds": 0.2766437331999999 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m1024_k1024_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m1024_k1024_n1024_results.json new file mode 100644 index 0000000000..c3a369b4e3 --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m1024_k1024_n1024_results.json @@ -0,0 +1,3 @@ +{ + "inference_time_us": 14606.591800000146 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m2048_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m2048_k4096_n1024_results.json new file mode 100644 index 0000000000..5262ffae0e --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m2048_k4096_n1024_results.json @@ -0,0 +1,3 @@ +{ + "inference_time_us": 120721.20819999998 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m4096_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m4096_k4096_n1024_results.json new file mode 100644 index 0000000000..e75a3f832b --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m4096_k4096_n1024_results.json @@ -0,0 +1,3 @@ +{ + "inference_time_us": 266769.6750000001 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m1024_k1024_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m1024_k1024_n1024_results.json new file mode 100644 index 0000000000..b023cdea29 --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m1024_k1024_n1024_results.json @@ -0,0 +1,3 @@ +{ + "benchmark_model_inference_in_seconds": 0.01730855819999988 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m2048_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m2048_k4096_n1024_results.json new file mode 100644 index 0000000000..d8f8a11a1d --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m2048_k4096_n1024_results.json @@ -0,0 +1,3 @@ +{ + "benchmark_model_inference_in_seconds": 0.11659547500000009 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m4096_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m4096_k4096_n1024_results.json new file mode 100644 index 0000000000..64a616bd6a --- /dev/null +++ b/benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m4096_k4096_n1024_results.json @@ -0,0 +1,3 @@ +{ + "benchmark_model_inference_in_seconds": 0.24648710819999983 +} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index e69de29bb2..5ff267f770 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -0,0 +1,520 @@ +import time +from typing import Optional, List, Dict, Any + +import torch +from torch.profiler import ProfilerActivity, profile + +from torchao.quantization import ( + MappingType, + PerRow, + PerTensor, + float8_dynamic_activation_float8_weight, + float8_weight_only, + fpx_weight_only, + int4_weight_only, + int8_dynamic_activation_int4_weight, + int8_dynamic_activation_int8_weight, + int8_weight_only, + quantize_, + uintx_weight_only, +) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_5, + unwrap_tensor_subclass, +) + +try: + import triton + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + + +class BenchmarkConfig: + def __init__( + self, + quantization: str, + params: Dict[str, Any], + shape_name: str, + shape: List[int], + output_dir: str, + ): + self.quantization = quantization + self.m, self.k, self.n = shape + self.shape_name = shape_name + self.precision = self._parse_precision(params["precision"]) + self.compile = params.get("compile", False) + self.device = params.get("device", get_default_device()) + self.model_type = params.get("model_type", "linear") + self.output_dir = output_dir + self.name = f"benchmark_{self.quantization}_{self.model_type}_m{self.m}_k{self.k}_n{self.n}" + + @staticmethod + def _parse_precision(precision_str: str) -> torch.dtype: + """Convert string precision to torch dtype""" + return getattr(torch, precision_str.split(".")[-1]) + + def to_dict(self) -> Dict[str, Any]: + """Convert config to dictionary for main function""" + return { + "quantization": self.quantization, + "m": self.m, + "k": self.k, + "n": self.n, + "precision": self.precision, + "compile": self.compile, + "device": self.device, + "model_type": self.model_type, + "output_dir": self.output_dir, + } + + +# TODO: add more models +class ToyLinearModel(torch.nn.Module): + def __init__(self, k=64, n=32, dtype=torch.bfloat16): + super().__init__() + self.linear1 = torch.nn.Linear(k, n, bias=False).to(dtype) + + def forward(self, x): + x = self.linear1(x) + return x + + +class LNLinearSigmoid(torch.nn.Module): + def __init__(self, fc_dim1, fc_dim2, dtype=torch.bfloat16): + super().__init__() + self.ln = torch.nn.LayerNorm(fc_dim1, elementwise_affine=False) + self.fc = torch.nn.Linear(fc_dim1, fc_dim2, bias=False).to(dtype) + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x): + x = self.ln(x) + x = self.fc(x) + x = self.sigmoid(x) + return x + + +def get_default_device() -> str: + try: + if torch.cuda.is_available(): + return "cuda" + elif torch.xpu.is_available(): + return "xpu" + elif torch.backends.mps.is_available(): + return "mps" + except (AssertionError, AttributeError): + print("Warning: Running on CPU as no GPU support was found") + return "cpu" + + +def benchmark_func_call(func, *args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Time taken to run {func.__name__}: {elapsed_time:.2f} seconds") + return result + + +def ffn_only(mod, fqn): + return isinstance(mod, torch.nn.Linear) and "feed_forward" in fqn + + +def not_ffn_only(mod, fqn): + return isinstance(mod, torch.nn.Linear) and not ffn_only(mod, fqn) + + +def ffn_or_attn_only(mod, fqn): + return isinstance(mod, torch.nn.Linear) and ( + "feed_forward" in fqn or "attention" in fqn + ) + + +def quantize_model( + model: torch.nn.Module, + quantization: str, + **kwargs, +): + """Quantize a model inplace or return a new quantized model. + + Args: + model (torch.nn.Module): model to be quantized + quantization (str): quantization method to be used + kwargs: additional arguments to be passed to the quantization method + """ + if "int4wo" in quantization and not HAS_TRITON: + print("Warning: Triton not available, falling back to baseline") + return model + + # Define kwargs + sparsity = kwargs.get("sparsity", None) + precision = kwargs.get("precision", None) + + # Quantization techniques + if "baseline" in quantization: + return model + if "int8wo" in quantization: + quantize_(model, int8_weight_only()) + if "int8dq" in quantization: + if sparsity and "semi" in sparsity: + from torchao.dtypes import SemiSparseLayout + + quantize_( + model, + int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), + filter_fn=ffn_only, + ) + quantize_( + model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only + ) + elif "int8dq_prefill_wo_decode" in quantization: + quantize_( + model, int8_dynamic_activation_int8_weight(weight_only_decode=True) + ) + else: + quantize_(model, int8_dynamic_activation_int8_weight()) + if "int4wo" in quantization: + use_hqq = False + if "hqq" in quantization: + use_hqq = True + group_size = int(quantization.split("-")[1]) + assert group_size in [ + 32, + 64, + 128, + 256, + ], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" + quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) + elif "int8adq-int4w-symm" in quantization: + from torchao.dtypes import CutlassInt4PackedLayout + + quantize_( + model, + int8_dynamic_activation_int4_weight( + group_size=None, + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, + layout=CutlassInt4PackedLayout(), + ), + ) + if "marlin" in quantization: + if "qqq" in quantization: + from torchao.dtypes import MarlinQQQLayout + + quantize_( + model, + int8_dynamic_activation_int4_weight( + group_size=128, + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, + layout=MarlinQQQLayout(), + ), + ) + elif "semi" in sparsity: + from torchao.dtypes import MarlinSparseLayout + + quantize_( + model, + int4_weight_only(layout=MarlinSparseLayout()), + filter_fn=ffn_or_attn_only, + ) + if "fp6" in quantization: + quantize_(model, fpx_weight_only(3, 2)) + elif "embed-int8wo" in quantization: + quantize_( + model, + int8_weight_only(group_size=64), + filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding), + ) + elif "uintx" in quantization: + # uintx-nbits-group_size, e.g. "uintx-2-64" + if "hqq" in quantization: + # uintx-nbits-group_size-hqq + use_hqq = True + else: + use_hqq = False + _quant_args = quantization.split("-") + nbits = int(_quant_args[1]) + assert nbits >= 1 and nbits <= 8, "nbits must be 1 to 8" + _NBITS_TO_DTYPE = { + 1: torch.uint1, + 2: torch.uint2, + 3: torch.uint3, + 4: torch.uint4, + 5: torch.uint5, + 6: torch.uint6, + 7: torch.uint7, + 8: torch.uint8, + } + dtype = _NBITS_TO_DTYPE[nbits] + group_size = int(_quant_args[2]) + quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq)) + elif "int8_dynamic_activation_intx_weight" in quantization: + from torchao.experimental.quant_api import ( + int8_dynamic_activation_intx_weight, + ) + from torchao.quantization.granularity import PerGroup + + assert ( + precision == torch.float32 + ), "int8_dynamic_activation_intx_weight requires using precision=torch.float32" + + # Quantize model + _quant_args = quantization.split("-") + weight_dtype = getattr(torch, f"int{_quant_args[1]}") + granularity = PerGroup(int(_quant_args[2])) + has_weight_zeros = bool(_quant_args[3]) + quantize_( + model, + int8_dynamic_activation_intx_weight( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + ), + ) + elif "float8wo" in quantization: + quantize_(model, float8_weight_only()) + elif "float8dq" in quantization: + granularity = str(quantization.split("-")[-1]) + if granularity == "tensor": + granularity = PerTensor() + elif granularity == "row": + granularity = PerRow() + else: + granularity = PerTensor() + quantize_( + model, float8_dynamic_activation_float8_weight(granularity=granularity) + ) + else: + if not TORCH_VERSION_AT_LEAST_2_5: + unwrap_tensor_subclass(model) + return model + + +# Function to benchmark model evaluation - e2e eval run +def benchmark_model_inference_in_seconds(model, input_data): + # Returns model run time in seconds + if torch.cuda.is_available(): + torch.cuda.synchronize() + + # warm up + for _ in range(2): + model(input_data) + if torch.cuda.is_available(): + torch.cuda.synchronize() + + num_iters = 5 + start_time = time.perf_counter() + with torch.no_grad(): + for _ in range(num_iters): + _ = model(input_data) + if torch.cuda.is_available(): + torch.cuda.synchronize() + end_time = time.perf_counter() + + return (end_time - start_time) / num_iters + + +def benchmark_model_op_with_profiler_in_microseconds(model, input_data, op_name: str): + """Benchmarks model inference using PyTorch profiler to measure GPU kernel execution times. + + This function profiles the model execution and measures the time spent in specific GPU operations + versus overhead time. It performs warmup runs before profiling to ensure accurate measurements. + + Args: + model (torch.nn.Module): PyTorch model to benchmark + input_data (torch.Tensor): Input tensor to run through the model + op_name (str): Name of the GPU operation to measure time for + + Returns: + tuple[float, float]: A tuple containing: + - gpu_op_time (float): Time spent in the specified GPU operation in microseconds + - gpu_overhead_time (float): Time spent in other GPU operations in microseconds + """ + # Warm up + for _ in range(2): + model(input_data) + + # Profile model execution + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True + ) as prof: + with torch.no_grad(): + _ = model(input_data) + torch.cuda.synchronize() + + # Get event data from profiler + event_data = [ + (event.key, event.device_time) + for event in prof.key_averages() + if event.device_type == torch.autograd.DeviceType.CUDA + ] + + # Calculate op time and overhead time + gpu_op_time, gpu_overhead_time = 0, 0 + for event in event_data: + if op_name in event[0]: + gpu_op_time += event[1] + else: + gpu_overhead_time += event[1] + + return gpu_op_time, gpu_overhead_time + + +def create_model_and_input( + model_type: str, + m: int, + k: int, + n: int, + dtype: torch.dtype = torch.bfloat16, + device: str = get_default_device(), +): + """Create a model and input data for benchmarking. + + Args: + model_type (str): type of the model to be created + batch_size (int): batch size of the input data + device (str): device to run the model on + dtype (torch.dtype): data + m, k, n (int): dimensions of the model and input data + """ + if model_type == "linear": + model = ToyLinearModel(k, n, dtype).to(device) + input_data = torch.randn(m, k, device=device, dtype=dtype) + elif model_type == "ln_linear_sigmoid": + model = LNLinearSigmoid(k, n, dtype).to(device) + input_data = torch.randn(m, k, device=device, dtype=dtype) + else: + raise ValueError(f"Unknown model type: {model_type}") + return model, input_data + + +@torch.no_grad() +def benchmark_op_with_cuda_graph(op, *args, **kwargs): + """ + runs benchmark op(*args, **kwargs) avoiding torch.compile overhead + """ + rep = kwargs.pop("rep", 100) + warmup = kwargs.pop("warmup", 25) + with torch.no_grad(): + torch.cuda.synchronize() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + op(*args, **kwargs) + stream.synchronize() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + op(*args, **kwargs) + if TORCH_VERSION_AT_LEAST_2_5: + from torch._inductor.runtime.benchmarking import benchmarker + + res = benchmarker.benchmark_gpu( + lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" + ) + elif TORCH_VERSION_AT_LEAST_2_3: + from torch._inductor.runtime.runtime_utils import do_bench_gpu + + res = do_bench_gpu( + lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" + ) + else: + from torch._inductor.utils import do_bench + + res = do_bench( + lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" + ) + return res + + +def _is_interpolate_mode(mode): + if ( + isinstance(mode, list) + and mode[0] == "interpolate" + and len(mode) == 2 + and isinstance(mode[1], float) + ): + return True + return False + + +def clean_caches(): + import gc + + # Clear everything before starting + gc.collect() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + if hasattr(torch, '_dynamo'): + torch._dynamo.reset() + + +def get_name_to_shapes_iter( + shape_gen_name: str, + M: Optional[int], + K: Optional[int], + N: Optional[int], +): + if shape_gen_name == "llama": + assert ( + M == K == N == None + ), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}" + bsz, seq_len = 4, 4096 + M = bsz * seq_len + # LLaMa 2 70B single-node weight shapes + # assumes fused attn.wqkv and ffn.w13 + # source: https://fburl.com/gsheet/g8onr7rh + name_to_shapes_70b = { + "attn.wqkv": (M, 8192, 1280), + "attn.w0": (M, 1024, 8192), + "ffn.w13": (M, 8192, 7168), + "ffn.w2": (M, 3584, 8192), + } + return name_to_shapes_70b.items() + + elif shape_gen_name == "square": + assert ( + M == K == N == None + ), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}" + name_to_shapes = {} + min_power_of_2 = 8 # 256 + max_power_of_2 = 15 # 32,768 + for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)): + val = 2**power_of_2 + name_to_shapes[idx] = val, val, val + return name_to_shapes.items() + + elif shape_gen_name == "sweep": + assert ( + M == K == N == None + ), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}" + name_to_shapes = {} + min_p2 = 8 # 256 + max_p2 = 15 # 32,768 + counter = 0 + for M_p2 in range(min_p2, max_p2 + 1): + M = 2**M_p2 + for K_p2 in range(min_p2, max_p2 + 1): + K = 2**K_p2 + for N_p2 in range(min_p2, max_p2 + 1): + N = 2**N_p2 + name_to_shapes[counter] = M, K, N + counter += 1 + return name_to_shapes.items() + + elif shape_gen_name == "custom": + assert ( + M is not None and K is not None and N is not None + ), "M, K, N must be specified for custom shape_gen" + name_to_shapes = { + 1: (M, K, N), + } + return name_to_shapes.items() + + raise AssertionError(f"unknown shape_gen_name {shape_gen_name}") \ No newline at end of file From a56f3abc9c249637b868a56ccab79783d775d267 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 25 Feb 2025 11:00:37 -0800 Subject: [PATCH 03/14] Update new quantize_ api --- .../microbenchmarks/benchmark_config.yml | 2 +- .../microbenchmarks/benchmark_inference.py | 17 +++--- .../microbenchmarks/benchmark_runner.py | 9 +++- benchmarks/microbenchmarks/utils.py | 52 ++++++++++--------- 4 files changed, 44 insertions(+), 36 deletions(-) diff --git a/benchmarks/microbenchmarks/benchmark_config.yml b/benchmarks/microbenchmarks/benchmark_config.yml index 003c43da0a..b8ed4b05bb 100644 --- a/benchmarks/microbenchmarks/benchmark_config.yml +++ b/benchmarks/microbenchmarks/benchmark_config.yml @@ -15,5 +15,5 @@ model_params: ] precision: "torch.bfloat16" compile: false - device: "cpu" # Change this to "cuda", "mps", "xpu", or "cpu" as needed + device: "cuda" # Change this to "cuda", "mps", "xpu", or "cpu" as needed model_type: "linear" diff --git a/benchmarks/microbenchmarks/benchmark_inference.py b/benchmarks/microbenchmarks/benchmark_inference.py index d57dadcac6..ce52cb4e96 100644 --- a/benchmarks/microbenchmarks/benchmark_inference.py +++ b/benchmarks/microbenchmarks/benchmark_inference.py @@ -7,17 +7,18 @@ from copy import deepcopy import json from pathlib import Path +from typing import Dict import torch from utils import ( - benchmark_model_inference_in_seconds, + benchmark_model_inference_in_microseconds, clean_caches, create_model_and_input, quantize_model, BenchmarkConfig, ) -def run(config: BenchmarkConfig) -> None: +def run(config: BenchmarkConfig) -> Dict[str, float]: """Run inference benchmarks""" clean_caches() # Clean caches @@ -48,23 +49,23 @@ def run(config: BenchmarkConfig) -> None: results = {} # Benchmark time to run an inference call for quantized model - model_time = benchmark_model_inference_in_seconds( + model_time = benchmark_model_inference_in_microseconds( model=m_copy, input_data=input_data ) - results[f"benchmark_model_inference_in_seconds"] = model_time + results[f"benchmark_model_inference_in_microseconds"] = model_time print( - f"Time to run a {base_model.__class__.__name__}: {model_time:.2f} seconds quantized with {config.quantization}" + f"Time to run a {base_model.__class__.__name__}: {model_time:.2f} microseconds quantized with {config.quantization}" ) - # 2. Benchmark time using profiler + # TODO: Benchmark time using profiler # Profile dtype model evaluation # prof_dtype = benchmark_model_op_with_profiler_in_microseconds(m_copy, input_data, quantized_dtype) # prof_dtype.export_chrome_trace(f"{quantization}_model_{input_data[0].size()[0]}.json") # Save profiling details - # 3. Benchmark gemm time using cuda graph + # TODO: Benchmark gemm time using cuda graph # gemm_time = benchmark_torch_function_in_microseconds(gemm_op, *args, **kwargs) - # 4. Benchmark op with cuda graph + # TODO: Benchmark op with cuda graph # time = benchmark_op_with_cuda_graph(op, args) return results diff --git a/benchmarks/microbenchmarks/benchmark_runner.py b/benchmarks/microbenchmarks/benchmark_runner.py index 0303dc0ba3..9a79bd683a 100644 --- a/benchmarks/microbenchmarks/benchmark_runner.py +++ b/benchmarks/microbenchmarks/benchmark_runner.py @@ -67,8 +67,13 @@ def run_benchmarks_from_config(config_path: str) -> None: results.append(result) # TODO: Convert results to csv - # 1. For different shapes for same model + # Speedups: + # 1. For different shapes for same model and quantization # 2. For different quantizations for same model and shape + # 3. For different models for same quantization + + + if __name__ == "__main__": @@ -82,4 +87,4 @@ def run_benchmarks_from_config(config_path: str) -> None: help="Path to benchmark configuration file", ) args = parser.parse_args() - run_benchmarks_from_config(args.config) \ No newline at end of file + run_benchmarks_from_config(args.config) diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 5ff267f770..738c30cfbd 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -8,16 +8,18 @@ MappingType, PerRow, PerTensor, - float8_dynamic_activation_float8_weight, - float8_weight_only, - fpx_weight_only, - int4_weight_only, - int8_dynamic_activation_int4_weight, - int8_dynamic_activation_int8_weight, - int8_weight_only, quantize_, - uintx_weight_only, + Float8DynamicActivationFloat8WeightConfig, + Float8WeightOnlyConfig, + FPXWeightOnlyConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationInt8WeightConfig, + Int4DynamicActivationInt4WeightConfig, + Int8WeightOnlyConfig, + UIntXWeightOnlyConfig, ) +from torchao.quantization.quant_api import Float8WeightOnlyConfig from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, @@ -155,25 +157,25 @@ def quantize_model( if "baseline" in quantization: return model if "int8wo" in quantization: - quantize_(model, int8_weight_only()) + quantize_(model, Int8WeightOnlyConfig()) if "int8dq" in quantization: if sparsity and "semi" in sparsity: from torchao.dtypes import SemiSparseLayout quantize_( model, - int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), + Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()), filter_fn=ffn_only, ) quantize_( - model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only + model, Int8DynamicActivationInt8WeightConfig(), filter_fn=not_ffn_only ) elif "int8dq_prefill_wo_decode" in quantization: quantize_( - model, int8_dynamic_activation_int8_weight(weight_only_decode=True) + model, Int8DynamicActivationInt8WeightConfig(weight_only_decode=True) ) else: - quantize_(model, int8_dynamic_activation_int8_weight()) + quantize_(model, Int8DynamicActivationInt8WeightConfig()) if "int4wo" in quantization: use_hqq = False if "hqq" in quantization: @@ -185,13 +187,13 @@ def quantize_model( 128, 256, ], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" - quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) + quantize_(model, Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq)) elif "int8adq-int4w-symm" in quantization: from torchao.dtypes import CutlassInt4PackedLayout quantize_( model, - int8_dynamic_activation_int4_weight( + Int8DynamicActivationInt4WeightConfig( group_size=None, mapping_type=MappingType.SYMMETRIC, act_mapping_type=MappingType.SYMMETRIC, @@ -204,7 +206,7 @@ def quantize_model( quantize_( model, - int8_dynamic_activation_int4_weight( + Int8DynamicActivationInt4WeightConfig( group_size=128, mapping_type=MappingType.SYMMETRIC, act_mapping_type=MappingType.SYMMETRIC, @@ -216,15 +218,15 @@ def quantize_model( quantize_( model, - int4_weight_only(layout=MarlinSparseLayout()), + Int4WeightOnlyConfig(layout=MarlinSparseLayout()), filter_fn=ffn_or_attn_only, ) if "fp6" in quantization: - quantize_(model, fpx_weight_only(3, 2)) + quantize_(model, FPXWeightOnlyConfig(3, 2)) elif "embed-int8wo" in quantization: quantize_( model, - int8_weight_only(group_size=64), + Int8WeightOnlyConfig(group_size=64), filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding), ) elif "uintx" in quantization: @@ -249,7 +251,7 @@ def quantize_model( } dtype = _NBITS_TO_DTYPE[nbits] group_size = int(_quant_args[2]) - quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq)) + quantize_(model, UIntXWeightOnlyConfig(dtype, group_size, use_hqq=use_hqq)) elif "int8_dynamic_activation_intx_weight" in quantization: from torchao.experimental.quant_api import ( int8_dynamic_activation_intx_weight, @@ -274,7 +276,7 @@ def quantize_model( ), ) elif "float8wo" in quantization: - quantize_(model, float8_weight_only()) + quantize_(model, Float8WeightOnlyConfig()) elif "float8dq" in quantization: granularity = str(quantization.split("-")[-1]) if granularity == "tensor": @@ -284,7 +286,7 @@ def quantize_model( else: granularity = PerTensor() quantize_( - model, float8_dynamic_activation_float8_weight(granularity=granularity) + model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity) ) else: if not TORCH_VERSION_AT_LEAST_2_5: @@ -293,7 +295,7 @@ def quantize_model( # Function to benchmark model evaluation - e2e eval run -def benchmark_model_inference_in_seconds(model, input_data): +def benchmark_model_inference_in_microseconds(model, input_data): # Returns model run time in seconds if torch.cuda.is_available(): torch.cuda.synchronize() @@ -313,7 +315,7 @@ def benchmark_model_inference_in_seconds(model, input_data): torch.cuda.synchronize() end_time = time.perf_counter() - return (end_time - start_time) / num_iters + return ((end_time - start_time) / num_iters) * 1e6 def benchmark_model_op_with_profiler_in_microseconds(model, input_data, op_name: str): @@ -517,4 +519,4 @@ def get_name_to_shapes_iter( } return name_to_shapes.items() - raise AssertionError(f"unknown shape_gen_name {shape_gen_name}") \ No newline at end of file + raise AssertionError(f"unknown shape_gen_name {shape_gen_name}") From 24bea429b22f9a2215fc1756844a2af0daf5cadf Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 25 Feb 2025 12:12:39 -0800 Subject: [PATCH 04/14] Updates --- .../microbenchmarks/benchmark_config.yml | 2 +- .../microbenchmarks/benchmark_inference.py | 25 +- .../microbenchmarks/benchmark_runner.py | 32 +-- .../microbenchmarks/benchmark_training.py | 4 +- ...line_custom_m1024_k1024_n1024_results.json | 3 - ...line_custom_m2048_k4096_n1024_results.json | 3 - ...line_custom_m4096_k4096_n1024_results.json | 3 - ...line_linear_m1024_k1024_n1024_results.json | 3 - ...line_linear_m2048_k4096_n1024_results.json | 3 - ...line_linear_m4096_k4096_n1024_results.json | 3 - ...-hqq_custom_m1024_k1024_n1024_results.json | 3 - ...-hqq_custom_m2048_k4096_n1024_results.json | 3 - ...-hqq_custom_m4096_k4096_n1024_results.json | 3 - ...-hqq_linear_m1024_k1024_n1024_results.json | 3 - ...-hqq_linear_m2048_k4096_n1024_results.json | 3 - ...-hqq_linear_m4096_k4096_n1024_results.json | 3 - ...-128_custom_m1024_k1024_n1024_results.json | 3 - ...-128_custom_m2048_k4096_n1024_results.json | 3 - ...-128_custom_m4096_k4096_n1024_results.json | 3 - ...-128_linear_m1024_k1024_n1024_results.json | 3 - ...-128_linear_m2048_k4096_n1024_results.json | 3 - ...-128_linear_m4096_k4096_n1024_results.json | 3 - ...t8wo_custom_m1024_k1024_n1024_results.json | 3 - ...t8wo_custom_m2048_k4096_n1024_results.json | 3 - ...t8wo_custom_m4096_k4096_n1024_results.json | 3 - ...t8wo_linear_m1024_k1024_n1024_results.json | 3 - ...t8wo_linear_m2048_k4096_n1024_results.json | 3 - ...t8wo_linear_m4096_k4096_n1024_results.json | 3 - benchmarks/microbenchmarks/utils.py | 232 ++++-------------- 29 files changed, 74 insertions(+), 293 deletions(-) delete mode 100644 benchmarks/microbenchmarks/results/benchmark_baseline_custom_m1024_k1024_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_baseline_custom_m2048_k4096_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_baseline_custom_m4096_k4096_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_baseline_linear_m1024_k1024_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_baseline_linear_m2048_k4096_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_baseline_linear_m4096_k4096_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m1024_k1024_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m2048_k4096_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m4096_k4096_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m1024_k1024_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m2048_k4096_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m4096_k4096_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m1024_k1024_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m2048_k4096_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m4096_k4096_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m1024_k1024_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m2048_k4096_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m4096_k4096_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m1024_k1024_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m2048_k4096_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m4096_k4096_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m1024_k1024_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m2048_k4096_n1024_results.json delete mode 100644 benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m4096_k4096_n1024_results.json diff --git a/benchmarks/microbenchmarks/benchmark_config.yml b/benchmarks/microbenchmarks/benchmark_config.yml index b8ed4b05bb..a7a2b4c017 100644 --- a/benchmarks/microbenchmarks/benchmark_config.yml +++ b/benchmarks/microbenchmarks/benchmark_config.yml @@ -14,6 +14,6 @@ model_params: [4096, 4096, 1024] ] precision: "torch.bfloat16" - compile: false + compile: "max-autotune" device: "cuda" # Change this to "cuda", "mps", "xpu", or "cpu" as needed model_type: "linear" diff --git a/benchmarks/microbenchmarks/benchmark_inference.py b/benchmarks/microbenchmarks/benchmark_inference.py index ce52cb4e96..4bd171e376 100644 --- a/benchmarks/microbenchmarks/benchmark_inference.py +++ b/benchmarks/microbenchmarks/benchmark_inference.py @@ -4,27 +4,28 @@ This script runs inference benchmarks and generates a micro-benchmarking report for it. - run() function is the main entry point for running inference benchmarks. """ + from copy import deepcopy -import json from pathlib import Path from typing import Dict import torch from utils import ( + BenchmarkConfig, benchmark_model_inference_in_microseconds, clean_caches, create_model_and_input, quantize_model, - BenchmarkConfig, ) + def run(config: BenchmarkConfig) -> Dict[str, float]: """Run inference benchmarks""" clean_caches() # Clean caches - + # Create output directory if it doesn't exist Path(config.output_dir).mkdir(parents=True, exist_ok=True) - + base_model, input_data = create_model_and_input( config.model_type, config.m, @@ -33,10 +34,7 @@ def run(config: BenchmarkConfig) -> Dict[str, float]: dtype=config.precision, device=config.device, ) - print( - f"Starting benchmarking for model: {base_model.__class__.__name__} for quantization: {config.quantization}" - ) - + # Use quantize_ to apply each quantization function to the model m_copy = deepcopy(base_model).eval().to(config.device) m_copy = quantize_model(m_copy, config.quantization) @@ -46,16 +44,13 @@ def run(config: BenchmarkConfig) -> Dict[str, float]: m_copy = torch.compile(m_copy, mode=config.compile, fullgraph=True) # Run benchmarks - results = {} - + result = {**config.__dict__} + # Benchmark time to run an inference call for quantized model model_time = benchmark_model_inference_in_microseconds( model=m_copy, input_data=input_data ) - results[f"benchmark_model_inference_in_microseconds"] = model_time - print( - f"Time to run a {base_model.__class__.__name__}: {model_time:.2f} microseconds quantized with {config.quantization}" - ) + result["benchmark_model_inference_in_microseconds"] = model_time # TODO: Benchmark time using profiler # Profile dtype model evaluation @@ -68,4 +63,4 @@ def run(config: BenchmarkConfig) -> Dict[str, float]: # TODO: Benchmark op with cuda graph # time = benchmark_op_with_cuda_graph(op, args) - return results + return result diff --git a/benchmarks/microbenchmarks/benchmark_runner.py b/benchmarks/microbenchmarks/benchmark_runner.py index 9a79bd683a..1c60497b8f 100644 --- a/benchmarks/microbenchmarks/benchmark_runner.py +++ b/benchmarks/microbenchmarks/benchmark_runner.py @@ -13,13 +13,13 @@ The YAML file should contain all necessary configuration parameters for the benchmarks. """ + from itertools import product from typing import Any, Dict, List, Tuple -import torch import yaml +from utils import BenchmarkConfig, generate_results_csv -from utils import BenchmarkConfig def get_shapes_for_config(shape_config: Dict[str, Any]) -> List[Tuple[str, List[int]]]: """Get shapes for a given configuration""" @@ -43,13 +43,16 @@ def load_benchmark_configs(config_path: str) -> List[BenchmarkConfig]: shapes = get_shapes_for_config(shape_config) # Generate combinations for each shape for quant, (shape_name, shape) in product(quantizations, shapes): - configs.append(BenchmarkConfig( - quantization=quant, - params=params, - shape_name=shape_name, - shape=shape, - output_dir=output_dir, - )) + configs.append( + BenchmarkConfig( + quantization=quant, + params=params, + shape_name=shape_name, + shape=shape, + output_dir=output_dir, + ) + ) + print("Configs: ", configs[0].__dict__) return configs @@ -60,22 +63,21 @@ def run_benchmarks_from_config(config_path: str) -> None: configs = load_benchmark_configs(config_path) results = [] - print(f"Benchmarking Inference ......") + print("Benchmarking Inference ......") for config in configs: print(f"Running: {config.name}") result = run_inference(config) # Pass the config object directly results.append(result) - # TODO: Convert results to csv - # Speedups: + # Add results to csv + generate_results_csv(results, configs[0].output_dir) + + # TODO: Process results: Speedups: # 1. For different shapes for same model and quantization # 2. For different quantizations for same model and shape # 3. For different models for same quantization - - - if __name__ == "__main__": import argparse diff --git a/benchmarks/microbenchmarks/benchmark_training.py b/benchmarks/microbenchmarks/benchmark_training.py index 604f44f489..08e624976c 100644 --- a/benchmarks/microbenchmarks/benchmark_training.py +++ b/benchmarks/microbenchmarks/benchmark_training.py @@ -10,4 +10,6 @@ def run(config: BenchmarkConfig) -> None: """Run training benchmarks""" - raise NotImplementedError("Training benchmarks are not implemented yet. This is a placeholder function.") + raise NotImplementedError( + "Training benchmarks are not implemented yet. This is a placeholder function." + ) diff --git a/benchmarks/microbenchmarks/results/benchmark_baseline_custom_m1024_k1024_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_baseline_custom_m1024_k1024_n1024_results.json deleted file mode 100644 index 1abee3d60b..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_baseline_custom_m1024_k1024_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "inference_time_us": 15562.51660000001 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_baseline_custom_m2048_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_baseline_custom_m2048_k4096_n1024_results.json deleted file mode 100644 index 14c88528fa..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_baseline_custom_m2048_k4096_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "inference_time_us": 110424.1749999999 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_baseline_custom_m4096_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_baseline_custom_m4096_k4096_n1024_results.json deleted file mode 100644 index d1ce48c394..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_baseline_custom_m4096_k4096_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "inference_time_us": 242173.30819999994 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_baseline_linear_m1024_k1024_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_baseline_linear_m1024_k1024_n1024_results.json deleted file mode 100644 index 24ce509670..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_baseline_linear_m1024_k1024_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "benchmark_model_inference_in_seconds": 0.013309416799999951 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_baseline_linear_m2048_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_baseline_linear_m2048_k4096_n1024_results.json deleted file mode 100644 index 736461b641..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_baseline_linear_m2048_k4096_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "benchmark_model_inference_in_seconds": 0.11298482500000002 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_baseline_linear_m4096_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_baseline_linear_m4096_k4096_n1024_results.json deleted file mode 100644 index 67d64673c7..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_baseline_linear_m4096_k4096_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "benchmark_model_inference_in_seconds": 0.24748350000000005 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m1024_k1024_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m1024_k1024_n1024_results.json deleted file mode 100644 index c1f3d71203..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m1024_k1024_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "inference_time_us": 14940.825000000046 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m2048_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m2048_k4096_n1024_results.json deleted file mode 100644 index f1769826ef..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m2048_k4096_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "inference_time_us": 122228.56680000013 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m4096_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m4096_k4096_n1024_results.json deleted file mode 100644 index b0bc0a8a45..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_custom_m4096_k4096_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "inference_time_us": 260059.6167999999 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m1024_k1024_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m1024_k1024_n1024_results.json deleted file mode 100644 index 8f16ac0420..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m1024_k1024_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "benchmark_model_inference_in_seconds": 0.018115591599999804 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m2048_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m2048_k4096_n1024_results.json deleted file mode 100644 index a595902e73..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m2048_k4096_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "benchmark_model_inference_in_seconds": 0.11708445820000009 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m4096_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m4096_k4096_n1024_results.json deleted file mode 100644 index cf3671aae2..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_int4wo-128-hqq_linear_m4096_k4096_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "benchmark_model_inference_in_seconds": 0.24722661660000006 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m1024_k1024_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m1024_k1024_n1024_results.json deleted file mode 100644 index 16521f36b6..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m1024_k1024_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "inference_time_us": 14625.07500000001 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m2048_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m2048_k4096_n1024_results.json deleted file mode 100644 index 163625d2c9..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m2048_k4096_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "inference_time_us": 115082.14999999992 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m4096_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m4096_k4096_n1024_results.json deleted file mode 100644 index 19379b2b00..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_int4wo-128_custom_m4096_k4096_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "inference_time_us": 275312.9167999998 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m1024_k1024_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m1024_k1024_n1024_results.json deleted file mode 100644 index c69854bb60..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m1024_k1024_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "benchmark_model_inference_in_seconds": 0.013688116599999845 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m2048_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m2048_k4096_n1024_results.json deleted file mode 100644 index 658fe43027..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m2048_k4096_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "benchmark_model_inference_in_seconds": 0.1140080165999997 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m4096_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m4096_k4096_n1024_results.json deleted file mode 100644 index ae6d7ae357..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_int4wo-128_linear_m4096_k4096_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "benchmark_model_inference_in_seconds": 0.2766437331999999 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m1024_k1024_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m1024_k1024_n1024_results.json deleted file mode 100644 index c3a369b4e3..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m1024_k1024_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "inference_time_us": 14606.591800000146 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m2048_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m2048_k4096_n1024_results.json deleted file mode 100644 index 5262ffae0e..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m2048_k4096_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "inference_time_us": 120721.20819999998 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m4096_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m4096_k4096_n1024_results.json deleted file mode 100644 index e75a3f832b..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_int8wo_custom_m4096_k4096_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "inference_time_us": 266769.6750000001 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m1024_k1024_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m1024_k1024_n1024_results.json deleted file mode 100644 index b023cdea29..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m1024_k1024_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "benchmark_model_inference_in_seconds": 0.01730855819999988 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m2048_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m2048_k4096_n1024_results.json deleted file mode 100644 index d8f8a11a1d..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m2048_k4096_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "benchmark_model_inference_in_seconds": 0.11659547500000009 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m4096_k4096_n1024_results.json b/benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m4096_k4096_n1024_results.json deleted file mode 100644 index 64a616bd6a..0000000000 --- a/benchmarks/microbenchmarks/results/benchmark_int8wo_linear_m4096_k4096_n1024_results.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "benchmark_model_inference_in_seconds": 0.24648710819999983 -} \ No newline at end of file diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 738c30cfbd..44c1f0782d 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -1,33 +1,32 @@ +import csv +import os import time -from typing import Optional, List, Dict, Any +from typing import Any, Dict, List import torch -from torch.profiler import ProfilerActivity, profile from torchao.quantization import ( - MappingType, - PerRow, - PerTensor, - quantize_, Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, FPXWeightOnlyConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, - Int4DynamicActivationInt4WeightConfig, Int8WeightOnlyConfig, + MappingType, + PerRow, + PerTensor, UIntXWeightOnlyConfig, + quantize_, ) -from torchao.quantization.quant_api import Float8WeightOnlyConfig from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass, ) try: - import triton + import triton # noqa: F401 + HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -50,7 +49,7 @@ def __init__( self.device = params.get("device", get_default_device()) self.model_type = params.get("model_type", "linear") self.output_dir = output_dir - self.name = f"benchmark_{self.quantization}_{self.model_type}_m{self.m}_k{self.k}_n{self.n}" + self.name = f"benchmark_{self.quantization}_{self.model_type}_m{self.m}_k{self.k}_n{self.n}{'_compile' if self.compile else ''}" @staticmethod def _parse_precision(precision_str: str) -> torch.dtype: @@ -72,7 +71,6 @@ def to_dict(self) -> Dict[str, Any]: } -# TODO: add more models class ToyLinearModel(torch.nn.Module): def __init__(self, k=64, n=32, dtype=torch.bfloat16): super().__init__() @@ -98,27 +96,17 @@ def forward(self, x): def get_default_device() -> str: - try: - if torch.cuda.is_available(): - return "cuda" - elif torch.xpu.is_available(): - return "xpu" - elif torch.backends.mps.is_available(): - return "mps" - except (AssertionError, AttributeError): + if torch.cuda.is_available(): + return "cuda" + elif torch.xpu.is_available(): + return "xpu" + elif torch.backends.mps.is_available(): + return "mps" + else: print("Warning: Running on CPU as no GPU support was found") return "cpu" -def benchmark_func_call(func, *args, **kwargs): - start_time = time.time() - result = func(*args, **kwargs) - end_time = time.time() - elapsed_time = end_time - start_time - print(f"Time taken to run {func.__name__}: {elapsed_time:.2f} seconds") - return result - - def ffn_only(mod, fqn): return isinstance(mod, torch.nn.Linear) and "feed_forward" in fqn @@ -148,7 +136,7 @@ def quantize_model( if "int4wo" in quantization and not HAS_TRITON: print("Warning: Triton not available, falling back to baseline") return model - + # Define kwargs sparsity = kwargs.get("sparsity", None) precision = kwargs.get("precision", None) @@ -318,52 +306,6 @@ def benchmark_model_inference_in_microseconds(model, input_data): return ((end_time - start_time) / num_iters) * 1e6 -def benchmark_model_op_with_profiler_in_microseconds(model, input_data, op_name: str): - """Benchmarks model inference using PyTorch profiler to measure GPU kernel execution times. - - This function profiles the model execution and measures the time spent in specific GPU operations - versus overhead time. It performs warmup runs before profiling to ensure accurate measurements. - - Args: - model (torch.nn.Module): PyTorch model to benchmark - input_data (torch.Tensor): Input tensor to run through the model - op_name (str): Name of the GPU operation to measure time for - - Returns: - tuple[float, float]: A tuple containing: - - gpu_op_time (float): Time spent in the specified GPU operation in microseconds - - gpu_overhead_time (float): Time spent in other GPU operations in microseconds - """ - # Warm up - for _ in range(2): - model(input_data) - - # Profile model execution - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True - ) as prof: - with torch.no_grad(): - _ = model(input_data) - torch.cuda.synchronize() - - # Get event data from profiler - event_data = [ - (event.key, event.device_time) - for event in prof.key_averages() - if event.device_type == torch.autograd.DeviceType.CUDA - ] - - # Calculate op time and overhead time - gpu_op_time, gpu_overhead_time = 0, 0 - for event in event_data: - if op_name in event[0]: - gpu_op_time += event[1] - else: - gpu_overhead_time += event[1] - - return gpu_op_time, gpu_overhead_time - - def create_model_and_input( model_type: str, m: int, @@ -392,57 +334,6 @@ def create_model_and_input( return model, input_data -@torch.no_grad() -def benchmark_op_with_cuda_graph(op, *args, **kwargs): - """ - runs benchmark op(*args, **kwargs) avoiding torch.compile overhead - """ - rep = kwargs.pop("rep", 100) - warmup = kwargs.pop("warmup", 25) - with torch.no_grad(): - torch.cuda.synchronize() - stream = torch.cuda.Stream() - stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(stream): - op(*args, **kwargs) - stream.synchronize() - torch.cuda.current_stream().wait_stream(stream) - torch.cuda.synchronize() - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, stream=stream): - op(*args, **kwargs) - if TORCH_VERSION_AT_LEAST_2_5: - from torch._inductor.runtime.benchmarking import benchmarker - - res = benchmarker.benchmark_gpu( - lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" - ) - elif TORCH_VERSION_AT_LEAST_2_3: - from torch._inductor.runtime.runtime_utils import do_bench_gpu - - res = do_bench_gpu( - lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" - ) - else: - from torch._inductor.utils import do_bench - - res = do_bench( - lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median" - ) - return res - - -def _is_interpolate_mode(mode): - if ( - isinstance(mode, list) - and mode[0] == "interpolate" - and len(mode) == 2 - and isinstance(mode[1], float) - ): - return True - return False - - def clean_caches(): import gc @@ -453,70 +344,33 @@ def clean_caches(): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() - if hasattr(torch, '_dynamo'): + if hasattr(torch, "_dynamo"): torch._dynamo.reset() -def get_name_to_shapes_iter( - shape_gen_name: str, - M: Optional[int], - K: Optional[int], - N: Optional[int], +def generate_results_csv( + results: List[Dict[str, Any]], + output_dir: str, + file_name: str = "results.csv", ): - if shape_gen_name == "llama": - assert ( - M == K == N == None - ), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}" - bsz, seq_len = 4, 4096 - M = bsz * seq_len - # LLaMa 2 70B single-node weight shapes - # assumes fused attn.wqkv and ffn.w13 - # source: https://fburl.com/gsheet/g8onr7rh - name_to_shapes_70b = { - "attn.wqkv": (M, 8192, 1280), - "attn.w0": (M, 1024, 8192), - "ffn.w13": (M, 8192, 7168), - "ffn.w2": (M, 3584, 8192), - } - return name_to_shapes_70b.items() + """Generate a CSV file with the results of the benchmarking. - elif shape_gen_name == "square": - assert ( - M == K == N == None - ), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}" - name_to_shapes = {} - min_power_of_2 = 8 # 256 - max_power_of_2 = 15 # 32,768 - for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)): - val = 2**power_of_2 - name_to_shapes[idx] = val, val, val - return name_to_shapes.items() - - elif shape_gen_name == "sweep": - assert ( - M == K == N == None - ), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}" - name_to_shapes = {} - min_p2 = 8 # 256 - max_p2 = 15 # 32,768 - counter = 0 - for M_p2 in range(min_p2, max_p2 + 1): - M = 2**M_p2 - for K_p2 in range(min_p2, max_p2 + 1): - K = 2**K_p2 - for N_p2 in range(min_p2, max_p2 + 1): - N = 2**N_p2 - name_to_shapes[counter] = M, K, N - counter += 1 - return name_to_shapes.items() - - elif shape_gen_name == "custom": - assert ( - M is not None and K is not None and N is not None - ), "M, K, N must be specified for custom shape_gen" - name_to_shapes = { - 1: (M, K, N), - } - return name_to_shapes.items() - - raise AssertionError(f"unknown shape_gen_name {shape_gen_name}") + Args: + results (List[Dict[str, Any]]): List Dictionary containing the results of the benchmarking with the config. + output_dir (str): Directory to save the CSV file. + file_name (str, optional): Name of the CSV file. Defaults to "results.csv". + """ + # Create the output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + file_path = os.path.join(output_dir, file_name) + + # Create a CSV file with the results + with open(file_path, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + # Write the header row + header = results[0].keys() + writer.writerow(header) + for result in results: + writer.writerow(result.values()) + + print(f"Results saved to {file_path}") From 97cea12622d10bf93e383f3afd8734188df85864 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 25 Feb 2025 15:31:59 -0800 Subject: [PATCH 05/14] New test folder --- benchmarks/microbenchmarks/benchmark_runner.py | 2 ++ .../microbenchmarks/{ => test}/benchmark_config.yml | 2 +- benchmarks/microbenchmarks/test/results/results.csv | 13 +++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) rename benchmarks/microbenchmarks/{ => test}/benchmark_config.yml (83%) create mode 100644 benchmarks/microbenchmarks/test/results/results.csv diff --git a/benchmarks/microbenchmarks/benchmark_runner.py b/benchmarks/microbenchmarks/benchmark_runner.py index 1c60497b8f..77c0c2125e 100644 --- a/benchmarks/microbenchmarks/benchmark_runner.py +++ b/benchmarks/microbenchmarks/benchmark_runner.py @@ -26,6 +26,8 @@ def get_shapes_for_config(shape_config: Dict[str, Any]) -> List[Tuple[str, List[ name = shape_config["name"] if name == "custom": return [(name, shape) for shape in shape_config["shapes"]] + else: + NotImplementedError(f"Shape config {name} not supported. Currently only supports custom shapes.") def load_benchmark_configs(config_path: str) -> List[BenchmarkConfig]: diff --git a/benchmarks/microbenchmarks/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml similarity index 83% rename from benchmarks/microbenchmarks/benchmark_config.yml rename to benchmarks/microbenchmarks/test/benchmark_config.yml index a7a2b4c017..4a4fca7642 100644 --- a/benchmarks/microbenchmarks/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -4,7 +4,7 @@ quantizations: - "int8wo" - "int4wo-128" - "int4wo-128-hqq" -output_dir: "benchmarks/microbenchmarks/results" # Directory for results and plots +output_dir: "benchmarks/microbenchmarks/test/results" # Directory for results and plots model_params: matrix_shapes: - name: "custom" diff --git a/benchmarks/microbenchmarks/test/results/results.csv b/benchmarks/microbenchmarks/test/results/results.csv new file mode 100644 index 0000000000..036d6b532c --- /dev/null +++ b/benchmarks/microbenchmarks/test/results/results.csv @@ -0,0 +1,13 @@ +quantization,m,k,n,shape_name,precision,compile,device,model_type,output_dir,name,benchmark_model_inference_in_microseconds +baseline,1024,1024,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_baseline_linear_m1024_k1024_n1024_compile,64510.37060469389 +baseline,2048,4096,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_baseline_linear_m2048_k4096_n1024_compile,53887.79062777758 +baseline,4096,4096,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_baseline_linear_m4096_k4096_n1024_compile,36628.598207607865 +int8wo,1024,1024,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_int8wo_linear_m1024_k1024_n1024_compile,56611.56056448817 +int8wo,2048,4096,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_int8wo_linear_m2048_k4096_n1024_compile,55212.84379065037 +int8wo,4096,4096,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_int8wo_linear_m4096_k4096_n1024_compile,51695.895195007324 +int4wo-128,1024,1024,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_int4wo-128_linear_m1024_k1024_n1024_compile,40540.05299694836 +int4wo-128,2048,4096,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_int4wo-128_linear_m2048_k4096_n1024_compile,39183.96681547165 +int4wo-128,4096,4096,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_int4wo-128_linear_m4096_k4096_n1024_compile,40781.22219070792 +int4wo-128-hqq,1024,1024,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_int4wo-128-hqq_linear_m1024_k1024_n1024_compile,37873.45583550632 +int4wo-128-hqq,2048,4096,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_int4wo-128-hqq_linear_m2048_k4096_n1024_compile,37539.9901997298 +int4wo-128-hqq,4096,4096,1024,custom,torch.bfloat16,max-autotune,cuda,linear,benchmarks/microbenchmarks/test/results,benchmark_int4wo-128-hqq_linear_m4096_k4096_n1024_compile,38310.51839515567 From 35b28406fd946b980106c0dfc71d8f993fb4c56b Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Tue, 25 Feb 2025 16:03:11 -0800 Subject: [PATCH 06/14] Added test cases --- benchmarks/microbenchmarks/__init__.py | 1 + .../microbenchmarks/benchmark_inference.py | 3 +- .../microbenchmarks/benchmark_runner.py | 11 +- .../microbenchmarks/benchmark_training.py | 2 +- benchmarks/microbenchmarks/test/__init__.py | 1 + .../test/test_benchmark_inference.py | 36 ++++++ .../test/test_benchmark_runner.py | 82 +++++++++++++ benchmarks/microbenchmarks/test/test_utils.py | 116 ++++++++++++++++++ benchmarks/microbenchmarks/utils.py | 6 +- 9 files changed, 248 insertions(+), 10 deletions(-) create mode 100644 benchmarks/microbenchmarks/__init__.py create mode 100644 benchmarks/microbenchmarks/test/__init__.py create mode 100644 benchmarks/microbenchmarks/test/test_benchmark_inference.py create mode 100644 benchmarks/microbenchmarks/test/test_benchmark_runner.py create mode 100644 benchmarks/microbenchmarks/test/test_utils.py diff --git a/benchmarks/microbenchmarks/__init__.py b/benchmarks/microbenchmarks/__init__.py new file mode 100644 index 0000000000..80c75f6509 --- /dev/null +++ b/benchmarks/microbenchmarks/__init__.py @@ -0,0 +1 @@ +# Empty file to mark directory as Python package diff --git a/benchmarks/microbenchmarks/benchmark_inference.py b/benchmarks/microbenchmarks/benchmark_inference.py index 4bd171e376..cefc70dbc6 100644 --- a/benchmarks/microbenchmarks/benchmark_inference.py +++ b/benchmarks/microbenchmarks/benchmark_inference.py @@ -10,7 +10,8 @@ from typing import Dict import torch -from utils import ( + +from benchmarks.microbenchmarks.utils import ( BenchmarkConfig, benchmark_model_inference_in_microseconds, clean_caches, diff --git a/benchmarks/microbenchmarks/benchmark_runner.py b/benchmarks/microbenchmarks/benchmark_runner.py index 77c0c2125e..dd3eb338f8 100644 --- a/benchmarks/microbenchmarks/benchmark_runner.py +++ b/benchmarks/microbenchmarks/benchmark_runner.py @@ -18,7 +18,8 @@ from typing import Any, Dict, List, Tuple import yaml -from utils import BenchmarkConfig, generate_results_csv + +from benchmarks.microbenchmarks.utils import BenchmarkConfig, generate_results_csv def get_shapes_for_config(shape_config: Dict[str, Any]) -> List[Tuple[str, List[int]]]: @@ -27,7 +28,9 @@ def get_shapes_for_config(shape_config: Dict[str, Any]) -> List[Tuple[str, List[ if name == "custom": return [(name, shape) for shape in shape_config["shapes"]] else: - NotImplementedError(f"Shape config {name} not supported. Currently only supports custom shapes.") + raise NotImplementedError( + f"Shape config {name} not supported. Currently only supports custom shapes." + ) def load_benchmark_configs(config_path: str) -> List[BenchmarkConfig]: @@ -54,14 +57,12 @@ def load_benchmark_configs(config_path: str) -> List[BenchmarkConfig]: output_dir=output_dir, ) ) - print("Configs: ", configs[0].__dict__) - return configs def run_benchmarks_from_config(config_path: str) -> None: """Run benchmarks using configurations from YAML file""" - from benchmark_inference import run as run_inference + from benchmarks.microbenchmarks.benchmark_inference import run as run_inference configs = load_benchmark_configs(config_path) results = [] diff --git a/benchmarks/microbenchmarks/benchmark_training.py b/benchmarks/microbenchmarks/benchmark_training.py index 08e624976c..1f7be1df2a 100644 --- a/benchmarks/microbenchmarks/benchmark_training.py +++ b/benchmarks/microbenchmarks/benchmark_training.py @@ -5,7 +5,7 @@ - run() function is the main entry point for running training benchmarks. """ -from utils import BenchmarkConfig +from benchmarks.microbenchmarks.utils import BenchmarkConfig def run(config: BenchmarkConfig) -> None: diff --git a/benchmarks/microbenchmarks/test/__init__.py b/benchmarks/microbenchmarks/test/__init__.py new file mode 100644 index 0000000000..80c75f6509 --- /dev/null +++ b/benchmarks/microbenchmarks/test/__init__.py @@ -0,0 +1 @@ +# Empty file to mark directory as Python package diff --git a/benchmarks/microbenchmarks/test/test_benchmark_inference.py b/benchmarks/microbenchmarks/test/test_benchmark_inference.py new file mode 100644 index 0000000000..d726a87539 --- /dev/null +++ b/benchmarks/microbenchmarks/test/test_benchmark_inference.py @@ -0,0 +1,36 @@ +import unittest + +from benchmarks.microbenchmarks.benchmark_inference import run +from benchmarks.microbenchmarks.utils import BenchmarkConfig + + +class TestBenchmarkInference(unittest.TestCase): + def setUp(self): + self.params = { + "precision": "torch.float32", # Use float32 for testing + "compile": False, + "device": "cpu", # Use CPU for testing + "model_type": "linear", + } + self.config = BenchmarkConfig( + quantization="baseline", + params=self.params, + shape_name="test", + shape=[16, 32, 8], # Small shape for testing + output_dir="benchmarks/microbenchmarks/test/test_output/", + ) + + def test_run_inference(self): + result = run(self.config) + + # Check result contains all config attributes + for key in self.config.__dict__: + self.assertIn(key, result) + + # Check benchmark result is present and reasonable + self.assertIn("benchmark_model_inference_in_microseconds", result) + self.assertGreater(result["benchmark_model_inference_in_microseconds"], 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/benchmarks/microbenchmarks/test/test_benchmark_runner.py b/benchmarks/microbenchmarks/test/test_benchmark_runner.py new file mode 100644 index 0000000000..65b63513dd --- /dev/null +++ b/benchmarks/microbenchmarks/test/test_benchmark_runner.py @@ -0,0 +1,82 @@ +import os +import tempfile +import unittest + +import yaml + +from benchmarks.microbenchmarks.benchmark_runner import ( + get_shapes_for_config, + load_benchmark_configs, + run_benchmarks_from_config, +) + + +class TestBenchmarkRunner(unittest.TestCase): + def setUp(self): + self.config = { + "quantizations": ["baseline", "int8wo"], + "output_dir": "benchmarks/microbenchmarks/test/test_output", + "model_params": { + "matrix_shapes": [ + { + "name": "custom", + "shapes": [[16, 32, 8]], # Small shape for testing + } + ], + "precision": "torch.float32", + "compile": False, + "device": "cpu", + "model_type": "linear", + }, + } + + # Create temporary config file + self.temp_dir = tempfile.mkdtemp() + self.config_path = os.path.join(self.temp_dir, "test_config.yml") + with open(self.config_path, "w") as f: + yaml.dump(self.config, f) + + # Create output directory if it doesn't exist + os.makedirs(self.config["output_dir"], exist_ok=True) + + def tearDown(self): + # Clean up temporary files + if os.path.exists(self.config_path): + os.unlink(self.config_path) + if os.path.exists(self.temp_dir): + os.rmdir(self.temp_dir) + + # Clean up test output directory + results_file = os.path.join(self.config["output_dir"], "results.csv") + if os.path.exists(results_file): + os.unlink(results_file) + if os.path.exists(self.config["output_dir"]): + os.rmdir(self.config["output_dir"]) + + def test_get_shapes_for_config(self): + shape_config = { + "name": "custom", + "shapes": [[1024, 1024, 1024], [2048, 2048, 2048]], + } + shapes = get_shapes_for_config(shape_config) + self.assertEqual(len(shapes), 2) + self.assertEqual(shapes[0], ("custom", [1024, 1024, 1024])) + self.assertEqual(shapes[1], ("custom", [2048, 2048, 2048])) + + with self.assertRaises(NotImplementedError): + get_shapes_for_config({"name": "unsupported", "shapes": []}) + + def test_load_benchmark_configs(self): + configs = load_benchmark_configs(self.config_path) + self.assertEqual(len(configs), 2) # 2 quantizations * 1 shape + self.assertEqual(configs[0].quantization, "baseline") + self.assertEqual(configs[1].quantization, "int8wo") + + def test_run_benchmarks_from_config(self): + run_benchmarks_from_config(self.config_path) + results_file = os.path.join(self.config["output_dir"], "results.csv") + self.assertTrue(os.path.exists(results_file)) + + +if __name__ == "__main__": + unittest.main() diff --git a/benchmarks/microbenchmarks/test/test_utils.py b/benchmarks/microbenchmarks/test/test_utils.py new file mode 100644 index 0000000000..f48cd4e67f --- /dev/null +++ b/benchmarks/microbenchmarks/test/test_utils.py @@ -0,0 +1,116 @@ +import os +import tempfile +import unittest + +import torch + +from benchmarks.microbenchmarks.utils import ( + BenchmarkConfig, + LNLinearSigmoid, + ToyLinearModel, + clean_caches, + create_model_and_input, + generate_results_csv, +) + + +class TestUtils(unittest.TestCase): + def test_benchmark_config(self): + params = { + "precision": "torch.bfloat16", + "compile": "max-autotune", + "device": "cuda", + "model_type": "linear", + } + config = BenchmarkConfig( + quantization="int8wo", + params=params, + shape_name="custom", + shape=[1024, 1024, 1024], + output_dir="test_output", + ) + + self.assertEqual(config.quantization, "int8wo") + self.assertEqual(config.m, 1024) + self.assertEqual(config.k, 1024) + self.assertEqual(config.n, 1024) + self.assertEqual(config.precision, torch.bfloat16) + self.assertEqual(config.compile, "max-autotune") + self.assertEqual(config.device, "cuda") + self.assertEqual(config.model_type, "linear") + self.assertEqual(config.output_dir, "test_output") + self.assertEqual( + config.name, "benchmark_int8wo_linear_m1024_k1024_n1024_compile" + ) + + def test_toy_linear_model(self): + model = ToyLinearModel(k=64, n=32, dtype=torch.float32) + x = torch.randn(16, 64) + out = model(x) + self.assertEqual(out.shape, (16, 32)) + self.assertEqual(out.dtype, torch.float32) + + def test_ln_linear_sigmoid(self): + model = LNLinearSigmoid(fc_dim1=64, fc_dim2=32, dtype=torch.float32) + x = torch.randn(16, 64) + out = model(x) + self.assertEqual(out.shape, (16, 32)) + self.assertEqual(out.dtype, torch.float32) + self.assertTrue( + torch.all((out >= 0) & (out <= 1)) + ) # Check sigmoid output range + + def test_create_model_and_input(self): + m, k, n = 16, 64, 32 + model, input_data = create_model_and_input( + model_type="linear", + m=m, + k=k, + n=n, + dtype=torch.float32, + device="cpu", + ) + self.assertIsInstance(model, ToyLinearModel) + self.assertEqual(input_data.shape, (m, k)) + + model, input_data = create_model_and_input( + model_type="ln_linear_sigmoid", + m=m, + k=k, + n=n, + dtype=torch.float32, + device="cpu", + ) + self.assertIsInstance(model, LNLinearSigmoid) + self.assertEqual(input_data.shape, (m, k)) + + def test_generate_results_csv(self): + results = [ + { + "quantization": "int8wo", + "m": 1024, + "k": 1024, + "n": 1024, + "time_us": 100.0, + }, + { + "quantization": "int4wo", + "m": 1024, + "k": 1024, + "n": 1024, + "time_us": 50.0, + }, + ] + + with tempfile.TemporaryDirectory() as tmp_dir: + generate_results_csv(results, tmp_dir) + csv_path = os.path.join(tmp_dir, "results.csv") + self.assertTrue(os.path.exists(csv_path)) + + def test_clean_caches(self): + # Just test that it runs without error + clean_caches() + + +if __name__ == "__main__": + unittest.main() diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 44c1f0782d..2f3505b83e 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -246,9 +246,9 @@ def quantize_model( ) from torchao.quantization.granularity import PerGroup - assert ( - precision == torch.float32 - ), "int8_dynamic_activation_intx_weight requires using precision=torch.float32" + assert precision == torch.float32, ( + "int8_dynamic_activation_intx_weight requires using precision=torch.float32" + ) # Quantize model _quant_args = quantization.split("-") From 8b7291c5430ede91e41f559822292543b9e7f1c3 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 25 Feb 2025 16:06:09 -0800 Subject: [PATCH 07/14] Lint fixes --- benchmarks/microbenchmarks/__init__.py | 1 - benchmarks/microbenchmarks/test/__init__.py | 1 - benchmarks/microbenchmarks/utils.py | 6 +++--- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/benchmarks/microbenchmarks/__init__.py b/benchmarks/microbenchmarks/__init__.py index 80c75f6509..e69de29bb2 100644 --- a/benchmarks/microbenchmarks/__init__.py +++ b/benchmarks/microbenchmarks/__init__.py @@ -1 +0,0 @@ -# Empty file to mark directory as Python package diff --git a/benchmarks/microbenchmarks/test/__init__.py b/benchmarks/microbenchmarks/test/__init__.py index 80c75f6509..e69de29bb2 100644 --- a/benchmarks/microbenchmarks/test/__init__.py +++ b/benchmarks/microbenchmarks/test/__init__.py @@ -1 +0,0 @@ -# Empty file to mark directory as Python package diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 2f3505b83e..44c1f0782d 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -246,9 +246,9 @@ def quantize_model( ) from torchao.quantization.granularity import PerGroup - assert precision == torch.float32, ( - "int8_dynamic_activation_intx_weight requires using precision=torch.float32" - ) + assert ( + precision == torch.float32 + ), "int8_dynamic_activation_intx_weight requires using precision=torch.float32" # Quantize model _quant_args = quantization.split("-") From 8f8a6ff571ce4715fc19bf60cadb1ca911629793 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Thu, 27 Feb 2025 14:29:16 -0800 Subject: [PATCH 08/14] Updates --- benchmarks/microbenchmarks/README.md | 84 ++++++++++++++++++ .../microbenchmarks/benchmark_inference.py | 12 +-- .../microbenchmarks/benchmark_runner.py | 6 +- .../microbenchmarks/test/benchmark_config.yml | 7 +- .../test/test_benchmark_inference.py | 4 +- .../test/test_benchmark_runner.py | 4 +- benchmarks/microbenchmarks/test/test_utils.py | 10 ++- benchmarks/microbenchmarks/utils.py | 48 +++------- .../microbenchmarking_process_diagram.png | Bin 0 -> 24536 bytes .../microbenchmarks_code_flow_diagram.png | Bin 0 -> 31113 bytes 10 files changed, 123 insertions(+), 52 deletions(-) create mode 100644 benchmarks/microbenchmarks/README.md create mode 100644 docs/static/microbenchmarking_process_diagram.png create mode 100644 docs/static/microbenchmarks_code_flow_diagram.png diff --git a/benchmarks/microbenchmarks/README.md b/benchmarks/microbenchmarks/README.md new file mode 100644 index 0000000000..32333144d0 --- /dev/null +++ b/benchmarks/microbenchmarks/README.md @@ -0,0 +1,84 @@ +# Microbenchmarks + +This directory contains microbenchmarking tools for measuring inference performance across different quantization methods and model architectures. + +## Overview + +The microbenchmarking system works as follows: + +![Microbenchmarks Process Flow](../../docs/static/microbenchmarking_process_diagram.png) + +## Components + +![Microbenchmarks Flow](../../docs/static/microbenchmarks_code_flow_diagram.png) + +- **benchmark_runner.py**: Main entry point that orchestrates the benchmarking process +- **benchmark_inference.py**: Handles model creation and inference benchmarking +- **utils.py**: Contains utility functions and configuration classes +- **test\/**: Test files and sample configurations + +## Usage + +1. Create a configuration YAML file (see example below) +2. Run the benchmark using: + +```bash +python -m benchmarks.microbenchmarks.benchmark_runner --config path/to/config.yml +``` + +### Example Configuration + +```yaml +# Sample configuration for inference benchmarks +quantization_config_recipe_names: + - "baseline" + - "int8wo" + - "int4wo-128" + - "int4wo-128-hqq" + +output_dir: "benchmarks/microbenchmarks/results" + +model_params: + matrix_shapes: + - name: "custom" + shapes: [ + [1024, 1024, 1024], # [m, k, n] + [2048, 4096, 1024], + [4096, 4096, 1024] + ] + high_precision_dtype: "torch.bfloat16" + compile: true + compile_mode: "max-autotune" + device: "cuda" # Options: "cuda", "mps", "xpu", "cpu" + model_type: "linear" # Options: "linear", "ln_linear_sigmoid" +``` + +## Configuration Options + +### Quantization Methods +- `baseline`: No quantization +- `int8wo`: 8-bit weight-only quantization +- `int4wo-{group_size}`: 4-bit weight-only quantization with specified group size +- `int4wo-{group_size}-hqq`: 4-bit weight-only quantization with HQQ + +### Model Types +- `linear`: Simple linear layer +- `ln_linear_sigmoid`: LayerNorm + Linear + Sigmoid + +### Device Options +- `cuda`: NVIDIA GPU +- `xpu`: Intel GPU +- `mps`: Apple Silicon GPU +- `cpu`: CPU fallback + +## Output + +Results are saved to a CSV file in the specified output directory + +## Running Tests + +To run the test suite: + +```bash +python -m unittest discover benchmarks/microbenchmarks/test +``` diff --git a/benchmarks/microbenchmarks/benchmark_inference.py b/benchmarks/microbenchmarks/benchmark_inference.py index cefc70dbc6..aa1a2dc12a 100644 --- a/benchmarks/microbenchmarks/benchmark_inference.py +++ b/benchmarks/microbenchmarks/benchmark_inference.py @@ -16,7 +16,7 @@ benchmark_model_inference_in_microseconds, clean_caches, create_model_and_input, - quantize_model, + quantization_string_to_quantized_model, ) @@ -32,20 +32,22 @@ def run(config: BenchmarkConfig) -> Dict[str, float]: config.m, config.k, config.n, - dtype=config.precision, + high_precision_dtype=config.high_precision_dtype, device=config.device, ) # Use quantize_ to apply each quantization function to the model m_copy = deepcopy(base_model).eval().to(config.device) - m_copy = quantize_model(m_copy, config.quantization) + m_copy = quantization_string_to_quantized_model( + m_copy, config.quantization, high_precision_dtype=config.high_precision_dtype + ) if config.compile: print("Compiling model....") - m_copy = torch.compile(m_copy, mode=config.compile, fullgraph=True) + m_copy = torch.compile(m_copy, mode=config.compile_mode, fullgraph=True) # Run benchmarks - result = {**config.__dict__} + result = {**config.to_dict()} # Benchmark time to run an inference call for quantized model model_time = benchmark_model_inference_in_microseconds( diff --git a/benchmarks/microbenchmarks/benchmark_runner.py b/benchmarks/microbenchmarks/benchmark_runner.py index dd3eb338f8..f980b17aae 100644 --- a/benchmarks/microbenchmarks/benchmark_runner.py +++ b/benchmarks/microbenchmarks/benchmark_runner.py @@ -38,7 +38,7 @@ def load_benchmark_configs(config_path: str) -> List[BenchmarkConfig]: with open(config_path, "r") as f: config_data = yaml.safe_load(f) - quantizations = config_data["quantizations"] + quantization_config_recipe_names = config_data["quantization_config_recipe_names"] params = config_data["model_params"] output_dir = config_data.get("output_dir", "benchmarks/microbenchmarks/results") @@ -47,7 +47,9 @@ def load_benchmark_configs(config_path: str) -> List[BenchmarkConfig]: for shape_config in params["matrix_shapes"]: shapes = get_shapes_for_config(shape_config) # Generate combinations for each shape - for quant, (shape_name, shape) in product(quantizations, shapes): + for quant, (shape_name, shape) in product( + quantization_config_recipe_names, shapes + ): configs.append( BenchmarkConfig( quantization=quant, diff --git a/benchmarks/microbenchmarks/test/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml index 4a4fca7642..4ba6e99a4f 100644 --- a/benchmarks/microbenchmarks/test/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -1,5 +1,5 @@ # Sample configuration for inference kernel benchmarks -quantizations: +quantization_config_recipe_names: - "baseline" - "int8wo" - "int4wo-128" @@ -13,7 +13,8 @@ model_params: [2048, 4096, 1024], [4096, 4096, 1024] ] - precision: "torch.bfloat16" - compile: "max-autotune" + high_precision_dtype: "torch.bfloat16" + compile: true + compile_mode: "max-autotune" device: "cuda" # Change this to "cuda", "mps", "xpu", or "cpu" as needed model_type: "linear" diff --git a/benchmarks/microbenchmarks/test/test_benchmark_inference.py b/benchmarks/microbenchmarks/test/test_benchmark_inference.py index d726a87539..ceb3a25a74 100644 --- a/benchmarks/microbenchmarks/test/test_benchmark_inference.py +++ b/benchmarks/microbenchmarks/test/test_benchmark_inference.py @@ -7,7 +7,7 @@ class TestBenchmarkInference(unittest.TestCase): def setUp(self): self.params = { - "precision": "torch.float32", # Use float32 for testing + "high_precision_dtype": "torch.float32", # Use float32 for testing "compile": False, "device": "cpu", # Use CPU for testing "model_type": "linear", @@ -24,7 +24,7 @@ def test_run_inference(self): result = run(self.config) # Check result contains all config attributes - for key in self.config.__dict__: + for key in self.config.to_dict(): self.assertIn(key, result) # Check benchmark result is present and reasonable diff --git a/benchmarks/microbenchmarks/test/test_benchmark_runner.py b/benchmarks/microbenchmarks/test/test_benchmark_runner.py index 65b63513dd..7bbb9609d2 100644 --- a/benchmarks/microbenchmarks/test/test_benchmark_runner.py +++ b/benchmarks/microbenchmarks/test/test_benchmark_runner.py @@ -14,7 +14,7 @@ class TestBenchmarkRunner(unittest.TestCase): def setUp(self): self.config = { - "quantizations": ["baseline", "int8wo"], + "quantization_config_recipe_names": ["baseline", "int8wo"], "output_dir": "benchmarks/microbenchmarks/test/test_output", "model_params": { "matrix_shapes": [ @@ -23,7 +23,7 @@ def setUp(self): "shapes": [[16, 32, 8]], # Small shape for testing } ], - "precision": "torch.float32", + "high_precision_dtype": "torch.float32", "compile": False, "device": "cpu", "model_type": "linear", diff --git a/benchmarks/microbenchmarks/test/test_utils.py b/benchmarks/microbenchmarks/test/test_utils.py index f48cd4e67f..54d0b96393 100644 --- a/benchmarks/microbenchmarks/test/test_utils.py +++ b/benchmarks/microbenchmarks/test/test_utils.py @@ -17,8 +17,9 @@ class TestUtils(unittest.TestCase): def test_benchmark_config(self): params = { - "precision": "torch.bfloat16", - "compile": "max-autotune", + "high_precision_dtype": "torch.bfloat16", + "compile": True, + "compile_mode": "max-autotune", "device": "cuda", "model_type": "linear", } @@ -34,8 +35,9 @@ def test_benchmark_config(self): self.assertEqual(config.m, 1024) self.assertEqual(config.k, 1024) self.assertEqual(config.n, 1024) - self.assertEqual(config.precision, torch.bfloat16) - self.assertEqual(config.compile, "max-autotune") + self.assertEqual(config.high_precision_dtype, torch.bfloat16) + self.assertEqual(config.compile, True) + self.assertEqual(config.compile_mode, "max-autotune") self.assertEqual(config.device, "cuda") self.assertEqual(config.model_type, "linear") self.assertEqual(config.output_dir, "test_output") diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 44c1f0782d..bde2f5cd9a 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -44,8 +44,11 @@ def __init__( self.quantization = quantization self.m, self.k, self.n = shape self.shape_name = shape_name - self.precision = self._parse_precision(params["precision"]) + self.high_precision_dtype = self._parse_precision( + params["high_precision_dtype"] + ) self.compile = params.get("compile", False) + self.compile_mode = params.get("compile_mode", "default") self.device = params.get("device", get_default_device()) self.model_type = params.get("model_type", "linear") self.output_dir = output_dir @@ -63,8 +66,9 @@ def to_dict(self) -> Dict[str, Any]: "m": self.m, "k": self.k, "n": self.n, - "precision": self.precision, + "high_precision_dtype": self.high_precision_dtype, "compile": self.compile, + "compile_mode": "default", "device": self.device, "model_type": self.model_type, "output_dir": self.output_dir, @@ -121,44 +125,28 @@ def ffn_or_attn_only(mod, fqn): ) -def quantize_model( - model: torch.nn.Module, - quantization: str, - **kwargs, -): +def quantization_string_to_quantized_model( + model: torch.nn.Module, quantization: str, **kwargs +) -> torch.nn.Module: """Quantize a model inplace or return a new quantized model. Args: model (torch.nn.Module): model to be quantized quantization (str): quantization method to be used - kwargs: additional arguments to be passed to the quantization method + **kwargs: additional arguments to be passed to the quantization method """ + high_precision_dtype = kwargs.get("high_precision_dtype", torch.bfloat16) if "int4wo" in quantization and not HAS_TRITON: print("Warning: Triton not available, falling back to baseline") return model - # Define kwargs - sparsity = kwargs.get("sparsity", None) - precision = kwargs.get("precision", None) - # Quantization techniques if "baseline" in quantization: return model if "int8wo" in quantization: quantize_(model, Int8WeightOnlyConfig()) if "int8dq" in quantization: - if sparsity and "semi" in sparsity: - from torchao.dtypes import SemiSparseLayout - - quantize_( - model, - Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()), - filter_fn=ffn_only, - ) - quantize_( - model, Int8DynamicActivationInt8WeightConfig(), filter_fn=not_ffn_only - ) - elif "int8dq_prefill_wo_decode" in quantization: + if "int8dq_prefill_wo_decode" in quantization: quantize_( model, Int8DynamicActivationInt8WeightConfig(weight_only_decode=True) ) @@ -201,14 +189,6 @@ def quantize_model( layout=MarlinQQQLayout(), ), ) - elif "semi" in sparsity: - from torchao.dtypes import MarlinSparseLayout - - quantize_( - model, - Int4WeightOnlyConfig(layout=MarlinSparseLayout()), - filter_fn=ffn_or_attn_only, - ) if "fp6" in quantization: quantize_(model, FPXWeightOnlyConfig(3, 2)) elif "embed-int8wo" in quantization: @@ -247,8 +227,8 @@ def quantize_model( from torchao.quantization.granularity import PerGroup assert ( - precision == torch.float32 - ), "int8_dynamic_activation_intx_weight requires using precision=torch.float32" + high_precision_dtype == torch.float32 + ), "int8_dynamic_activation_intx_weight requires using high_precision_dtype=torch.float32" # Quantize model _quant_args = quantization.split("-") diff --git a/docs/static/microbenchmarking_process_diagram.png b/docs/static/microbenchmarking_process_diagram.png new file mode 100644 index 0000000000000000000000000000000000000000..d00d9f882123f5aa220ab5396cf1d94d90a535c6 GIT binary patch literal 24536 zcmdqHWmp_Rvj&O=m*DQfg9e8s0fKvQf-Wq<-91=v3+}<)-Q6X)6Ck*|!yP_5=ehUC zcYfb}c6MsJtE#Kts;QoqU`2TeG-M)VC@3g2DM>M9D5y7m0Ih-e2KW^h5qS>PnGaBooG~#=oOR68zk}&>Q zU=hX2>`>$~9kGgE#S|f=n3V?Vc)iD3cqoia2Zo(xElAJ~3+in=caZIIVUuURo#jSj zsLc9{+({A@Y@e@E>CCkA8C)z3NuT%O==D1?urtr$?HK478A&=gCSF=wnvpP#y+=P4 zUX7orH9|qrAa8ug`Wl6D_6h2g)>kDS0ZIwk^`g*K`JFCWWdtfeqHP?ONr6o) zR(|1W94MuMfTc$PilX@9t_l_ELlfor5WdkzQvNVJUMYV=q#iHKhriqMcp`%MhQ!}s z=W>{?7EFbAjUMGVEfr(O#|e6hH13i4J~Ah4pXR@0?A@-t$72KkkN|8$%#vO#^^pRPO;OK3QgMs{>xX{;sp{Pz(Dky z$J}s`k{5GnmosSMl{AID2#OEBaIwWa`1JK%3_{=aQfic31fj+aq$zB8Zztoz*b_v5 z)y`10$;ap^EAXY~V6LA8g<}(oGEamgidFuHj|S`pqN|gIM!TH!Ct4oE4)MGnN_~dV zm9jKQ*Xn!V+`X#IM6ZXm#&h5KO6(*@~I9MjWZ6~Vm(IB4sq8dTHwb?a! z1EIMNm((XlLGH-F4EFJ`N%vkWK>REEe1NGo7>vNDdH^-~(EpQh;SWPkBL(W;j+YS7 zyTMEZq(4YTJujy}Fp8o2qF{^d9@j2l51j*HlTN+F*(<0*LHbgtG;c6Pe|!?tUWU5q zY(2G^zwg`+tG12wvULNNsN7#l$B_aQpZuxFU z#m`9xKU|}3L z(0K}2>DMNLysNUtCWJ1Yh8Q^+oQrJt9qxY-4P(m3TgP4e6JnSacSMo8!eEBMT&v;>Z{=h$eq(Fp^ra-6h4W4+AZ5TJ!XG}!e z_N7vUiTC5mB(Z?e{)9V^cJ1dj0Vn^x`VUqEjFnHb&5eT(ekO4XY{pwSO3DxZzPr@4 zKi(Ne@={y$ARvi#MHA2Tq30t(e%VynFilT_@PHX=4Zn1!Z!qqRVkV6>sID|D!71)Bs}y%hD>pdOz`uS=w}F}(AkaZd^1Xi(#se8d ztoJ=(Q`Lpki>N(NxTpEd8jh$lN?+lW<`i)gk6hrke-8gdA2KPzn@Qf6 z3(tlq?I;Zw)J}6wyHE3&nVCV%$jpA58P+zM8JoG!j1>#1JQS3bTIG+aDk!NJu&ShjmZgw` zKOwV+u$!=_lZ+$BBv(eOXI}>i*fto|P8QNaEmhKr)@sy}sc|yXnQ$wUo2;8epT;v| zomegMlzw1`1t)SgphAN|VB$^GY|)?saB26{;3UKEh0>{_snX)ogI_6`0&?{tHw9J# z*kCPo*AbF!j%}}Px$S%d@owI3v2Hw!Vp1o^dH2?JH!nA;R>M{yH=aw-UiBsHRrQtn z?%cr|PY2o|>MyiBB8v~a;0dm;CJl2oou?b?4?!=GA4H3UeFS4{%VzV2^I#(j919vQ zEfa^Kg5gBdLz7hFmc`pd+kB+%o_;Zls>QzzgXGLQctLU9jNmkgV+v7nU?QVzePUg* z7k4#xC=b(zC9u<4s@WK~0*{{Ms>Q0Mfg$^>VczHPUh8-qnL6XVEv;eu_>BM)VgEI{MrYTb^ zN{iVd61Wj&Bv*?IW14@yu6uZXuymN%<;x}(Z?~&-e%ol_(0DXtLwe9OQM+(eoYI^! zQaW2YwbCUt zWGS(K6|(Aa@_zLQ#GIj$F$y|r%WLyklU|eX+V&D@Kk^#qtxIJdV;Ph2rt?mEp?tad zs&LGF91O)w!JEmO$)za32C5J|{S>9CD6S}{=v-lb8uka9(m$#q3YW6pMO+ZNJb8|E zpNyaB(jn1J!$QM#KHH|*M$=_e!@**AUKM=uSwmcsYM=TO4Xd5%-zWNm0tO|CU%|hX ze&haD0E_pIgz@1I|474W=DgNP^RKH3zINZ~qG zj`-&wu~HaXu9UG5OB5*Bmtb9!>ha5ndR zwQamrRg;{}E4f*SUJ$5qx5u)-{u^We#U4YZ=G{WfRBSk%Ldr62Z33RuQof74CTFfT z9_cXs-OHqV1mzvTvo%*w*OUsb0FL{RVga?u5>&k9@rBm$hycoF*;Q3lx#cIhdn=Gnw zV6ruTp3Psvs!Mwac#^qt@;iFB)P5Xw_^mvokYDFWSEqH#W~Yj%MqFvCk;!;2ZUwrT z*o@Sy{?Oo1RcqXJ`^tVXcm8FN$7mQ=(i)s=HJb2sZnQzZ~@0Bhs(WL zP6Ly%3^acCeUhE!vy0W86c!gb?pk06huOZCcDwQG2*US4!NW5sLd-VYwm$3k7L8xP4h zGUNoHg2$EX8=)hsBMI_1f)=hg54*?U&P{xUzY1**%r4UITK6M=HLvV+INcoqG0ww&LiFb-0961i8tp1C&Ptgo=XakF=K&-EZ|KluU$JCrTQ4^{3A^-%`K z2YNhw`mE@}K<|@^$s!2u8Cxxl@*#v6s`fMm)P&--vJ9ht!VOiD2@-<$p(lc}C4(Xr za;rd3-MSx38xrw1@f-cEH)S`~(auoBxrgC}apbE!1CM(20${KEn+mrR&{bHC}!g#9-Sr2*7IeflCQB?a8SfbEQpt?a*9J1}O|MFRy$Hjw28fqe@E`2|pNju!@P7$v{Er|z59f#fE&9Kz{@s2$4jr)t)^poJr1>g9K>_wR7@yK#gqw}`C zZuw5q(Q0+exX_Kq_>KKoW0HIZD2>)jBu_T!(c|GnJ1{nVT4>N^@!C~lD(AH%kqI3+y7dW~AXX(Q) zh^)eb{5ZZJ8h;($Dd&idDI)OS-_*)@M^~GWq7{^2(d`(k4*{DGbhK zkSj%3O9Z2>t3=K`S;BGM4*v4n7haMq(i=r$^-dnJj6;!(_^OGO9`}voqW}kyn;;J# zQjXLhM`Ny7l3*DN*-~9`@%UAef{aY;QZdfYLq45V2^Kzjlkv%IhbQ;|k9df~M5QrKGygo3LCK)P>XT2Ej7@DEuf`O@p!Ds4Yg@7rBdUAP zfR8hPcT0i)$l<*YVo)4e*+@rW;|yh?I9+;f&l%h7)rp$w&=kBsU`P^pMrhoxd-TQF zV5PyN^evwyiMjSGk2tm1N#%aND^;!&`Ltv7gHK8dMz55Qx%44E+JcXxx#`iq?M|xzG_EK{&LnyedM-o@El zpW0P)LUp8Uc=4=I?(+2-wem}0Ow9OJ+X(#B^9^;qTXy`FGLi`MmG8-tw)IvOrt*G=bl;cQNK;di-`QqJReQv0Fd0ik@pPwvlAD=Cq*;cRN< zmV1euV3@+se=Y~BIh;o#a^Cjs*D)O3YYr4T!qe)Ru6dZcPa=YO)ESxYP=5QSd#E$W z;2`@3p2}d^2mw@Cxd>hd^ecXuScw5S?*P#G6S8pbn z#Tp00167roL%d%AtSeQ@JmZCuUIE%(G+x;F3;@(3$(g3|3O}xSnz0iHz%5@!y8ZsO zo6x^c5wL7X0jjK+1HwD@iZvl`tjw#IW7dY6q&A2v9^b?ShWP8w;?R%%a$DkpL=Ezu=EZ*#Q8>lUxlbNBO*VwN30c z39jo^(gPys2Uh}&FOWE#NNV5P*SI4elkJ57k``}~VtY)6v!DQc_~u90cQ^~6za!#_VT@j{U5e!TLglOR z>OAwVu(hLC9j$MYyz^haz+~mdYn`u7KxC8&Lek(o0HzD*u1K0cq5u;HrWcJbJz!rb z7KH*GbkOTSJ`rM2_Z5hV#NwjRBu7HPP?lmg2+!VYFYX!T$6%BRWX9Z#3zh$E%OXF( zGpl(fiB$n~FcDW;Fc5B80-zrD!}y%YfXcCA>}`72SM$=FjHF^-z4Eu9|0ffCK;_S3 zJ3QY1AoOp{{NV(F{<$iiq9M_|<~J$TG-lcFuL>|-Xj8C&x?C}?>L`}i)bu&7x8zyI zr)$2Q`8EP>`4zM!^QgN>EZ^jdWC$Nv3UCDT9>=lDYmCj5 zL&dL|0y_Ch#S_V807A3tZ?fOya3|L?Mz*LXatk@8vTYcL$Ce?Tj78FsWNHs20^s7hPrkWOhv| z^}jh=m2tCw{3p0SV(@ZrGgm$>P&sd8#vuP@Jt)*@IG#1vcC$ZM zqjH0AOs~?g(Bo?6jp=+v5)P}b^z-8dM%*)|>v1DK2Z;$tHUsp@Y`TCc02!A{Z*L9> zTQ@#WPFrfVW#?dAyUrp8L-;E_hTw!GOZ(#xYCr8hY(VAi;LG)rf$LH2@AEl5?;+ur z)cCG|(81FB_=GBB>v%!ebMcIi1ndsS%Ub9px0;YqEu22hN`cGEy)XOeil1sD-Ofii zg_!9+IPUWh{~ERT3!r68>H zZlT&F;gg?k^kiED&nN7^)@ocnjoCG|c4@Y=2 z&ld+S=CdWN@F==CmeY0V*7sLxrRmWZK%e%W?8f}Uc z&mf6t@5wmY!C4Css>N#8pbuN9Cs1 zOz@^YmCs&$Y20~5RqD85bG2c?v+Zli4yaN>nV>r1rODC!C;4OhP0=3)MR|N>UrKGe5EVY`V#KtQweg z1r?^=f7vGT%sO0oXuUk-lRR$R0cjWB?Ee15$5=NYiXp_9@Ksr)|D_k-JlFXoVzTA# zi`FtSPbaz8eCeAs44J$08_9OBZlQe9zu%S{0y)Ual?>-qj{NnS9NNN&xFl)i+_X{b z)BR?M5h=OMR|^p;t`3AZ?)Li0y%k9hMLCQHG#dBvUUp9s-fL4(i=+O0e$tD#+d4ic zNg3b)-Wc+&_leMbhgq1Yo1Lu_vLzxV$5IjeLE*R6d+yEg$Q=dZ$zFo?<=qyY>p>hS+hP(I6 z8H}WeAFhrQ)=OBGsc4+z50N#CG53T?eJa)BnvM--E2PJRMrWHZ_CtsaX1rGbX;Q)> zZ!`vS1k04dLNoO$oJ}|1BC6bG3q5#Yj`M4oPZ5LdP z$4$hXk9dPYE_t|8&xgE0U5TVP*aCUJ-K0L+KNd)t-ghuP}=qF~lF zr8;yY+@0>ffrHX#m#>f!4!mSr;oX0h{2+hl^g4g7_2o^cNw{v@5&VAulE|JWhHh z(cdA%eI!N{MI%<5@rmM!7aOi*JsWmK#YI(kQhwlH@364BeFNyFioiCiaP_8vg zLyguq=IS|;VQm%czKQnwVWJhi&AxcCOcMrgzVte?43OE1j#2M*#>eO;kA1@8$_rK@ zKrsT=L(XRSp?Q$Nor}LyE8YLbGX}g{`y8AkXn0ut&GMfCe2CY__8YnQbH1!AU5=cE zN6e`$4uf6gaq%0B<5D^ybaD_oORTI%;~(VurM|jVvl8+>S_5W{2Q4BD$=nB-1oN9~Z(Dr6a^`;5g zyEJyjJTJ#66QTIsuT@$l3pA_ga;0M9mGTWF;N(rq8K}by^naIVk_u7VjHMbE{r-!g z5ou~_V!SnEe10{jXWDxg!c?Y)Z%SHA`sF@G6w#vLZf7h_d$vTAMIz#n5g6-Zb#-9a z&7bZSjPK5Sht8@*j@z26Ky{x(>&%?aAy33uWh#ZqCQ*J#gu@A}(R$uOIRcox6u=vY zLT!j;rtaNV%5Q8sFLiCw*;yIU{_vE-URDcmK&hS0I<nwQ;Ap2+fV zE5Q=R$ICo5)Mq`xo8Ql~(dtd7H1}s)>AFMMORmKh>(_KC|CH#Qe99T~xeuC&|0c@X z#w%Xdxb3z=Y0x~hWYyeF@=L&><}MJ8d?rlvY`MX<6;$SN*F2*}eR0JIyqgrNpBN}C zVDiD;zmr*=8Xls2^9L009-X3vT3-#B_ZN#=sTQq<>te0>S_8j-6b276c!1}-bBdE6 z+Z(#On0F_`5u|H-=D8LN+8kf8HTD|pwv-)5EfT6t#tYUMi_M?MQ0L%KdA_pFS3xb6 z83v5@1XF)6Bj^r$NAL|6A!5kpDF4gIt@p(?tE%>+20H2eX9cRplQu0kZDkVN{l)$e zxf-Mq{T^g9f*y_c@m%I=t}jpZjw0WJG#tZyl2TFd$KpKA4mO#}8lO5Qa@|LiuY?_q_rhI0EPq@FIU@?Rd`XbB)-sDiks$U=zBRNJ| zmM6SVVNI0vKfv@vz0ZyWX_ef^m0zwmSp6B{<~A(g_J3bli450oa9|E+p=w!YgEG6e z@Avd*zS@$WPP*6n#JZ?7S7vz$#Om1feCjQg266X^+M)|R`m;ZKehuBv-*VavT?_`_ zt=;Xuvjwu_rG9~M@BbaNzl=Yo9Vg&1IE49ay4g2g8y+~LfXlRcLb#*B*S7>6=+1L| zQDHv8BImQ>R|>AsLA9V+W3%6;spU71aX0~ejO=ak=8@cUq@%9p1tp7^y!IRZ{* zkEOa;fA(036ZX*ne7p&8caf^OJIOAj#c3;UZZ}RhtIQj_{G_;?4UaMXV_n9AI8)3S z=Avh%VXt3AuB>aZs{2O;kJSd-G#2DX`qfQA52W@<@0Vuz+TX2Z$6sjw)L$LU znXzMSn#|)in9wpd0PUuFEGbb1uorXFos%8vhOV_f2z#K@gauTM+OUw-!{}`tuLxYR zC|3lp+xwQHD4}{`--5=4h4mEMgZ4OE>0S&`Ea$x=L5tGD$?$vd)#mHu_vgcCO_usH z;8 zwVd1-N&xyMQ6Fr0atnR}8T+#@BWAc21-A_%EKQ=0L0MG#CHW<$A^b-1BjZieb2U!Z zWkjU|ZFECn_GcR+!)za6)7y*HNquni3` zRZ`YN`9`Lr{9&xIsQH^^l*S#|ga@vryy>VG$L6Le?LyA|g^#}eH^H&aK1JCy0^ zLV79@^CXk`3UoS7Ba32dKe*?9efG1zx8wcbcIonoPt;4@l1nF*=Y~Yk8{4nd1%U?z z92mEE1Ek8%=m_7IJ#E=l554dLKhXM|9>zh2L}Yr=L->*W(zLO?Jz>;Wni<#x@Oi$S!3h-yO*Ca#a%vAGx5?T-ofef7N1s;P++CR zqT2AqeTPf;WQ8$|5R_*rAV+XPI+ghzm1tV@>8V+n5FEnqcsz+6V>24i^emU;sW*B% zdMhK+i#>4NeKANyiea#|+v$h!JtBAH7N?H3&u|s2ani4x_+IIEymSIf_U8{j3BfG4 z!7UHwC+x%?3N9>GRNu7&*t>^nvc-_WYL;V)H(Eajn_s$3yn&UqB~)zab$2mg=nMoH zI6;dQwU*o5qBY{}#(4X!p;B68;S#$016a@JJM2DvSZ1C3w)6am=~y+9g<1_opEJTm zOs6T4wm?+Ed^)N+i-nIAv@QxpoUz&?vpa2HVHV~a?ROl1qV200jC_U3^L(dF;OF^9 zV04~_*4kHBbj3vUit`6g@-tlVQbjlM_sZWbp3Hq9E^d5(_Cw4d%p%e~KabazD<%wf zCqQn7wCM^2MS^*eqSL1d8@95$2bhbs2m5bJMZ@2Rr|Wx^?e3g56v85Sq4Uzt2KSul z4IzRPujbml)oZbGvb`|%DEc(E;HfeE5Ct#b7#ZSm#uLwqlG*3Q$LBfy>Z7UY4w68a zKP$>X-`~cLJYTK$QF*OZtaG|P-&ZLJsJn;dm zQh1-%J$ip)GJt(5k3v%T(Z*5e0e3E~k)d0P&#=d+l?pB5GcOt&HAX*0%dH_R7qLOE z1Wn&4F)%EL;u(^l*8a{$C~lYvU;wLZeB~7?Z`$XtmEJB{{SJUDmjcSNdW_6B0woS?_2&mi1oUcUz}xkkisdi zCXCKQ@ARQAqOQw!p6L#)&`4B?digJ{v{;s~q$JHWfg3f;nmUsx)709Fo$=q;GGVYd zoVEHTc#JANVubbTh+unh>jaWq;c)Hb-!wQSX&;I?bckx&IViOmlBGl&62Q6YnfhX+ zHL($-?t7(NmC5rDtJ=*@_o!;9BY#M~E#KPC=z(yn&?x1=euOD+*EIL6FSdoqgCK|y z^AF%=!XTJ?tIGu&f@JemY2xb1zET@D)Nr=#oNG9r>Okwb!*_F1dCVpprw$uI`bT?} zMBrSvO=7J3F!mpG75S?k_Jvh+$@@|PE|J`RP{bJ6$R;C+<#aOZ8Wu|t?9r7HB_x{% zP~An&hqOK>yGc5z6Ll3s!#A$(dAe^AQg1MuJ-V|*5o#-p?xC2vGzvgd&82@nowi*Q z-7`-GKJhuTxsz`ya?)OmUF6B_#WICQ1+6{bw1+Q7^p?0^o+VNQAX%QY)=Tpd*lWw9 zRJ&-`Tjh*%6+{2|xHo&mAT;Lx#cpTJlJdy= zqeOZ412U@QNef$^Mpw4V*-{Js7P;SgxBV4Zr(0*&hUhoDPgkXXH@BaeWOpS5KNHO* z?77>v9RXVUfR(iPHqQqr2j&pNsh2HdT8;dwPV$_dEm_v`uJg^9-&FT&P0s!?SvGv0bXN zB)t85k3xsQejX1QowS5I<-;B>7xp=dP<4ZPxrtzeqCKU66hVrBQ&0}>?ebNZM)#0t z4kX(S3Ae9eg|05NOYM7xN0oGGmDYyl)3+XF=KOxS892lS0?bbio%XUe;jY#)u&^kG zm2YF$ar#HGdF#`Gu}~Ng@g(;vqC;E%?yMh4al*EAp*r$2IiUUxDucmL3 zxyXS@qWdrkqGV#{&BeHI5pIrvqjta8+Y*-E+7Rd&U~RJyjnU_Wb^x3fJLDK4>z-?B zPM_=D;twQ9N&`F!NNH@0(5==e)13S%S$GF=;~ck7w*9Dek;v2_Ug%g={HLhE{$y?> z-l1&z#*?(c2-!y7SiZccy7eY8y&4$her%R|Zo>zfT2S8m65oaUcevS-JZCkl2C+!; z^P3~Jd+idMGqib9s4@J@{zR96_vGetR+$g_^DXo8HzP^aZhfV3+_Tvd4^fuO$iO=v zB3Rx`L0{5fa&sflSDRn?cYP}ZCwPZ?C z{KNE8U^fA7u`95WR-m5bTe~1XW2TiTJ5#kK>Lyd%>DTYVo3!Ja=&nQmux4*avc{a+ zBJ(`d*a?K-Q^-@U*$OfeeqT%Rn7b@?1RLH6N1GU+*%>=RWR>pccR zsNSiKB>Oy|lJ)u`=|EIXU-u@KTv9eq&3dq9L;9HAu=~P^%bfX6oy$4;d(Pv!G%g)V z=TpW1e77Om)kS%jtbgI$aYX@VSyaoXxtSwy!!1a~PSz*fyf>!lN^~ZnN1)XgiK;N1 z>ym(Nh6eZ0rGG1@<4CzZ1op%seq+S-qGLE3CvFBM23wuV)dqL|h~PZWBzW5GrcTb$ z%1zYoXAJ&L<6UcrFt6LJNZ>`Mbxvz&d!lg}|I_!xYIPdrJ!uRK#A2{?^ z@Os$qC=DSIr*}5cTon!%z9;nUddQis_a&n)ws_009|b=s2tG!IeMd$_Mx>XNmr=u` z*-8BbFZLP!t=KQ4s8n3|>AMFeK}!YeaoLk&|Kp_c43m13qT*ut*0}LKzBH%(=|Uen zUSPA8L?rIs?b++h*glpCp zHa??ivc<8fTBM@oO~wg{5J`PqL4Qk~mon#mwXeU-sU(-yB)K(~CgFL#(7)e$lU(d{ z$C%6+Z+8#RMHS4bS^Ml9?<~nVA|> zV%4JQ>}^kZa{1I-xi}{0a(RTkL8W|o+*sYIRvWwm&9t@}>Rp zy!!XCdEsb^x5`eceHr7|h61g2uiE<)?-R9?wj~`TzXaW){B+N{WCr3P|Hl%Er+c|D zJ3%no&}AcJRC<5+clVoW)BTHa{DPY`>$lUw7$o!g4|wMy>2{fq2pGreqJgr?Yo7HV zm=(ed3SteqPkBD$GCCB>Lpr}V^0%Xtx)HBUsZ>L*-rdRWBc^2!6ZjcQn_f zqDK4t1K0Cd1xS$P!^_>cx1(Os*T%5psaY?l$6MC@!hAAC1wkdHnc_JOx#SO8S0`(u zsl+$OE4e$OBcbK3vm!gf23?X|Hu@M(jVgt}-W>74c}iFBUUwn3chT7%v5Mb~^^Y$R ze8YW(8Tw=0dTV&*GqH{#h08msN!7Adwl`6fPtyfTeyaKMPt$TdPb+IGdY85QZBMZ`#XEAo@s4E*<} zOcI;!_YbDy-xI4IEym5fJl`amE#H#9kDVgIFeeL}nm^DV-=vM0O3pPNo|^_XG|F!D zfZ4$G18E6Y<+cMq5m!+eu-#!iATS+ubDv3nFjuZr-l9B{NXJP0@M|uNSI1#3NzfIT z3saq+Z>{ZKsa|ZK6h54I^KPUPnzjb=nP2V4^`%}^b!G-I>NI_gk|iYc471&!1`B~? zaOo8l^-$no`qE1ZSB0-ngq5EzGSpLU23i84{K@$pK9YhwBy;mIl;(6FSm#``$Wpk> zqfFPr8VT$Er=b+Bw(&5-f>F*p&dl`zzsRCK=Hrt)qlx0a;d*B!pSDQgwe`TC)p z(eCMPK}Imh`Saowv({y0)(+yVV=Aw0S0eebQlyZYZtJ~S^^XVK8t?qg%}5i%i%WsL zwjMd*81~+#{+Q?ep9o!DM?uTW4?Wm%97hV#R<&jdg($2o0}~k@nf=_m@gXAoSx2>0tG{q?@9xdTSmvkAXGv@Ndld9t$DxGsb9Fb|>BgPynTv$K z45lJ;=aJrDQ-e`9_)06^9_MoH{`DfBnQ{y*cm&NEAAi-9#AUMDJDgPikPoQt^BfpvMPWHj^SC#JgAPHR^dQPkJ$?*^O;gfc~vLmEK z?3bsLlG?--V)Qd@wD02Dx8vWV%#CJoCEg6w>p%g`XS>pV1ehSjKXtg>XKmNAps$ z_E^tX1QXtbvbN0`06R$7zz$;iv)^(IXcC8zRz9`(B-Z1t$?soV{HQlP-%gvJ1r2*> zJYO#NP)0nc2JsFKV3*1;NLIbO2!va_eTp9mr8_$&b7kbR=r9)XG-{dcbo#&VIImew z+SOXGy)V#RbIEQ**K2Y4tl!~daDPFU;X#Vi&p7^gr?VF?CJN^Mcy9Fo&4T&I8G(|IaX&W;pnndN9xFX zzlp6NnA=T*F>t$J!u0)kslq@n)AhpHL6A_ORW^Yu@z0n+Q@J+-y58#YU zbI(;>?D@tgBg{<~HZihn2?Z20*sOEQ#hToRdh4~5yi|?^;EC-zh-&ugO9>w8&LU!b#)BP(pJ{Iz*Z;u0>;rJ#~ zN|91;uOc(+_GVZo6Kv{k^4VK?5V%V~WYh6EmOb|^_B*!r@=-zvV^pktUkCow^u@OR z2Xr}DC&+Z+$x zavrz1$~X{gGw?hKO}e3-8=Sx=1)`;V+ora2-k+Ay17In4J<5*k+pF+;3QRQk5ebTL zV50>hP+KpyK?5K1u6$fruW+jCStdSU zUl^IY>T(7X^tHDdb+d0N)>QY_Mt+;~dYVM)DINr` zud#`IlcUK~qZ*;t`WY{W{HQ$UoP))tQ+Sk1L=IHD-NOmisDu)WCe2Pdj>~aO^}7vK z5x6Tc9vJx&F*R*>khDVUPEW*svQoP({O^BvSD$Lo#T;s#SQ^Gn!{6Vz7^B^m-mQy7Dq$3&`Y~JdQF*&y!Rvo=f$YT$4?qn8g zH5;nYEVA7fzJ|=GpmrZ!W1pIZ@!5@9LjL!j)Qd#MV)afAlczxeA9|B)PB`D<=&G`V zilKSM5Mf!n)@JYhMSTW2`UpWCwYH8jA7L-%wh@k~1{E3@ib2{;B^{Ufjh6*2=9|w% zjGAHMh$);qD5QaMpf&-^)hcrWKSTLR(d^yXQfvKTJm*k-x^9i(_!CpjHLqs5Owk~+ zZF0PNdVJclHcy0u@!@#Oi7j6a?74~H-fxHj$ZgpzlTn`=BP+yS?nc=l8Dcs8uA^+U z;`%iR7DdKWq1Q&G_QjK*HvBol)IvuR*2&?F~Vd;cwpbTKsS zxz+*to#Aq20`=V zXXZX~i+whE6)v12HX7^YhP4r5&(D4I%d8kl8V^wega_5+XkaJ!wh{N>4`ButINgz+ zAd!dID|Py8^liqP-?&WjOwc*2xpEDAq74Z4NRtZc0&+EuUSI0 z=w$is{hsf8|9$g3`^P-@oHO&AIdkUTy?17q(3Xe?EH;nqW<9@E?8Q@Oeepv|&h0F| znXbsY>&N$aTw{diU2R$ut8KWuX03FAT}Q<3mF4hH-REkjFR>|pJLHZcM(?l)UUG5> zdS>5O$}+yjZ&834i9zP-5g(fMPZ=kqK-X)XD&ETTzuR2<>@V{RW+VDq& zi`z#(yLh&pIPqUZHT>=NkcMqX<=yPU^Ww>ewqN)*)TB1&E6kKJVwq|9R`1-}Y0}n4%s0y-W{9G~aOG)_zMm zG>awcpH}bWahA$qrq#h|4%gqJv%~i~G^q zo;OeFwz}KDmbXkki@)F9aRk@W#*dl2^a~!0vl60uDE4MPCk}J%tVy{=FG>jZhQ#oC zCzETzmc&f;bceJ^3tQD&CX9 zF@w$!y?EslDj9VAWE4;1N|!|pLw&Hr@H9Oi4niF)Gs3HS8`6=!krXOOOb>hQy@!bG z;!NQ)y^DR?E%(}Y__EWvK+PLl8s|Fm^h2Q`9_5gim~_M@nQwV)QA4KCN!a4JV`M~? z4$`z+=WvmySJ#&5#fBsc^@Kev0~vf~KU-+>ajVj7>~8aUq#$~=P_Sw?v*b&SX)_i! zWMEht7MSQ<+pO5l&2&0D?e@|4m;NxUoDWYYzt_Ig6X7(2Tr^?0Ld2*Kq)?8@pD{`i z>lrR-46nD)%J-pls-HEj13BYoA#qGz3?DfSj*MW?C@8Tjq0@w`Es{ifv!u$pmyznI z!Vt1m17uH>wC;~!lv?x}_q>+%Y-f;GaKT8g(jjxlKSF0@W?6fo2)fi7^)1Ok2=^pC>dT*$eUNGlZOUrgF1Eo@CKJ2b&))qP@yvk=b=j5^pO% zyqK<@HC*JyuQ?P2iED07HR*rEYl<{~i`9R0M2e(Ukk69~j?iwkLgA*y>heBpuV|Bs ze>daTFkW~7@sz{$SohZI`_6jT6>b<*8@+Dv4b^F7>BU2ITw6CqP>+5)PQ+hHm{^oD zj9)}@jbw1FM6@;mNovltn9N|4q!i{m)uXlaz4$c5XoJJWCdpN)Sj_E4#ztmD9gtsi z`&8R?vyjo#zDzF4OBo)x9o%n}G*@K?oA;gwrd@qD<9z>v>#S|{bfKhK^g=E&6m>Jk zyR0!Vy-6{WhxROcpQ7w~Ws>SrSRk|ycrw&TqJ=dC>?{xt?eq4p@*2h`OMOdGGmKx# zweL@kL$`TE^K;BJ8`hkb{gg9Y>h`9eNdUeanAMmZ6qsM4W|0|te$o6$N@{Vbt47$A zy|}1nmhP0Dy4w34uA9B~=Rwxhixw}((p6z3Q_w>A2>N{EU9Ple>Kq$KC^Dy~-gDLG zY5VDDz#x#%9x>J8-QdF<2W+C>eDFoW_xV|@ad=k|JmJg(ogbu8mj`{RP)i}w@37Uk zpt0MwHz+kT`gHWn7@HDJaj}Ub-+KE_HTs?xRv$)6<{P_24sjl2l4OE7EBg9(zbM^- zGutBRR6C1g{SiLwalc!QFCw3|(a1OJQTEx7#}*XA${|iV`c>mqaLoE>U2vG-$NA6; zrNo=kYDb(onzN0eyd)(blMQu{>b&2kM0wAO)sAE~86UG&392AOAvDFI5h0|W_s6r2 zQQygq(j=Xpd_8m9KYQldY*rVz7~r(fHB9C{pC}HvFUJa%H{#JNL6v^4jr&2=O*n+a z_>y^?O~J*^>9sYf-r4rNrW>Z?DDV4TVzAI8|7&|}ILdxrk;6qNWYCF2Q%JpObNp5I zQTEzvHG3ka-z8m)%!cyr2imU%^Nwc}@2e#{j8$GYIU;6Uw&|n3ZIK4`-I73gZ?}@4 zqNm+D`(SKsYQpA!;xd&?`DNT`apRIcK&?ce6V*Df|2YD4bYk8+hf&RXQII8(t3}21 zo#&^a%mGo}Q#G2dK-N1G`_e^j;JoAWhE8UsOf^_Pby1N z={QvS-oKF!n;o$F{;nc`hFRN!&+{fa!D(#WarP%RhrVfx1WyrQS}n-=({06@P1}{N z;MAGVtg0O`{HH^hwDwxC1LV?O%0q}{J$u?EIANO`aUykJi?{q+5wj_$%d8zJL)T?= zjn;g*p9Afk;Bj9>`5WJF(=cP;cC7O`98sf{$^1;n93S7`A&p%FLuQJLxLIJQ8io zgrw;Wo70^yj*Yi$5Klkyw~113=R)qdyqZI>Q>P;{1Pu*o?mr>1IQW$4UT4;7ei~SE zJ}kz-^80-_Oydy)+$gAP4M8PCy|09IEBR*n;g=wOP8jkIwqTYfmb+|4J{fhZs3W*_ zUBdEfZyO;u$sBJu)w8#k{7d9}l6gpR6H`qVqj{`RRvka*+6%TY$1ka#ja4=!`lde} zp$SAfo~=ddjxVs}BwJ9gB>5Jw3nMK2vddmfzcy)(@yHJik;79zSdq5Jx?C>*)Sq$# z*UF}mj#LtTx4B)OvTx)6+DfR!t6p6d$Vk6Zp~N2j!RS?2H0~0m4+^q0x!r3|sNvLB zMeLc8Cl~c0yKqjOc{!P9iUGZ-wkNz~LgU{@F_%LYd6&A{O*L`g3e!Y!nbH~D{hl71 z5|itx9~ZS)|6=h~*`V<U5(HOHyk0CqJOcyCnY;LH)`{xqroqjy=9SU99iZ= z8{`DO;tE_&%nVqj;s4~vAVNLvrTaamN$hi1MlUmGzxrgRB@ZIO;V@K6YB-31j@Z}-Yf3uSK^EvOY zO`^}LR+*O#7?p>-EH_`0s#)Gm_hv)$@M&39cb(I1*Y_^+h$145C}>z`h9pRq=2Q{B zu%-5IZ+RD49q~zLOa5Ph?VWNxl<<3XWxER{N3uyZAd9WBhIlsEd!`yt+LVa8r)j@! zvd}(MO4;(B5-pK6KtVcehNenY1NASW>WN!eWy|0l@FSIjUz6QIe)M>XhZ6geut%R1 zQAYc3rzbSkCR@Ph$k8$SC&p`7O3W>GtGh5OI5EDNVEZzvHOKZmlJ4~UOfxSKiggEq zon#%5VAEW|7$M{l`*^?!f|49wT)_YGdn=RAO6EEKYf`_lka&|7ZnI`*)!&l~*_+Il z9&Hpd5?saj7MtHwqDi=4M$h<0^3J}nxk==R7*7q~>gvIU;l};zgC4JD6Y3dXkg1lr zkAIP+chpvp6GKS!=Z1^bFa+W4%|TNrrg#2e0I1wd`Y12iI>b* zg5l%cPv+IsMoB3#2j!K&%YMIoeET3RqkdpIcWby*m9qTviDIXCgZa^v0~315f)l92 zmJ8DD{4+ge-XK6W#RU8dw_NrH@(I5_pJ!aNP~*_qf6!E^Zgr`(!z*nRGt@+i~PUuUZ!Uc&oa> zBPZDB$ZDhvvGK3cEfSVc$OoWvc_&hT`Uj zPSsQjPtI}YlRXQFgs4i(PzcM9aI2|xe`5SsnL%4_(YOSQ6g%(yMhrvJwgjR#(gVi2 zb0<~Q;yk-{BEl1UD)sZtoj;xv}bd&6gsGJ@p(On<0yKxo`8T=raSjb zZ585M(t8S>mG@bsT5O?hWp|>wH?H7LVb0xZff(3t>4OQ_ZfUTv9hJ>6NzUy44wrCY zV-LL&@QtZ_xA7AXKK*iGSeUcm85-Jp!ZjT%f#80Y_+GX`D+~_QUgJMyuxzO=n4jc0 z6iS60_@spxk%-XkI5(f;Ro5-td9_;>@&S;zZa1!EeInrgnEvYQtar`spwX&x`q9m^ z`+jR=o%V3e06rNKDke7DKlObyM=a!;vSS^v#h^W_SgCoJsv;j6U$#2`Bmyd%xhi^% z6?uTsN!=h{edv|1Vz{Cf4LkK3c>Tk~ISr?406Zj6y;k7rvNF!x28lc{i~fcI(p4Z6a$+f+7j zwS>4L7vkrT1dGHLXh3tEO_q_V8cNt0lgUD57^AZJ`28Oi*VlJkjBJp-a&PCKhV$C^ zA144X&nrbWHXZ>m{Ts~ro&p8oRzlM2&`C{2RCX)V;g5Uhnd`zG^*arY$lqN$^gw3Y zx|VznqE6zxKM@WM{YGETMJNTx^;C*;w^ku=zslM-V$X6Sk@cxr4^9@LAP@X zg(CfKiU~v%c=*uwg|L~5b|?S1{L#i|^7y_xuwL_KATD1PB*0STHdR3T?J+X$0WE^9 z6T|9bgX=h;E>pcc$mrvOnrxgHtd^Sf3S!e^YGW8?ciijgX zso2PQwOhu3Zv4pJmprVtN4&9|{0C&C~TRkpIP6bpp)s@av1rU1X0=ebk zByh>-5V;F@DF-1Fy3I&1{j(;pa!>DtZl>)mU{L%U2S$M?2V9UtrlfYO3K+!*K+(K} zM*!tH6c1Iu#R^afq;@!3Z(x&ESorC=!lP(`UA9`wsn-E9rp1?Jcd!G}-74cFQWY|R zYnj}p2rZ(QAdgQy3UsmLfcTP#l(=MtZ=imX=~{e+AY0xSI2EFLZUBkPiMaO#=&2Xp z{d9hyW6f34Xebd4a{z;lYT^Y?cmP6RLJq^~4B+U^6&vLZuLf~Aw?U>k0cLw@D!P<( zu|%LKf)sTwNq|Ug2{1+ygFZ?4HHVw*Q#ioaJq#<9S_ly87QI52M4;{p$Jkuro1iWP z(Ta0`m9vjFV(C)Do&l1YxqJO8G{6W@aK%)1?+ywUZ1CsSU{lS-Q#FtaS$^-9Z9mnOVCi9L2Wt2vVf03hYYAL z@6iIFgPs^)8}kJ**W1<@X9`{iA;&ijiaSIDOpS|skz7}E04n7M zl!WX!1)xHzB?o> z3wN>>ocJ8DI3$mVjEfgEV1v1FEpqq15KVgU?rUZ1Tzmcpk2wK4tt`$!F z0OWq{20Tr{8nh^63o97|FvqYf%xM#`0@`H92t=0nfv2~ynt}X>4Zz_9`I}FMEDG>i zBgox0Wg-!P8;c5m zl^$5zODstLhyx&iko2Cf5OAh!-pcR?z)1lhF?g40iCqRYjmQ|}_{XwpUDv~-Am~Bq zZ6U)RC>>c8$2(k}(*F}rlvH+Aj)9kBl?!&O#}5q1)rso4V6ch0${EdQn;aGFs9exO z8JH;1a{hq`I+Spuj5U0lVN0KCuCx2-^7gcKjez-Yj`^&}%R5rwfML^Ivq3?@Dy0RZ ztxJ*c`g?2tU^5bO9im~bcm89K4`Ah*W4+P%_Rd^M>o+o;mRrwDtW3BTa=U(CK6_Gd zXQINiIap5h@D8rON?Nt0*6$&LBHUY(rsB;8!{d|ST1>u-Uu(;bUg%Jhdj-sXiGoP= zs7M$5w2FEExXILN@5zGs8IsT%xS2zTWc_g|y zT>sW=4d})eb_BRQu3e^BnM;A_Uw=?A+)=&~SN{tbgxlcZ4)lHy_ctBJtQElXreK%) zuNZ&{bKvdnXg!K9}J)?)t&Xnz62@Bfc=%KvU= zuB(^8;6Nad_BY+VE-a;XWAO>a_W2|H6|LD8akP lrMFtCe*vQ_kocS}YQpeS@)N-{AbF8MT}4~D_JL)@e*oF`?ymp< literal 0 HcmV?d00001 diff --git a/docs/static/microbenchmarks_code_flow_diagram.png b/docs/static/microbenchmarks_code_flow_diagram.png new file mode 100644 index 0000000000000000000000000000000000000000..9cd72b707669f2811f2f91207db0b46b7f3c4c39 GIT binary patch literal 31113 zcmeFZ2VB(IvL{Ybg2YD2si8>|I>;y-g{^7yZgVpZ{P0c&5XjYPOhp`r|Ns^OoWb>3JDP%5f&B}i7H%C z7Yhq}2K*r-I1MOXcsIU-AJ`tcD)Lw*y$nlOSWIY7C1Xz)Ut0%f8!UDqg`dB%3-a5$ zdU&!6DY6R+B2XwkYX^j#JHo|-4{75GNT3{HXM;Scpy}%G;OvZG7gQGE;|E1tng}Zg zq^GNgJ-d)R_*O-F+Bk!sfDHcD)&qYTgFgcN*9Anbi*bWrw_km__^P-g zQ1+Ux)<2kJ?RzpIz@MLUviAE)5fKsoNw9POHO|Sf7=oUI|F5N-w6pVau(t6y>FsAr zPghrGPY2ZB60Ka3NE@pYS)TNWaCdk0`P*u?uFfZ`JE?;L3;&zmF$BGTU@3-b&pe)91bH{D%bL4Usq z?f2qem0WxkY@9I~`HPPy;{e6^pKtpIH{Ac&5dThha3@}SOH%_EH&FVxj?`r3YL^vz`NtJhZ^+Hfn)%*wD|7Y^@{$1Zhoq^=!tsT7C1x4*Jzk;7X zNS1$3uyy^xq(7BY)XfXCE#)xrkt%Ekj{o2Sh$WyI*>KNtTm<@!yt{?O%Ly8BC) z1qAp6MSiNc&!6U75Yq)_DfYjaZZR>wlbXL2|MwmCuW0@M1`Ymo#{Z{UNN@zW`x&2@ z%O4*m;DaBqM@at1ry@pJ81wt{3&U*8*H51Q|FWznCi8dG`p>rXf1+vq-!AySJVL~N z^73y?*Wc~E|Jq08UkLOs7kZ-L|1SyjuNC)%A%7#$pGNlIw&B05^&e-3j=QUsjfV#? zaDSM0@H=pkZ5`}>RsO3p#?l>>{x}u?Xm{P+3pm#9|2LlrLSiS@{*Oc9?|I^*_f3aCG@UU^$wsgeAq#*7`Sc16t zN5t@J{{bH#fDQas_6T>6lW5|X1L@^y%PS7*{hdGgt0maxzxiOlYT3Gi;KiEP%GKG` zofl#4=!HQ(F#ajh6$t_ea5w`40Lp$6-R&&7fOcariC;gse^&b8-P$5t9Gt;+!L*0l zID6Z8I#?loEB_1cm;Xn@6W{RXxb7|puxQN4KgzrTs(}Emv$KO8!qdwg6V?CIdR7Qb zjQ{sloI!AB9eQYe996Wh{ z%*4af-POs4_r(1C&6q!2=pTImnPEzhVCY|5v_x1r+5K?2e~S5M4U8xSMZnI&2=n(} zP6Ya6NNXD_S9cIgx*~Z!?H#O~keK7)#{dpU2TzAzBm8~c-^BS(>;0x#zt{VvoBw#K z{5*^J{6JLxXQ*NMtKLuN$Jqht^xu2tzxT|4@0lO?!S#P-&-}sy|Hr5AzZzSgfQ!Ec z`v7eOs4M0y`A-e>|5>;%^vgv%aqR?t1@^zWb^^cZ{O;O`{+q7duSo$f`d@^sK+WG= zf`2Xr%niohfxCYxP@$iO@^=~iiIgh-iH$1#N&WZn(;qQE1|9eV05HLP@QM8|fd~F_ zoqrKb{=e!r{%eB&$6!*$21IEHPt5I-f}SDx+s6UiKVTqa4=-m=5DNZ%nDBpQGh;OS z4?Oq})J{O?WILX0W1(N)zi(sVlREz^xADJ)*!`udp5PgO6QTVK^MS{T1OoYUv;Ojq zemc(oF`oQqRR4!i0iq19H-1JDAkI*>!(74vVJm~%CwB+m-+Zf|qyK`L|0^ql(BEOW zf2j<=?Dv1RGW<}O|H1X#i887Dc@g@%mH$!t>+1Dq!AS$mt@TgHJE!OU4J<4cELBB0 zeP8pntW%Y+n*O^swwQl5($mt`QZ1!y=e^#%As1>ZV-yl9N=r-2O0~2NIR3(YtI<9H zzpkL(Zb-;zVRdZkYhC^T?-QhfDldtM{o4m}T1o^2*iaU{T=@?#57y~5vV(B(Nq3Lu zHd?QXe!JGFwYxFhH2>W3u;F^s&0uM>jj2McX#LcOis*PpWlKVA=-c}QA$SgLbVqNF zvS8J7$BBwGcZP|$N7&Y3`-57a$2N~Euc?n1S}(uCj$2HE*HTv&$q<2NhImm)Ec^#P zyzU~FM1)n#)=HNZK5FeUD}^X(PCt9Ua4A?FmWeOjcmXd&R+~*H)>%PUV(&w+v~r_{ zaJL);B63X1m_%M+&SD>r(#|0FGAAGi;vt~%qj94_Y}P1MuukxVlO8!2bEW6pjjxap zDXxLM%V3?U-G(}Z+$Iml!;b7Fm6#*Hcm;mCk^G@jZm7&EZ<7>lSP)LkJ zo+=e!jAW=RyiU26;E|_$r~W|c%v%W}#=NIWuUKcqJ5%x?IPwoj2rL?UQq4SQ-13tq zux_?!rZVx$S2tO4+Qg$;aOH*pm&h;TROMb%s@qzAX3(AB^i_be9A0a{yCMw+I3Ijl zsUf*qztW6$_#CS}4ST0@7we3?evukxq9oW9uFiB33fVBIkH1w4Ts+1N$_)y*2!c>q z?_?MU=v?JKc$bPC~3mUz{v1+gI}LS<>Ml)-CVmh}|C3@Yzng_1!Gw10oX)eJ zry`^-h7NNAp(6;$527B6yM%WM!fmqX+3-oACPR^xns#Q_?J9fSyD z6dfM)3pa*k5?M;az>repUEgTK;kAsWOH-l)3H$^$ZcGY_O6W$ZmqQ*B6~1nDGS|CxGQQK4*`Iqdec^Kpnao~u;LflS@?!+2_iU7qR=k1T)0N1J zFE@>&L`IB}1FD=o+ej(qn}<8^rY2mQ)*fg3JBoeoVyk~=9?+vH+^>k7PqRZ7(0EPw zAMGzpXI`vQCKDdaYc{Lme6(A|zK7b{FWtnV_O8{O$}MZv>DVjbFV% zkBw6MbBV7Z#fU~T~aFDC&0od4b5CW^x>J!M`0o_Y!F{v5M%>r;kzJc?sy5dj4AO7_8hE^*k5rr6X_wB&N(`iPn9j>D4l*U3TN znn`!_p^c{F_qPW91~tX}9%(<_ENvzANYoIF(4K}wL|;DILvFQG`z&O7#2g%cot~;& zDOstWKuU>y3}^c$X&$&^;PW8}zm7BK#z2AQlYm4`(XA)fBSXGU-@5e`edgjGf#c&e zTlY>zpUe!8UWq{cszEI&zZbSyUWPs`r>JY=mFc?Is@|EuJYtKzG|+f_bnx?)k@fDB&$J(_)Tbo9Nif8!4I$f1{l?3Q zEVmBYBfoFj6VXJ692uxj_l+cVCn{^^)14&w8VhcG=@COV9qxMgOuB^~eE-rn=0?_C^O+ic*BTJUe6u!SDL-X$`PuXJLA2O=-$&2!{m}A^JowHp_vev~ zJ0oA3B)8uxOtq3-UKJg+z^xgzjI`;HxkzZo_?mS_kmug=SsZy5N*oiC6a(baqG>cD zd?CYCMw`QD%Ilp)Y^g2L!}4S)pz$L~JB5#8+`)3|4;CNAPA2JTNwf()E2(?R*60Y! zykSwVMV%3Tl;hIj0_W!%6SgAjKetq$4t#j>6s5wVh8hkSM zZ5We(m&&*Lwd#ykhX>Pf1rluJqKey-9(_BAuV5O(7POngCOvBxF}_0` zNs!@4a7!~QA~J7p4D9z6y>6OAMxV)kN($?$AzgL9(eQZ|E`m5XI5nAX?zNDJO`-zZ zN`W{Igt?XsbxRc^;_yBJ^Zw;l;y~yk6ro#1)kvo*v(-k4OhtrlDc&Z&ese$nRv+=9 zOU=~$r}D0MqFm534v1WCf^&ievb9$&5jzzx9hWT3C`mg{yF;L27ZmH~i@*RENds>H zYb&ECE8P89$oi`Hn9Wp=*!1eNV2z%!2iKF49E=TpG6#Eh8d|ItU1hX{3cHK=jL7OS z+n1ya!tb9ep`ppZI2T_RZFWGri36*0329Q+r;hGQvbl5tf)th92>>imQ(v@lc;yyWhkQ6GfD`Fz%vobA7Y)UjOXvgqTr-} zac^5A)>1^K-;M=sqCC08MZ}@q*FK7KTb~(7*j%mXTdz2qa@AdRn$J;pO0jnbK3e+d zxwTZ6ektGOzDr+*U^BL(3+#yeN9tnKjH@Dq9sgBWMPF);1X7;7g^?%4lbQOI~``f=fqxvo|$su0UO0a|qs3>LcICPcdK4q7ev!ByA05~_P>KTH>$3Z}N! z5H`I4;mZiASa;QdU9A+k5(wN1Yh{*aGv%kz(I(^q$~lGi-Frmd?VY`#oMIoP&viP> zK3cv_xCM_=m>r?}a&<3X77_;yadqlYQbrFcwW(faIr}n*gZB-oCDAb_uQLoB6l+ol z-OjuUs73P)R0aTOU&^t#NIB}uCjd$(IqL6mQzdg*`K)i?9<7lHGwBaCCB=R0y#*2%PT zIWfq7G|8$y+$6D$hE#g{?fR2~?9UG71RX>6(L`P|#+$15&*eImK6=ZSgH4fT`cN7*qj!OA zh8#aKWbBQS>GVj53vXs$Ra}aur}O9BO|iO1jI5U1@{$vtwt)+jJZ7CTGa-}>!SZFA z7T@cs2*p}42Z;%pvT{7Ce?)Ny6E#3^iZdJOvl)BbQu6Lepxu+MDAQ`34o{li>e&$y zC9k}jRoRFivnm-%Q;_9tnnTK0KKfXqD?{YQ%N27}%c(}(a|y_D&5=IUrV3>eLTZEZ zOh(NNlSdZus4MBZ(m;FWlyBS9u=cupSMs6Y_);0K;C=5dSd)&jj4YuM&>NqXphUL8 z_ie;jj!g6VV!T)b8Sn?X73Y-h7r`<`h=A{!&ppfr97!mFwRC_H>?*-(zy2K=RbG$H z)7SRRnJ!v4cW~bLxes*3cQl+&?F9HdOY4!xEa#NU5I%3t9k`6Xf(8 z-0_pj8ulZe7)(c*Z|}SWp6^{Y7FIq6T?1nXjvs+>JW5sVN)=}MDzb{1%nHLgd^moD z#;~g;H}&g(I`0$os&icKGusd$=@`idf2=|Q0+I2rBF;1u6HJwP=x z1b@eB(SMOqBS0Hv%CCrah#&9(lTr-2o%@k%=q;r@&HiT}6B8N`uMx9*=@NJ{uj>cpt=aTJ?H+t5D`4!>MowglNWq!HQqdIu*_Q} zl*MvA9`)qB!6z_)ClU5R-jg?84O%;zYgA!av%ymN`xBYfx!#xT;&qo6dF+==fQ1HO z7OlxarE0&MjlR*-2Z$l3>2Kb5Zk!`bd1`CG2ZW|UFj}C=I+t2GAY)?5_P%uCUdV(P z=}-dTAO#CKM392~*0-G%qtTf&6rN=ZN~LEjcr`-`%)A7bi}S&v`pAs1S@8U+gW}@k zg5V|&D~>T?o3bP;a_3?s&k-7VG^jda&IwtT$swC7Aj3# znI8O$UZ6@HCHBPVIpr?3>psfwX0&YCCf`cj<89A8#U78Ms@yID^O26eo&Ow)!xkKe zRau?s>T4WAe^c`8$I0ludm&$&@om0u@QiLeq5j!nUt{D!hMIK#MYj$AT6 ztxo%m-#fe(k6NWmxB)n*4+)BU!0TQeW4t}-a-}lw9!SaD zcJfwAvM(&6xnlU=}(=IIJ_e7Ci<};NsLb{7rt) z81WR*TT5y%h8!(eJZL8W;4GlzH>W?-{UdIae;`Q^9x^+4A1*GF1!n~;50T{t6!jnx zqkPtxtD8$8vII0Qa&p4Xf(!Ae8orj&R|Q_Qf+8 z+S6dLNSSL9U~&;K04>N8B;C0UuRUw*VQ?POdn(5q1l>Rg_}E}hbr3K)N>vcI0`K$` zpm3MYKh=k2o{?$f!O#jx0ekj5+5e{*MvjL8Mtn9~@V^eCG#M62SF!_0!CBh$q$ftlS zw|XjHK`{Vkbw0#bh5l$j5@ao0#8oR?C}hL~=_O7x2SWkwkOlG68J=Z`Jp#;LlurPf z0+_A9@qHUW3K48JE;k3^GTzc#-N?q4Wn7cQ zTiv8Qkh;W^Z)lg1P^x(kmFvmqxEx2`6(sBXQKy11ta33b}3Xx&%ji*lCl*H7^!eb$eZP=5^M){BvaLS%0j}h9_6u71_TOw z-a76XD^h6_Q8iA?JHQY1H)g^a$I@dqb3W7dEK1IAR$HVq0?Rw-!|wZl4MvS-DBL&^ zrgci~=`vKub->^`XFHr{LYFw^X?`8SP&r|6d?|;j*Zm9Kvx>5c->)|st7%c7GH3O5 ztH^@#G#A+r6sUuw2g&|tg3LKBDNwG_K0DPCIGe?55u30g)d2=!X(--j7ILk+OFQ~o z*Tz!LJ8BB1MBH|{VH-G_$j^*g;prEI8cj9E$ciud6P}sKdu#;>P+eJeqQiS5!0gz3 z8~aFJwRku$NVAV1XoO<1-QM^LYGqc*)msclR7?>)RcyDz$BO44ALrk{6>i_9niNh^ zx;m`K|6~;nxoJM7}#*t|v5~T<_Y@kI1Nh}nYWwPa)ueiEK+wO>71k?tC zC50Cp0n%NN7wrp zhQNo|Uu5#x2{pj$>a*!$C<_EhBG$-`XR7?*9sJ3=Zxf%J-$qJDfIZAPT%zRzIT??$ zGkLs(2~1;oNdscUa6)YlD`;~S0#s4llKBwX=kdDFcZ}Gs5?BLEk$H+&`T`!2?9Xn^ zn#r%6lMVshh9wo3bOimZD;9HwgZnAQV4P1+$h;Q(lUn0+`l;jq6&?}i4B6%HZv(gc zRm}GUIvECg_*;ns8T~LgQ|#rV#pBC|+k^E0r+F<6aMpDYR=NX3B8T28N??4zfzA(n zOuf5)>jz|YQHAr>h35?)0gh!9+aAsqGfGHtBLhHC*`1c<-5lz7l`cp=y19J}0Zj`qXJ<{1`+FV!%u6cu6ab-?7Y=zU-sl=Pnwmnt^!J}u z(_M3I+_^-0yw`bo@-)5kTm6_K$=0oO_gCD3*SQ;Sy4X_bIzEjex)hh+RFLiebOPXT zp^L+ok-YQXQ(i_j9du6m4?264ZmgDdHZ-LI3ZqfaJwS)V4RZ8KaZr;m1PJTKn=el~_hG!i{3b+B47r6JTamxY&^!m9oR7WKea z!dD!KwARcb=+rlBF^CLU-IUSwz)#iFHv_^tCrVf$N6Z5zs^ZUjPm?k%e*3zy(ZPJ{ zh7W)z^I5-nM+a+FCf-ewgkJA=v0Zu|5AvdBqs4a9axoxQBZ#TZi!dNapEwn-ozw~? z;W617jp`o26MSy%zrH>R;DxmqnQ!)DoeUnn6H4AbebQffJPVVQ(T}1<(+>c85m@;G z^dgmUb=QM!MpJaG?eJjw*c-%vC?)RkTQ>`XHMqXT>h9(4jTPm2sguJ5T_G~raysAfp(mv*}P(~}RDdsAE z%xin%I_^-gvk@mK2~>M3TI~Je&Al1?nl^~KI6?S!Pbc%O)#`D_3o63u2fk`uqz~5s z#u#|z@wE(ktO9BG2qieYer=d(s82C2X+}|O#y>>Qc1z9_D`nWzn z+4NP74+f}=ijiFA@>$68sq^TQe)rn185@6F=}X=h01aNxk$iZYo+eTLN!WCnUABZV zEUMwE%;i^Hk9_+)hpwnF61VEu@j!*qK|?z?57rj4eLoV0t*a_sckgDG5z-}_@OR%` zdDr*a1MJ*`B=&FP0Mz|9pwQKxx<8n6G{^UH&ypEIF7c2o*M~?Q>tLg)DJ4vp#mb|i zM|4HKba{%M==t4TLT;M_4p;&8mA?#oYf%z|39h@0m`v_@5VZ1OmovQubeb29f0p_LQpw#ej!MfVF%rhn}%I zz<^Of?e4F67do|ziLb-aib2#@Q9&EjF`az%s~?Fg1Uu`M(D+Qg7C9ua06#P#XYGsm z^w$$?;o%?bM@;N?z994Bl|Ors76G$``U*_Mfs9~59CK7@{HV0^1O`}LxYVc9b@S%e zDIX*N#y)#fNLMh>D5C);nIg4TIr#m{n|4zZw6gaO*bMgKq=h*$mh}CUSB0lM2mPcX z4m*++8n$yjsR`zPXlCw!R^OFZB<^fx9Q_0^rjdyP45G{E-UWS9VCDd8cy<1nuTosL z@&m7T*WLj%-9@J6s2P`RFXo+^${;tE8WpR>bjg{)tj&sqaAjz{qx+-dX$7UTwA z>f;YLEX0n!y=(GcNS7CFo53YpYls>ca1%(N!ty z(ww^qD(}CfOtWgDgg8)YkHc}vqP%FAPl@RPw63j1G3@H`_l3&=Y5azYZbPRY zcVQv!;F~Xl%Zd%LvzuY=;Mn{o1KUzSF3%l;3y*;;U|f&(n~o1RWHNM~4YAGmi(d^` zt9n1fD%ecHrsf*L+ob-g+c?!eH=;pYk!~m~CZ2lMr~RDK_RNdf7%BgSGj^T2Dn$xH zp{W{*Z=+N@pp>DXKc!s%CM+Rd!4Uz)!}Yu;fVlBtiTX-;oG!(hS6vLxA&HRfn=?^a zN(S4wc9e``-#bbZmu~ZP>4=o$TBm}&E}D8g)RM`uEJxMB`Xzi{qlMdgDY`4d4L6?u zVk;zp{|bK$%PeJB>OPqtg=dn)d#FDFVFC7mtiJkLm4`Vzr)UTja^GBbraLud@j8`x z)b888_0ngnxR&OY8@zD>@*a(DI^+yZR&UMfD0Hw>S_r$pc*s`U$DQTC4Z-ee6NH3*2%|@eYri*hZLHFknAO}8 zB2v`w$KQvWY*K`}B=i#H_r)i^LX+}JsGD~qv1@kkvY;~5y`Il=2m2A?q_o>HE0)3)_X=&&;(;SLcPx|Jfvl|2!;HldO6e- z9E$4<6K}jY?9DhuDOQF)L`+ty*D-&* zt4K3f^l-)?494&(>XIsC-mc5l>6)ta(7lkA8J)-w0o&+O_VU{yt&##n^LK^((CBdy zEp3Q4s?%4GKVe?(Q33FPZ-f&fweUNp26A@Ed+DBRv#Nm;H@8viU3`9%GJ26@An(KX zkkO3(YGQ)EyfR$tM>HzM%Q;)&hiKanZi6WP4jOce1mxDJA#kXJSsvYxq zW5kwN-VuPgMM&T|BgBt3qGFDL`b=2if5fo zo^{7!k3OFx+5>*IXjssGGX~10*4TH;Su-_)ct+MSQCsE(Cl4_;Eb2QkhxLr>Kq2w@ z)2zp?is-ml68i&UqOjW{?7A{Va#Wh=JcC8D(~({AEw+s_BLFfJMUIkszoI~m4;gGq%F+3H#EQZEjuaI7^Uc{4(IGCvXmG0@< ztu>RDwW%ocsiGB&@_ZB^wny=6x4XD`#>ys>d7@X(uPhq6O3OTW-vE2_8TzeVOK`W8 zL|jvx0Wg#)f!0l((PYlhNt5O-;>H#7Y<|U8c@JfaR=Kq3>~2ycAFwL5wLRdMFkUG% z0CyL8*#fo6b?o`-PpN5HY#0a!_=%0#u!(O9K{(Fc6(d7XPEp+dC^u235D$yOZk8x5 zWV{%?V>IA;Huq-ktm)p}94>_OC`m-7s9VC`Cm`K2<MX29u$kJHK2$=fYP9^(5RB)_+0d{UTjH|d^ZMdFV2lIt^AW(Vn9|K%^? zJZEP~$}J8Pc2q`gxtR+rS6@YsT7DPA_j@nOKFo%=sa2`|I8HuEN#9`C_o5(#qe?S0 zQX+h8VkI<{LHRK*Pjgy+{#?1SQ2lrmnX{E)Hhiz5Pv#4rMvJ)) zkJNx0#MX|1i2|4Md7JL-`9Qj6RY*s$g%GX-(Iu#v*+Y-KL{*>C586a*X-w3JGL3OJ%STi(s12~@gS<*PjTray8cgYIxUV+LziKO-TlMRB$* zP8}YC1EoV6wd=Gcgc+TEA=N$1-JX?=>uS|l%UrMMR&Qs)B&M@#tBV zH9w?`hddEjFaEIo@dKpFvf%BtTC0(P%I=$^F9>PeGyUedPEEADJlpkjbqfi5BU$6J zC!7E?wSI9SEylKoF2H`tzWj{%it>Z%A;s&9^daGl)jfrVm(w^sOf_qbi|Pe`BEwRIK+B}uW%JkpJqX~_`A_MahkyW_8uc> zZkr1<>w~iyJ?s0BFE>ynDq3l3N4rsVkGCTtLRoS0)s+HkX%_qN8Y<-|se{I~F23^P z9XKIwubpTTQqq9u3?X%phmcg9(J>x@vvXU==+^}fF{|{(3yZN>EH7#vm>R z2P`0iP5v{8b-f zzCU1mI?d-bn%0+PT4M?v&O-;C%+$mx)I2wHy9bMGh2``LM^goIH~Z#w9-JO{1c@0# zbh=HLO&NsB%&g*wYGORDs=K<<*6h_t@OaqtPlXgQo^h zN!UhDq>GD0R;yC7QHm|2bnwjvR1e(?aU8DY8fDLf@YmO-cO~higlqZ5>Y>H&-fd+s z%%FVg%QRMdjnIz=5~?Hdk!1fxLYc|EbdT(P@$5xHi5L4JlH&QNUx+_T*Sppnl9-E0 zbGo%DSuTRENY1Qm*@_-&6?4mUY*ShCo-%o~+t2Ji0QBu6Ic`kk9D!sqGhKH=CS(z7 zRRWn5JNp90z zAI#iqyVNdIoq7xg!yP)W#X9jxF?@yqwo4tfkewnj>BZ8UhVS!QeH;Zc+^ zs~?b20--6YmMV#E!Z2&nCHEMi(9B3!twUIoz(#!tNO|08GQ~k`u0U!Ui)%CKL9+aakbNblryi1E+L|SpY(-yAX z+Z3o4_N&R16`J8OAy6`SjAcrb#|URUad0=O*uqzDM?x{YnRD`Ja-k1;zR6Bi;tOeo zDfDhEa^V?YR3ugst7{ji%h~HpeP`|-))RYka3sdH1gCUfaVC$?ihb2NUbHX{6K2yM zPx0r0R4i>zGNJYeF!LdO9p)7d|UJF*dQ?Q*u`!Hrmtas{}oh zGWwxFGq>+tyvmK~H6l0z`?uJzH!n&jL+``3g=we{H40RpEFK9zdyY|Xz-aOGvjb>D z@)=#!hctTY$<&4DRfT%zGt}d6slIfS2V1XpRa3Ok6NdTZEtK(k!#L##MHB6XNC?F$ z9#yU4n)G1wS6h5NDumFYw&1p<7$d^NEq%`~3ZuClPAvfsN9S6aFM@nRne;gy?@h6O zzPwYTisEr7aV#Bz_mC^p!T6A7Cv~bXA6Tr^FAC6Zg-7hNQ0HD%zlG|@4(s%|$@4`2 zW~XAcrcGD-sdI67q*%v~LtqRLUdr}t=)Jph0|`nAArfMe4O{r96x3)Q#D%P#uk!az zC1^rM@3S&vkzPGRy`?jdfRfHEe3l*9tm#Ik8ci7fh!iPB&RnU_%Exc#Xj)+E(nrJJ@5sMQ0eSYg1j;C;jFsz^A z3bj*P96trUgS{?%2QDtW4BWEXi5>vDDUYi{s+}N9_$p~`SW@L7FpVMQ94^r-;apqI z(ZYi?uVI8??w@Z;LqZGbXChj02s1MCUc%M;nG+}GTHP$BXib)@`s$0)CQi{KrqjFc zq-4SnWMMeO1$Q zWSK(_>1J`J?1=jGXH40O_)x-N+Lh0F-Op=3-g~(f+uH0?e8XqbKH;Ith%@>bD^%;7 z+P9*3q@S27+WFFq?S;8K^0d;lSXDSekA9^`T(lB*^1y9ySlI1>R$VCfmX$pSq|fWM30_#dM~g<39WO zRXcOYW9A-bSvm;m*T7UOElE21bI&k#h}JmIc+|7?H0y}?SZcc|4+V`>EH?EWOBX!c z*QUIeL4ZfjAaxKPW&F&Bk__)G^38cx$h?;IC9Cyfn?bxKw{a`pbF^bD_hB`gQi!fQ zRm+*;EV+LgvfwY-_<8EE|AkdiFMiil!Zc~f*ys;u>j=pt<_ z69fr6^W^3wcT@`(NM(A~ibC4waGh=#BT>b+t5@z_$8!ZoF88-AZ2_#%`Da+%K!7 ztwrZJ1lQ&VICQLoAy@2l(>2=m4OR8eLL({(FpZ?s@svBdm+5&~P((1YD}we;n&rar zn-`nDf3&e3C_j_l-EvwmoDe?=mxMt+ znFLH}nV_I`B#M5fo7-Gfq9s0SNwLDP3&s!S2C04Hoe$!laY?Dw;VKti7u_shqB9f#Jj^9MWRrkP9(U~i65)51!2@wq69OaHi`3W>A45!sB*7} z@{t?dH?hZc4_(Ek_1jO}O>nhtRp9(&`n=S<_iK=!+6CW)v=4fmk5TvN%1o6Rhh%QKwOXOmH3J?!;I@z9rRj38Orb#KwgCSP5?k6d=0 zz^|v%xZ4bJ$d@l)w>eDVF)MF#rc^{f(Ju-Mk#!^Rq~|UZ?StSJ2o1B&B3E49*yKbA zgTnbxaf0A>nB}cYo}w@V6Mi}jQ}{eeZ>K_vdEj*oMB1&a>`E722}RB3GY> zR*-WMf&1+%@TW}BI7m=s-q{IL?^o-?H;VM$h0MpBY!C%`gR~VS{wqfX0z3kcmi9Ex zh&N0wzBIIv+&%`}DN88(%M+ZvCtDNGPhEtdf}+(6ZAvw3H-Ul+k##Ot{x)S(IT*z+s<+DVoE zDA|NfHr+awD_TO4B;5lUC!bj3l9L8xj7lo^oi>th&Z)?Y5ClO%2TuAR8Hu%-BG%bf z&7;w>CXQOg^i{X(EQ_AreYxh<79HHv#55s8(ca>QkBuPu1d36z=~HbSFX7^rN@~JT z9W0zc*bM;Kld7#`QNJyD3R1BmWiKY6l4ZWGHtTc-K(I6D6-elW?c;Qnr6o=eErG#8 zW@RHQRjzo0_e|`krj@-4xuI{}@KR?-@wCzDT#WC+@D|=kpmZ1y>j*yuPCMIM zB5XD#F(nY&wZ~o*yJEBC+vXI&XWMbRn5$F90SUT8xZ z%7SRC(1h(Em6xtk50#ULGh?eau{tE5^&G80MuvQv>%5F)N4y{y@_6NxP)Z?iOQG$1 z*}(JbP44UlYw{275D+tnzkMEP+8rRPR09rJv%Q+#2l}W|^*Id!VUJfU38>L~pGx;l z?Ck|rbHmOkeE@GIyucNf-9xXZ4<-5 z^}dT+TM512u1liC_I6FV&c5NP;q~W>?-*_fU3LHbOjMWM6@u-7?h$aEPl^n=AQ|f4 z4mZ-5e$k$M(?g+<$dToN+qvmdOb%C~KEy#V|IU+gGS#O02G@q}I9+WoV8SGJsj|m& z=$%hJe-O8AW-_eATiao9W0~T=oBTJ&{%=#5C*1weFout$GJ79hPP20kf>%llAhY9o z{CoHtTY@fQQ3ZIDcJS?0Wx6EEgD7}y#a;n2t6SKj8S5BdP?U{ie;wQiBf;Y}=BPY9 zqayrAiXdY#-}F3Z(yH?J^OS_0&mXnqPU}ocEMyKgf+UM|kg($(e46B`jyB9u2$=02 z@_~&()@I{`DR`veK@KHh4fB^|GvVi=CP7M_yhcUTj}$+$-9z_(l4d=kK787h4t`1W|y9C$Oh9*ccMLshTX$MJB z;#8WRu?*$(Zp*qbh63|+HFmbCv$<-QSq=?N1N(`dyPeB%R75il5KezJr4>h6OgFXW zgyjw&`0Z&lwo0ONzOKA`zHyFmy&w9iz}~P!@DyBn8}5FT(PM8*oqqV9MZ-$U*KqgI zo*XBHk7*D}aG%gBvE_&qKTtYBL!s0mypE|mjPfk41}R56L0wx=d@cyRh~nzzC^EuS zEj4@lOm|se?WU-_tj?V-$x6l{W*H;(mv2KiZ;bEOE8gezQh-fDoP(R%ESwc9i4`+= z%D@>EtXcWigj|A}A1-JJh4c~_&&7r*PP~D}P8kl}sY_#{klb%zW@@2-v8+a|4a=mk zq<409=SP=g8@6{379CV|No4AzYxIyiuYCnsxTwgPV%Q7^;K~T^ zZeD?^ahur@Uhb(}>(Y5LN>vJXk4YtPB%J>eVcK89@I{b9`?^o6SH&T>0moc6Y=(b| z$5r(3!^Mx?n?p>I9QO)qCYC`epYo$B1`nnBtbV0*L*5lzpJJo_$J=KK%@iN&>!#0K zTNipL_Dr2Cw9OQDb$#!Q6ZKX#*@~c<+IZ4a>&d$85RsNw47=?^#X>Gh==P{Q=PHU2-s*eT9Gu!;ST!VNOvwW-vt=K;@&%7(k3h0VJxE>534D%uzq$^{mxm7<3VqgT?`Zv|<#bTzxxc(U-R6W&jbP>`ks%YAiAd3b` zk(D|0XBoyYuK@eVIs(Nx%T0KI2mN?jO6OEG zgL$fqdGc%BZ#2^f@`0i;PgC7MmZafw9c{=elkKMYjjeVW_U&<}=Pxj+E(P99esdIG zL9S8e_R-XvqjIXfK%YX1W2xH2sF95 z|7nWK%z0+wA%D}?_W`b7N3CLCEEVLu2dQG-?+G8@0I5Ukm4jLym>fBv+|!e47mDYy zKE{Q!Axyx5A^{lTlL(J zQ_T%hd-77`^Ph~=Q(5@;~Xa^na4c#mZRgyk*s87k1|R|g|g?t>0}&y zLPF-Tk`ae+5Rt41he*g)D)qa4fA@cIzwY~Sz2Db$JzuX19_@}4`PTz4qW&G$Mjgk{ zY3OTnl%H~ZI{U-Qy46j9Yy4;TspGVDqxQ{b(WpPWoqs)^I!#(CF>#4$F!)tBO$OUK z-L-Mjpjf?i_o}OFR_)1N!X;KH?8zcnq`(jD(LQHCb7pGyhip{IUh$;e*qI)ZA2kC`7p z7@ks{1uanE+ZKs$OtV9r$=aw^JWpxcDM4)L(f($)Po%I88s^`J)!JH7w#H8Z$knNp zUB{W*_|dyAx<08fF)sg-h4W+PgQ+Cc24Fg^YKa-PWf5YImhC#F!J+9yNOnSXNho(2~Yu*aSxu9SCW>%3f6SUt>vz-*1FG=#UAN2K-r=3wQk`|BpQ?Gu+BahuP z23xGG$;AD;xkyDJ%Yw?Gj)SZS(jql0oW*JZ;87oP?P!NaCxRYD+vPH7_kZRXtbAGX zP1Kf!`%$4!SuVsL*>}e30(|t3&q1350RH&r%F&;hEz$q))JJ`K`n6m1#7GSMw|&65 zwgj;6JH}7rW0MHkfYVF~ffnSj)9np%trmLGMS7wr%2|doSydTpK@^5jPR;sadd2(K z%w-}u9DH6V!LRpzS8d0&+ZNwo<1IMJ{|Bz9CmIlZE9|Aje4bkB9`}N|t;Az7F&K%v zOiDfiQlcUqW?TCIh2>yW%(!9h)sLz1!I@u9xe0JqVm7N_SqXGGBKS%Rzgf1fgjb27 zJ4Pz{>r>1ba9X1PmhLQb^a1Ay6LUN5{S~`;482*=k=Yvf3xF=i>jq0`_PO9I&akK2 zD2b*$HguK_8^S0_MHXgFX$*Y`wckz#^kz2=|8Z8oj_3RcEKc)SXdg@~-HBqM@6R(d zHa}ns{~sOgbxvm)4{|bOb8pCs3l;o_ujD7&4_lHRuyFSaBnSze0%0ux-5%L0d?G)C z#d)r(8Jrnwc|J^iy`%^TwBPxVxL(I@Of2oTPp8c^0{AFs1QwDNTc6<$sg@|=4KSlj zTR>9YOaD5Ya(+-3JrbDgAh%_ARLQ|A;7SmT1`dXhg*+K6g|^r3?eB8u^6g zXJV`kr6!6oItT|ft(1-)Q?Gfuyxws&zb5jU#5+hOBddVwlEnG(BFrNyqNtr){L`ym zW49@ls7>5$?Mgt8vDzrS64$*S!mP`!8Sr34Y!eB$ zvSQIeQ`wT&m%DzK6uQsGUErHb>_QjpjsPC{sVT2e1?63eQ|e{8CW2+=J_JkowNdLF z!ofSAB^o`Afz2s$0_{ncVVVMaSt22Amc}!Nh!;7_*fe%MW4xyhdUtUfA z!aqLa9xuTt=>u`=L*yzGDH5%U$x@0gx`yAU`34fJgoX?!sGVT4EzgbWh0IEenn)>8 z=2aEHmT4t=2DX2|`SJ3Bd&V+nJy{{nZA9lNMsFH5egtq`;mHp~tK@*4#B9Y>k%Ll; zIi*=4J~5{&+ECFTG|NnHeWw)Kx1#*DwrZdCq5;K32+qXa%ivQ-97Lct6Pd2V^nf#4 z;wm5eF1aUzk7cEM<)sM)uO%ip3JO>NRCuVUIDpnmI0X{H^m{A!{6D+UFBvEcC0+Mf z2EslnWG?Nsc@_sUiQbU)yBgNmnA8q%X^e3GVJUeCyA<puRSi<%tIM2B(t6Fy|_$mxC@l-1nf( zGgDECJnH_$=GtyIqmbfCKhZ@i@Ztcy+fZGk`;ytRA#obH$;5lUtqUEnEC_7WVBQ0y zMS5J^#Ks{)I^-Ov`Ky3qzp$nOU9;#KU1V}k4Ch;^`Z;_HL5TB`CB+1b+Q*wYP5&VC z@KgxnIE{-+T?9$^^Kq(918H*GXprl;B8%N`*E0#3lG%ye)TF@N|A^_HN_EwRhg@u6 zJVE|3kxGp}uNH>1-=Jnfu3#>7KIj8CK2p~2s9MK^X#bAn;G@Gp!i+9XZ8Wo}J|>oM z)F*O#^Pb$43|=F9Xwsz%+|%n`7M=aPAgvDlKsh%%Q?W>S!a>W=W1O&F-3+rVyk|^l z8k3n7bU5h$^U^l)Vy<3Qg+qN+-ibI3Ch|6mi&yryv6+Ik-GawHl+n~x`hxZncY@Gw zgF95=3!_UDxWG!Pxxu5fhffI|#Nlb?#B{zEs);*S5^vJJ4b%SErdV(412@(*OOAJ) z5(@TbNc5ZV*aWAW%EB|On<9S8(>$AA9e?2-SR=}`BIMjaA&0~R*Y%X6nP`t!u2ZbO z_~+#~;)8V!kSUadkIRqcS4i7<2!=yjPn+WKSsXs(UNRplhB53NR`drGHDlG_jY&$Kv3Tul<3hjB0EQDIe84Yb`T|xnAVtMt60j^1@3X+Lp z0XK6yhNR>+!7D|Z+>HGY>^|BRwYEQaz1I04POn)Y{*J~k36J!7`Csr}a^H;u_;uos zOwX=wxxUaVcL)>I&ndCXh$U<~chmes#zC?IZ0TqNCr6Ccx6dwx2Z>yh0o8D3koK^t zrcJE$9y3qBqV4+3z!L2AO~iBLcJU7lCYyiLD)dYSV3AiClPgPX7zYKT#*zF^@nmT*8Q999xaM5*7A8r0^4|n&_+RGzqVe z5}ud7tF=W}w70`4w=lAI@{l(UT=~--5iTrRWFAbCONeZI$v38_^<6ET&GobHL=W}b zm$2(2rrv9vNo2C(ShnJS0zp?IKO0z0{DzRaFMYosI6_svFFUH<-Be>xYQe?0L_3S! zqfJP2om;vi$3?6AopG#KxV(%ypqS-E#`X9yYP2RxN@-ZdXH&w8K7_okAk_O5=kKf& zI|&{f*+2OVa8%8$y40QT!6Z-YC$?JQfP|Xlui^>B!1qa_6V`;era-oh)jy2#QA_Wi zg%XrR|Cr+g)tzknT_N`|T)V-ZGKk&4IG1Wap5q;o(YS^m(MS#TH&cd+B0P>#W|1aw z80kRHxWR{g)R$BrLD1B;KZTDcHM$diKkV7y_=LG4Ulz(*RZ zJYG*4;#Y2A&6Tim--s%09jE$zU1yk}EYgnZ)~zgN@k5l)bKoAnz%Iuw^A0 z57YCVQTzP0`Co+;$eT?yo(y&63o>l-Y*gO~>#HimuP^kQD{)KcrE16a$1&kAt{BGH z1reak^4=~FXEdmgWtOp^65-MzVd)=W>jY67YEv|xr(d4-hsNmmEjJW(DR0@$;+#Lj zm{Q0&*z$p{KckR5+i1{&`rR=#_lCts?7uoB8b^Julj1!Dv=*-}GV0?=w75;4hb4c5 zCgc(GR_aQ@V;pO=Ppx~#-${jWPL#rEUNY^D{w!qu^3sD%u{EzxCd#41x8+(jzC8`8 zG$NtmKh(IqWnEH1UcU6{ZbO-X`~Y-uvFcXgkmF$Jj*JoIf66Ea=AN0`v)UFmW)fn4 zWkfO_xn9%tmNjf1$_kbaxfUSEia_Vk4ZC}g6*F54o?lQvZcB{*emIj2ue~4fCX6xp z_1Z^yIo9psxW!iP9F4aSkCC7y`nV$GLXMxYv#hiM77{Qtx5{F|at?$9t$uebwl-IK zXU;c*Bpcs0E*K7y&wloR?YAZ>Sq<)*DB23OAMeI1E8*wvwlKuhv?fY;fc*gxc1;?bH8O@d?zmYcqc1TWah~H^w|M$reyyG6pTM z&pet=7qHPDHi*HnJ%(_%@UmuC(jX4F8I1h+)GEu0XnDc?H;$wRtef4(<}Gt0d1qGZ zQu|*F_HK!WUG#1+id5R= zi<94>ree3G-9%O4t|@BAI_k1xQXBmPhQFT3+5&A2_ouV~Ug^CcKV0ySd((Q~`k@lw z{ye3MAT}TGfdb1cvjw9gG)@J?C0dG)6`?w1!~$7ym$u7+3zk~D8}qt)F8z1wQO8;S z1Do{N6mz35HdkJ5i?Awp3lRvw8!PRa3C%dLzy~#X3uW{9?24ck?^IJxXKm{6rF#^h zm~JMnYP1!wD&!ic=`=nbdG15Snh#f#s%R}%hI@gwa6_D!nk1a3&j|48p%%s2in zU+wE-n^#`>^PqUnw6`?5p|mW(n%fGx4Vov{n8D8`sluc$do);qV1`VSzV_hNg(`mV zlQ=kLh0}_rpKTbAm)&bRbV!UH$Vq%!uGWf_NsXn3W_~P|E50^aNDa&+drm1e)Ai^7 za}>@?Bpzt^?3teQbh{V{l`=i!Jv77u^WwJA5!dX^&xf*bjOOdixzY|J(yMR0FUoJg zF*|e>`&^gV@u(4o80@5sW5Alu%?s`Ot-~5NXb=<6WfNjStpfP|#{$t9*d$#;OQm{c zwEjyqI?9(1!fdhqs~3uKiIu!)ULixQ_NAD><3CKW_=806Z20gmWZOMN!5y6C-gV4v ztBZH*T$2wSiK&Td)Yv6K{W{e9%?%cG&WUJa_dZ`J=7a8U9lLKZ2GhJe6}h<3B;n&b z751+S_!P8W&1n0cExwtcewg({IpKvEtC4SMlC=9@AUsl4;ihD(ebHJ9Pft>(5oPn{ z4B*;(;N%*4+Rr(&OW7I7hj<2-{q(;Wuucf5kOfvL9D!R04&xWEj`C;dzZ}EvW|jBl z@Opc+XbT}yb>LnL^<`U<`Le*Egx2}a_H>+s*pGNRXq02zZ(QaI{erpL-!d>@S^_zmy66I1;T6eby^X*?x2Nt5vb=M zFVC*Y{`zuUl}l(0P>HA+hx>~HcX@ug!f)Q<^rhVw{12}1(!b;JR=s%Z+U?voEu&BH z1q2lz-%yDnU`L_LpWmb(Y)Ub?Uz6Q*a#35nXZ)HiUU9I7R{D}6vDKK0#AIP_{&(bP zByoMA;lhiiyxAY;7Wf*Ba_=Z*%C0+OTZe=@U`Sr1)%gk_oU&N0)>KpUB!2H6Cty*H z{?Z^Sv*lnnswHvo7thkGf-WucmE_Uv6?<oaxD_+N+< z>UD&c;Ypm5)DOb+&zxKut`MCjs+-DN+zcc<+s2>3WtGxBzItUFjaaL-AKU?gTT&f+ zs*`WO#dOAd*=E-bp#J*Y%#M8%@U#s^_<1y?;kp`n`WL1Iky+%PWl*tl^NX`tUb<(Vw&cYDmmi z0p=?w#V8x0WH>n+gqan>5i3Err?U&;t_)qF*2+tXeSDexdP|fEJZ|Le@=M_%qb{xl z*SU*^^4x2(s#6V!3Eq&wRx9zyyjUX_L+)u7`x}A8_1!mKChgTrWi2!liAHXsiS_Jj zonLGXUh`iEQeyp5OB4+!WBqi;vP7VpX;=u|)OCl`2otEwQOuh(zQDOJjcYxsxz9Lu zn@+TB(n5jI9YK*LItgMH`Ay{A@$+cg_zV4KKD>N#sjJvslJjz!u&XzNcL0CB=u-6g z-}e+|K_(QroamL(h67cbyCO&gBH1EKk$)=fvk!ZZVJ%CHE@K{VNXtx;X{EXdMcMS# zQA6PRDCy}95mVR_y^5IW&|&*h4#G(%Lxsd=-omeV<*JFLq)RUc3tEeJM78Z#t@%iR zoQT^=mB*!sCxLzzXEFnmJfDeaAedcHGye2aZ_SmVMj6WY7e+o{|LPu?T9jQBjt1G3 zSH|hD0}0IMK(;97p5vm%v*G@x(>FTO)f&XnSzrE%>1= literal 0 HcmV?d00001 From cfe268820e6ac7fefbaffa42ae0cd03ff9f9cf27 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Mon, 10 Mar 2025 14:53:34 -0700 Subject: [PATCH 09/14] Updates --- .../microbenchmarks/benchmark_inference.py | 12 +- .../microbenchmarks/benchmark_runner.py | 11 +- .../microbenchmarks/test/benchmark_config.yml | 1 - .../test/test_benchmark_runner.py | 2 +- benchmarks/microbenchmarks/test/test_utils.py | 4 +- benchmarks/microbenchmarks/utils.py | 231 ++++++++++-------- 6 files changed, 151 insertions(+), 110 deletions(-) diff --git a/benchmarks/microbenchmarks/benchmark_inference.py b/benchmarks/microbenchmarks/benchmark_inference.py index aa1a2dc12a..9aa921fd8a 100644 --- a/benchmarks/microbenchmarks/benchmark_inference.py +++ b/benchmarks/microbenchmarks/benchmark_inference.py @@ -11,13 +11,14 @@ import torch -from benchmarks.microbenchmarks.utils import ( +from utils import ( BenchmarkConfig, benchmark_model_inference_in_microseconds, clean_caches, create_model_and_input, - quantization_string_to_quantized_model, + quantization_string_to_quantization_config, ) +from torchao.quantization import quantize_ def run(config: BenchmarkConfig) -> Dict[str, float]: @@ -38,10 +39,11 @@ def run(config: BenchmarkConfig) -> Dict[str, float]: # Use quantize_ to apply each quantization function to the model m_copy = deepcopy(base_model).eval().to(config.device) - m_copy = quantization_string_to_quantized_model( - m_copy, config.quantization, high_precision_dtype=config.high_precision_dtype + quantization_config = quantization_string_to_quantization_config( + config.quantization, high_precision_dtype=config.high_precision_dtype ) - + if quantization_config: + quantize_(m_copy, quantization_config) if config.compile: print("Compiling model....") m_copy = torch.compile(m_copy, mode=config.compile_mode, fullgraph=True) diff --git a/benchmarks/microbenchmarks/benchmark_runner.py b/benchmarks/microbenchmarks/benchmark_runner.py index f980b17aae..d5a0e09e0d 100644 --- a/benchmarks/microbenchmarks/benchmark_runner.py +++ b/benchmarks/microbenchmarks/benchmark_runner.py @@ -19,7 +19,11 @@ import yaml -from benchmarks.microbenchmarks.utils import BenchmarkConfig, generate_results_csv +from utils import ( + BenchmarkConfig, + generate_results_csv, + print_results, +) def get_shapes_for_config(shape_config: Dict[str, Any]) -> List[Tuple[str, List[int]]]: @@ -64,7 +68,7 @@ def load_benchmark_configs(config_path: str) -> List[BenchmarkConfig]: def run_benchmarks_from_config(config_path: str) -> None: """Run benchmarks using configurations from YAML file""" - from benchmarks.microbenchmarks.benchmark_inference import run as run_inference + from benchmark_inference import run as run_inference configs = load_benchmark_configs(config_path) results = [] @@ -77,6 +81,9 @@ def run_benchmarks_from_config(config_path: str) -> None: # Add results to csv generate_results_csv(results, configs[0].output_dir) + # Print results + print_results(results) + # TODO: Process results: Speedups: # 1. For different shapes for same model and quantization # 2. For different quantizations for same model and shape diff --git a/benchmarks/microbenchmarks/test/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml index 4ba6e99a4f..ed5c328681 100644 --- a/benchmarks/microbenchmarks/test/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -3,7 +3,6 @@ quantization_config_recipe_names: - "baseline" - "int8wo" - "int4wo-128" - - "int4wo-128-hqq" output_dir: "benchmarks/microbenchmarks/test/results" # Directory for results and plots model_params: matrix_shapes: diff --git a/benchmarks/microbenchmarks/test/test_benchmark_runner.py b/benchmarks/microbenchmarks/test/test_benchmark_runner.py index 7bbb9609d2..e88f45994b 100644 --- a/benchmarks/microbenchmarks/test/test_benchmark_runner.py +++ b/benchmarks/microbenchmarks/test/test_benchmark_runner.py @@ -15,7 +15,7 @@ class TestBenchmarkRunner(unittest.TestCase): def setUp(self): self.config = { "quantization_config_recipe_names": ["baseline", "int8wo"], - "output_dir": "benchmarks/microbenchmarks/test/test_output", + "output_dir": "tmp", "model_params": { "matrix_shapes": [ { diff --git a/benchmarks/microbenchmarks/test/test_utils.py b/benchmarks/microbenchmarks/test/test_utils.py index 54d0b96393..b27e181921 100644 --- a/benchmarks/microbenchmarks/test/test_utils.py +++ b/benchmarks/microbenchmarks/test/test_utils.py @@ -69,7 +69,7 @@ def test_create_model_and_input(self): m=m, k=k, n=n, - dtype=torch.float32, + high_precision_dtype=torch.float32, device="cpu", ) self.assertIsInstance(model, ToyLinearModel) @@ -80,7 +80,7 @@ def test_create_model_and_input(self): m=m, k=k, n=n, - dtype=torch.float32, + high_precision_dtype=torch.float32, device="cpu", ) self.assertIsInstance(model, LNLinearSigmoid) diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index bde2f5cd9a..7fb8bf3eb4 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -1,10 +1,12 @@ import csv import os -import time from typing import Any, Dict, List import torch +from tabulate import tabulate +from torch.utils.benchmark import Timer +from torchao.core.config import AOBaseConfig from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, @@ -17,11 +19,6 @@ PerRow, PerTensor, UIntXWeightOnlyConfig, - quantize_, -) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - unwrap_tensor_subclass, ) try: @@ -48,7 +45,13 @@ def __init__( params["high_precision_dtype"] ) self.compile = params.get("compile", False) - self.compile_mode = params.get("compile_mode", "default") + # Handle compile_mode based on compile flag + if not self.compile: + self.compile_mode = None + else: + # Use provided compile_mode if exists, else use "default" + self.compile_mode = params.get("compile_mode", "default") + self.device = params.get("device", get_default_device()) self.model_type = params.get("model_type", "linear") self.output_dir = output_dir @@ -68,7 +71,7 @@ def to_dict(self) -> Dict[str, Any]: "n": self.n, "high_precision_dtype": self.high_precision_dtype, "compile": self.compile, - "compile_mode": "default", + "compile_mode": self.compile_mode, "device": self.device, "model_type": self.model_type, "output_dir": self.output_dir, @@ -111,47 +114,33 @@ def get_default_device() -> str: return "cpu" -def ffn_only(mod, fqn): - return isinstance(mod, torch.nn.Linear) and "feed_forward" in fqn - - -def not_ffn_only(mod, fqn): - return isinstance(mod, torch.nn.Linear) and not ffn_only(mod, fqn) - - -def ffn_or_attn_only(mod, fqn): - return isinstance(mod, torch.nn.Linear) and ( - "feed_forward" in fqn or "attention" in fqn - ) - - -def quantization_string_to_quantized_model( - model: torch.nn.Module, quantization: str, **kwargs -) -> torch.nn.Module: - """Quantize a model inplace or return a new quantized model. +def quantization_string_to_quantization_config( + quantization: str, **kwargs +) -> AOBaseConfig: + """Get quantization config based on quantization string. Args: - model (torch.nn.Module): model to be quantized quantization (str): quantization method to be used **kwargs: additional arguments to be passed to the quantization method + + Returns: + AOBaseConfig: Quantization configuration object """ high_precision_dtype = kwargs.get("high_precision_dtype", torch.bfloat16) if "int4wo" in quantization and not HAS_TRITON: print("Warning: Triton not available, falling back to baseline") - return model + return None # Quantization techniques if "baseline" in quantization: - return model + return None if "int8wo" in quantization: - quantize_(model, Int8WeightOnlyConfig()) + return Int8WeightOnlyConfig() if "int8dq" in quantization: if "int8dq_prefill_wo_decode" in quantization: - quantize_( - model, Int8DynamicActivationInt8WeightConfig(weight_only_decode=True) - ) + return Int8DynamicActivationInt8WeightConfig(weight_only_decode=True) else: - quantize_(model, Int8DynamicActivationInt8WeightConfig()) + return Int8DynamicActivationInt8WeightConfig() if "int4wo" in quantization: use_hqq = False if "hqq" in quantization: @@ -163,40 +152,28 @@ def quantization_string_to_quantized_model( 128, 256, ], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" - quantize_(model, Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq)) + return Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq) elif "int8adq-int4w-symm" in quantization: from torchao.dtypes import CutlassInt4PackedLayout - quantize_( - model, - Int8DynamicActivationInt4WeightConfig( - group_size=None, - mapping_type=MappingType.SYMMETRIC, - act_mapping_type=MappingType.SYMMETRIC, - layout=CutlassInt4PackedLayout(), - ), + return Int8DynamicActivationInt4WeightConfig( + group_size=None, + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, + layout=CutlassInt4PackedLayout(), ) if "marlin" in quantization: if "qqq" in quantization: from torchao.dtypes import MarlinQQQLayout - quantize_( - model, - Int8DynamicActivationInt4WeightConfig( - group_size=128, - mapping_type=MappingType.SYMMETRIC, - act_mapping_type=MappingType.SYMMETRIC, - layout=MarlinQQQLayout(), - ), + return Int8DynamicActivationInt4WeightConfig( + group_size=128, + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, + layout=MarlinQQQLayout(), ) if "fp6" in quantization: - quantize_(model, FPXWeightOnlyConfig(3, 2)) - elif "embed-int8wo" in quantization: - quantize_( - model, - Int8WeightOnlyConfig(group_size=64), - filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding), - ) + return FPXWeightOnlyConfig(3, 2) elif "uintx" in quantization: # uintx-nbits-group_size, e.g. "uintx-2-64" if "hqq" in quantization: @@ -219,10 +196,10 @@ def quantization_string_to_quantized_model( } dtype = _NBITS_TO_DTYPE[nbits] group_size = int(_quant_args[2]) - quantize_(model, UIntXWeightOnlyConfig(dtype, group_size, use_hqq=use_hqq)) + return UIntXWeightOnlyConfig(dtype, group_size, use_hqq=use_hqq) elif "int8_dynamic_activation_intx_weight" in quantization: from torchao.experimental.quant_api import ( - int8_dynamic_activation_intx_weight, + Int8DynamicActivationIntxWeightConfig, ) from torchao.quantization.granularity import PerGroup @@ -235,16 +212,13 @@ def quantization_string_to_quantized_model( weight_dtype = getattr(torch, f"int{_quant_args[1]}") granularity = PerGroup(int(_quant_args[2])) has_weight_zeros = bool(_quant_args[3]) - quantize_( - model, - int8_dynamic_activation_intx_weight( - weight_dtype=weight_dtype, - granularity=granularity, - has_weight_zeros=has_weight_zeros, - ), + return Int8DynamicActivationIntxWeightConfig( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, ) elif "float8wo" in quantization: - quantize_(model, Float8WeightOnlyConfig()) + return Float8WeightOnlyConfig() elif "float8dq" in quantization: granularity = str(quantization.split("-")[-1]) if granularity == "tensor": @@ -253,37 +227,37 @@ def quantization_string_to_quantized_model( granularity = PerRow() else: granularity = PerTensor() - quantize_( - model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity) - ) - else: - if not TORCH_VERSION_AT_LEAST_2_5: - unwrap_tensor_subclass(model) - return model + return Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + return None -# Function to benchmark model evaluation - e2e eval run +@torch.no_grad() def benchmark_model_inference_in_microseconds(model, input_data): - # Returns model run time in seconds - if torch.cuda.is_available(): - torch.cuda.synchronize() + """Benchmark model inference time without compile overhead. + + Args: + model: The model to benchmark + input_data: Input data for the model + + Returns: + float: Median inference time in microseconds + """ + # First run to trigger any compilation/lazy initialization - # warm up - for _ in range(2): - model(input_data) - if torch.cuda.is_available(): - torch.cuda.synchronize() + timer = Timer( + stmt="model(input_data)", + globals={"model": model, "input_data": input_data}, + num_threads=1, + ) - num_iters = 5 - start_time = time.perf_counter() - with torch.no_grad(): - for _ in range(num_iters): - _ = model(input_data) - if torch.cuda.is_available(): - torch.cuda.synchronize() - end_time = time.perf_counter() + # warmup + timer.timeit(number=100) + # actual measurement + measurement = timer.timeit(number=100) + res = measurement.mean - return ((end_time - start_time) / num_iters) * 1e6 + # Convert to microseconds + return res * 1e6 def create_model_and_input( @@ -291,7 +265,7 @@ def create_model_and_input( m: int, k: int, n: int, - dtype: torch.dtype = torch.bfloat16, + high_precision_dtype: torch.dtype = torch.bfloat16, device: str = get_default_device(), ): """Create a model and input data for benchmarking. @@ -300,15 +274,15 @@ def create_model_and_input( model_type (str): type of the model to be created batch_size (int): batch size of the input data device (str): device to run the model on - dtype (torch.dtype): data + high_precision_dtype (torch.dtype): data type of the model m, k, n (int): dimensions of the model and input data """ if model_type == "linear": - model = ToyLinearModel(k, n, dtype).to(device) - input_data = torch.randn(m, k, device=device, dtype=dtype) + model = ToyLinearModel(k, n, high_precision_dtype).to(device) + input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) elif model_type == "ln_linear_sigmoid": - model = LNLinearSigmoid(k, n, dtype).to(device) - input_data = torch.randn(m, k, device=device, dtype=dtype) + model = LNLinearSigmoid(k, n, high_precision_dtype).to(device) + input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) else: raise ValueError(f"Unknown model type: {model_type}") return model, input_data @@ -354,3 +328,62 @@ def generate_results_csv( writer.writerow(result.values()) print(f"Results saved to {file_path}") + + +def print_results(results: List[Dict[str, Any]]): + """Print benchmark results in a formatted table. + + Args: + results (List[Dict[str, Any]]): List of benchmark results + """ + if not results: + print("No results to display") + return + + # Extract relevant columns for display + display_columns = [ + "quantization", + "model_type", + "m", + "k", + "n", + "benchmark_model_inference_in_microseconds", + "compile", + ] + + # Format data for tabulate + headers = { + "quantization": "Quantization", + "model_type": "Model Type", + "m": "M", + "k": "K", + "n": "N", + "benchmark_model_inference_in_microseconds": "Time (μs)", + "compile": "Compile", + } + + # Extract and format data + table_data = [] + for result in results: + row = [] + for col in display_columns: + value = result.get(col, "N/A") + if col == "benchmark_model_inference_in_microseconds": + value = f"{value:.2f}" if isinstance(value, (int, float)) else value + elif col == "compile": + # Show compile mode if compile is True, otherwise show False + value = result.get("compile_mode", "default") if value else "False" + row.append(value) + table_data.append(row) + + # Print formatted table + print("\nBenchmark Results:") + print( + tabulate( + table_data, + headers=[headers[col] for col in display_columns], + tablefmt="grid", + floatfmt=".2f", + ) + ) + print() From d3935acaf4ae6badfa41a0c57172f9a0c6ed0158 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Tue, 11 Mar 2025 15:10:41 -0700 Subject: [PATCH 10/14] Add sparsity support to microbenchmarking --- .gitignore | 3 ++ .../microbenchmarks/benchmark_inference.py | 4 +- .../microbenchmarks/test/benchmark_config.yml | 4 +- benchmarks/microbenchmarks/utils.py | 54 +++++++++++++++++-- 4 files changed, 59 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index d8c3199a1e..5b04d6e287 100644 --- a/.gitignore +++ b/.gitignore @@ -376,3 +376,6 @@ checkpoints/ # Experimental torchao/experimental/cmake-out torchao/experimental/deps + +# Benchmark outputs +benchmarks/microbenchmarks/test/results/ diff --git a/benchmarks/microbenchmarks/benchmark_inference.py b/benchmarks/microbenchmarks/benchmark_inference.py index 9aa921fd8a..550713971f 100644 --- a/benchmarks/microbenchmarks/benchmark_inference.py +++ b/benchmarks/microbenchmarks/benchmark_inference.py @@ -40,7 +40,9 @@ def run(config: BenchmarkConfig) -> Dict[str, float]: # Use quantize_ to apply each quantization function to the model m_copy = deepcopy(base_model).eval().to(config.device) quantization_config = quantization_string_to_quantization_config( - config.quantization, high_precision_dtype=config.high_precision_dtype + config.quantization, + config.sparsity, + high_precision_dtype=config.high_precision_dtype ) if quantization_config: quantize_(m_copy, quantization_config) diff --git a/benchmarks/microbenchmarks/test/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml index ed5c328681..10fc594df7 100644 --- a/benchmarks/microbenchmarks/test/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -1,8 +1,7 @@ # Sample configuration for inference kernel benchmarks quantization_config_recipe_names: - "baseline" - - "int8wo" - - "int4wo-128" + - "int8dq" output_dir: "benchmarks/microbenchmarks/test/results" # Directory for results and plots model_params: matrix_shapes: @@ -17,3 +16,4 @@ model_params: compile_mode: "max-autotune" device: "cuda" # Change this to "cuda", "mps", "xpu", or "cpu" as needed model_type: "linear" + sparsity: "2:4" diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 7fb8bf3eb4..b2d6becd71 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -7,6 +7,10 @@ from torch.utils.benchmark import Timer from torchao.core.config import AOBaseConfig +from torchao.dtypes import ( + MarlinSparseLayout, + SemiSparseLayout, +) from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, @@ -52,10 +56,18 @@ def __init__( # Use provided compile_mode if exists, else use "default" self.compile_mode = params.get("compile_mode", "default") + # Add sparsity configuration + self.sparsity = params.get("sparsity", None) # Default to None if not specified + self.device = params.get("device", get_default_device()) self.model_type = params.get("model_type", "linear") self.output_dir = output_dir - self.name = f"benchmark_{self.quantization}_{self.model_type}_m{self.m}_k{self.k}_n{self.n}{'_compile' if self.compile else ''}" + # Update name to include sparsity info + self.name = ( + f"benchmark_{self.quantization}_{self.model_type}_m{self.m}_k{self.k}_n{self.n}" + + (f"_sparse_{self.sparsity}" if self.sparsity else "") + + ("_compile" if self.compile else "") + ) @staticmethod def _parse_precision(precision_str: str) -> torch.dtype: @@ -75,6 +87,7 @@ def to_dict(self) -> Dict[str, Any]: "device": self.device, "model_type": self.model_type, "output_dir": self.output_dir, + "sparsity": self.sparsity, } @@ -115,7 +128,7 @@ def get_default_device() -> str: def quantization_string_to_quantization_config( - quantization: str, **kwargs + quantization: str, sparsity: str, **kwargs ) -> AOBaseConfig: """Get quantization config based on quantization string. @@ -127,6 +140,7 @@ def quantization_string_to_quantization_config( AOBaseConfig: Quantization configuration object """ high_precision_dtype = kwargs.get("high_precision_dtype", torch.bfloat16) + if "int4wo" in quantization and not HAS_TRITON: print("Warning: Triton not available, falling back to baseline") return None @@ -137,11 +151,17 @@ def quantization_string_to_quantization_config( if "int8wo" in quantization: return Int8WeightOnlyConfig() if "int8dq" in quantization: + if "2:4" or "semi" in sparsity: + return Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()) if "int8dq_prefill_wo_decode" in quantization: return Int8DynamicActivationInt8WeightConfig(weight_only_decode=True) else: return Int8DynamicActivationInt8WeightConfig() if "int4wo" in quantization: + if "2:4" or "semi" in sparsity: + layout = MarlinSparseLayout() + else: + layout = None use_hqq = False if "hqq" in quantization: use_hqq = True @@ -152,7 +172,9 @@ def quantization_string_to_quantization_config( 128, 256, ], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" - return Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq) + return Int4WeightOnlyConfig( + group_size=group_size, use_hqq=use_hqq, layout=layout + ) elif "int8adq-int4w-symm" in quantization: from torchao.dtypes import CutlassInt4PackedLayout @@ -231,6 +253,30 @@ def quantization_string_to_quantization_config( return None +def sparsity_string_to_sparsity_config(sparsity: str) -> Dict[str, Any]: + """Convert sparsity string to sparsity config. + + Args: + sparsity (str): sparsity string to be converted + + """ + from torchao.sparsity import ( + block_sparse_weight, + semi_sparse_weight, + ) + + if sparsity is None: + return None + + # Parse the sparsity string + if "2:4" in sparsity or "semi" in sparsity: + return semi_sparse_weight + if "block" in sparsity: + return block_sparse_weight + else: + raise ValueError(f"Unknown sparsity: {sparsity}") + + @torch.no_grad() def benchmark_model_inference_in_microseconds(model, input_data): """Benchmark model inference time without compile overhead. @@ -347,6 +393,7 @@ def print_results(results: List[Dict[str, Any]]): "m", "k", "n", + "sparsity", "benchmark_model_inference_in_microseconds", "compile", ] @@ -358,6 +405,7 @@ def print_results(results: List[Dict[str, Any]]): "m": "M", "k": "K", "n": "N", + "sparsity": "Sparsity", "benchmark_model_inference_in_microseconds": "Time (μs)", "compile": "Compile", } From 08c5a0c638f027147a86afbd8fcf4013e6801603 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Tue, 11 Mar 2025 15:10:41 -0700 Subject: [PATCH 11/14] Add sparsity support to microbenchmarking --- benchmarks/microbenchmarks/test/benchmark_config.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/microbenchmarks/test/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml index 10fc594df7..a2715cba4d 100644 --- a/benchmarks/microbenchmarks/test/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -2,6 +2,7 @@ quantization_config_recipe_names: - "baseline" - "int8dq" + - "int4wo-128" output_dir: "benchmarks/microbenchmarks/test/results" # Directory for results and plots model_params: matrix_shapes: @@ -12,8 +13,7 @@ model_params: [4096, 4096, 1024] ] high_precision_dtype: "torch.bfloat16" - compile: true - compile_mode: "max-autotune" + compile: "max-autotune" device: "cuda" # Change this to "cuda", "mps", "xpu", or "cpu" as needed model_type: "linear" sparsity: "2:4" From 22b3ddde67c8381ffbeb4c8dc22ccc08a062b640 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 11 Mar 2025 14:03:38 -0700 Subject: [PATCH 12/14] Minor fix --- .gitignore | 3 +++ benchmarks/microbenchmarks/README.md | 3 +-- benchmarks/microbenchmarks/benchmark_inference.py | 9 ++++++--- benchmarks/microbenchmarks/benchmark_runner.py | 1 - benchmarks/microbenchmarks/test/benchmark_config.yml | 3 +-- benchmarks/microbenchmarks/test/test_utils.py | 6 ++---- benchmarks/microbenchmarks/utils.py | 12 ++---------- 7 files changed, 15 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index d8c3199a1e..730efa6e85 100644 --- a/.gitignore +++ b/.gitignore @@ -376,3 +376,6 @@ checkpoints/ # Experimental torchao/experimental/cmake-out torchao/experimental/deps + +# Benchmarking results +benchmarks/microbenchmarks/test/results/ diff --git a/benchmarks/microbenchmarks/README.md b/benchmarks/microbenchmarks/README.md index 32333144d0..b80858b499 100644 --- a/benchmarks/microbenchmarks/README.md +++ b/benchmarks/microbenchmarks/README.md @@ -47,8 +47,7 @@ model_params: [4096, 4096, 1024] ] high_precision_dtype: "torch.bfloat16" - compile: true - compile_mode: "max-autotune" + compile: "max-autotune" # Options: "default", "max-autotune", "false" device: "cuda" # Options: "cuda", "mps", "xpu", "cpu" model_type: "linear" # Options: "linear", "ln_linear_sigmoid" ``` diff --git a/benchmarks/microbenchmarks/benchmark_inference.py b/benchmarks/microbenchmarks/benchmark_inference.py index 9aa921fd8a..44339c94f5 100644 --- a/benchmarks/microbenchmarks/benchmark_inference.py +++ b/benchmarks/microbenchmarks/benchmark_inference.py @@ -10,7 +10,6 @@ from typing import Dict import torch - from utils import ( BenchmarkConfig, benchmark_model_inference_in_microseconds, @@ -18,6 +17,7 @@ create_model_and_input, quantization_string_to_quantization_config, ) + from torchao.quantization import quantize_ @@ -44,9 +44,12 @@ def run(config: BenchmarkConfig) -> Dict[str, float]: ) if quantization_config: quantize_(m_copy, quantization_config) - if config.compile: + if config.compile is not "false": print("Compiling model....") - m_copy = torch.compile(m_copy, mode=config.compile_mode, fullgraph=True) + if compile == "true": + m_copy = torch.compile(m_copy, fullgraph=True) + else: + m_copy = torch.compile(m_copy, mode=config.compile, fullgraph=True) # Run benchmarks result = {**config.to_dict()} diff --git a/benchmarks/microbenchmarks/benchmark_runner.py b/benchmarks/microbenchmarks/benchmark_runner.py index d5a0e09e0d..a167108806 100644 --- a/benchmarks/microbenchmarks/benchmark_runner.py +++ b/benchmarks/microbenchmarks/benchmark_runner.py @@ -18,7 +18,6 @@ from typing import Any, Dict, List, Tuple import yaml - from utils import ( BenchmarkConfig, generate_results_csv, diff --git a/benchmarks/microbenchmarks/test/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml index ed5c328681..5357d60db4 100644 --- a/benchmarks/microbenchmarks/test/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -13,7 +13,6 @@ model_params: [4096, 4096, 1024] ] high_precision_dtype: "torch.bfloat16" - compile: true - compile_mode: "max-autotune" + compile: "max-autotune" device: "cuda" # Change this to "cuda", "mps", "xpu", or "cpu" as needed model_type: "linear" diff --git a/benchmarks/microbenchmarks/test/test_utils.py b/benchmarks/microbenchmarks/test/test_utils.py index b27e181921..9f663fefb6 100644 --- a/benchmarks/microbenchmarks/test/test_utils.py +++ b/benchmarks/microbenchmarks/test/test_utils.py @@ -18,8 +18,7 @@ class TestUtils(unittest.TestCase): def test_benchmark_config(self): params = { "high_precision_dtype": "torch.bfloat16", - "compile": True, - "compile_mode": "max-autotune", + "compile": "max-autotune", "device": "cuda", "model_type": "linear", } @@ -36,8 +35,7 @@ def test_benchmark_config(self): self.assertEqual(config.k, 1024) self.assertEqual(config.n, 1024) self.assertEqual(config.high_precision_dtype, torch.bfloat16) - self.assertEqual(config.compile, True) - self.assertEqual(config.compile_mode, "max-autotune") + self.assertEqual(config.compile, "max-autotune") self.assertEqual(config.device, "cuda") self.assertEqual(config.model_type, "linear") self.assertEqual(config.output_dir, "test_output") diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 7fb8bf3eb4..fa816dc905 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -44,14 +44,7 @@ def __init__( self.high_precision_dtype = self._parse_precision( params["high_precision_dtype"] ) - self.compile = params.get("compile", False) - # Handle compile_mode based on compile flag - if not self.compile: - self.compile_mode = None - else: - # Use provided compile_mode if exists, else use "default" - self.compile_mode = params.get("compile_mode", "default") - + self.compile = params.get("compile", "false") self.device = params.get("device", get_default_device()) self.model_type = params.get("model_type", "linear") self.output_dir = output_dir @@ -71,7 +64,6 @@ def to_dict(self) -> Dict[str, Any]: "n": self.n, "high_precision_dtype": self.high_precision_dtype, "compile": self.compile, - "compile_mode": self.compile_mode, "device": self.device, "model_type": self.model_type, "output_dir": self.output_dir, @@ -372,7 +364,7 @@ def print_results(results: List[Dict[str, Any]]): value = f"{value:.2f}" if isinstance(value, (int, float)) else value elif col == "compile": # Show compile mode if compile is True, otherwise show False - value = result.get("compile_mode", "default") if value else "False" + value = result.get("compile", "default") if value else "False" row.append(value) table_data.append(row) From 0a2499c8f5b3eaa001ca41d466fa66e1f43021de Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 12 Mar 2025 23:48:15 -0700 Subject: [PATCH 13/14] Updates --- .gitignore | 3 - benchmarks/__init__.py | 0 benchmarks/microbenchmarks/README.md | 1 + .../microbenchmarks/benchmark_inference.py | 22 +++--- .../microbenchmarks/benchmark_runner.py | 28 ++++++-- .../microbenchmarks/test/benchmark_config.yml | 3 +- .../test/test_benchmark_inference.py | 10 +-- .../test/test_benchmark_runner.py | 8 +-- benchmarks/microbenchmarks/test/test_utils.py | 39 ++++++----- benchmarks/microbenchmarks/utils.py | 68 +++++++++++++------ 10 files changed, 112 insertions(+), 70 deletions(-) create mode 100644 benchmarks/__init__.py diff --git a/.gitignore b/.gitignore index 730efa6e85..d8c3199a1e 100644 --- a/.gitignore +++ b/.gitignore @@ -376,6 +376,3 @@ checkpoints/ # Experimental torchao/experimental/cmake-out torchao/experimental/deps - -# Benchmarking results -benchmarks/microbenchmarks/test/results/ diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/benchmarks/microbenchmarks/README.md b/benchmarks/microbenchmarks/README.md index b80858b499..a95dc53755 100644 --- a/benchmarks/microbenchmarks/README.md +++ b/benchmarks/microbenchmarks/README.md @@ -55,6 +55,7 @@ model_params: ## Configuration Options ### Quantization Methods +Currently, quantization string is in same format as the one being passed in llama/generate.py. - `baseline`: No quantization - `int8wo`: 8-bit weight-only quantization - `int4wo-{group_size}`: 4-bit weight-only quantization with specified group size diff --git a/benchmarks/microbenchmarks/benchmark_inference.py b/benchmarks/microbenchmarks/benchmark_inference.py index 44339c94f5..9bcb6c3efe 100644 --- a/benchmarks/microbenchmarks/benchmark_inference.py +++ b/benchmarks/microbenchmarks/benchmark_inference.py @@ -7,21 +7,21 @@ from copy import deepcopy from pathlib import Path -from typing import Dict import torch -from utils import ( + +from benchmarks.microbenchmarks.utils import ( BenchmarkConfig, - benchmark_model_inference_in_microseconds, + BenchmarkResult, clean_caches, create_model_and_input, + model_inference_time_in_ms, quantization_string_to_quantization_config, ) - from torchao.quantization import quantize_ -def run(config: BenchmarkConfig) -> Dict[str, float]: +def run(config: BenchmarkConfig) -> BenchmarkResult: """Run inference benchmarks""" clean_caches() # Clean caches @@ -44,21 +44,17 @@ def run(config: BenchmarkConfig) -> Dict[str, float]: ) if quantization_config: quantize_(m_copy, quantization_config) - if config.compile is not "false": + if config.use_torch_compile: print("Compiling model....") - if compile == "true": - m_copy = torch.compile(m_copy, fullgraph=True) - else: - m_copy = torch.compile(m_copy, mode=config.compile, fullgraph=True) + m_copy = torch.compile(m_copy, mode=config.torch_compile_mode, fullgraph=True) # Run benchmarks - result = {**config.to_dict()} + result = BenchmarkResult(config=config) # Benchmark time to run an inference call for quantized model - model_time = benchmark_model_inference_in_microseconds( + result.model_inference_time_in_ms = model_inference_time_in_ms( model=m_copy, input_data=input_data ) - result["benchmark_model_inference_in_microseconds"] = model_time # TODO: Benchmark time using profiler # Profile dtype model evaluation diff --git a/benchmarks/microbenchmarks/benchmark_runner.py b/benchmarks/microbenchmarks/benchmark_runner.py index a167108806..86ed82cf09 100644 --- a/benchmarks/microbenchmarks/benchmark_runner.py +++ b/benchmarks/microbenchmarks/benchmark_runner.py @@ -18,7 +18,8 @@ from typing import Any, Dict, List, Tuple import yaml -from utils import ( + +from benchmarks.microbenchmarks.utils import ( BenchmarkConfig, generate_results_csv, print_results, @@ -65,9 +66,9 @@ def load_benchmark_configs(config_path: str) -> List[BenchmarkConfig]: return configs -def run_benchmarks_from_config(config_path: str) -> None: +def run_inference_benchmarks_from_config(config_path: str) -> None: """Run benchmarks using configurations from YAML file""" - from benchmark_inference import run as run_inference + from benchmarks.microbenchmarks.benchmark_inference import run as run_inference configs = load_benchmark_configs(config_path) results = [] @@ -99,5 +100,24 @@ def run_benchmarks_from_config(config_path: str) -> None: required=True, help="Path to benchmark configuration file", ) + parser.add_argument( + "--benchmark_mode", + "-m", + type=str, + default="inference", + choices=["inference", "training"], + help="Benchmark mode to run: inference or training", + ) args = parser.parse_args() - run_benchmarks_from_config(args.config) + + # Run benchmarks + if args.benchmark_mode == "inference": + run_inference_benchmarks_from_config(args.config) + elif args.benchmark_mode == "training": + print("Training mode not implemented yet") + else: + raise ValueError( + f"Invalid benchmark mode: {args.benchmark_mode}, choose from inference or training" + ) + + # TODO: Add support for args to override config values and run smaller benchmarks diff --git a/benchmarks/microbenchmarks/test/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml index 5357d60db4..4355bf0b0a 100644 --- a/benchmarks/microbenchmarks/test/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -13,6 +13,7 @@ model_params: [4096, 4096, 1024] ] high_precision_dtype: "torch.bfloat16" - compile: "max-autotune" + use_torch_compile: true + torch_compile_mode: "max-autotune" device: "cuda" # Change this to "cuda", "mps", "xpu", or "cpu" as needed model_type: "linear" diff --git a/benchmarks/microbenchmarks/test/test_benchmark_inference.py b/benchmarks/microbenchmarks/test/test_benchmark_inference.py index ceb3a25a74..9f030cfb08 100644 --- a/benchmarks/microbenchmarks/test/test_benchmark_inference.py +++ b/benchmarks/microbenchmarks/test/test_benchmark_inference.py @@ -8,7 +8,7 @@ class TestBenchmarkInference(unittest.TestCase): def setUp(self): self.params = { "high_precision_dtype": "torch.float32", # Use float32 for testing - "compile": False, + "use_torch_compile": False, "device": "cpu", # Use CPU for testing "model_type": "linear", } @@ -23,13 +23,9 @@ def setUp(self): def test_run_inference(self): result = run(self.config) - # Check result contains all config attributes - for key in self.config.to_dict(): - self.assertIn(key, result) - # Check benchmark result is present and reasonable - self.assertIn("benchmark_model_inference_in_microseconds", result) - self.assertGreater(result["benchmark_model_inference_in_microseconds"], 0) + self.assertTrue(hasattr(result, "model_inference_time_in_ms")) + self.assertGreater(result.model_inference_time_in_ms, 0) if __name__ == "__main__": diff --git a/benchmarks/microbenchmarks/test/test_benchmark_runner.py b/benchmarks/microbenchmarks/test/test_benchmark_runner.py index e88f45994b..b626f8a929 100644 --- a/benchmarks/microbenchmarks/test/test_benchmark_runner.py +++ b/benchmarks/microbenchmarks/test/test_benchmark_runner.py @@ -7,7 +7,7 @@ from benchmarks.microbenchmarks.benchmark_runner import ( get_shapes_for_config, load_benchmark_configs, - run_benchmarks_from_config, + run_inference_benchmarks_from_config, ) @@ -24,7 +24,7 @@ def setUp(self): } ], "high_precision_dtype": "torch.float32", - "compile": False, + "use_torch_compile": False, "device": "cpu", "model_type": "linear", }, @@ -72,8 +72,8 @@ def test_load_benchmark_configs(self): self.assertEqual(configs[0].quantization, "baseline") self.assertEqual(configs[1].quantization, "int8wo") - def test_run_benchmarks_from_config(self): - run_benchmarks_from_config(self.config_path) + def test_run_inference_benchmarks_from_config(self): + run_inference_benchmarks_from_config(self.config_path) results_file = os.path.join(self.config["output_dir"], "results.csv") self.assertTrue(os.path.exists(results_file)) diff --git a/benchmarks/microbenchmarks/test/test_utils.py b/benchmarks/microbenchmarks/test/test_utils.py index 9f663fefb6..272dac5e51 100644 --- a/benchmarks/microbenchmarks/test/test_utils.py +++ b/benchmarks/microbenchmarks/test/test_utils.py @@ -6,6 +6,7 @@ from benchmarks.microbenchmarks.utils import ( BenchmarkConfig, + BenchmarkResult, LNLinearSigmoid, ToyLinearModel, clean_caches, @@ -18,7 +19,8 @@ class TestUtils(unittest.TestCase): def test_benchmark_config(self): params = { "high_precision_dtype": "torch.bfloat16", - "compile": "max-autotune", + "use_torch_compile": True, + "torch_compile_mode": "max-autotune", "device": "cuda", "model_type": "linear", } @@ -35,7 +37,8 @@ def test_benchmark_config(self): self.assertEqual(config.k, 1024) self.assertEqual(config.n, 1024) self.assertEqual(config.high_precision_dtype, torch.bfloat16) - self.assertEqual(config.compile, "max-autotune") + self.assertEqual(config.use_torch_compile, True) + self.assertEqual(config.torch_compile_mode, "max-autotune") self.assertEqual(config.device, "cuda") self.assertEqual(config.model_type, "linear") self.assertEqual(config.output_dir, "test_output") @@ -86,20 +89,24 @@ def test_create_model_and_input(self): def test_generate_results_csv(self): results = [ - { - "quantization": "int8wo", - "m": 1024, - "k": 1024, - "n": 1024, - "time_us": 100.0, - }, - { - "quantization": "int4wo", - "m": 1024, - "k": 1024, - "n": 1024, - "time_us": 50.0, - }, + BenchmarkResult( + BenchmarkConfig( + quantization="int8wo", + params={}, + shape_name="custom", + shape=[1024, 1024, 1024], + output_dir="test_output", + ), + ), + BenchmarkResult( + BenchmarkConfig( + quantization="int4wo", + params={}, + shape_name="custom", + shape=[1024, 1024, 1024], + output_dir="test_output", + ), + ), ] with tempfile.TemporaryDirectory() as tmp_dir: diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index fa816dc905..0f6dab9ece 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -32,7 +32,7 @@ class BenchmarkConfig: def __init__( self, - quantization: str, + quantization: str, # Quantization string format is similar to the format being used for llama/generate.py params: Dict[str, Any], shape_name: str, shape: List[int], @@ -42,13 +42,14 @@ def __init__( self.m, self.k, self.n = shape self.shape_name = shape_name self.high_precision_dtype = self._parse_precision( - params["high_precision_dtype"] + params.get("high_precision_dtype", "torch.bfloat16") ) - self.compile = params.get("compile", "false") + self.use_torch_compile = bool(params.get("use_torch_compile", False)) + self.torch_compile_mode = params.get("torch_compile_mode", "default") self.device = params.get("device", get_default_device()) self.model_type = params.get("model_type", "linear") self.output_dir = output_dir - self.name = f"benchmark_{self.quantization}_{self.model_type}_m{self.m}_k{self.k}_n{self.n}{'_compile' if self.compile else ''}" + self.name = f"benchmark_{self.quantization}_{self.model_type}_m{self.m}_k{self.k}_n{self.n}{'_compile' if self.use_torch_compile else ''}" @staticmethod def _parse_precision(precision_str: str) -> torch.dtype: @@ -63,13 +64,31 @@ def to_dict(self) -> Dict[str, Any]: "k": self.k, "n": self.n, "high_precision_dtype": self.high_precision_dtype, - "compile": self.compile, + "use_torch_compile": self.use_torch_compile, + "torch_compile_mode": self.torch_compile_mode, "device": self.device, "model_type": self.model_type, "output_dir": self.output_dir, } +class BenchmarkResult: + def __init__( + self, + config: BenchmarkConfig, + ): + self.config = config + self.output_dir = config.output_dir + self.model_inference_time_in_ms = 0.0 + + def to_dict(self) -> Dict[str, Any]: + """Convert result to dictionary for main function""" + return { + **self.config.to_dict(), + "model_inference_time_in_ms": self.model_inference_time_in_ms, + } + + class ToyLinearModel(torch.nn.Module): def __init__(self, k=64, n=32, dtype=torch.bfloat16): super().__init__() @@ -112,8 +131,8 @@ def quantization_string_to_quantization_config( """Get quantization config based on quantization string. Args: - quantization (str): quantization method to be used - **kwargs: additional arguments to be passed to the quantization method + quantization (str): Quantization method to be used. The quantiation string format is similar to the format being used for llama/generate.py. + **kwargs: Additional arguments to be passed to the quantization method Returns: AOBaseConfig: Quantization configuration object @@ -224,7 +243,7 @@ def quantization_string_to_quantization_config( @torch.no_grad() -def benchmark_model_inference_in_microseconds(model, input_data): +def model_inference_time_in_ms(model, input_data): """Benchmark model inference time without compile overhead. Args: @@ -295,14 +314,14 @@ def clean_caches(): def generate_results_csv( - results: List[Dict[str, Any]], + results: List[BenchmarkResult], output_dir: str, file_name: str = "results.csv", ): """Generate a CSV file with the results of the benchmarking. Args: - results (List[Dict[str, Any]]): List Dictionary containing the results of the benchmarking with the config. + results (List[BenchmarkResult]): List Dictionary containing the results of the benchmarking with the config. output_dir (str): Directory to save the CSV file. file_name (str, optional): Name of the CSV file. Defaults to "results.csv". """ @@ -314,19 +333,19 @@ def generate_results_csv( with open(file_path, "w", newline="") as csvfile: writer = csv.writer(csvfile) # Write the header row - header = results[0].keys() + header = results[0].to_dict().keys() writer.writerow(header) for result in results: - writer.writerow(result.values()) + writer.writerow(result.to_dict().values()) print(f"Results saved to {file_path}") -def print_results(results: List[Dict[str, Any]]): +def print_results(results: List[BenchmarkResult]): """Print benchmark results in a formatted table. Args: - results (List[Dict[str, Any]]): List of benchmark results + results (List[BenchmarkResult]): List of benchmark results """ if not results: print("No results to display") @@ -339,8 +358,8 @@ def print_results(results: List[Dict[str, Any]]): "m", "k", "n", - "benchmark_model_inference_in_microseconds", - "compile", + "model_inference_time_in_ms", + "use_torch_compile", ] # Format data for tabulate @@ -350,21 +369,26 @@ def print_results(results: List[Dict[str, Any]]): "m": "M", "k": "K", "n": "N", - "benchmark_model_inference_in_microseconds": "Time (μs)", - "compile": "Compile", + "model_inference_time_in_ms": "Time (μs)", + "use_torch_compile": "Compile Mode", } # Extract and format data table_data = [] for result in results: + result_dict = result.to_dict() row = [] for col in display_columns: - value = result.get(col, "N/A") - if col == "benchmark_model_inference_in_microseconds": + value = result_dict.get(col, "N/A") + if col == "model_inference_time_in_ms": value = f"{value:.2f}" if isinstance(value, (int, float)) else value - elif col == "compile": + elif col == "use_torch_compile": # Show compile mode if compile is True, otherwise show False - value = result.get("compile", "default") if value else "False" + value = ( + result_dict.get("torch_compile_mode", "default") + if result_dict.get("use_torch_compile") + else "False" + ) row.append(value) table_data.append(row) From 2193189292fe7ed1d270bb0a2b2a91850445da14 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Thu, 13 Mar 2025 10:30:48 -0700 Subject: [PATCH 14/14] Updates --- benchmarks/microbenchmarks/benchmark_inference.py | 6 +++--- benchmarks/microbenchmarks/utils.py | 4 +--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/benchmarks/microbenchmarks/benchmark_inference.py b/benchmarks/microbenchmarks/benchmark_inference.py index 9bcb6c3efe..2950e4529b 100644 --- a/benchmarks/microbenchmarks/benchmark_inference.py +++ b/benchmarks/microbenchmarks/benchmark_inference.py @@ -16,7 +16,7 @@ clean_caches, create_model_and_input, model_inference_time_in_ms, - quantization_string_to_quantization_config, + string_to_config, ) from torchao.quantization import quantize_ @@ -39,10 +39,10 @@ def run(config: BenchmarkConfig) -> BenchmarkResult: # Use quantize_ to apply each quantization function to the model m_copy = deepcopy(base_model).eval().to(config.device) - quantization_config = quantization_string_to_quantization_config( + quantization_config = string_to_config( config.quantization, high_precision_dtype=config.high_precision_dtype ) - if quantization_config: + if quantization_config is not None: quantize_(m_copy, quantization_config) if config.use_torch_compile: print("Compiling model....") diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 0f6dab9ece..95f3b57da5 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -125,9 +125,7 @@ def get_default_device() -> str: return "cpu" -def quantization_string_to_quantization_config( - quantization: str, **kwargs -) -> AOBaseConfig: +def string_to_config(quantization: str, **kwargs) -> AOBaseConfig: """Get quantization config based on quantization string. Args: