19
19
20
20
import numpy as np
21
21
import torch
22
+ import fcntl
22
23
23
24
from .._quant_common .helper_modules import *
24
25
from .._quant_common .quant_config import get_hqt_config
@@ -105,22 +106,69 @@ def load_npz(fname):
105
106
return d ["arr_0" ].item ()
106
107
107
108
109
+ class ProcessSafeReaderLock :
110
+ def __init__ (self , file_path ):
111
+ self .file_path = file_path
112
+
113
+ def __enter__ (self ):
114
+ self .lock = open (self .file_path + ".lock" , 'w' )
115
+ fcntl .flock (self .lock , fcntl .LOCK_SH ) # Shared lock for reading
116
+ return self
117
+
118
+ def __exit__ (self , exc_type , exc_val , exc_tb ):
119
+ fcntl .flock (self .lock , fcntl .LOCK_UN ) # Unlock the file
120
+ self .lock .close ()
121
+
122
+
123
+ class ProcessSafeWriterLock :
124
+ def __init__ (self , file_path ):
125
+ self .file_path = file_path
126
+
127
+ def __enter__ (self ):
128
+ self .lock = open (self .file_path + ".lock" , 'w' )
129
+ fcntl .flock (self .lock , fcntl .LOCK_EX ) # Exclusive lock for writing
130
+ return self
131
+
132
+ def __exit__ (self , * args ):
133
+ fcntl .flock (self .lock , fcntl .LOCK_UN ) # Unlock the file
134
+ self .lock .close ()
135
+
136
+
108
137
def save_file (model , d , source_format , fname , mode ):
109
138
config = get_hqt_config (model )
110
139
logger .debug ("Saving %s file: %s" , mode , fname )
111
140
ext = os .path .splitext (fname )[1 ]
112
- target_format = file_functions [ext ][0 ]
141
+ target_format = file_functions [ext ]['format' ]
113
142
dc = rec_fn (d , format_functions [(source_format , target_format )])
114
143
df = {
115
144
"GlobalRank" : config .cfg ["global_rank" ],
116
145
"LocalRank" : config .cfg ["local_rank" ],
117
146
"Mode" : mode ,
118
147
"Nodes" : dc ,
119
148
}
120
- try :
121
- file_functions [ext ][1 ](df , fname )
122
- except :
123
- pass
149
+ with ProcessSafeWriterLock (fname ):
150
+ try :
151
+ file_functions [ext ]['save' ](df , fname )
152
+ except :
153
+ pass
154
+
155
+
156
+ def load_file (fname , target_format , fail_on_file_not_exist ):
157
+ logger .debug ("Loading file: %s" , fname )
158
+ ext = os .path .splitext (fname )[1 ]
159
+ source_format = file_functions [ext ]['format' ]
160
+ d = {}
161
+ if os .path .isfile (fname ):
162
+ with ProcessSafeReaderLock (fname ):
163
+ d = file_functions [ext ]['load' ](fname )
164
+ elif fail_on_file_not_exist :
165
+ raise FileNotFoundError (f"Failed to load file { fname } " )
166
+ if "Nodes" in d :
167
+ dc = {k : ModuleConfig (** fix_fields (d ["Nodes" ][k ])) for k in d ["Nodes" ]}
168
+ dc = {k : module_convert (dc [k ], format_functions [(source_format , target_format )]) for k in dc }
169
+ else :
170
+ dc = {}
171
+ return dc
124
172
125
173
126
174
# convert module config data to other format
@@ -147,29 +195,26 @@ def fix_fields(d):
147
195
return d
148
196
149
197
150
- def load_file (fname , target_format , fail_on_file_not_exist ):
151
- logger .debug ("Loading file: %s" , fname )
152
- ext = os .path .splitext (fname )[1 ]
153
- source_format = file_functions [ext ][0 ]
154
- d = {}
155
- if os .path .isfile (fname ):
156
- d = file_functions [ext ][2 ](fname )
157
- elif fail_on_file_not_exist :
158
- raise FileNotFoundError (f"Failed to load file { fname } " )
159
- if "Nodes" in d :
160
- dc = {k : ModuleConfig (** fix_fields (d ["Nodes" ][k ])) for k in d ["Nodes" ]}
161
- dc = {k : module_convert (dc [k ], format_functions [(source_format , target_format )]) for k in dc }
162
- else :
163
- dc = {}
164
- return dc
165
-
166
-
167
198
def save_scales (model , d , source_format , fname ):
199
+ """Saves scales measured of a given model.
200
+
201
+ Args:
202
+ model : The measured model.
203
+ d : Modules_names to configuration dictionary.
204
+ source_format : How the data is stored in memory.
205
+ fname : File to save the scales to.
206
+ """
168
207
dc = {k : d [k ].__dict__ for k in d }
169
208
save_file (model , dc , source_format , fname , "Scale" )
170
209
171
210
172
211
def load_scales (fname , target_format ):
212
+ """Loads scales from given file.
213
+
214
+ Args:
215
+ fname : File to load the scales from.
216
+ target_format: How the data is stored in file.
217
+ """
173
218
logger .debug ("Loading scales file %s" , fname )
174
219
d = load_file (fname , target_format , False )
175
220
return d
@@ -184,8 +229,8 @@ def convert_scales_to_tensors_dict(scales_obj, scales_file_format, hp_dtype):
184
229
185
230
186
231
file_functions = {
187
- ".json" : ( list , save_json , load_json ) ,
188
- ".npz" : ( np .ndarray , save_npz , load_npz ),
232
+ ".json" : { 'format' : list , 'save' : save_json , 'load' : load_json } ,
233
+ ".npz" : { 'format' : np .ndarray , 'save' : save_npz , 'load' : load_npz }
189
234
}
190
235
191
236
format_functions = {
0 commit comments