Skip to content

Commit 9c6b0c3

Browse files
committed
Add type hints
1 parent e1c7c12 commit 9c6b0c3

File tree

3 files changed

+56
-34
lines changed

3 files changed

+56
-34
lines changed

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
'Operating System :: MacOS :: MacOS X',
6464
# 'Operating System :: Microsoft :: Windows', -- Not tested yet
6565
'Operating System :: POSIX',
66+
'Typing :: Typed',
6667
'Programming Language :: Python :: 3',
6768
'Programming Language :: Python :: 3.8',
6869
'Programming Language :: Python :: 3.9',
@@ -75,4 +76,5 @@
7576
install_requires=install_requires,
7677
setup_requires=setup_requires,
7778
package_dir={'': 'src'},
79+
package_data={'snappy': ['py.typed']}
7880
)

src/snappy/py.typed

Whitespace-only changes.

src/snappy/snappy.py

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@
3939
assert "some data" == snappy.uncompress(compressed)
4040
4141
"""
42-
from __future__ import absolute_import
42+
from __future__ import absolute_import, annotations
43+
44+
from typing import (
45+
Optional, Union, IO, BinaryIO, Protocol, Type, Any, overload,
46+
)
4347

4448
import cramjam
4549

@@ -57,7 +61,7 @@ class UncompressError(Exception):
5761
pass
5862

5963

60-
def isValidCompressed(data):
64+
def isValidCompressed(data: Union[str, bytes]) -> bool:
6165
if isinstance(data, str):
6266
data = data.encode('utf-8')
6367

@@ -69,12 +73,18 @@ def isValidCompressed(data):
6973
return ok
7074

7175

72-
def compress(data, encoding='utf-8'):
76+
def compress(data: Union[str, bytes], encoding: str = 'utf-8') -> bytes:
7377
if isinstance(data, str):
7478
data = data.encode(encoding)
7579

7680
return bytes(_compress(data))
7781

82+
@overload
83+
def uncompress(data: bytes) -> bytes: ...
84+
85+
@overload
86+
def uncompress(data: bytes, decoding: Optional[str] = None) -> Union[str, bytes]: ...
87+
7888
def uncompress(data, decoding=None):
7989
if isinstance(data, str):
8090
raise UncompressError("It's only possible to uncompress bytes")
@@ -89,6 +99,16 @@ def uncompress(data, decoding=None):
8999

90100
decompress = uncompress
91101

102+
103+
class Compressor(Protocol):
104+
def add_chunk(self, data) -> Any: ...
105+
106+
107+
class Decompressor(Protocol):
108+
def decompress(self, data) -> Any: ...
109+
def flush(self): ...
110+
111+
92112
class StreamCompressor():
93113

94114
"""This class implements the compressor-side of the proposed Snappy framing
@@ -109,7 +129,7 @@ class StreamCompressor():
109129
def __init__(self):
110130
self.c = cramjam.snappy.Compressor()
111131

112-
def add_chunk(self, data: bytes, compress=None):
132+
def add_chunk(self, data: bytes, compress=None) -> bytes:
113133
"""Add a chunk, returning a string that is framed and compressed.
114134
115135
Outputs a single snappy chunk; if it is the very start of the stream,
@@ -120,10 +140,10 @@ def add_chunk(self, data: bytes, compress=None):
120140

121141
compress = add_chunk
122142

123-
def flush(self):
143+
def flush(self) -> bytes:
124144
return bytes(self.c.flush())
125145

126-
def copy(self):
146+
def copy(self) -> 'StreamCompressor':
127147
"""This method exists for compatibility with the zlib compressobj.
128148
"""
129149
return self
@@ -157,7 +177,7 @@ def check_format(fin):
157177
except:
158178
return False
159179

160-
def decompress(self, data: bytes):
180+
def decompress(self, data: bytes) -> bytes:
161181
"""Decompress 'data', returning a string containing the uncompressed
162182
data corresponding to at least part of the data in string. This data
163183
should be concatenated to the output produced by any preceding calls to
@@ -189,15 +209,15 @@ def decompress(self, data: bytes):
189209
self.c.decompress(data)
190210
return self.flush()
191211

192-
def flush(self):
212+
def flush(self) -> bytes:
193213
return bytes(self.c.flush())
194214

195-
def copy(self):
215+
def copy(self) -> 'StreamDecompressor':
196216
return self
197217

