Skip to content

Commit 2340a92

Browse files
committed
added testing for dtype at cmdline, assuming cmdline functionality. changed value output in data_diff function to allow for validation of data type
1 parent 97ead00 commit 2340a92

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

nibabel/cmdline/diff.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,8 @@ def get_data_diff(files, max_abs=0, max_rel=0, dtype=np.float64):
208208

209209
diff_rec = OrderedDict() # so that abs goes before relative
210210

211-
diff_rec['abs'] = max_abs_diff
212-
diff_rec['rel'] = max_rel_diff
211+
diff_rec['abs'] = max_abs_diff.astype(dtype)
212+
diff_rec['rel'] = max_rel_diff.astype(dtype)
213213
diffs1.append(diff_rec)
214214
else:
215215
diffs1.append(None)
@@ -274,7 +274,8 @@ def display_diff(files, diff):
274274
return output
275275

276276

277-
def diff(files, header_fields='all', data_max_abs_diff=None, data_max_rel_diff=None, dtype=np.float64):
277+
def diff(files, header_fields='all', data_max_abs_diff=None, data_max_rel_diff=None,
278+
dtype=np.float64):
278279
assert len(files) >= 2, "Please enter at least two files"
279280

280281
file_headers = [nib.load(f).header for f in files]

nibabel/cmdline/tests/test_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import nibabel as nib
1212
import numpy as np
1313
from nibabel.cmdline.utils import *
14-
from nibabel.cmdline.diff import get_headers_diff, display_diff, main, get_data_hash_diff, get_data_diff
14+
from nibabel.cmdline.diff import *
1515
from os.path import (join as pjoin)
1616
from nibabel.testing import data_path
1717
from collections import OrderedDict
@@ -146,6 +146,16 @@ def test_get_data_diff():
146146
OrderedDict([('DATA(diff 1:)', [None, {'CMP': 'incompat'}, {'CMP': 'incompat'}]),
147147
('DATA(diff 2:)', [None, None, {'CMP': 'incompat'}])]))
148148

149+
test_return = get_data_diff([test_array, test_array_2], dtype=np.float32)
150+
assert_equal(type(test_return['DATA(diff 1:)'][1]['abs']), np.float32)
151+
assert_equal(type(test_return['DATA(diff 1:)'][1]['rel']), np.float32)
152+
153+
test_return_2 = get_data_diff([test_array, test_array_2, test_array_3])
154+
assert_equal(type(test_return_2['DATA(diff 1:)'][1]['abs']), np.float64)
155+
assert_equal(type(test_return_2['DATA(diff 1:)'][1]['rel']), np.float64)
156+
assert_equal(type(test_return_2['DATA(diff 2:)'][2]['abs']), np.float64)
157+
assert_equal(type(test_return_2['DATA(diff 2:)'][2]['rel']), np.float64)
158+
149159

150160
def test_main():
151161
test_names = [pjoin(data_path, f)

0 commit comments

Comments
 (0)