Skip to content

1. add -t for multiple test rounds 2. separate time for each test phase #2617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 62 additions & 6 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import gc
import os
import unittest
import argparse
import sys
import time # Add time module import

import torch
from torchbenchmark import (
Expand All @@ -26,6 +29,15 @@
# unresponsive for 5 minutes the parent will presume it dead / incapacitated.)
TIMEOUT = int(os.getenv("TIMEOUT", 300)) # Seconds

# Add argument parser
parser = argparse.ArgumentParser(description='Run benchmark tests', add_help=False)
parser.add_argument('-t', '--iterations', type=int, default=300,
help='Number of iterations to run inference (default: 300)')
# Parse only known arguments to avoid interfering with unittest
args, unknown = parser.parse_known_args()

# Store iterations in a global variable
ITERATIONS = args.iterations

class TestBenchmark(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -55,6 +67,8 @@ def _create_example_model_instance(task: ModelTask, device: str):

def _load_test(path, device):
model_name = os.path.basename(path)
print(f"Loading test for model {model_name} on {device}")


def _skip_cuda_memory_check_p(metadata):
if device != "cuda":
Expand Down Expand Up @@ -94,12 +108,32 @@ def train_fn(self):
skip=_skip_cuda_memory_check_p(metadata), assert_equal=self.assertEqual
):
try:
# Measure model initialization time
init_start_time = time.time()
task.make_model_instance(
test="train", device=device, batch_size=batch_size
)
task.invoke()
task.check_details_train(device=device, md=metadata)
init_time = time.time() - init_start_time
print(f"\nModel initialization time: {init_time:.2f} seconds")

# Measure training time
train_start_time = time.time()
# Run training for specified number of iterations
for _ in range(ITERATIONS):
task.invoke()
task.check_details_train(device=device, md=metadata)
train_time = time.time() - train_start_time
print(f"Training time: {train_time:.2f} seconds")

# Measure cleanup time
cleanup_start_time = time.time()
task.del_model_instance()
cleanup_time = time.time() - cleanup_start_time
print(f"Cleanup time: {cleanup_time:.2f} seconds")

# Print total time
total_time = init_time + train_time + cleanup_time
print(f"Total time: {total_time:.2f} seconds")
except NotImplementedError as e:
self.skipTest(
f'Method train on {device} is not implemented because "{e}", skipping...'
Expand All @@ -117,13 +151,33 @@ def eval_fn(self):
skip=_skip_cuda_memory_check_p(metadata), assert_equal=self.assertEqual
):
try:
# Measure model initialization time
init_start_time = time.time()
task.make_model_instance(
test="eval", device=device, batch_size=batch_size
)
task.invoke()
task.check_details_eval(device=device, md=metadata)
task.check_eval_output()
init_time = time.time() - init_start_time
print(f"\nModel initialization time: {init_time:.2f} seconds")

# Measure evaluation time
eval_start_time = time.time()
# Run inference for specified number of iterations
for _ in range(ITERATIONS):
task.invoke()
task.check_details_eval(device=device, md=metadata)
task.check_eval_output()
eval_time = time.time() - eval_start_time
print(f"Evaluation time: {eval_time:.2f} seconds")

# Measure cleanup time
cleanup_start_time = time.time()
task.del_model_instance()
cleanup_time = time.time() - cleanup_start_time
print(f"Cleanup time: {cleanup_time:.2f} seconds")

# Print total time
total_time = init_time + eval_time + cleanup_time
print(f"Total time: {total_time:.2f} seconds")
except NotImplementedError as e:
self.skipTest(
f'Method eval on {device} is not implemented because "{e}", skipping...'
Expand Down Expand Up @@ -187,4 +241,6 @@ def _load_tests():

_load_tests()
if __name__ == "__main__":
unittest.main()
# Pass unknown arguments to unittest
sys.argv[1:] = unknown
unittest.main()