198218

199219
class HadoopStreamCompressor():
200-
def add_chunk(self, data: bytes, compress=None):
220+
def add_chunk(self, data: bytes, compress=None) -> bytes:
201221
"""Add a chunk, returning a string that is framed and compressed.
202222
203223
Outputs a single snappy chunk; if it is the very start of the stream,
@@ -208,11 +228,11 @@ def add_chunk(self, data: bytes, compress=None):
208228

209229
compress = add_chunk
210230

211-
def flush(self):
231+
def flush(self) -> bytes:
212232
# never maintains a buffer
213233
return b""
214234

215-
def copy(self):
235+
def copy(self) -> 'HadoopStreamCompressor':
216236
"""This method exists for compatibility with the zlib compressobj.
217237
"""
218238
return self
@@ -239,7 +259,7 @@ def check_format(fin):
239259
except:
240260
return False
241261

242-
def decompress(self, data: bytes):
262+
def decompress(self, data: bytes) -> bytes:
243263
"""Decompress 'data', returning a string containing the uncompressed
244264
data corresponding to at least part of the data in string. This data
245265
should be concatenated to the output produced by any preceding calls to
@@ -262,18 +282,18 @@ def decompress(self, data: bytes):
262282
data = data[8 + chunk_length:]
263283
return b"".join(out)
264284

265-
def flush(self):
285+
def flush(self) -> bytes:
266286
return b""
267287

268-
def copy(self):
288+
def copy(self) -> 'HadoopStreamDecompressor':
269289
return self
270290

271291

272292

273-
def stream_compress(src,
274-
dst,
275-
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
276-
compressor_cls=StreamCompressor):
293+
def stream_compress(src: IO,
294+
dst: IO,
295+
blocksize: int = _STREAM_TO_STREAM_BLOCK_SIZE,
296+
compressor_cls: Type[Compressor] = StreamCompressor) -> None:
277297
"""Takes an incoming file-like object and an outgoing file-like object,
278298
reads data from src, compresses it, and writes it to dst. 'src' should
279299
support the read method, and 'dst' should support the write method.
@@ -288,11 +308,11 @@ def stream_compress(src,
288308
if buf: dst.write(buf)
289309

290310

291-
def stream_decompress(src,
292-
dst,
293-
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
294-
decompressor_cls=StreamDecompressor,
295-
start_chunk=None):
311+
def stream_decompress(src: IO,
312+
dst: IO,
313+
blocksize: int = _STREAM_TO_STREAM_BLOCK_SIZE,
314+
decompressor_cls: Type[Decompressor] = StreamDecompressor,
315+
start_chunk=None) -> None:
296316
"""Takes an incoming file-like object and an outgoing file-like object,
297317
reads data from src, decompresses it, and writes it to dst. 'src' should
298318
support the read method, and 'dst' should support the write method.
@@ -317,10 +337,10 @@ def stream_decompress(src,
317337

318338

319339
def hadoop_stream_decompress(
320-
src,
321-
dst,
322-
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
323-
):
340+
src: BinaryIO,
341+
dst: BinaryIO,
342+
blocksize: int = _STREAM_TO_STREAM_BLOCK_SIZE,
343+
) -> None:
324344
c = HadoopStreamDecompressor()
325345
while True:
326346
data = src.read(blocksize)
@@ -333,10 +353,10 @@ def hadoop_stream_decompress(
333353

334354

335355
def hadoop_stream_compress(
336-
src,
337-
dst,
338-
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
339-
):
356+
src: BinaryIO,
357+
dst: BinaryIO,
358+
blocksize: int = _STREAM_TO_STREAM_BLOCK_SIZE,
359+
) -> None:
340360
c = HadoopStreamCompressor()
341361
while True:
342362
data = src.read(blocksize)
@@ -348,11 +368,11 @@ def hadoop_stream_compress(
348368
dst.flush()
349369

350370

351-
def raw_stream_decompress(src, dst):
371+
def raw_stream_decompress(src: BinaryIO, dst: BinaryIO) -> None:
352372
data = src.read()
353373
dst.write(decompress(data))
354374

355375

356-
def raw_stream_compress(src, dst):
376+
def raw_stream_compress(src: BinaryIO, dst: BinaryIO) -> None:
357377
data = src.read()
358378
dst.write(compress(data))

0 commit comments

Comments
 (0)