Skip to content

Commit 5ecfaae

Browse files
authored
fix multiprocess remove_columns (#2088)
1 parent 9021947 commit 5ecfaae

File tree

4 files changed

+51
-17
lines changed

4 files changed

+51
-17
lines changed

swift/llm/utils/dataset.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,23 @@ def _update_fingerprint_mac(*args, **kwargs):
4646
datasets.arrow_dataset.update_fingerprint = _update_fingerprint_mac
4747

4848

49-
def partialed_map(self, *args, **kwargs):
50-
if 'num_proc' not in kwargs:
51-
num_proc = os.environ.get('DATASET_MAP_NPROC')
52-
kwargs['num_proc'] = int(num_proc) if num_proc else num_proc
53-
return self._origin_map(*args, **kwargs)
49+
def patch_num_proc(func_name: str):
50+
_origin_func_name = f'_origin_{func_name}'
51+
_old_func = getattr(HfDataset, func_name)
5452

53+
def new_func(self, *args, **kwargs):
54+
if 'num_proc' not in kwargs:
55+
num_proc = os.environ.get('DATASET_MAP_NPROC')
56+
if num_proc:
57+
kwargs['num_proc'] = int(num_proc)
58+
return _old_func(self, *args, **kwargs)
5559

56-
datasets.Dataset._origin_map = datasets.Dataset.map
57-
datasets.Dataset.map = partialed_map
60+
setattr(HfDataset, _origin_func_name, _old_func)
61+
setattr(HfDataset, func_name, new_func)
62+
63+
64+
for func_name in ['map', 'filter']:
65+
patch_num_proc(func_name)
5866

5967
standard_keys = {
6068
'query', 'query_role', 'response', 'rejected_response', 'system', 'history', 'history_roles', 'images', 'objects',

swift/llm/utils/preprocess.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
import ast
33
import os
4+
from multiprocessing import shared_memory
45
from typing import Any, Callable, Dict, List, Literal, Optional, Union
56

7+
import numpy as np
68
from datasets import Dataset as HfDataset
79
from datasets import IterableDataset as HfIterableDataset
810
from tqdm import tqdm
@@ -30,29 +32,50 @@ def _reduce_columns(cls: type) -> type:
3032
cls._patching = True
3133

3234
def new_call_func(self, dataset: DATASET_TYPE) -> DATASET_TYPE:
33-
self.column_state = set(['images', 'videos', 'audios'])
35+
self.key_mapping = {k: i for i, k in enumerate(self.empty_row.keys())}
36+
num_proc = int(os.environ.get('DATASET_MAP_NPROC', '1'))
37+
self.shared_shm_name = None
38+
shm, buffer = None, None
39+
if num_proc > 1: # multiprocess
40+
shm = shared_memory.SharedMemory(create=True, size=len(self.key_mapping))
41+
self.shared_shm_name = shm.name
42+
buffer = shm.buf
43+
self.column_state = np.ndarray((len(self.key_mapping), ), dtype=np.bool_, buffer=buffer)
3444
dataset = call_func(self, dataset)
3545
if isinstance(dataset, HfIterableDataset) and dataset.features is None:
3646
features = next(iter(dataset)).keys()
3747
else:
3848
features = dataset.features.keys()
3949
for k in features:
40-
if k not in self.column_state:
50+
if k in ['images', 'videos', 'audios']:
51+
continue
52+
k_i = self.key_mapping.get(k, -1)
53+
if k_i == -1 or not self.column_state[k_i]:
4154
dataset = dataset.remove_columns([k])
55+
if shm:
56+
shm.close()
57+
shm.unlink()
4258
return dataset
4359

4460
def new_preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
61+
if self.shared_shm_name is not None: # multiprocess
62+
shm = shared_memory.SharedMemory(name=self.shared_shm_name)
63+
column_state = np.ndarray((len(self.key_mapping), ), dtype=np.bool_, buffer=shm.buf)
64+
else:
65+
column_state = self.column_state
4566
row = preprocess(self, row)
4667
for k, v in row.items():
68+
k_i = self.key_mapping[k]
69+
if column_state[k_i]:
70+
continue
4771
if k == 'query_role':
48-
if k not in self.column_state and v and v != 'user':
49-
self.column_state.add(k)
72+
if v and v != 'user':
73+
column_state[k_i] = True
5074
elif k == 'history_roles':
51-
if k not in self.column_state and v and any(_v[0] != 'user' or _v[1] != 'assistant' for _v in v):
52-
self.column_state.add(k)
53-
else:
54-
if v:
55-
self.column_state.add(k)
75+
if v and any(_v[0] != 'user' or _v[1] != 'assistant' for _v in v):
76+
column_state[k_i] = True
77+
elif v:
78+
column_state[k_i] = True
5679
return row
5780

5881
cls.__call__ = new_call_func
@@ -142,6 +165,7 @@ def __call__(self, dataset: DATASET_TYPE) -> DATASET_TYPE:
142165
return dataset
143166

144167

168+
@_reduce_columns
145169
class AlpacaPreprocessor(MediaMixin, RowPreprocessMixin):
146170

147171
def __init__(self, concat_inst_inp: Optional[Callable[[str, str], str]] = None, **kwargs):
@@ -194,6 +218,7 @@ def _default_repair_conversations(s: Union[str, Any]) -> Any:
194218
return s
195219

196220

221+
@_reduce_columns
197222
class ConversationsPreprocessor(MediaMixin, RowPreprocessMixin):
198223

199224
def __init__(self,

swift/llm/utils/template.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,7 @@ def _concat_context_list(
638638
new_str_list = [system, query, round0, round1]
639639
for (old_str, new_str) in zip(old_str_list, new_str_list):
640640
if new_str is not None and old_str in context:
641+
assert isinstance(new_str, str), f'new_str: {new_str}'
641642
context = context.replace(old_str, new_str)
642643
if len(context) == 0:
643644
continue

swift/llm/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def _find_module_list(vision_tower) -> Optional[nn.ModuleList]:
431431
return
432432
if isinstance(m, nn.ModuleList) and len(m) >= 10:
433433
module_lists.append(m)
434-
if module_lists is not None:
434+
if module_lists:
435435
return max(module_lists, key=lambda x: len(x))
436436

437437

0 commit comments

Comments
 (0)