Skip to content

Commit e48b746

Browse files
authored
Merge pull request #597 from skoudoro/fix-array-sequence
[Fix] buffer_size on ArraySequence
2 parents 33767c9 + ddd7240 commit e48b746

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

nibabel/streamlines/array_sequence.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, arr_seq, common_shape, dtype):
3737
self.common_shape = common_shape
3838
n_in_row = reduce(mul, common_shape, 1)
3939
bytes_per_row = n_in_row * dtype.itemsize
40-
self.rows_per_buf = bytes_per_row / self.bytes_per_buf
40+
self.rows_per_buf = max(1, self.bytes_per_buf // bytes_per_row)
4141

4242
def update_seq(self, arr_seq):
4343
arr_seq._offsets = np.array(self.offsets)
@@ -185,6 +185,7 @@ def finalize_append(self):
185185
return
186186
self._build_cache.update_seq(self)
187187
self._build_cache = None
188+
self.shrink_data()
188189

189190
def _resize_data_to(self, n_rows, build_cache):
190191
""" Resize data array if required """

nibabel/streamlines/tests/test_array_sequence.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
import unittest
44
import tempfile
5+
import itertools
56
import numpy as np
67

78
from nose.tools import assert_equal, assert_raises, assert_true
@@ -91,11 +92,20 @@ def test_creating_arraysequence_from_list(self):
9192
SEQ_DATA['data'])
9293

9394
def test_creating_arraysequence_from_generator(self):
94-
gen = (e for e in SEQ_DATA['data'])
95-
check_arr_seq(ArraySequence(gen), SEQ_DATA['data'])
95+
gen_1, gen_2 = itertools.tee((e for e in SEQ_DATA['data']))
96+
seq = ArraySequence(gen_1)
97+
seq_with_buffer = ArraySequence(gen_2, buffer_size=256)
98+
99+
# Check buffer size effect
100+
assert_equal(seq_with_buffer.data.shape, seq.data.shape)
101+
assert_true(seq_with_buffer._buffer_size > seq._buffer_size)
102+
103+
# Check generator result
104+
check_arr_seq(seq, SEQ_DATA['data'])
105+
check_arr_seq(seq_with_buffer, SEQ_DATA['data'])
96106

97107
# Already consumed generator
98-
check_empty_arr_seq(ArraySequence(gen))
108+
check_empty_arr_seq(ArraySequence(gen_1))
99109

100110
def test_creating_arraysequence_from_arraysequence(self):
101111
seq = ArraySequence(SEQ_DATA['data'])

0 commit comments

Comments
 (0)