diff --git a/nibabel/arrayops.py b/nibabel/arrayops.py new file mode 100644 index 0000000000..591f2a22d7 --- /dev/null +++ b/nibabel/arrayops.py @@ -0,0 +1,130 @@ +import operator + +import numpy as np + +from .orientations import aff2axcodes +# support_np_type = ( +# np.int8, +# np.int64, +# np.float16, +# np.float32, +# np.float64, +# np.complex128) + + +class OperableImage: + def _binop(self, val, *, op): + """Apply operator to Nifti1Image. + + Arithmetic and logical operation on Nifti image. + Currently support: +, -, *, /, //, &, | + The nifit image should contain the same header information and affine. + Images should be the same shape. + + Parameters + ---------- + op : + Python operator. + """ + affine, header = self.affine, self.header + self_, val_ = _input_validation(self, val) + # numerical operator should work work + + if op.__name__ in ["add", "sub", "mul", "truediv", "floordiv"]: + dataobj = op(self_, val_) + if op.__name__ in ["and_", "or_"]: + self_ = self_.astype(bool) + val_ = val_.astype(bool) + dataobj = op(self_, val_).astype(int) + return self.__class__(dataobj, affine, header) + + def _unop(self, *, op): + """ + Parameters + ---------- + op : + Python operator. + """ + # _type_check(self) + if op.__name__ in ["pos", "neg", "abs"]: + dataobj = op(np.asanyarray(self.dataobj)) + return self.__class__(dataobj, self.affine, self.header) + + def __add__(self, other): + return self._binop(other, op=operator.__add__) + + def __sub__(self, other): + return self._binop(other, op=operator.__sub__) + + def __mul__(self, other): + return self._binop(other, op=operator.__mul__) + + def __truediv__(self, other): + return self._binop(other, op=operator.__truediv__) + + def __floordiv__(self, other): + return self._binop(other, op=operator.__floordiv__) + + def __and__(self, other): + return self._binop(other, op=operator.__and__) + + def __or__(self, other): + return self._binop(other, op=operator.__or__) + + def __pos__(self): + return self._unop(op=operator.__pos__) + + def __neg__(self): + return self._unop(op=operator.__neg__) + + def __abs__(self): + return self._unop(op=operator.__abs__) + + +def _input_validation(self, val): + """Check images orientation, affine, and shape muti-images operation.""" + # _type_check(self) + if isinstance(val, self.__class__): + # _type_check(val) + # Check orientations are the same + if aff2axcodes(self.affine) != aff2axcodes(val.affine): + raise ValueError("Two images should have the same orientation") + # Check affine + if (self.affine != val.affine).any(): + raise ValueError("Two images should have the same affine.") + + # Check shape. + if self.shape[:3] != val.shape[:3]: + raise ValueError("Two images should have the same shape except " + "the time dimension.") + + # if 4th dim exist in a image, + # reshape the 3d image to ensure valid projection + ndims = (len(self.shape), len(val.shape)) + if 4 not in ndims: + self_ = np.asanyarray(self.dataobj) + val_ = np.asanyarray(val.dataobj) + return self_, val_ + + reference = None + imgs = [] + for ndim, img in zip(ndims, (self, val)): + img_ = np.asanyarray(img.dataobj) + if ndim == 3: + reference = tuple(list(img.shape) + [1]) + img_ = np.reshape(img_, reference) + imgs.append(img_) + return imgs + else: + self_ = np.asanyarray(self.dataobj) + val_ = val + return self_, val_ + + +# def _type_check(*args): +# """Ensure image contains correct nifti data type.""" +# # Check types +# dtypes = [img.get_data_dtype().type for img in args] +# # check allowed dtype based on the operator +# if set(support_np_type).union(dtypes) == 0: +# raise ValueError("Image contains illegal datatype for Nifti1Image.") diff --git a/nibabel/nifti1.py b/nibabel/nifti1.py index 799377f282..8220b4d53e 100644 --- a/nibabel/nifti1.py +++ b/nibabel/nifti1.py @@ -26,6 +26,7 @@ from .spm99analyze import SpmAnalyzeHeader from .casting import have_binary128 from .pydicom_compat import have_dicom, pydicom as pdcm +from .arrayops import OperableImage # nifti1 flat header definition for Analyze-like first 348 bytes # first number in comments indicates offset in file header in bytes @@ -2011,7 +2012,7 @@ def as_reoriented(self, ornt): return img -class Nifti1Image(Nifti1Pair, SerializableImage): +class Nifti1Image(Nifti1Pair, SerializableImage, OperableImage): """ Class for single file NIfTI1 format image """ header_class = Nifti1Header diff --git a/nibabel/tests/test_arrayops.py b/nibabel/tests/test_arrayops.py new file mode 100644 index 0000000000..be389c9e6f --- /dev/null +++ b/nibabel/tests/test_arrayops.py @@ -0,0 +1,94 @@ +import numpy as np +from .. import Nifti1Image +from numpy.testing import assert_array_equal +import pytest + + +def test_binary_operations(): + data1 = np.random.rand(5, 5, 2) + data2 = np.random.rand(5, 5, 2) + data1[0, 0, :] = 0 + img1 = Nifti1Image(data1, np.eye(4)) + img2 = Nifti1Image(data2, np.eye(4)) + + output = img1 + 2 + assert_array_equal(output.dataobj, data1 + 2) + + output = img1 + img2 + assert_array_equal(output.dataobj, data1 + data2) + + output = img1 + img2 + img2 + assert_array_equal(output.dataobj, data1 + data2 + data2) + + output = img1 - img2 + assert_array_equal(output.dataobj, data1 - data2) + + output = img1 * img2 + assert_array_equal(output.dataobj, data1 * data2) + + output = img1 / img2 + assert_array_equal(output.dataobj, data1 / data2) + + output = img1 // img2 + assert_array_equal(output.dataobj, data1 // data2) + + output = img2 / img1 + assert_array_equal(output.dataobj, data2 / data1) + + output = img2 // img1 + assert_array_equal(output.dataobj, data2 // data1) + + output = img1 & img2 + assert_array_equal(output.dataobj, (data1.astype(bool) & data2.astype(bool)).astype(int)) + + output = img1 | img2 + assert_array_equal(output.dataobj, (data1.astype(bool) | data2.astype(bool)).astype(int)) + + +def test_binary_operations_4d(): + data1 = np.random.rand(5, 5, 2, 3) + data2 = np.random.rand(5, 5, 2) + img1 = Nifti1Image(data1, np.eye(4)) + img2 = Nifti1Image(data2, np.eye(4)) + data2_ = np.reshape(data2, (5, 5, 2, 1)) + + output = img1 * img2 + assert_array_equal(output.dataobj, data1 * data2_) + + +def test_unary_operations(): + data = np.random.rand(5, 5, 2) - 0.5 + img = Nifti1Image(data, np.eye(4)) + + output = +img + assert_array_equal(output.dataobj, +data) + + output = -img + assert_array_equal(output.dataobj, -data) + + output = abs(img) + assert_array_equal(output.dataobj, abs(data)) + + +def test_error_catching(): + data1 = np.random.rand(5, 5, 1) + data2 = np.random.rand(5, 5, 2) + img1 = Nifti1Image(data1, np.eye(4)) + img2 = Nifti1Image(data2, np.eye(4)) + with pytest.raises(ValueError, match=r'should have the same shape'): + img1 + img2 + + data1 = np.random.rand(5, 5, 2) + data2 = np.random.rand(5, 5, 2) + img1 = Nifti1Image(data1, np.eye(4) * 2) + img2 = Nifti1Image(data2, np.eye(4)) + with pytest.raises(ValueError, match=r'should have the same affine'): + img1 + img2 + + data = np.random.rand(5, 5, 2) + aff1 = [[0,1,0,10],[-1,0,0,20],[0,0,1,30],[0,0,0,1]] + aff2 = np.eye(4) + img1 = Nifti1Image(data, aff1) + img2 = Nifti1Image(data, aff2) + with pytest.raises(ValueError, match=r'should have the same orientation'): + img1 + img2