5
5
"""
6
6
# std imports
7
7
from argparse import ArgumentParser
8
+ import contextlib
8
9
import json
9
10
import logging
10
11
import os
11
- import tempfile
12
12
from typing import Optional
13
13
14
14
# tpl imports
@@ -30,8 +30,11 @@ def get_args():
30
30
parser .add_argument ("input_json" , type = str , help = "Input JSON file containing the test cases." )
31
31
parser .add_argument ("-o" , "--output" , type = str , help = "Output JSON file containing the results." )
32
32
parser .add_argument ("--scratch-dir" , type = str , help = "If provided, put scratch files here." )
33
+ parser .add_argument ("--driver-root" , type = str , help = "Where to look for the driver files, if not in cwd." )
33
34
parser .add_argument ("--launch-configs" , type = str , default = "launch-configs.json" ,
34
35
help = "config for how to run samples." )
36
+ parser .add_argument ("--build-configs" , type = str , default = "build-configs.json" ,
37
+ help = "config for how to build samples. If not provided, will use the default build settings for each model." )
35
38
parser .add_argument ("--problem-sizes" , type = str , default = "problem-sizes.json" ,
36
39
help = "config for how to run samples." )
37
40
parser .add_argument ("--yes-to-all" , action = "store_true" , help = "If provided, automatically answer yes to all prompts." )
@@ -56,11 +59,19 @@ def get_args():
56
59
parser .add_argument ("--log-runs" , action = "store_true" , help = "Display the stderr and stdout of runs." )
57
60
return parser .parse_args ()
58
61
59
- def get_driver (prompt : dict , scratch_dir : Optional [os .PathLike ], launch_configs : dict , problem_sizes : dict , dry : bool , ** kwargs ) -> DriverWrapper :
62
+ def get_driver (
63
+ prompt : dict ,
64
+ scratch_dir : Optional [os .PathLike ],
65
+ launch_configs : dict ,
66
+ build_configs : dict ,
67
+ problem_sizes : dict ,
68
+ dry : bool ,
69
+ ** kwargs
70
+ ) -> DriverWrapper :
60
71
""" Get the language drive wrapper for this prompt """
61
72
driver_cls = LANGUAGE_DRIVERS [prompt ["language" ]]
62
73
return driver_cls (parallelism_model = prompt ["parallelism_model" ], launch_configs = launch_configs ,
63
- problem_sizes = problem_sizes , scratch_dir = scratch_dir , dry = dry , ** kwargs )
74
+ build_configs = build_configs , problem_sizes = problem_sizes , scratch_dir = scratch_dir , dry = dry , ** kwargs )
64
75
65
76
def already_has_results (prompt : dict ) -> bool :
66
77
""" Check if a prompt already has results stored in it. """
@@ -102,10 +113,25 @@ def main():
102
113
launch_configs = load_json (args .launch_configs )
103
114
logging .info (f"Loaded launch configs from { args .launch_configs } ." )
104
115
116
+ # load build configs
117
+ build_configs = load_json (args .build_configs )
118
+ logging .info (f"Loaded build configs from { args .build_configs } ." )
119
+
105
120
# load problem sizes
106
121
problem_sizes = load_json (args .problem_sizes )
107
122
logging .info (f"Loaded problem sizes from { args .problem_sizes } ." )
108
123
124
+ # set driver root; If provided, use user argument. If it's not provided, then check if the PAREVAL_ROOT environment
125
+ # variable is set, then use "${PAREVAL_ROOT}/drivers" as the root. If neither is set, then use the location of
126
+ # this script as the root.
127
+ if args .driver_root :
128
+ DRIVER_ROOT = args .driver_root
129
+ elif "PAREVAL_ROOT" in os .environ :
130
+ DRIVER_ROOT = os .path .join (os .environ ["PAREVAL_ROOT" ], "drivers" )
131
+ else :
132
+ DRIVER_ROOT = os .path .dirname (os .path .abspath (__file__ ))
133
+ logging .info (f"Using driver root: { DRIVER_ROOT } " )
134
+
109
135
# gather the list of parallelism models to test
110
136
models_to_test = args .include_models if args .include_models else ["serial" , "omp" , "mpi" , "mpi+omp" , "kokkos" , "cuda" , "hip" ]
111
137
if args .exclude_models :
@@ -139,15 +165,18 @@ def main():
139
165
prompt ,
140
166
args .scratch_dir ,
141
167
launch_configs ,
168
+ build_configs ,
142
169
problem_sizes ,
143
170
args .dry ,
144
171
display_build_errors = args .log_build_errors ,
145
172
display_runs = args .log_runs ,
146
173
early_exit_runs = args .early_exit_runs ,
147
174
build_timeout = args .build_timeout ,
148
- run_timeout = args .run_timeout
175
+ run_timeout = args .run_timeout ,
149
176
)
150
- driver .test_all_outputs_in_prompt (prompt )
177
+
178
+ with contextlib .chdir (DRIVER_ROOT ):
179
+ driver .test_all_outputs_in_prompt (prompt )
151
180
152
181
# go ahead and write out outputs now
153
182
if args .output and args .output != '-' :
0 commit comments