@@ -76,14 +76,14 @@ def __init__(
76
76
if buffer_size is not None and buffer_size <= 0 :
77
77
raise ValueError ("'buffer_size' is required to be either None or a positive integer." )
78
78
self .buffer_size : int = buffer_size
79
+ self .buffer : OrderedDict = OrderedDict ()
79
80
80
81
def __iter__ (self ) -> Iterator :
81
- buffer : OrderedDict = OrderedDict ()
82
82
ref_it = iter (self .ref_datapipe )
83
83
warn_once_flag = True
84
84
for data in self .source_datapipe :
85
85
key = self .key_fn (data )
86
- while key not in buffer :
86
+ while key not in self . buffer :
87
87
try :
88
88
ref_data = next (ref_it )
89
89
except StopIteration :
@@ -92,18 +92,18 @@ def __iter__(self) -> Iterator:
92
92
"Please consider increasing the buffer size."
93
93
)
94
94
ref_key = self .ref_key_fn (ref_data )
95
- if ref_key in buffer :
95
+ if ref_key in self . buffer :
96
96
raise ValueError ("Duplicate key is found in reference DataPipe" )
97
- if self .buffer_size is not None and len (buffer ) > self .buffer_size :
97
+ if self .buffer_size is not None and len (self . buffer ) > self .buffer_size :
98
98
if warn_once_flag :
99
99
warn_once_flag = False
100
100
warnings .warn (
101
101
"Buffer reaches the upper limit, so reference key-data pair begins to "
102
102
"be removed from buffer in FIFO order. Please consider increase buffer size."
103
103
)
104
- buffer .popitem (last = False )
105
- buffer [ref_key ] = ref_data
106
- res = self .merge_fn (data , buffer .pop (key )) if self .merge_fn else (data , buffer .pop (key ))
104
+ self . buffer .popitem (last = False )
105
+ self . buffer [ref_key ] = ref_data
106
+ res = self .merge_fn (data , self . buffer .pop (key )) if self .merge_fn else (data , self . buffer .pop (key ))
107
107
if self .keep_key :
108
108
yield key , res
109
109
else :
@@ -112,6 +112,38 @@ def __iter__(self) -> Iterator:
112
112
def __len__ (self ) -> int :
113
113
return len (self .source_datapipe )
114
114
115
+ def reset (self ) -> None :
116
+ self .buffer = OrderedDict ()
117
+
118
+ def __getstate__ (self ):
119
+ if IterDataPipe .getstate_hook is not None :
120
+ return IterDataPipe .getstate_hook (self )
121
+ state = (
122
+ self .source_datapipe ,
123
+ self .ref_datapipe ,
124
+ self .key_fn ,
125
+ self .ref_key_fn ,
126
+ self .keep_key ,
127
+ self .merge_fn ,
128
+ self .buffer_size ,
129
+ )
130
+ return state
131
+
132
+ def __setstate__ (self , state ):
133
+ (
134
+ self .source_datapipe ,
135
+ self .ref_datapipe ,
136
+ self .key_fn ,
137
+ self .ref_key_fn ,
138
+ self .keep_key ,
139
+ self .merge_fn ,
140
+ self .buffer_size ,
141
+ ) = state
142
+ self .buffer = OrderedDict ()
143
+
144
+ def __del__ (self ):
145
+ self .buffer .clear ()
146
+
115
147
116
148
@functional_datapipe ("zip_with_map" )
117
149
class MapKeyZipperIterDataPipe (IterDataPipe [T_co ]):
0 commit comments