Skip to content

S3 with writing #174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ install:
- source activate test-environment
- conda install pytest coverage tornado toolz dill futures dask ipywidgets psutil bokeh requests
- pip install git+https://github.com/dask/dask.git --upgrade
- pip install moto
- conda install -c pandas pandas=v0.18.0rc1

# Install distributed
Expand Down
2 changes: 2 additions & 0 deletions distributed/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import logging
import io

from tornado import gen

from dask.imperative import Value, do
from dask.base import tokenize

Expand Down
121 changes: 105 additions & 16 deletions distributed/s3fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ def __init__(self, anon=True, key=None, secret=None, **kwargs):
self.key = key
self.secret = secret
self.kwargs = kwargs
self.connect(anon, key, secret, kwargs)
self.dirs = {}
self.s3 = self.connect(anon, key, secret, kwargs)
self.s3 = self.connect()

def connect(self, anon, key, secret, kwargs):
def connect(self):
anon, key, secret, kwargs = self.anon, self.key, self.secret, self.kwargs
tok = tokenize(anon, key, secret, kwargs)
if tok not in self._conn:
logger.debug("Open S3 connection. Anonymous: %s",
Expand All @@ -97,7 +97,7 @@ def __getstate__(self):

def __setstate__(self, state):
self.__dict__.update(state)
self.s3 = self.connect(self.anon, self.key, self.secret, self.kwargs)
self.s3 = self.connect()

def open(self, path, mode='rb', block_size=4*1024**2):
""" Open a file for reading or writing
Expand Down Expand Up @@ -143,7 +143,10 @@ def _ls(self, path, refresh=False):
f['Size'] = 0
del f['Name']
else:
files = self.s3.list_objects(Bucket=bucket).get('Contents', [])
try:
files = self.s3.list_objects(Bucket=bucket).get('Contents', [])
except ClientError:
files = []
for f in files:
f['Key'] = "/".join([bucket, f['Key']])
self.dirs[bucket] = list(sorted(files, key=lambda x: x['Key']))
Expand All @@ -162,7 +165,7 @@ def ls(self, path, detail=False):
if not files:
try:
files = [self.info(path)]
except (OSError, IOError):
except (OSError, IOError, ClientError):
files = []
if detail:
return files
Expand Down Expand Up @@ -245,6 +248,44 @@ def head(self, path, size=1024):
with self.open(path, 'rb', block_size=size) as f:
return f.read(size)

def mkdir(self, path):
self.touch(path)

def mv(self, path1, path2):
self.copy(path1, path2)
self.rm(path1)

def copy(self, path1, path2):
buc1, key1 = split_path(path1)
buc2, key2 = split_path(path2)
try:
self.s3.copy_object(Bucket=buc2, Key=key2, CopySource='/'.join([buc1, key1]))
except ClientError:
raise IOError('Copy failed on %s->%s', path1, path2)
self._ls(path2, refresh=True)

def rm(self, path, recursive=True):
if recursive:
for f in self.walk(path):
self.rm(f, recursive=False)
bucket, key = split_path(path)
if key:
try:
out = self.s3.delete_object(Bucket=bucket, Key=key)
except ClientError:
raise IOError('Delete key failed: (%s, %s)', bucket, key)
else:
try:
out = self.s3.delete_bucket(Bucket=bucket)
except ClientError:
raise IOError('Delete bucket failed: %s', bucket)
if out['ResponseMetadata']['HTTPStatusCode'] != 204:
raise IOError('rm failed on %s', path)
self._ls(path, refresh=True)

def touch(self, path):
self.open(path, mode='wb')

def read_block(self, fn, offset, length, delimiter=None):
""" Read a block of bytes from an S3 file

Expand Down Expand Up @@ -314,12 +355,11 @@ def __init__(self, s3, path, mode='rb', block_size=4*2**20):
read-ahead size for finding delimiters
"""
self.mode = mode
if mode != 'rb':
raise NotImplementedError("File mode must be 'rb', not %s" % mode)
if mode not in {'rb', 'wb'}:
raise NotImplementedError("File mode must be 'rb' or 'wb', not %s" % mode)
self.path = path
bucket, key = split_path(path)
self.s3 = s3
self.size = self.info()['Size']
self.bucket = bucket
self.key = key
self.blocksize = block_size
Expand All @@ -328,6 +368,14 @@ def __init__(self, s3, path, mode='rb', block_size=4*2**20):
self.start = None
self.end = None
self.closed = False
if mode == 'wb':
self.buffer = io.BytesIO()
self.size = 0
else:
try:
self.size = self.info()['Size']
except ClientError:
raise IOError("File not accessible: %s", path)

def info(self):
return self.s3.info(self.path)
Expand All @@ -336,18 +384,25 @@ def tell(self):
return self.loc

def seek(self, loc, whence=0):
if not self.mode == 'rb':
raise ValueError('Seek only available in read mode')
if whence == 0:
self.loc = loc
nloc = loc
elif whence == 1:
self.loc += loc
nloc = self.loc + loc
elif whence == 2:
self.loc = self.size + loc
nloc = self.size + loc
else:
raise ValueError("invalid whence (%s, should be 0, 1 or 2)" % whence)
if self.loc < 0:
self.loc = 0
if nloc < 0:
raise ValueError('Seek before start of file')
self.loc = nloc
return self.loc

def mv(self, path1, path2):
self.copy(path1, path2)
self.rm(path1)

def _fetch(self, start, end):
if self.start is None and self.end is None:
# First read
Expand Down Expand Up @@ -384,14 +439,49 @@ def read(self, length=-1):
self.loc += len(out)
return out

def write(self, data):
"""
Write data to buffer.

Buffer only sent to S3 on flush().
"""
if self.mode != 'wb':
raise ValueError('File not in write mode')
if self.closed:
raise ValueError('I/O operation on closed file.')
return self.buffer.write(data)

def flush(self):
pass
"""
Write buffered data to S3.
"""
if self.mode == 'wb':
try:
self.s3.s3.head_bucket(Bucket=self.bucket)
except ClientError:
try:
self.s3.s3.create_bucket(Bucket=self.bucket)
except ClientError:
raise IOError('Create bucket failed: %s', self.bucket)
pos = self.buffer.tell()
self.buffer.seek(0)
try:
out = self.s3.s3.put_object(Bucket=self.bucket, Key=self.key,
Body=self.buffer.read())
finally:
self.buffer.seek(pos)
self.s3._ls(self.bucket, refresh=True)
if out['ResponseMetadata']['HTTPStatusCode'] != 200:
raise IOError("Write failed: %s", out)

def close(self):
self.flush()
self.cache = None
self.closed = True

def __del__(self):
self.close()

def __str__(self):
return "<S3File %s/%s>" % (self.bucket, self.key)

Expand All @@ -404,7 +494,6 @@ def __exit__(self, *args):
self.close()



def _fetch_range(client, bucket, key, start, end, max_attempts=10):
logger.debug("Fetch: %s/%s, %s-%s", bucket, key, start, end)
for i in range(max_attempts):
Expand Down
119 changes: 117 additions & 2 deletions distributed/tests/test_s3fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from distributed.s3fs import S3FileSystem
from distributed.s3 import seek_delimiter
from distributed.utils_test import slow
import moto

# These get mirrored on s3://distributed-test/
test_bucket_name = 'distributed-test'
files = {'test/accounts.1.json': (b'{"amount": 100, "name": "Alice"}\n'
b'{"amount": 200, "name": "Bob"}\n'
Expand All @@ -24,11 +24,64 @@
b'Dennis,400,4\n'
b'Edith,500,5\n'
b'Frank,600,6\n')}
text_files = {'nested/file1': b'hello\n',
'nested/file2': b'world',
'nested/nested2/file1': b'hello\n',
'nested/nested2/file2': b'world'}
a = 'tmp/test/a'
b = 'tmp/test/b'
c = 'tmp/test/c'
d = 'tmp/test/d'


@pytest.yield_fixture
def s3():
# could do with a bucket with write privileges.
# writable local S3 system
m = moto.mock_s3()
m.start()
import boto3
client = boto3.client('s3')
client.create_bucket(Bucket=test_bucket_name)
for flist in [files, csv_files, text_files]:
for f, data in flist.items():
client.put_object(Bucket=test_bucket_name, Key=f, Body=data)
yield S3FileSystem(anon=True)
m.stop()


def test_simple(s3):
data = b'a' * (10 * 2**20)

with s3.open(a, 'wb') as f:
f.write(data)

with s3.open(a, 'rb') as f:
out = f.read(len(data))
assert len(data) == len(out)
assert out == data


def test_idempotent_connect(s3):
s3.connect()
s3.connect()


def test_ls_touch(s3):
assert not s3.ls('tmp/test')
s3.touch(a)
s3.touch(b)
L = s3.ls('tmp/test', True)
assert set(d['Key'] for d in L) == set([a, b])
L = s3.ls('tmp/test', False)
assert set(L) == set([a, b])


def test_rm(s3):
assert not s3.exists(a)
s3.touch(a)
assert s3.exists(a)
s3.rm(a)
assert not s3.exists(a)


def test_s3_file_access(s3):
Expand Down Expand Up @@ -103,6 +156,68 @@ def test_read_keys_from_bucket(s3):
s3.cat('s3://' + '/'.join([test_bucket_name, k])))


def test_seek(s3):
with s3.open(a, 'wb') as f:
f.write(b'123')

with s3.open(a) as f:
f.seek(1000)
with pytest.raises(ValueError):
f.seek(-1)
with pytest.raises(ValueError):
f.seek(-5, 2)
with pytest.raises(ValueError):
f.seek(0, 10)
f.seek(0)
assert f.read(1) == b'1'
f.seek(0)
assert f.read(1) == b'1'
f.seek(3)
assert f.read(1) == b''
f.seek(-1, 2)
assert f.read(1) == b'3'
f.seek(-1, 1)
f.seek(-1, 1)
assert f.read(1) == b'2'
for i in range(4):
assert f.seek(i) == i


def test_bad_open(s3):
with pytest.raises(IOError):
s3.open('')


def test_errors(s3):
with pytest.raises((IOError, OSError)):
s3.open('tmp/test/shfoshf', 'rb')

## This is fine, no need for interleving directories on S3
#with pytest.raises((IOError, OSError)):
# s3.touch('tmp/test/shfoshf/x')

with pytest.raises((IOError, OSError)):
s3.rm('tmp/test/shfoshf/x')

with pytest.raises((IOError, OSError)):
s3.mv('tmp/test/shfoshf/x', 'tmp/test/shfoshf/y')

#with pytest.raises((IOError, OSError)):
# s3.open('x', 'wb')

with pytest.raises((IOError, OSError)):
s3.open('x', 'rb')

#with pytest.raises(IOError):
# s3.chown('/unknown', 'someone', 'group')

#with pytest.raises(IOError):
# s3.chmod('/unknonwn', 'rb')

with pytest.raises(IOError):
s3.rm('unknown')


@slow
def test_seek_delimiter(s3):
fn = 'test/accounts.1.json'
Expand Down