Skip to content

Commit 775d5a2

Browse files
Tiefen-boopEran Geva
authored andcommitted
[SW-174155] Fix race condition bug when reading scales
Implement an inter-process reader-writer lock Implement locking mechanism at save_file/load_file Change-Id: I140fdc05814286796bb47e6be8170b2ae9dd5154
1 parent a529cf4 commit 775d5a2

File tree

1 file changed

+69
-24
lines changed
  • neural_compressor/torch/algorithms/fp8_quant/_core

1 file changed

+69
-24
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/common.py

Lines changed: 69 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import numpy as np
2121
import torch
22+
import fcntl
2223

2324
from .._quant_common.helper_modules import *
2425
from .._quant_common.quant_config import get_hqt_config
@@ -105,22 +106,69 @@ def load_npz(fname):
105106
return d["arr_0"].item()
106107

107108

109+
class ProcessSafeReaderLock:
110+
def __init__(self, file_path):
111+
self.file_path = file_path
112+
113+
def __enter__(self):
114+
self.lock = open(self.file_path + ".lock", 'w')
115+
fcntl.flock(self.lock, fcntl.LOCK_SH) # Shared lock for reading
116+
return self
117+
118+
def __exit__(self, exc_type, exc_val, exc_tb):
119+
fcntl.flock(self.lock, fcntl.LOCK_UN) # Unlock the file
120+
self.lock.close()
121+
122+
123+
class ProcessSafeWriterLock:
124+
def __init__(self, file_path):
125+
self.file_path = file_path
126+
127+
def __enter__(self):
128+
self.lock = open(self.file_path + ".lock", 'w')
129+
fcntl.flock(self.lock, fcntl.LOCK_EX) # Exclusive lock for writing
130+
return self
131+
132+
def __exit__(self, *args):
133+
fcntl.flock(self.lock, fcntl.LOCK_UN) # Unlock the file
134+
self.lock.close()
135+
136+
108137
def save_file(model, d, source_format, fname, mode):
109138
config = get_hqt_config(model)
110139
logger.debug("Saving %s file: %s", mode, fname)
111140
ext = os.path.splitext(fname)[1]
112-
target_format = file_functions[ext][0]
141+
target_format = file_functions[ext]['format']
113142
dc = rec_fn(d, format_functions[(source_format, target_format)])
114143
df = {
115144
"GlobalRank": config.cfg["global_rank"],
116145
"LocalRank": config.cfg["local_rank"],
117146
"Mode": mode,
118147
"Nodes": dc,
119148
}
120-
try:
121-
file_functions[ext][1](df, fname)
122-
except:
123-
pass
149+
with ProcessSafeWriterLock(fname):
150+
try:
151+
file_functions[ext]['save'](df, fname)
152+
except:
153+
pass
154+
155+
156+
def load_file(fname, target_format, fail_on_file_not_exist):
157+
logger.debug("Loading file: %s", fname)
158+
ext = os.path.splitext(fname)[1]
159+
source_format = file_functions[ext]['format']
160+
d = {}
161+
if os.path.isfile(fname):
162+
with ProcessSafeReaderLock(fname):
163+
d = file_functions[ext]['load'](fname)
164+
elif fail_on_file_not_exist:
165+
raise FileNotFoundError(f"Failed to load file {fname}")
166+
if "Nodes" in d:
167+
dc = {k: ModuleConfig(**fix_fields(d["Nodes"][k])) for k in d["Nodes"]}
168+
dc = {k: module_convert(dc[k], format_functions[(source_format, target_format)]) for k in dc}
169+
else:
170+
dc = {}
171+
return dc
124172

125173

126174
# convert module config data to other format
@@ -147,29 +195,26 @@ def fix_fields(d):
147195
return d
148196

149197

150-
def load_file(fname, target_format, fail_on_file_not_exist):
151-
logger.debug("Loading file: %s", fname)
152-
ext = os.path.splitext(fname)[1]
153-
source_format = file_functions[ext][0]
154-
d = {}
155-
if os.path.isfile(fname):
156-
d = file_functions[ext][2](fname)
157-
elif fail_on_file_not_exist:
158-
raise FileNotFoundError(f"Failed to load file {fname}")
159-
if "Nodes" in d:
160-
dc = {k: ModuleConfig(**fix_fields(d["Nodes"][k])) for k in d["Nodes"]}
161-
dc = {k: module_convert(dc[k], format_functions[(source_format, target_format)]) for k in dc}
162-
else:
163-
dc = {}
164-
return dc
165-
166-
167198
def save_scales(model, d, source_format, fname):
199+
"""Saves scales measured of a given model.
200+
201+
Args:
202+
model : The measured model.
203+
d : Modules_names to configuration dictionary.
204+
source_format : How the data is stored in memory.
205+
fname : File to save the scales to.
206+
"""
168207
dc = {k: d[k].__dict__ for k in d}
169208
save_file(model, dc, source_format, fname, "Scale")
170209

171210

172211
def load_scales(fname, target_format):
212+
"""Loads scales from given file.
213+
214+
Args:
215+
fname : File to load the scales from.
216+
target_format: How the data is stored in file.
217+
"""
173218
logger.debug("Loading scales file %s", fname)
174219
d = load_file(fname, target_format, False)
175220
return d
@@ -184,8 +229,8 @@ def convert_scales_to_tensors_dict(scales_obj, scales_file_format, hp_dtype):
184229

185230

186231
file_functions = {
187-
".json": (list, save_json, load_json),
188-
".npz": (np.ndarray, save_npz, load_npz),
232+
".json": {'format': list, 'save': save_json, 'load': load_json},
233+
".npz": {'format': np.ndarray, 'save': save_npz, 'load': load_npz}
189234
}
190235

191236
format_functions = {

0 commit comments

Comments
 (0)