diff --git a/.travis.yml b/.travis.yml index dd8f7cf02b4..3ecf50bbd28 100644 --- a/.travis.yml +++ b/.travis.yml @@ -50,6 +50,7 @@ install: - source activate test-environment - conda install pytest coverage tornado toolz dill futures dask ipywidgets psutil bokeh requests - pip install git+https://github.com/dask/dask.git --upgrade + - pip install moto - conda install -c pandas pandas=v0.18.0rc1 # Install distributed diff --git a/distributed/s3.py b/distributed/s3.py index 7900ac3c871..190f933f5ca 100644 --- a/distributed/s3.py +++ b/distributed/s3.py @@ -3,6 +3,8 @@ import logging import io +from tornado import gen + from dask.imperative import Value, do from dask.base import tokenize diff --git a/distributed/s3fs.py b/distributed/s3fs.py index f707bffe78b..db858e313c8 100644 --- a/distributed/s3fs.py +++ b/distributed/s3fs.py @@ -66,11 +66,11 @@ def __init__(self, anon=True, key=None, secret=None, **kwargs): self.key = key self.secret = secret self.kwargs = kwargs - self.connect(anon, key, secret, kwargs) self.dirs = {} - self.s3 = self.connect(anon, key, secret, kwargs) + self.s3 = self.connect() - def connect(self, anon, key, secret, kwargs): + def connect(self): + anon, key, secret, kwargs = self.anon, self.key, self.secret, self.kwargs tok = tokenize(anon, key, secret, kwargs) if tok not in self._conn: logger.debug("Open S3 connection. Anonymous: %s", @@ -97,7 +97,7 @@ def __getstate__(self): def __setstate__(self, state): self.__dict__.update(state) - self.s3 = self.connect(self.anon, self.key, self.secret, self.kwargs) + self.s3 = self.connect() def open(self, path, mode='rb', block_size=4*1024**2): """ Open a file for reading or writing @@ -143,7 +143,10 @@ def _ls(self, path, refresh=False): f['Size'] = 0 del f['Name'] else: - files = self.s3.list_objects(Bucket=bucket).get('Contents', []) + try: + files = self.s3.list_objects(Bucket=bucket).get('Contents', []) + except ClientError: + files = [] for f in files: f['Key'] = "/".join([bucket, f['Key']]) self.dirs[bucket] = list(sorted(files, key=lambda x: x['Key'])) @@ -162,7 +165,7 @@ def ls(self, path, detail=False): if not files: try: files = [self.info(path)] - except (OSError, IOError): + except (OSError, IOError, ClientError): files = [] if detail: return files @@ -245,6 +248,44 @@ def head(self, path, size=1024): with self.open(path, 'rb', block_size=size) as f: return f.read(size) + def mkdir(self, path): + self.touch(path) + + def mv(self, path1, path2): + self.copy(path1, path2) + self.rm(path1) + + def copy(self, path1, path2): + buc1, key1 = split_path(path1) + buc2, key2 = split_path(path2) + try: + self.s3.copy_object(Bucket=buc2, Key=key2, CopySource='/'.join([buc1, key1])) + except ClientError: + raise IOError('Copy failed on %s->%s', path1, path2) + self._ls(path2, refresh=True) + + def rm(self, path, recursive=True): + if recursive: + for f in self.walk(path): + self.rm(f, recursive=False) + bucket, key = split_path(path) + if key: + try: + out = self.s3.delete_object(Bucket=bucket, Key=key) + except ClientError: + raise IOError('Delete key failed: (%s, %s)', bucket, key) + else: + try: + out = self.s3.delete_bucket(Bucket=bucket) + except ClientError: + raise IOError('Delete bucket failed: %s', bucket) + if out['ResponseMetadata']['HTTPStatusCode'] != 204: + raise IOError('rm failed on %s', path) + self._ls(path, refresh=True) + + def touch(self, path): + self.open(path, mode='wb') + def read_block(self, fn, offset, length, delimiter=None): """ Read a block of bytes from an S3 file @@ -314,12 +355,11 @@ def __init__(self, s3, path, mode='rb', block_size=4*2**20): read-ahead size for finding delimiters """ self.mode = mode - if mode != 'rb': - raise NotImplementedError("File mode must be 'rb', not %s" % mode) + if mode not in {'rb', 'wb'}: + raise NotImplementedError("File mode must be 'rb' or 'wb', not %s" % mode) self.path = path bucket, key = split_path(path) self.s3 = s3 - self.size = self.info()['Size'] self.bucket = bucket self.key = key self.blocksize = block_size @@ -328,6 +368,14 @@ def __init__(self, s3, path, mode='rb', block_size=4*2**20): self.start = None self.end = None self.closed = False + if mode == 'wb': + self.buffer = io.BytesIO() + self.size = 0 + else: + try: + self.size = self.info()['Size'] + except ClientError: + raise IOError("File not accessible: %s", path) def info(self): return self.s3.info(self.path) @@ -336,18 +384,25 @@ def tell(self): return self.loc def seek(self, loc, whence=0): + if not self.mode == 'rb': + raise ValueError('Seek only available in read mode') if whence == 0: - self.loc = loc + nloc = loc elif whence == 1: - self.loc += loc + nloc = self.loc + loc elif whence == 2: - self.loc = self.size + loc + nloc = self.size + loc else: raise ValueError("invalid whence (%s, should be 0, 1 or 2)" % whence) - if self.loc < 0: - self.loc = 0 + if nloc < 0: + raise ValueError('Seek before start of file') + self.loc = nloc return self.loc + def mv(self, path1, path2): + self.copy(path1, path2) + self.rm(path1) + def _fetch(self, start, end): if self.start is None and self.end is None: # First read @@ -384,14 +439,49 @@ def read(self, length=-1): self.loc += len(out) return out + def write(self, data): + """ + Write data to buffer. + + Buffer only sent to S3 on flush(). + """ + if self.mode != 'wb': + raise ValueError('File not in write mode') + if self.closed: + raise ValueError('I/O operation on closed file.') + return self.buffer.write(data) + def flush(self): - pass + """ + Write buffered data to S3. + """ + if self.mode == 'wb': + try: + self.s3.s3.head_bucket(Bucket=self.bucket) + except ClientError: + try: + self.s3.s3.create_bucket(Bucket=self.bucket) + except ClientError: + raise IOError('Create bucket failed: %s', self.bucket) + pos = self.buffer.tell() + self.buffer.seek(0) + try: + out = self.s3.s3.put_object(Bucket=self.bucket, Key=self.key, + Body=self.buffer.read()) + finally: + self.buffer.seek(pos) + self.s3._ls(self.bucket, refresh=True) + if out['ResponseMetadata']['HTTPStatusCode'] != 200: + raise IOError("Write failed: %s", out) def close(self): self.flush() self.cache = None self.closed = True + def __del__(self): + self.close() + def __str__(self): return "" % (self.bucket, self.key) @@ -404,7 +494,6 @@ def __exit__(self, *args): self.close() - def _fetch_range(client, bucket, key, start, end, max_attempts=10): logger.debug("Fetch: %s/%s, %s-%s", bucket, key, start, end) for i in range(max_attempts): diff --git a/distributed/tests/test_s3fs.py b/distributed/tests/test_s3fs.py index 3f71e0ff6c2..958fb633bff 100644 --- a/distributed/tests/test_s3fs.py +++ b/distributed/tests/test_s3fs.py @@ -3,8 +3,8 @@ from distributed.s3fs import S3FileSystem from distributed.s3 import seek_delimiter from distributed.utils_test import slow +import moto -# These get mirrored on s3://distributed-test/ test_bucket_name = 'distributed-test' files = {'test/accounts.1.json': (b'{"amount": 100, "name": "Alice"}\n' b'{"amount": 200, "name": "Bob"}\n' @@ -24,11 +24,64 @@ b'Dennis,400,4\n' b'Edith,500,5\n' b'Frank,600,6\n')} +text_files = {'nested/file1': b'hello\n', + 'nested/file2': b'world', + 'nested/nested2/file1': b'hello\n', + 'nested/nested2/file2': b'world'} +a = 'tmp/test/a' +b = 'tmp/test/b' +c = 'tmp/test/c' +d = 'tmp/test/d' + @pytest.yield_fixture def s3(): - # could do with a bucket with write privileges. + # writable local S3 system + m = moto.mock_s3() + m.start() + import boto3 + client = boto3.client('s3') + client.create_bucket(Bucket=test_bucket_name) + for flist in [files, csv_files, text_files]: + for f, data in flist.items(): + client.put_object(Bucket=test_bucket_name, Key=f, Body=data) yield S3FileSystem(anon=True) + m.stop() + + +def test_simple(s3): + data = b'a' * (10 * 2**20) + + with s3.open(a, 'wb') as f: + f.write(data) + + with s3.open(a, 'rb') as f: + out = f.read(len(data)) + assert len(data) == len(out) + assert out == data + + +def test_idempotent_connect(s3): + s3.connect() + s3.connect() + + +def test_ls_touch(s3): + assert not s3.ls('tmp/test') + s3.touch(a) + s3.touch(b) + L = s3.ls('tmp/test', True) + assert set(d['Key'] for d in L) == set([a, b]) + L = s3.ls('tmp/test', False) + assert set(L) == set([a, b]) + + +def test_rm(s3): + assert not s3.exists(a) + s3.touch(a) + assert s3.exists(a) + s3.rm(a) + assert not s3.exists(a) def test_s3_file_access(s3): @@ -103,6 +156,68 @@ def test_read_keys_from_bucket(s3): s3.cat('s3://' + '/'.join([test_bucket_name, k]))) +def test_seek(s3): + with s3.open(a, 'wb') as f: + f.write(b'123') + + with s3.open(a) as f: + f.seek(1000) + with pytest.raises(ValueError): + f.seek(-1) + with pytest.raises(ValueError): + f.seek(-5, 2) + with pytest.raises(ValueError): + f.seek(0, 10) + f.seek(0) + assert f.read(1) == b'1' + f.seek(0) + assert f.read(1) == b'1' + f.seek(3) + assert f.read(1) == b'' + f.seek(-1, 2) + assert f.read(1) == b'3' + f.seek(-1, 1) + f.seek(-1, 1) + assert f.read(1) == b'2' + for i in range(4): + assert f.seek(i) == i + + +def test_bad_open(s3): + with pytest.raises(IOError): + s3.open('') + + +def test_errors(s3): + with pytest.raises((IOError, OSError)): + s3.open('tmp/test/shfoshf', 'rb') + + ## This is fine, no need for interleving directories on S3 + #with pytest.raises((IOError, OSError)): + # s3.touch('tmp/test/shfoshf/x') + + with pytest.raises((IOError, OSError)): + s3.rm('tmp/test/shfoshf/x') + + with pytest.raises((IOError, OSError)): + s3.mv('tmp/test/shfoshf/x', 'tmp/test/shfoshf/y') + + #with pytest.raises((IOError, OSError)): + # s3.open('x', 'wb') + + with pytest.raises((IOError, OSError)): + s3.open('x', 'rb') + + #with pytest.raises(IOError): + # s3.chown('/unknown', 'someone', 'group') + + #with pytest.raises(IOError): + # s3.chmod('/unknonwn', 'rb') + + with pytest.raises(IOError): + s3.rm('unknown') + + @slow def test_seek_delimiter(s3): fn = 'test/accounts.1.json'