diff --git a/docs/3x/benchmark.md b/docs/3x/benchmark.md new file mode 100644 index 00000000000..571e0f83f80 --- /dev/null +++ b/docs/3x/benchmark.md @@ -0,0 +1,61 @@ +Benchmark +--- + +1. [Introduction](#introduction) + +2. [Supported Matrix](#supported-matrix) + +3. [Usage](#usage) + +## Introduction + +Intel Neural Compressor provides a command `incbench` to launch the Intel CPU performance benchmark. + +To get the peak performance on Intel Xeon CPU, we should avoid crossing NUMA node in one instance. +Therefore, by default, `incbench` will trigger 1 instance on the first NUMA node. + +## Supported Matrix + +| Platform | Status | +|:---:|:---:| +| Linux | ✔ | +| Windows | ✔ | + +## Usage + +| Parameters | Default | comments | +|:----------------------:|:------------------------:|:-------------------------------------:| +| num_instances | 1 | Number of instances | +| num_cores_per_instance | None | Number of cores in each instance | +| C, cores | 0-${num_cores_on_NUMA-1} | decides the visible core range | +| cross_memory | False | whether to allocate memory cross NUMA | + +> Note: cross_memory is set to True only when memory is insufficient. + +### General Use Cases + +1. `incbench main.py`: run 1 instance on NUMA:0. +2. `incbench --num_i 2 main.py`: run 2 instances on NUMA:0. +3. `incbench --num_c 2 main.py`: run multi-instances with 2 cores per instance on NUMA:0. +4. `incbench -C 24-47 main.py`: run 1 instance on COREs:24-47. +5. `incbench -C 24-47 --num_c 4 main.py`: run multi-instances with 4 COREs per instance on COREs:24-47. + +> Note: + > - `num_i` works the same as `num_instances` + > - `num_c` works the same as `num_cores_per_instance` + +### Dump Throughput and Latency Summary + +To merge benchmark results from multi-instances, "incbench" automatically checks log file messages for "throughput" and "latency" information matching the following patterns. + +```python +throughput_pattern = r"[T,t]hroughput:\s*([0-9]*\.?[0-9]+)\s*([a-zA-Z/]*)" +latency_pattern = r"[L,l]atency:\s*([0-9]*\.?[0-9]+)\s*([a-zA-Z/]*)" +``` + +#### Demo usage + +```python +print("Throughput: {:.3f} samples/sec".format(throughput)) +print("Latency: {:.3f} ms".format(latency * 10**3)) +``` diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/smooth_quant/run_benchmark.sh b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/smooth_quant/run_benchmark.sh index 61c50611090..7b60727b047 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/smooth_quant/run_benchmark.sh +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/smooth_quant/run_benchmark.sh @@ -75,22 +75,34 @@ function run_benchmark { if [ "${topology}" = "opt_125m_ipex_sq" ]; then model_name_or_path="facebook/opt-125m" - extra_cmd=$extra_cmd" --ipex --sq --alpha 0.5" + extra_cmd=$extra_cmd" --ipex" elif [ "${topology}" = "llama2_7b_ipex_sq" ]; then model_name_or_path="meta-llama/Llama-2-7b-hf" - extra_cmd=$extra_cmd" --ipex --sq --alpha 0.8" + extra_cmd=$extra_cmd" --ipex" elif [ "${topology}" = "gpt_j_ipex_sq" ]; then model_name_or_path="EleutherAI/gpt-j-6b" - extra_cmd=$extra_cmd" --ipex --sq --alpha 1.0" + extra_cmd=$extra_cmd" --ipex" fi - python -u run_clm_no_trainer.py \ - --model ${model_name_or_path} \ - --approach ${approach} \ - --output_dir ${tuned_checkpoint} \ - --task ${task} \ - --batch_size ${batch_size} \ - ${extra_cmd} ${mode_cmd} + if [[ ${mode} == "accuracy" ]]; then + python -u run_clm_no_trainer.py \ + --model ${model_name_or_path} \ + --approach ${approach} \ + --output_dir ${tuned_checkpoint} \ + --task ${task} \ + --batch_size ${batch_size} \ + ${extra_cmd} ${mode_cmd} + elif [[ ${mode} == "performance" ]]; then + incbench --num_cores_per_instance 4 run_clm_no_trainer.py \ + --model ${model_name_or_path} \ + --approach ${approach} \ + --batch_size ${batch_size} \ + --output_dir ${tuned_checkpoint} \ + ${extra_cmd} ${mode_cmd} + else + echo "Error: No such mode: ${mode}" + exit 1 + fi } main "$@" diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/smooth_quant/run_clm_no_trainer.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/smooth_quant/run_clm_no_trainer.py index ef0590e2982..94acc14344f 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/smooth_quant/run_clm_no_trainer.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/smooth_quant/run_clm_no_trainer.py @@ -2,7 +2,7 @@ import os import sys -sys.path.append('./') +sys.path.append("./") import time import re import torch @@ -12,15 +12,11 @@ from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer parser = argparse.ArgumentParser() +parser.add_argument("--model", nargs="?", default="EleutherAI/gpt-j-6b") +parser.add_argument("--trust_remote_code", default=True, help="Transformers parameter: use the external repo") parser.add_argument( - "--model", nargs="?", default="EleutherAI/gpt-j-6b" + "--revision", default=None, help="Transformers parameter: set the model hub commit number" ) -parser.add_argument( - "--trust_remote_code", default=True, - help="Transformers parameter: use the external repo") -parser.add_argument( - "--revision", default=None, - help="Transformers parameter: set the model hub commit number") parser.add_argument("--dataset", nargs="?", default="NeelNanda/pile-10k", const="NeelNanda/pile-10k") parser.add_argument("--output_dir", nargs="?", default="./saved_results") parser.add_argument("--quantize", action="store_true") @@ -29,29 +25,26 @@ action="store_true", help="By default it is int8-fp32 mixed, to enable int8 mixed amp bf16 (work on platforms like SPR)", ) +parser.add_argument("--seed", type=int, default=42, help="Seed for sampling the calibration data.") parser.add_argument( - '--seed', - type=int, default=42, help='Seed for sampling the calibration data.' + "--approach", type=str, default="static", help="Select from ['dynamic', 'static', 'weight-only']" ) -parser.add_argument("--approach", type=str, default='static', - help="Select from ['dynamic', 'static', 'weight-only']") parser.add_argument("--int8", action="store_true") parser.add_argument("--ipex", action="store_true", help="Use intel extension for pytorch.") parser.add_argument("--load", action="store_true", help="Load quantized model.") parser.add_argument("--accuracy", action="store_true") parser.add_argument("--performance", action="store_true") -parser.add_argument("--iters", default=100, type=int, - help="For accuracy measurement only.") -parser.add_argument("--batch_size", default=1, type=int, - help="For accuracy measurement only.") -parser.add_argument("--save_accuracy_path", default=None, - help="Save accuracy results path.") -parser.add_argument("--pad_max_length", default=512, type=int, - help="Pad input ids to max length.") -parser.add_argument("--calib_iters", default=512, type=int, - help="calibration iters.") -parser.add_argument("--tasks", default="lambada_openai,hellaswag,winogrande,piqa,wikitext", - type=str, help="tasks for accuracy validation") +parser.add_argument("--iters", default=100, type=int, help="For accuracy measurement only.") +parser.add_argument("--batch_size", default=1, type=int, help="For accuracy measurement only.") +parser.add_argument("--save_accuracy_path", default=None, help="Save accuracy results path.") +parser.add_argument("--pad_max_length", default=512, type=int, help="Pad input ids to max length.") +parser.add_argument("--calib_iters", default=512, type=int, help="calibration iters.") +parser.add_argument( + "--tasks", + default="lambada_openai,hellaswag,winogrande,piqa,wikitext", + type=str, + help="tasks for accuracy validation", +) parser.add_argument("--peft_model_id", type=str, default=None, help="model_name_or_path of peft model") # ============SmoothQuant configs============== parser.add_argument("--sq", action="store_true") @@ -91,7 +84,7 @@ def collate_batch(self, batch): pad_len = self.pad_max - input_ids.shape[0] last_ind.append(input_ids.shape[0] - 1) if self.is_calib: - input_ids = input_ids[:self.pad_max] if len(input_ids) > self.pad_max else input_ids + input_ids = input_ids[: self.pad_max] if len(input_ids) > self.pad_max else input_ids else: input_ids = pad(input_ids, (0, pad_len), value=self.pad_val) input_ids_padded.append(input_ids) @@ -144,6 +137,7 @@ def get_user_model(): if args.peft_model_id is not None: from peft import PeftModel + user_model = PeftModel.from_pretrained(user_model, args.peft_model_id) # to channels last @@ -158,7 +152,9 @@ def get_user_model(): calib_dataset = load_dataset(args.dataset, split="train") # calib_dataset = datasets.load_from_disk('/your/local/dataset/pile-10k/') # use this if trouble with connecting to HF calib_dataset = calib_dataset.shuffle(seed=args.seed) - calib_evaluator = Evaluator(calib_dataset, tokenizer, args.batch_size, pad_max=args.pad_max_length, is_calib=True) + calib_evaluator = Evaluator( + calib_dataset, tokenizer, args.batch_size, pad_max=args.pad_max_length, is_calib=True + ) calib_dataloader = DataLoader( calib_evaluator.dataset, batch_size=calib_size, @@ -167,6 +163,7 @@ def get_user_model(): ) from neural_compressor.torch.quantization import SmoothQuantConfig + args.alpha = eval(args.alpha) excluded_precisions = [] if args.int8_bf16_mixed else ["bf16"] quant_config = SmoothQuantConfig(alpha=args.alpha, folding=False, excluded_precisions=excluded_precisions) @@ -176,6 +173,7 @@ def get_user_model(): from neural_compressor.torch.algorithms.smooth_quant import move_input_to_device from tqdm import tqdm + def run_fn(model): calib_iter = 0 for batch in tqdm(calib_dataloader, total=args.calib_iters): @@ -186,16 +184,18 @@ def run_fn(model): model(**batch) else: model(batch) - + calib_iter += 1 if calib_iter >= args.calib_iters: break return from utils import get_example_inputs + example_inputs = get_example_inputs(user_model, calib_dataloader) from neural_compressor.torch.quantization import prepare, convert + user_model = prepare(model=user_model, quant_config=quant_config, example_inputs=example_inputs) run_fn(user_model) user_model = convert(user_model) @@ -207,6 +207,7 @@ def run_fn(model): if args.int8 or args.int8_bf16_mixed: print("load int8 model") from neural_compressor.torch.quantization import load + tokenizer = AutoTokenizer.from_pretrained(args.model) config = AutoConfig.from_pretrained(args.model) user_model = load(os.path.abspath(os.path.expanduser(args.output_dir))) @@ -218,6 +219,7 @@ def run_fn(model): if args.accuracy: user_model.eval() from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser + eval_args = LMEvalParser( model="hf", user_model=user_model, @@ -233,32 +235,25 @@ def run_fn(model): else: acc = results["results"][task_name]["acc,none"] print("Accuracy: %.5f" % acc) - print('Batch size = %d' % args.batch_size) + print("Batch size = %d" % args.batch_size) if args.performance: user_model.eval() - from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser + batch_size, input_leng = args.batch_size, 512 + example_inputs = torch.ones((batch_size, input_leng), dtype=torch.long) + print("Batch size = {:d}".format(batch_size)) + print("The length of input tokens = {:d}".format(input_leng)) import time - samples = args.iters * args.batch_size - eval_args = LMEvalParser( - model="hf", - user_model=user_model, - tokenizer=tokenizer, - batch_size=args.batch_size, - tasks=args.tasks, - limit=samples, - device="cpu", - ) - start = time.time() - results = evaluate(eval_args) - end = time.time() - for task_name in args.tasks.split(","): - if task_name == "wikitext": - acc = results["results"][task_name]["word_perplexity,none"] - else: - acc = results["results"][task_name]["acc,none"] - print("Accuracy: %.5f" % acc) - print('Throughput: %.3f samples/sec' % (samples / (end - start))) - print('Latency: %.3f ms' % ((end - start) * 1000 / samples)) - print('Batch size = %d' % args.batch_size) + total_iters = args.iters + warmup_iters = 5 + with torch.no_grad(): + for i in range(total_iters): + if i == warmup_iters: + start = time.time() + user_model(example_inputs) + end = time.time() + latency = (end - start) / ((total_iters - warmup_iters) * args.batch_size) + throughput = ((total_iters - warmup_iters) * args.batch_size) / (end - start) + print("Latency: {:.3f} ms".format(latency * 10**3)) + print("Throughput: {:.3f} samples/sec".format(throughput)) diff --git a/examples/3.x_api/pytorch/recommendation/dlrm/static_quant/ipex/dlrm_s_pytorch.py b/examples/3.x_api/pytorch/recommendation/dlrm/static_quant/ipex/dlrm_s_pytorch.py index 12936c64165..2af63ea4b98 100644 --- a/examples/3.x_api/pytorch/recommendation/dlrm/static_quant/ipex/dlrm_s_pytorch.py +++ b/examples/3.x_api/pytorch/recommendation/dlrm/static_quant/ipex/dlrm_s_pytorch.py @@ -394,7 +394,7 @@ def dash_separated_ints(value): return value -def trace_model(args, dlrm, test_ld, inplace=True): +def trace_or_load_model(args, dlrm, test_ld, inplace=True): dlrm.eval() for j, inputBatch in enumerate(test_ld): X, lS_o, lS_i, _, _, _ = unpack_batch(inputBatch) @@ -462,7 +462,7 @@ def inference( total_time = 0 total_iter = 0 if args.inference_only and trace: - dlrm = trace_model(args, dlrm, test_ld) + dlrm = trace_or_load_model(args, dlrm, test_ld) if args.share_weight_instance != 0: run_throughput_benchmark(args, dlrm, test_ld) with torch.cpu.amp.autocast(enabled=args.bf16): @@ -833,11 +833,11 @@ def eval_func(model): # calibration def calib_fn(model): - calib_number = 0 + calib_iter = 0 for X_test, lS_o_test, lS_i_test, T in train_ld: - if calib_number < 100: + if calib_iter < 100: model(X_test, lS_o_test, lS_i_test) - calib_number += 1 + calib_iter += 1 else: break @@ -857,8 +857,22 @@ def calib_fn(model): dlrm.save(args.save_model) exit(0) if args.benchmark: - # To do - print('Not implemented yet') + dlrm = trace_or_load_model(args, dlrm, test_ld, inplace=True) + import time + X_test, lS_o_test, lS_i_test, T = next(iter(test_ld)) + total_iters = 100 + warmup_iters = 5 + with torch.no_grad(): + for i in range(total_iters): + if i == warmup_iters: + start = time.time() + dlrm(X_test, lS_o_test, lS_i_test) + end = time.time() + latency = (end - start) / ((total_iters - warmup_iters) * args.mini_batch_size) + throughput = ((total_iters - warmup_iters) * args.mini_batch_size) / (end - start) + print('Batch size = {:d}'.format(args.mini_batch_size)) + print('Latency: {:.3f} ms'.format(latency * 10**3)) + print('Throughput: {:.3f} samples/sec'.format(throughput)) exit(0) if args.accuracy_only: @@ -934,7 +948,7 @@ def update_training_performance(time, iters, training_record=training_record): training_record[0] += time training_record[1] += 1 - def print_training_performance( training_record=training_record): + def print_training_performance(training_record=training_record): if training_record[0] == 0: print("num-batches larger than warm up iters, please increase num-batches or decrease warmup iters") exit() diff --git a/examples/3.x_api/pytorch/recommendation/dlrm/static_quant/ipex/run_benchmark.sh b/examples/3.x_api/pytorch/recommendation/dlrm/static_quant/ipex/run_benchmark.sh index 3089868c3a0..dc593308678 100755 --- a/examples/3.x_api/pytorch/recommendation/dlrm/static_quant/ipex/run_benchmark.sh +++ b/examples/3.x_api/pytorch/recommendation/dlrm/static_quant/ipex/run_benchmark.sh @@ -80,7 +80,7 @@ function run_tuning { --save-model ${tuned_checkpoint} --test-freq=2048 --print-auc $ARGS \ --load-model=${input_model} --accuracy_only elif [[ ${mode} == "performance" ]]; then - python -u $MODEL_SCRIPT \ + incbench --num_cores_per_instance 4 -u $MODEL_SCRIPT \ --raw-data-file=${dataset_location}/day --processed-data-file=${dataset_location}/terabyte_processed.npz \ --data-set=terabyte --benchmark \ --memory-map --mlperf-bin-loader --round-targets=True --learning-rate=1.0 \ diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 267a1ed5deb..2a2b6579c56 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -19,6 +19,7 @@ import inspect import json +import os import re from abc import ABC, abstractmethod from collections import OrderedDict @@ -620,6 +621,7 @@ class Options: def __init__(self, random_seed=1978, workspace=DEFAULT_WORKSPACE, resume_from=None, tensorboard=False): """Init an Option object.""" + os.makedirs(workspace, exist_ok=True) self.random_seed = random_seed self.workspace = workspace self.resume_from = resume_from @@ -639,6 +641,7 @@ def random_seed(self, random_seed): @property def workspace(self): """Get workspace.""" + os.makedirs(self._workspace, exist_ok=True) return self._workspace @workspace.setter diff --git a/neural_compressor/common/benchmark.py b/neural_compressor/common/benchmark.py new file mode 100644 index 00000000000..f732eab5fd8 --- /dev/null +++ b/neural_compressor/common/benchmark.py @@ -0,0 +1,520 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import re +import subprocess +import sys + +import psutil + +from neural_compressor.common.utils import Statistics, get_workspace, logger + +description = """ +################################################################################################################## +This is the command used to launch the Intel CPU performance benchmark, supports both Linux and Windows platform. +To get the peak performance on Intel Xeon CPU, we should avoid crossing NUMA node in one instance. +By default, `incbench` will trigger 1 instance on the first NUMA node. + +Params in `incbench`: + - num_instances Default to 1. + - num_cores_per_instance Default to None. + - C, cores Default to 0-${num_cores_on_NUMA-1}, decides the visible core range. + - cross_memory Default to False, decides whether to allocate memory cross NUMA. + Note: Use it only when memory for instance is not enough. + +# General use cases: +1. `incbench main.py`: run 1 instance on NUMA:0. +2. `incbench --num_i 2 main.py`: run 2 instances on NUMA:0. +3. `incbench --num_c 2 main.py`: run multi-instances with 2 cores per instance on NUMA:0. +4. `incbench -C 24-47 main.py`: run 1 instance on COREs:24-47. +5. `incbench -C 24-47 --num_c 4 main.py`: run multi-instances with 4 COREs per instance on COREs:24-47. + +Note: + - `num_i` works the same as `num_instances` + - `num_c` works the same as `num_cores_per_instance` +################################################################################################################## +""" + + +def get_linux_numa_info(): + """Collect numa/socket information on linux system. + + Returns: + numa_info (dict): demo: {numa_index: {"physical_cpus": "xxx"; "logical_cpus": "xxx"}} + E.g. numa_info = { + 0: {"physical_cpus": "0-23", "logical_cpus": "0-23,48-71"}, + 1: {"physical_cpus": "24-47", "logical_cpus": "24-47,72-95"} + } + """ + result = subprocess.run(["lscpu"], capture_output=True, text=True) + output = result.stdout + + numa_info = {} + for line in output.splitlines(): + # demo: "NUMA node0 CPU(s): 0-3" + node_match = re.match(r"^NUMA node(\d+) CPU\(s\):\s+(.*)$", line) + if node_match: + node_id = int(node_match.group(1)) + cpus = node_match.group(2).strip() + numa_info[node_id] = { + "physical_cpus": cpus.split(",")[0], + "logical_cpus": ",".join(cpus.split(",")), + } + + # if numa_info is not collected, we go back to socket_info + if not numa_info: # pragma: no cover + for line in output.splitlines(): + # demo: "Socket(s): 2" + socket_match = re.match(r"^Socket\(s\):\s+(.*)$", line) + if socket_match: + num_socket = int(socket_match.group(1)) + # process big cores (w/ physical cores) and small cores (w/o physical cores) + physical_cpus = psutil.cpu_count(logical=False) + logical_cpus = psutil.cpu_count(logical=True) + physical_cpus_per_socket = physical_cpus // num_socket + logical_cpus_per_socket = logical_cpus // num_socket + for i in range(num_socket): + physical_cpus_str = str(i * physical_cpus_per_socket) + "-" + str((i + 1) * physical_cpus_per_socket - 1) + if num_socket == 1: + logical_cpus_str = str(i * logical_cpus_per_socket) + "-" + str((i + 1) * logical_cpus_per_socket - 1) + else: + remain_cpus = logical_cpus_per_socket - physical_cpus_per_socket + logical_cpus_str = ( + physical_cpus_str + + "," + + str(i * (remain_cpus) + physical_cpus) + + "-" + + str((i + 1) * remain_cpus + physical_cpus - 1) + ) + numa_info[i] = { + "physical_cpus": physical_cpus_str, + "logical_cpus": logical_cpus_str, + } + return numa_info + + +def get_windows_numa_info(): + """Collect socket information on Windows system due to no available numa info. + + Returns: + numa_info (dict): demo: {numa_index: {"physical_cpus": "xxx"; "logical_cpus": "xxx"}} + E.g. numa_info = { + 0: {"physical_cpus": "0-23", "logical_cpus": "0-23,48-71"}, + 1: {"physical_cpus": "24-47", "logical_cpus": "24-47,72-95"} + } + """ + # pylint: disable=import-error + # pragma: no cover + import wmi + + c = wmi.WMI() + processors = c.Win32_Processor() + socket_designations = set() + for processor in processors: + socket_designations.add(processor.SocketDesignation) + num_socket = len(socket_designations) + physical_cpus = sum(processor.NumberOfCores for processor in processors) + logical_cpus = sum(processor.NumberOfLogicalProcessors for processor in processors) + physical_cpus_per_socket = physical_cpus // num_socket + logical_cpus_per_socket = logical_cpus // num_socket + + numa_info = {} + for i in range(num_socket): + physical_cpus_str = str(i * physical_cpus_per_socket) + "-" + str((i + 1) * physical_cpus_per_socket - 1) + if num_socket == 1: + logical_cpus_str = str(i * logical_cpus_per_socket) + "-" + str((i + 1) * logical_cpus_per_socket - 1) + else: + remain_cpus = logical_cpus_per_socket - physical_cpus_per_socket + logical_cpus_str = ( + physical_cpus_str + + "," + + str(i * (remain_cpus) + physical_cpus) + + "-" + + str((i + 1) * remain_cpus + physical_cpus - 1) + ) + numa_info[i] = { + "physical_cpus": physical_cpus_str, + "logical_cpus": logical_cpus_str, + } + return numa_info + + +def dump_numa_info(): + """Fetch NUMA info and dump stats in shell, return numa_info. + + Returns: + numa_info (dict): {numa_node_index: list of Physical CPUs in this numa node, ...} + """ + if psutil.WINDOWS: # pragma: no cover + numa_info = get_windows_numa_info() + elif psutil.LINUX: + numa_info = get_linux_numa_info() + else: # pragma: no cover + logger.error(f"Unsupported platform detected: {sys.platform}, only supported on Linux and Windows") + + # dump stats to shell + field_names = ["NUMA node", "Physical CPUs", "Logical CPUs"] + output_data = [] + for op_type in numa_info.keys(): + field_results = [op_type, numa_info[op_type]["physical_cpus"], numa_info[op_type]["logical_cpus"]] + output_data.append(field_results) + Statistics(output_data, header="CPU Information", field_names=field_names).print_stat() + + # parse numa_info for ease-of-use + for n in numa_info: + numa_info[n] = parse_str2list(numa_info[n]["physical_cpus"]) + return numa_info + + +def parse_str2list(cpu_ranges): + """Parse '0-4,7,8' into [0,1,2,3,4,7,8] for machine readable.""" + cpus = [] + ranges = cpu_ranges.split(",") + for r in ranges: + if "-" in r: + try: + start, end = r.split("-") + cpus.extend(range(int(start), int(end) + 1)) + except ValueError: # pragma: no cover + raise ValueError(f"Invalid range: {r}") + else: + try: + cpus.append(int(r)) + except ValueError: # pragma: no cover + raise ValueError(f"Invalid number: {r}") + return cpus + + +def format_list2str(cpus): + """Format [0,1,2,3,4,7,8] back to '0-4,7,8' for human readable.""" + if not cpus: # pragma: no cover + return "" + cpus = sorted(set(cpus)) + ranges = [] + start = cpus[0] + end = start + for i in range(1, len(cpus)): + if cpus[i] == end + 1: + end = cpus[i] + else: + if start == end: + ranges.append(f"{start}") + else: + ranges.append(f"{start}-{end}") + start = cpus[i] + end = start + if start == end: + ranges.append(f"{start}") + else: + ranges.append(f"{start}-{end}") + return ",".join(ranges) + + +def get_reversed_numa_info(numa_info): + """Reverse numa_info.""" + reversed_numa_info = {} + for n, cpu_info in numa_info.items(): + for i in cpu_info: + reversed_numa_info[i] = n + return reversed_numa_info + + +def get_numa_node(core_list, reversed_numa_info): + """Return numa node used in current core_list.""" + numa_set = set() + for c in core_list: + assert c in reversed_numa_info, "Cores should be in physical CPUs" + numa_set.add(reversed_numa_info[c]) + return numa_set + + +def set_cores_for_instance(args, numa_info): + """All use cases are listed below: + Params: a=num_instance; b=num_cores_per_instance; c=cores; + - no a, b, c: a=1, c=numa:0 + - no a, b: a=1, c=c + - no a, c: a=numa:0/b, c=numa:0 + - no b, c: a=a, c=numa:0 + - no a: a=numa:0/b, c=c + - no b: a=a, c=c + - no c: a=a, c=a*b + - a, b, c: a=a, c=a*b + + Args: + args (argparse): arguments for setting different configurations + numa_info (dict): {numa_node_index: list of Physical CPUs in this numa node, ...} + + Returns: + core_list_per_instance (dict): {"instance_index": ["node_index", "cpu_index", num_cpu]} + """ + available_cores_list = [] + for n in numa_info: + available_cores_list.extend(numa_info[n]) + # preprocess args.cores to set default values + if args.cores is None: + if args.num_cores_per_instance and args.num_instances: + target_cores = args.num_instances * args.num_cores_per_instance + assert target_cores <= len(available_cores_list), ( + "Invalid configuration: num_instances * num_cores_per_instance = " + + "{} exceeds the number of physical CPUs = {}.".format(target_cores, len(available_cores_list)) + ) + cores_list = list(range(target_cores)) + # log for cores in use + logger.info("num_instances * num_cores_per_instance = {} cores are used.".format(target_cores)) + else: + # default behavior, only use numa:0 + cores_list = numa_info[0] + # log for cores in use + logger.info("By default, Intel Neural Compressor uses all cores on numa:0.") + else: + cores_list = parse_str2list(args.cores) + # log for cores available + logger.info("{} cores are available.".format(len(cores_list))) + if args.num_cores_per_instance and args.num_instances: + target_cores = args.num_instances * args.num_cores_per_instance + assert target_cores <= len(cores_list), ( + "Invalid configuration: num_instances * num_cores_per_instance = " + + "{} exceeds the number of available CPUs = {}.".format(target_cores, len(cores_list)) + ) + cores_list = cores_list[:target_cores] + + # preprocess args.num_instances to set default values + if args.num_instances is None: + if args.num_cores_per_instance: + assert args.num_cores_per_instance <= len(cores_list), ( + "Invalid configuration: num_cores_per_instance = " + + "{} exceeds the number of available CPUs = {}.".format(args.num_cores_per_instance, len(cores_list)) + ) + args.num_instances = len(cores_list) // args.num_cores_per_instance + target_cores = args.num_instances * args.num_cores_per_instance + cores_list = cores_list[:target_cores] + else: + args.num_instances = 1 + logger.info("By default, Intel Neural Compressor triggers only one instance.") + else: + assert args.num_instances <= len( + cores_list + ), "Invalid configuration: num_instances = " + "{} exceeds the number of available CPUs = {}.".format( + args.num_instances, len(cores_list) + ) + + ### log for instances number and cores in use + if args.num_instances == 1: + logger.info("1 instance is triggered.") + else: + logger.info("{} instances are triggered.".format(args.num_instances)) + if len(cores_list) == 1: + logger.info("1 core is in use.") + else: + logger.info("{} cores are in use.".format(len(cores_list))) + + # only need to process num_cores_per_instance now + core_list_per_instance = {} + # num_cores_per_instance = all_cores / num_instances + num_cores_per_instance = len(cores_list) // args.num_instances + for i in range(args.num_instances): + core_list_per_instance[i] = cores_list[i * num_cores_per_instance : (i + 1) * num_cores_per_instance] + if len(cores_list) % args.num_instances != 0: # pragma: no cover + last_index = args.num_instances - 1 + core_list_per_instance[last_index] = cores_list[last_index * num_cores_per_instance :] + + # convert core_list_per_instance = {"instance_index": cpu_index_list} + # -> {"instance_index": ["node_index", "cpu_index", num_cores]} + reversed_numa_info = get_reversed_numa_info(numa_info) + for i, core_list in core_list_per_instance.items(): + core_list_per_instance[i] = [ + format_list2str(get_numa_node(core_list, reversed_numa_info)), + format_list2str(core_list), + len(core_list), + ] + + # dump stats to shell + field_names = ["Instance", "NUMA node", "Physical CPUs", "Number of cores"] + output_data = [] + for i, core_list in core_list_per_instance.items(): + field_results = [i + 1, core_list[0], core_list[1], core_list[2]] + output_data.append(field_results) + Statistics(output_data, header="Instance Binding Information", field_names=field_names).print_stat() + return core_list_per_instance + + +def generate_prefix(args, core_list): + """Generate the command prefix with `numactl` (Linux) or `start` (Windows) command. + + Args: + args (argparse): arguments for setting different configurations + core_list: ["node_index", "cpu_index", num_cpu] + Returns: + command_prefix (str): command_prefix with specific core list for Linux or Windows. + """ + if sys.platform in ["linux"] and os.system("numactl --show >/dev/null 2>&1") == 0: + if args.cross_memory: + return "OMP_NUM_THREADS={} numactl -l -C {}".format(core_list[2], core_list[1]) + else: + return "OMP_NUM_THREADS={} numactl -m {} -C {}".format(core_list[2], core_list[0], core_list[1]) + elif sys.platform in ["win32"]: # pragma: no cover + socket_id = core_list[0] + from functools import reduce + + hex_core = hex(reduce(lambda x, y: x | y, [1 << p for p in parse_str2list(core_list[1])])) + return "start /B /WAIT /node {} /affinity {}".format(socket_id, hex_core) + else: # pragma: no cover + return "" + + +def run_multi_instance_command(args, core_list_per_instance, raw_cmd): + """Build and trigger commands for multi-instances with subprocess. + + Args: + args (argparse): arguments for setting different configurations + core_list_per_instance (dict): {"instance_index": ["node_index", "cpu_index", num_cpu]} + raw_cmd (str): script.py and parameters for this script + """ + instance_cmd = "" + if not os.getenv("PYTHON_PATH"): # pragma: no cover + logger.info("The interpreter path is not set, using string `python` as command.") + logger.info("To replace it, use `export PYTHON_PATH=xxx`.") + interpreter = os.getenv("PYTHON_PATH", "python") + workspace_dir = get_workspace() + logfile_process_map = {} + logfile_dict = {} + for i, core_list in core_list_per_instance.items(): + # build cmd and log file path + prefix = generate_prefix(args, core_list) + instance_cmd = "{} {} {}".format(prefix, interpreter, raw_cmd) + logger.info(f"Instance {i+1}: {instance_cmd}") + instance_log_file = "{}_{}_{}C.log".format(i + 1, len(core_list_per_instance), core_list[2]) + instance_log_file = os.path.join(workspace_dir, instance_log_file) + # trigger subprocess + p = subprocess.Popen( + instance_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True + ) # nosec + # log_file_path: [process_object, instance_command, instance_index] + logfile_process_map[instance_log_file] = [p, instance_cmd, i + 1] + logfile_dict[i + 1] = instance_log_file + + # Dump each instance's standard output to the corresponding log file + for instance_log_file, p_cmd_i in logfile_process_map.items(): + # p.communicate() reads std to avoid dead-lock, p.wait() only return. + stdout, stderr = p_cmd_i[0].communicate() # stderr is merged to stdout, so it's None + with open(instance_log_file, "w", 1, encoding="utf-8") as log_file: + log_file.write(f"[COMMAND]: {p_cmd_i[1]}\n") + log_file.write(stdout.decode()) + logger.info(f"The log of instance {p_cmd_i[2]} is saved to {instance_log_file}") + + return logfile_dict + + +def summary_latency_throughput(logfile_dict): + """Get the summary of the benchmark.""" + throughput_pattern = r"[T,t]hroughput:\s*([0-9]*\.?[0-9]+)\s*([a-zA-Z/]*)" + latency_pattern = r"[L,l]atency:\s*([0-9]*\.?[0-9]+)\s*([a-zA-Z/]*)" + + latency_list = [] + throughput_list = [] + latency_unit_name = "" + throughput_unit_name = "" + for idx, logfile in logfile_dict.items(): + with open(logfile, "r") as f: + for line in f: + re_latency = re.search(latency_pattern, line) + re_throughput = re.search(throughput_pattern, line) + if re_latency: + latency_list.append(float(re_latency.group(1))) + if not latency_unit_name: + latency_unit_name = re_latency.group(2) + if re_throughput: + throughput_list.append(float(re_throughput.group(1))) + if not throughput_unit_name: + throughput_unit_name = re_throughput.group(2) + if throughput_list and latency_list: + assert ( + len(latency_list) == len(throughput_list) == len(logfile_dict) + ), "Multiple instance benchmark failed with some instances!" + + # dump collected latency and throughput info + header = "Multiple Instance Benchmark Summary" + field_names = [ + "Instance", + "Latency ({})".format(latency_unit_name), + "Throughput ({})".format(throughput_unit_name), + ] + output_data = [] + for idx, (latency, throughput) in enumerate(zip(latency_list, throughput_list)): + output_data.append([idx + 1, round(latency, 3), round(throughput, 3)]) + Statistics(output_data, header=header, field_names=field_names).print_stat() + # show summary info + logger.info("Average latency: {} {}".format(round(sum(latency_list) / len(latency_list), 3), latency_unit_name)) + logger.info("Total throughput: {} {}".format(round(sum(throughput_list), 3), throughput_unit_name)) + elif throughput_list: + assert len(throughput_list) == len(logfile_dict), "Multiple instance benchmark failed with some instances!" + + # dump collected throughput info + header = "Multiple Instance Benchmark Summary" + field_names = [ + "Instance", + "Throughput ({})".format(throughput_unit_name), + ] + output_data = [] + for idx, throughput in enumerate(throughput_list): + output_data.append([idx + 1, round(throughput, 3)]) + Statistics(output_data, header=header, field_names=field_names).print_stat() + # show summary info + logger.info("Total throughput: {} {}.hdfghdfghs".format(round(sum(throughput_list), 3), throughput_unit_name)) + elif latency_list: + assert len(latency_list) == len(logfile_dict), "Multiple instance benchmark failed with some instances!" + + # dump collected latency info + header = "Multiple Instance Benchmark Summary" + field_names = [ + "Instance", + "Latency ({})".format(latency_unit_name), + ] + output_data = [] + for idx, latency in enumerate(latency_list): + output_data.append([idx + 1, round(latency, 3)]) + Statistics(output_data, header=header, field_names=field_names).print_stat() + # show summary info + logger.info("Average latency: {} {}".format(round(sum(latency_list) / len(latency_list), 3), latency_unit_name)) + + +def benchmark(): + """Benchmark API interface.""" + logger.info("Start benchmark with Intel Neural Compressor.") + logger.info("Intel Neural Compressor only uses physical CPUs for the best performance.") + + parser = argparse.ArgumentParser(description=description, formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument("--num_instances", type=int, default=None, help="Determine the number of instances.") + parser.add_argument( + "--num_cores_per_instance", + type=int, + default=None, + help="Determine the number of cores in 1 instance.", + ) + parser.add_argument("-C", "--cores", type=str, default=None, help="Determine the visible core range.") + parser.add_argument("--cross_memory", action="store_true", help="Determine the visible core range.") + parser.add_argument("script", type=str, help="The path to the script to launch.") + parser.add_argument("parameters", nargs=argparse.REMAINDER, help="arguments to the script.") + + args = parser.parse_args() + + assert sys.platform in ["linux", "win32"], "only support platform windows and linux..." + + numa_info = dump_numa_info() # show numa info and current usage of cores + core_list_per_instance = set_cores_for_instance(args, numa_info=numa_info) + script_and_parameters = args.script + " " + " ".join(args.parameters) + logfile_dict = run_multi_instance_command(args, core_list_per_instance, raw_cmd=script_and_parameters) + summary_latency_throughput(logfile_dict) diff --git a/neural_compressor/common/utils/utility.py b/neural_compressor/common/utils/utility.py index 82f24243a9b..35e511fcea0 100644 --- a/neural_compressor/common/utils/utility.py +++ b/neural_compressor/common/utils/utility.py @@ -23,11 +23,13 @@ import cpuinfo import psutil +from prettytable import PrettyTable from neural_compressor.common.utils import Mode, TuningLogger, logger __all__ = [ "set_workspace", + "get_workspace", "set_random_seed", "set_resume_from", "set_tensorboard", @@ -38,6 +40,7 @@ "CpuInfo", "default_tuning_logger", "call_counter", + "Statistics", ] @@ -110,14 +113,6 @@ def __init__(self): b"\xB8\x07\x00\x00\x00" b"\x0f\xa2" b"\xC3", # mov eax, 7 # cpuid # ret ) self._bf16 = bool(eax & (1 << 5)) - # TODO: The implementation will be refined in the future. - # https://github.com/intel/neural-compressor/tree/detect_sockets - if "arch" in info and "ARM" in info["arch"]: # pragma: no cover - self._sockets = 1 - else: - self._sockets = self.get_number_of_sockets() - self._cores = psutil.cpu_count(logical=False) - self._cores_per_socket = int(self._cores / self._sockets) @property def bf16(self): @@ -129,32 +124,6 @@ def vnni(self): """Get whether it is vnni.""" return self._vnni - @property - def cores_per_socket(self): - """Get the cores per socket.""" - return self._cores_per_socket - - def get_number_of_sockets(self) -> int: - """Get number of sockets in platform.""" - cmd = "cat /proc/cpuinfo | grep 'physical id' | sort -u | wc -l" - if psutil.WINDOWS: - cmd = r'wmic cpu get DeviceID | C:\Windows\System32\find.exe /C "CPU"' - elif psutil.MACOS: # pragma: no cover - cmd = "sysctl -n machdep.cpu.core_count" - - with subprocess.Popen( - args=cmd, - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - universal_newlines=False, - ) as proc: - proc.wait() - if proc.stdout: - for line in proc.stdout: - return int(line.decode("utf-8", errors="ignore").strip()) - return 0 - def dump_elapsed_time(customized_msg=""): """Get the elapsed time for decorated functions. @@ -193,6 +162,13 @@ def set_workspace(workspace: str): options.workspace = workspace +def get_workspace(): + """Get the workspace in config.""" + from neural_compressor.common import options + + return options.workspace + + def set_resume_from(resume_from: str): """Set the resume_from in config.""" from neural_compressor.common import options @@ -240,3 +216,45 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper + + +class Statistics: + """The statistics printer.""" + + def __init__(self, data, header, field_names, output_handle=logger.info): + """Init a Statistics object. + + Args: + data: The statistics data + header: The table header + field_names: The field names + output_handle: The output logging method + """ + self.field_names = field_names + self.header = header + self.data = data + self.output_handle = output_handle + self.tb = PrettyTable(min_table_width=40) + + def print_stat(self): + """Print the statistics.""" + valid_field_names = [] + for index, value in enumerate(self.field_names): + if index < 2: + valid_field_names.append(value) + continue + + if any(i[index] for i in self.data): + valid_field_names.append(value) + self.tb.field_names = valid_field_names + for i in self.data: + tmp_data = [] + for index, value in enumerate(i): + if self.field_names[index] in valid_field_names: + tmp_data.append(value) + if any(tmp_data[1:]): + self.tb.add_row(tmp_data) + lines = self.tb.get_string().split("\n") + self.output_handle("|" + self.header.center(len(lines[0]) - 2, "*") + "|") + for i in lines: + self.output_handle(i) diff --git a/neural_compressor/tensorflow/utils/utility.py b/neural_compressor/tensorflow/utils/utility.py index a7671da1f1e..6cbd109bef5 100644 --- a/neural_compressor/tensorflow/utils/utility.py +++ b/neural_compressor/tensorflow/utils/utility.py @@ -24,11 +24,10 @@ import cpuinfo import numpy as np -import prettytable as pt import psutil from pkg_resources import parse_version -from neural_compressor.common import logger +from neural_compressor.common.utils import Statistics, logger # Dictionary to store a mapping between algorithm names and corresponding algo implementation(function) algos_mapping: Dict[str, Callable] = {} @@ -268,48 +267,6 @@ def get_number_of_sockets(self) -> int: return 0 -class Statistics: - """The statistics printer.""" - - def __init__(self, data, header, field_names, output_handle=logger.info): - """Init a Statistics object. - - Args: - data: The statistics data - header: The table header - field_names: The field names - output_handle: The output logging method - """ - self.field_names = field_names - self.header = header - self.data = data - self.output_handle = output_handle - self.tb = pt.PrettyTable(min_table_width=40) - - def print_stat(self): - """Print the statistics.""" - valid_field_names = [] - for index, value in enumerate(self.field_names): - if index < 2: - valid_field_names.append(value) - continue - - if any(i[index] for i in self.data): - valid_field_names.append(value) - self.tb.field_names = valid_field_names - for i in self.data: - tmp_data = [] - for index, value in enumerate(i): - if self.field_names[index] in valid_field_names: - tmp_data.append(value) - if any(tmp_data[1:]): - self.tb.add_row(tmp_data) - lines = self.tb.get_string().split("\n") - self.output_handle("|" + self.header.center(len(lines[0]) - 2, "*") + "|") - for i in lines: - self.output_handle(i) - - class CaptureOutputToFile(object): """Not displayed in API Docs. diff --git a/neural_compressor/torch/algorithms/static_quant/utility.py b/neural_compressor/torch/algorithms/static_quant/utility.py index 64fec8de785..23ac16630a4 100644 --- a/neural_compressor/torch/algorithms/static_quant/utility.py +++ b/neural_compressor/torch/algorithms/static_quant/utility.py @@ -24,12 +24,11 @@ try: import intel_extension_for_pytorch as ipex - import prettytable as pt except: # pragma: no cover pass from neural_compressor.common.utils import DEFAULT_WORKSPACE, CpuInfo -from neural_compressor.torch.utils import get_ipex_version, get_torch_version, logger +from neural_compressor.torch.utils import Statistics, get_ipex_version, get_torch_version, logger version = get_torch_version() ipex_ver = get_ipex_version() @@ -567,48 +566,6 @@ def get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_ids return quantizable_ops -class Statistics: # pragma: no cover - """The statistics printer.""" - - def __init__(self, data, header, field_names, output_handle=logger.info): - """Init a Statistics object. - - Args: - data: The statistics data - header: The table header - field_names: The field names - output_handle: The output logging method - """ - self.field_names = field_names - self.header = header - self.data = data - self.output_handle = output_handle - self.tb = pt.PrettyTable(min_table_width=40) - - def print_stat(self): - """Print the statistics.""" - valid_field_names = [] - for index, value in enumerate(self.field_names): - if index < 2: - valid_field_names.append(value) - continue - - if any(i[index] for i in self.data): - valid_field_names.append(value) - self.tb.field_names = valid_field_names - for i in self.data: - tmp_data = [] - for index, value in enumerate(i): - if self.field_names[index] in valid_field_names: - tmp_data.append(value) - if any(tmp_data[1:]): - self.tb.add_row(tmp_data) - lines = self.tb.get_string().split("\n") - self.output_handle("|" + self.header.center(len(lines[0]) - 2, "*") + "|") - for i in lines: - self.output_handle(i) - - class TransformerBasedModelBlockPatternDetector: # pragma: no cover """Detect the attention block and FFN block in transformer-based model.""" diff --git a/neural_compressor/torch/utils/utility.py b/neural_compressor/torch/utils/utility.py index e312a9c388b..bf1bb2a77b1 100644 --- a/neural_compressor/torch/utils/utility.py +++ b/neural_compressor/torch/utils/utility.py @@ -16,10 +16,9 @@ from typing import Callable, Dict, List, Tuple, Union import torch -from prettytable import PrettyTable from typing_extensions import TypeAlias -from neural_compressor.common.utils import LazyImport, Mode, logger +from neural_compressor.common.utils import Mode, Statistics, logger OP_NAME_AND_TYPE_TUPLE_TYPE: TypeAlias = Tuple[str, Union[torch.nn.Module, Callable]] @@ -170,48 +169,6 @@ def postprocess_model(model, mode, quantizer): del model.quantizer -class Statistics: # pragma: no cover - """The statistics printer.""" - - def __init__(self, data, header, field_names, output_handle=logger.info): - """Init a Statistics object. - - Args: - data: The statistics data - header: The table header - field_names: The field names - output_handle: The output logging method - """ - self.field_names = field_names - self.header = header - self.data = data - self.output_handle = output_handle - self.tb = PrettyTable(min_table_width=40) - - def print_stat(self): - """Print the statistics.""" - valid_field_names = [] - for index, value in enumerate(self.field_names): - if index < 2: - valid_field_names.append(value) - continue - - if any(i[index] for i in self.data): - valid_field_names.append(value) - self.tb.field_names = valid_field_names - for i in self.data: - tmp_data = [] - for index, value in enumerate(i): - if self.field_names[index] in valid_field_names: - tmp_data.append(value) - if any(tmp_data[1:]): - self.tb.add_row(tmp_data) - lines = self.tb.get_string().split("\n") - self.output_handle("|" + self.header.center(len(lines[0]) - 2, "*") + "|") - for i in lines: - self.output_handle(i) - - def dump_model_op_stats(mode, tune_cfg): """This is a function to dump quantizable ops of model to user. diff --git a/neural_insights/components/model/onnxrt/model.py b/neural_insights/components/model/onnxrt/model.py index a421966f501..e60c0a1d457 100644 --- a/neural_insights/components/model/onnxrt/model.py +++ b/neural_insights/components/model/onnxrt/model.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Onnxrt model class.""" +# pylint: disable=import-error +# pylint: disable=no-name-in-module import os import re import sys diff --git a/neural_insights/components/model/tensorflow/model.py b/neural_insights/components/model/tensorflow/model.py index fd4a5499695..663556bcdc3 100644 --- a/neural_insights/components/model/tensorflow/model.py +++ b/neural_insights/components/model/tensorflow/model.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Abstract Tensorflow model class.""" +# pylint: disable=import-error +# pylint: disable=no-name-in-module import os.path from typing import Any, List, Optional diff --git a/requirements_ort.txt b/requirements_ort.txt index e438a675333..23f608859d1 100644 --- a/requirements_ort.txt +++ b/requirements_ort.txt @@ -2,6 +2,7 @@ numpy < 2.0 onnx onnxruntime onnxruntime-extensions +prettytable psutil py-cpuinfo pydantic diff --git a/setup.py b/setup.py index 071d56da9f6..949d53f910f 100644 --- a/setup.py +++ b/setup.py @@ -199,8 +199,12 @@ def get_build_version(): include_packages = PKG_INSTALL_CFG[cfg_key].get("include_packages") or {} package_data = PKG_INSTALL_CFG[cfg_key].get("package_data") or {} install_requires = PKG_INSTALL_CFG[cfg_key].get("install_requires") or [] - entry_points = PKG_INSTALL_CFG[cfg_key].get("entry_points") or {} extras_require = PKG_INSTALL_CFG[cfg_key].get("extras_require") or {} + entry_points = { + "console_scripts": [ + "incbench = neural_compressor.common.benchmark:benchmark", + ] + } setup( name=project_name, diff --git a/test/3x/common/test_benchmark.py b/test/3x/common/test_benchmark.py new file mode 100644 index 00000000000..78e94f70ace --- /dev/null +++ b/test/3x/common/test_benchmark.py @@ -0,0 +1,171 @@ +import os +import re +import shutil +import subprocess + +from neural_compressor.common.utils import DEFAULT_WORKSPACE + +# build files during test process to test benchmark +tmp_file_dict = {} +tmp = """ +print("test benchmark") +""" +tmp_file_dict["./tmp/tmp.py"] = tmp + +tmp = """ +print("test benchmark") +print("Throughput: 1 samples/sec") +print("Latency: 1000 ms") +""" +tmp_file_dict["./tmp/throughput_latency.py"] = tmp + +tmp = """ +print("test benchmark") +print("Throughput: 2 tokens/sec") +""" +tmp_file_dict["./tmp/throughput.py"] = tmp + +tmp = """ +print("test benchmark") +print("Latency: 10 ms") +""" +tmp_file_dict["./tmp/latency.py"] = tmp + + +def build_tmp_file(): + os.makedirs("./tmp") + for tmp_path, tmp in tmp_file_dict.items(): + f = open(tmp_path, "w") + f.write(tmp) + f.close() + + +def trigger_process(cmd): + # trigger subprocess + p = subprocess.Popen( + cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True + ) # nosec + return p + + +def check_main_process(message): + num_i_pattern = r"(.*) (\d+) instance(.*) triggered" + num_c_pattern = r"(.*) (\d+) core(.*) in use" + log_file_pattern = r"(.*) The log of instance 1 is saved to (.*)" + num_i = re.search(num_i_pattern, message, flags=re.DOTALL).group(2) + all_c = re.search(num_c_pattern, message).group(2) + log_file_path = re.search(log_file_pattern, message).group(2) + return int(num_i), int(all_c), log_file_path + + +def check_log_file(log_file_path): + output_pattern = r"(.*)test benchmark(.*)" + with open(log_file_path, "r") as f: + output = f.read() + f.close() + return re.match(output_pattern, output, flags=re.DOTALL) + + +class TestBenchmark: + def setup_class(self): + build_tmp_file() + + def teardown_class(self): + shutil.rmtree("./tmp") + shutil.rmtree("nc_workspace") + + def test_default(self): + cmd = "incbench tmp/tmp.py" + p = trigger_process(cmd) + stdout, _ = p.communicate() + num_i, all_c, log_file_path = check_main_process(stdout.decode()) + assert num_i == 1, "the number of instance should be 1." + assert check_log_file(log_file_path), "instance output is not correct." + + def test_only_num_i(self): + cmd = "incbench --num_i 2 tmp/tmp.py" + p = trigger_process(cmd) + stdout, _ = p.communicate() + num_i, all_c, log_file_path = check_main_process(stdout.decode()) + assert num_i == 2, "the number of instance should be 2." + assert check_log_file(log_file_path), "instance output is not correct." + + def test_only_num_c(self): + cmd = "incbench --num_c 1 tmp/tmp.py" + p = trigger_process(cmd) + stdout, _ = p.communicate() + num_i, all_c, log_file_path = check_main_process(stdout.decode()) + assert num_i == all_c, "the number of instance should equal the number of available cores." + assert check_log_file(log_file_path), "instance output is not correct." + + def test_only_cores(self): + cmd = "incbench -C 0-1 tmp/tmp.py" + p = trigger_process(cmd) + stdout, _ = p.communicate() + num_i, all_c, log_file_path = check_main_process(stdout.decode()) + assert num_i == 1, "the number of instance should be 1." + assert all_c == 2, "the number of available cores should be 2." + assert check_log_file(log_file_path), "instance output is not correct." + + def test_num_i_num_c(self): + cmd = "incbench --num_i 2 --num_c 2 tmp/tmp.py" + p = trigger_process(cmd) + stdout, _ = p.communicate() + num_i, all_c, log_file_path = check_main_process(stdout.decode()) + assert num_i == 2, "the number of instance should be 2." + assert all_c == 4, "the number of available cores should be 4." + assert check_log_file(log_file_path), "instance output is not correct." + + def test_num_i_cores(self): + cmd = "incbench --num_i 2 -C 0-2,5,8 tmp/tmp.py" + p = trigger_process(cmd) + stdout, _ = p.communicate() + num_i, all_c, log_file_path = check_main_process(stdout.decode()) + assert num_i == 2, "the number of instance should be 2." + assert all_c == 5, "the number of available cores should be 5." + assert check_log_file(log_file_path), "instance output is not correct." + + def test_num_c_cores(self): + cmd = "incbench --num_c 2 -C 0-6 tmp/tmp.py" + p = trigger_process(cmd) + stdout, _ = p.communicate() + num_i, all_c, log_file_path = check_main_process(stdout.decode()) + assert num_i == 3, "the number of instance should be all_c//num_c=3." + assert all_c == 6, "the number of available cores should be (all_c//num_c)*num_c=6." + assert check_log_file(log_file_path), "instance output is not correct." + + def test_cross_memory(self): + cmd = "incbench --num_c 1 -C 0 --cross_memory tmp/tmp.py" + p = trigger_process(cmd) + stdout, _ = p.communicate() + num_i, all_c, log_file_path = check_main_process(stdout.decode()) + assert num_i == 1, "the number of instance should be all_c//num_c=1." + assert all_c == 1, "the number of available cores should be 1." + assert check_log_file(log_file_path), "instance output is not correct." + + def test_throughput_latency(self): + cmd = "incbench --num_i 2 --num_c 2 -C 0-7 tmp/throughput_latency.py" + p = trigger_process(cmd) + stdout, _ = p.communicate() + num_i, all_c, log_file_path = check_main_process(stdout.decode()) + assert num_i == 2, "the number of instance should be 2." + assert all_c == 4, "the number of available cores should be num_i*num_c=4." + assert check_log_file(log_file_path), "instance output is not correct." + + def test_throughput(self): + cmd = "incbench --num_i 2 --num_c 2 -C 0-7 tmp/throughput.py" + p = trigger_process(cmd) + stdout, _ = p.communicate() + num_i, all_c, log_file_path = check_main_process(stdout.decode()) + assert num_i == 2, "the number of instance should be 2." + assert all_c == 4, "the number of available cores should be num_i*num_c=4." + assert check_log_file(log_file_path), "instance output is not correct." + + def test_latency(self): + cmd = "incbench --num_i 2 --num_c 2 -C 0-7 tmp/latency.py" + p = trigger_process(cmd) + stdout, _ = p.communicate() + num_i, all_c, log_file_path = check_main_process(stdout.decode()) + assert num_i == 2, "the number of instance should be 2." + assert all_c == 4, "the number of available cores should be num_i*num_c=4." + assert check_log_file(log_file_path), "instance output is not correct." diff --git a/test/3x/common/test_utility.py b/test/3x/common/test_utility.py index 527f74a4a13..b605b3b506b 100644 --- a/test/3x/common/test_utility.py +++ b/test/3x/common/test_utility.py @@ -19,6 +19,7 @@ Mode, default_tuning_logger, dump_elapsed_time, + get_workspace, log_process, set_random_seed, set_resume_from, @@ -43,6 +44,8 @@ def test_set_workspace(self): workspace = "/path/to/workspace" set_workspace(workspace) self.assertEqual(options.workspace, workspace) + returned_workspace = get_workspace() + self.assertEqual(returned_workspace, workspace) # non String type workspace = 12345 @@ -73,7 +76,6 @@ def test_set_tensorboard(self): class TestCPUInfo(unittest.TestCase): def test_cpu_info(self): cpu_info = CpuInfo() - assert cpu_info.cores_per_socket > 0, "CPU count should be greater than 0" assert isinstance(cpu_info.bf16, bool), "bf16 should be a boolean" assert isinstance(cpu_info.vnni, bool), "avx512 should be a boolean"