Skip to content

Commit 983c31a

Browse files
author
nickj
committed
TENSOR: Fix slices ref shen return value isn't scalar or vector. #41
1 parent fbbf554 commit 983c31a

File tree

2 files changed

+5
-11
lines changed

2 files changed

+5
-11
lines changed

pyttb/tensor.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,7 +1273,6 @@ def __getitem__(self, item):
12731273
kpdims = [] # dimensions to keep
12741274
rmdims = [] # dimensions to remove
12751275

1276-
# Determine the new size and what dimensions to keep
12771276
# Determine the new size and what dimensions to keep
12781277
for i in range(0, len(region)):
12791278
if isinstance(region[i], slice):
@@ -1291,19 +1290,11 @@ def __getitem__(self, item):
12911290

12921291
# If the size is zero, then the result is returned as a scalar
12931292
# otherwise, we convert the result to a tensor
1294-
12951293
if newsiz.size == 0:
12961294
a = newdata
12971295
else:
1298-
if rmdims.size == 0:
1299-
a = ttb.tensor.from_data(newdata)
1300-
else:
1301-
# If extracted data is a vector then no need to tranpose it
1302-
if len(newdata.shape) == 1:
1303-
a = ttb.tensor.from_data(newdata)
1304-
else:
1305-
a = ttb.tensor.from_data(np.transpose(newdata, np.concatenate((kpdims, rmdims))))
1306-
return ttb.tt_subsubsref(a, item)
1296+
a = ttb.tensor.from_data(newdata)
1297+
return a
13071298

13081299
# *** CASE 2a: Subscript indexing ***
13091300
if len(item) > 1 and isinstance(item[-1], str) and item[-1] == 'extract':

tests/test_tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,9 @@ def test_tensor__getitem__(sample_tensor_2way):
280280
assert tensorInstance[0, 0] == params['data'][0, 0]
281281
# Case 1 Subtensor
282282
assert (tensorInstance[:, :] == tensorInstance).data.all()
283+
three_way_data = np.random.random((2, 3, 4))
284+
two_slices = (slice(None,None,None), 0, slice(None,None,None))
285+
assert (ttb.tensor.from_data(three_way_data)[two_slices].double() == three_way_data[two_slices]).all()
283286
# Case 1 Subtensor
284287
assert (tensorInstance[np.array([0, 1]), :].data == tensorInstance.data[[0, 1], :]).all()
285288
# Case 1 Subtensor

0 commit comments

Comments
 (0)