45
45
_write_data ,
46
46
_ftype4scaled_finite ,
47
47
)
48
- from ..openers import Opener , BZ2File
48
+ from ..openers import Opener , BZ2File , HAVE_ZSTD
49
49
from ..casting import (floor_log2 , type_info , OK_FLOATS , shared_range )
50
50
51
51
from ..deprecator import ExpiredDeprecationError
56
56
57
57
from nibabel .testing import nullcontext , assert_dt_equal , assert_allclose_safely , suppress_warnings
58
58
59
+ # only import ZstdFile, if installed
60
+ if HAVE_ZSTD :
61
+ from ..openers import ZstdFile
62
+
59
63
#: convenience variables for numpy types
60
64
FLOAT_TYPES = np .sctypes ['float' ]
61
65
COMPLEX_TYPES = np .sctypes ['complex' ]
68
72
def test__is_compressed_fobj ():
69
73
# _is_compressed helper function
70
74
with InTemporaryDirectory ():
71
- for ext , opener , compressed in (('' , open , False ),
72
- ('.gz' , gzip .open , True ),
73
- ('.bz2' , BZ2File , True )):
75
+ file_openers = [('' , open , False ),
76
+ ('.gz' , gzip .open , True ),
77
+ ('.bz2' , BZ2File , True )]
78
+ if HAVE_ZSTD :
79
+ file_openers += [('.zst' , ZstdFile , True )]
80
+ for ext , opener , compressed in file_openers :
74
81
fname = 'test.bin' + ext
75
82
for mode in ('wb' , 'rb' ):
76
83
fobj = opener (fname , mode )
@@ -88,12 +95,15 @@ def make_array(n, bytes):
88
95
arr .flags .writeable = True
89
96
return arr
90
97
91
- # Check whether file, gzip file, bz2 file reread memory from cache
98
+ # Check whether file, gzip file, bz2, zst file reread memory from cache
92
99
fname = 'test.bin'
93
100
with InTemporaryDirectory ():
101
+ openers = [open , gzip .open , BZ2File ]
102
+ if HAVE_ZSTD :
103
+ openers += [ZstdFile ]
94
104
for n , opener in itertools .product (
95
105
(256 , 1024 , 2560 , 25600 ),
96
- ( open , gzip . open , BZ2File ) ):
106
+ openers ):
97
107
in_arr = np .arange (n , dtype = dtype )
98
108
# Write array to file
99
109
fobj_w = opener (fname , 'wb' )
@@ -230,7 +240,10 @@ def test_array_from_file_openers():
230
240
dtype = np .dtype (np .float32 )
231
241
in_arr = np .arange (24 , dtype = dtype ).reshape (shape )
232
242
with InTemporaryDirectory ():
233
- for ext , offset in itertools .product (('' , '.gz' , '.bz2' ),
243
+ extensions = ['' , '.gz' , '.bz2' ]
244
+ if HAVE_ZSTD :
245
+ extensions += ['.zst' ]
246
+ for ext , offset in itertools .product (extensions ,
234
247
(0 , 5 , 10 )):
235
248
fname = 'test.bin' + ext
236
249
with Opener (fname , 'wb' ) as out_buf :
@@ -251,9 +264,12 @@ def test_array_from_file_reread():
251
264
offset = 9
252
265
fname = 'test.bin'
253
266
with InTemporaryDirectory ():
267
+ openers = [open , gzip .open , bz2 .BZ2File , BytesIO ]
268
+ if HAVE_ZSTD :
269
+ openers += [ZstdFile ]
254
270
for shape , opener , dtt , order in itertools .product (
255
271
((64 ,), (64 , 65 ), (64 , 65 , 66 )),
256
- ( open , gzip . open , bz2 . BZ2File , BytesIO ) ,
272
+ openers ,
257
273
(np .int16 , np .float32 ),
258
274
('F' , 'C' )):
259
275
n_els = np .prod (shape )
@@ -901,7 +917,9 @@ def test_write_zeros():
901
917
def test_seek_tell ():
902
918
# Test seek tell routine
903
919
bio = BytesIO ()
904
- in_files = bio , 'test.bin' , 'test.gz' , 'test.bz2'
920
+ in_files = [bio , 'test.bin' , 'test.gz' , 'test.bz2' ]
921
+ if HAVE_ZSTD :
922
+ in_files += ['test.zst' ]
905
923
start = 10
906
924
end = 100
907
925
diff = end - start
@@ -920,9 +938,12 @@ def test_seek_tell():
920
938
fobj .write (b'\x01 ' * start )
921
939
assert fobj .tell () == start
922
940
# Files other than BZ2Files can seek forward on write, leaving
923
- # zeros in their wake. BZ2Files can't seek when writing, unless
924
- # we enable the write0 flag to seek_tell
925
- if not write0 and in_file == 'test.bz2' : # Can't seek write in bz2
941
+ # zeros in their wake. BZ2Files can't seek when writing,
942
+ # unless we enable the write0 flag to seek_tell
943
+ # ZstdFiles also does not support seek forward on write
944
+ if (not write0 and
945
+ (in_file == 'test.bz2' or
946
+ in_file == 'test.zst' )): # Can't seek write in bz2, zst
926
947
# write the zeros by hand for the read test below
927
948
fobj .write (b'\x00 ' * diff )
928
949
else :
@@ -946,7 +967,10 @@ def test_seek_tell():
946
967
# Check we have the expected written output
947
968
with ImageOpener (in_file , 'rb' ) as fobj :
948
969
assert fobj .read () == b'\x01 ' * start + b'\x00 ' * diff + b'\x02 ' * tail
949
- for in_file in ('test2.gz' , 'test2.bz2' ):
970
+ input_files = ['test2.gz' , 'test2.bz2' ]
971
+ if HAVE_ZSTD :
972
+ input_files += ['test2.zst' ]
973
+ for in_file in input_files :
950
974
# Check failure of write seek backwards
951
975
with ImageOpener (in_file , 'wb' ) as fobj :
952
976
fobj .write (b'g' * 10 )
0 commit comments