23
23
import argparse
24
24
import re
25
25
import sys
26
- import uuid
26
+ import os
27
27
28
28
# C type: const int
29
29
ERROR_CLASSES = [
@@ -1042,6 +1042,30 @@ def need_bigcount(self):
1042
1042
return any ('COUNT' in param .type_ for param in self .params )
1043
1043
1044
1044
1045
+ class TemplateParseError (Exception ):
1046
+ """Error raised during parsing."""
1047
+ pass
1048
+
1049
+
1050
+ def validate_body (body ):
1051
+ """Validate the body of a template."""
1052
+ # Just do a simple bracket balance test determine the bounds of the
1053
+ # function body. All lines after the function body should be blank. There
1054
+ # are cases where this will break, such as if someone puts code all on one
1055
+ # line.
1056
+ bracket_balance = 0
1057
+ line_count = 0
1058
+ for line in body :
1059
+ line = line .strip ()
1060
+ if bracket_balance == 0 and line_count > 0 and line :
1061
+ raise TemplateParserError ('Extra code found in template; only one function body is allowed' )
1062
+
1063
+ update = line .count ('{' ) - line .count ('}' )
1064
+ bracket_balance += update
1065
+ if bracket_balance != 0 :
1066
+ line_count += 1
1067
+
1068
+
1045
1069
class SourceTemplate :
1046
1070
"""Source template for a single API function."""
1047
1071
@@ -1051,8 +1075,10 @@ def __init__(self, prototype, header, body):
1051
1075
self .body = body
1052
1076
1053
1077
@staticmethod
1054
- def load (fname ):
1078
+ def load (fname , prefix = None ):
1055
1079
"""Load a template file and return the SourceTemplate."""
1080
+ if prefix is not None :
1081
+ fname = os .path .join (prefix , fname )
1056
1082
with open (fname ) as fp :
1057
1083
header = []
1058
1084
prototype = []
@@ -1061,11 +1087,12 @@ def load(fname):
1061
1087
for line in fp :
1062
1088
line = line .rstrip ()
1063
1089
if prototype and line .startswith ('PROTOTYPE' ):
1064
- raise RuntimeError ('more than one prototype found in template file' )
1090
+ raise TemplateParseError ('more than one prototype found in template file' )
1065
1091
elif ((prototype and not any (')' in s for s in prototype ))
1066
1092
or line .startswith ('PROTOTYPE' )):
1067
1093
prototype .append (line )
1068
1094
elif prototype :
1095
+ # Validate bracket balance
1069
1096
body .append (line )
1070
1097
else :
1071
1098
header .append (line )
@@ -1082,6 +1109,8 @@ def load(fname):
1082
1109
params = [param .strip () for param in prototype [i + 1 :j ].split (',' ) if param .strip ()]
1083
1110
params = [Parameter (param ) for param in params ]
1084
1111
prototype = Prototype (name , return_type , params )
1112
+ # Ensure the body contains only one function
1113
+ validate_body (body )
1085
1114
return SourceTemplate (prototype , header , body )
1086
1115
1087
1116
def print_header (self , file = sys .stdout ):
@@ -1148,7 +1177,7 @@ def standard_abi(base_name, template):
1148
1177
print (f'#include "{ ABI_INTERNAL_HEADER } "' )
1149
1178
1150
1179
# Static internal function (add a random component to avoid conflicts)
1151
- internal_name = f'ompi_ { template .prototype .name } _ { uuid . uuid4 (). hex [: 10 ] } '
1180
+ internal_name = f'ompi_abi_ { template .prototype .name } '
1152
1181
internal_sig = template .prototype .signature ('ompi' , internal_name ,
1153
1182
count_type = 'MPI_Count' )
1154
1183
print ('static inline' , internal_sig )
@@ -1190,7 +1219,7 @@ def generate_function(prototype, fn_name, internal_fn, count_type='int'):
1190
1219
1191
1220
def gen_header (args ):
1192
1221
"""Generate an ABI header and conversion code."""
1193
- prototypes = [SourceTemplate .load (file_ ).prototype for file_ in args .file ]
1222
+ prototypes = [SourceTemplate .load (file_ , args . srcdir ).prototype for file_ in args .file ]
1194
1223
1195
1224
builder = ABIHeaderBuilder (prototypes , external = args .external )
1196
1225
builder .dump_header ()
@@ -1219,6 +1248,7 @@ def main():
1219
1248
parser_header = subparsers .add_parser ('header' )
1220
1249
parser_header .add_argument ('file' , nargs = '+' , help = 'list of template source files' )
1221
1250
parser_header .add_argument ('--external' , action = 'store_true' , help = 'generate external mpi.h header file' )
1251
+ parser_header .add_argument ('--srcdir' , help = 'source directory' )
1222
1252
parser_header .set_defaults (func = gen_header )
1223
1253
1224
1254
parser_gen = subparsers .add_parser ('source' )
0 commit comments