|
1 | 1 | # Copyright (c) Alibaba, Inc. and its affiliates.
|
2 | 2 | import ast
|
3 | 3 | import os
|
| 4 | +from multiprocessing import shared_memory |
4 | 5 | from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
5 | 6 |
|
| 7 | +import numpy as np |
6 | 8 | from datasets import Dataset as HfDataset
|
7 | 9 | from datasets import IterableDataset as HfIterableDataset
|
8 | 10 | from tqdm import tqdm
|
@@ -30,29 +32,50 @@ def _reduce_columns(cls: type) -> type:
|
30 | 32 | cls._patching = True
|
31 | 33 |
|
32 | 34 | def new_call_func(self, dataset: DATASET_TYPE) -> DATASET_TYPE:
|
33 |
| - self.column_state = set(['images', 'videos', 'audios']) |
| 35 | + self.key_mapping = {k: i for i, k in enumerate(self.empty_row.keys())} |
| 36 | + num_proc = int(os.environ.get('DATASET_MAP_NPROC', '1')) |
| 37 | + self.shared_shm_name = None |
| 38 | + shm, buffer = None, None |
| 39 | + if num_proc > 1: # multiprocess |
| 40 | + shm = shared_memory.SharedMemory(create=True, size=len(self.key_mapping)) |
| 41 | + self.shared_shm_name = shm.name |
| 42 | + buffer = shm.buf |
| 43 | + self.column_state = np.ndarray((len(self.key_mapping), ), dtype=np.bool_, buffer=buffer) |
34 | 44 | dataset = call_func(self, dataset)
|
35 | 45 | if isinstance(dataset, HfIterableDataset) and dataset.features is None:
|
36 | 46 | features = next(iter(dataset)).keys()
|
37 | 47 | else:
|
38 | 48 | features = dataset.features.keys()
|
39 | 49 | for k in features:
|
40 |
| - if k not in self.column_state: |
| 50 | + if k in ['images', 'videos', 'audios']: |
| 51 | + continue |
| 52 | + k_i = self.key_mapping.get(k, -1) |
| 53 | + if k_i == -1 or not self.column_state[k_i]: |
41 | 54 | dataset = dataset.remove_columns([k])
|
| 55 | + if shm: |
| 56 | + shm.close() |
| 57 | + shm.unlink() |
42 | 58 | return dataset
|
43 | 59 |
|
44 | 60 | def new_preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
|
| 61 | + if self.shared_shm_name is not None: # multiprocess |
| 62 | + shm = shared_memory.SharedMemory(name=self.shared_shm_name) |
| 63 | + column_state = np.ndarray((len(self.key_mapping), ), dtype=np.bool_, buffer=shm.buf) |
| 64 | + else: |
| 65 | + column_state = self.column_state |
45 | 66 | row = preprocess(self, row)
|
46 | 67 | for k, v in row.items():
|
| 68 | + k_i = self.key_mapping[k] |
| 69 | + if column_state[k_i]: |
| 70 | + continue |
47 | 71 | if k == 'query_role':
|
48 |
| - if k not in self.column_state and v and v != 'user': |
49 |
| - self.column_state.add(k) |
| 72 | + if v and v != 'user': |
| 73 | + column_state[k_i] = True |
50 | 74 | elif k == 'history_roles':
|
51 |
| - if k not in self.column_state and v and any(_v[0] != 'user' or _v[1] != 'assistant' for _v in v): |
52 |
| - self.column_state.add(k) |
53 |
| - else: |
54 |
| - if v: |
55 |
| - self.column_state.add(k) |
| 75 | + if v and any(_v[0] != 'user' or _v[1] != 'assistant' for _v in v): |
| 76 | + column_state[k_i] = True |
| 77 | + elif v: |
| 78 | + column_state[k_i] = True |
56 | 79 | return row
|
57 | 80 |
|
58 | 81 | cls.__call__ = new_call_func
|
@@ -142,6 +165,7 @@ def __call__(self, dataset: DATASET_TYPE) -> DATASET_TYPE:
|
142 | 165 | return dataset
|
143 | 166 |
|
144 | 167 |
|
| 168 | +@_reduce_columns |
145 | 169 | class AlpacaPreprocessor(MediaMixin, RowPreprocessMixin):
|
146 | 170 |
|
147 | 171 | def __init__(self, concat_inst_inp: Optional[Callable[[str, str], str]] = None, **kwargs):
|
@@ -194,6 +218,7 @@ def _default_repair_conversations(s: Union[str, Any]) -> Any:
|
194 | 218 | return s
|
195 | 219 |
|
196 | 220 |
|
| 221 | +@_reduce_columns |
197 | 222 | class ConversationsPreprocessor(MediaMixin, RowPreprocessMixin):
|
198 | 223 |
|
199 | 224 | def __init__(self,
|
|
0 commit comments