Skip to content

TUCKER_ALS: TTM with negative values is broken in ttensor (#62) #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyttb/pyttb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def tt_dimscheck(

# Fix "minus" case
if np.max(dims) < 0:
# Check that all memebers in range
# Check that all members in range
if not np.all(np.isin(-dims, np.arange(0, N + 1))):
assert False, "Invalid magnitude for negative dims selection"
dims = np.setdiff1d(np.arange(1, N + 1), -dims) - 1
Expand Down
4 changes: 3 additions & 1 deletion pyttb/ttensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,9 @@ def ttm(self, matrix, dims=None, transpose=False):
dims = np.arange(self.ndims)
elif isinstance(dims, list):
dims = np.array(dims)
elif np.isscalar(dims) or isinstance(dims, list):
elif np.isscalar(dims):
if dims < 0:
raise ValueError("Negative dims is currently unsupported, see #62")
dims = np.array([dims])

if not isinstance(matrix, list):
Expand Down
15 changes: 5 additions & 10 deletions pyttb/tucker_als.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,11 @@ def tucker_als(

# Iterate over all N modes of the tensor
for n in dimorder:
if (
n == 0
): # TODO proposal to change ttm to include_dims and exclude_dims to resolve -0 ambiguity
dims = np.arange(1, tensor.ndims)
Utilde = tensor.ttm(U, dims, True)
else:
Utilde = tensor.ttm(U, -n, True)

# TODO proposal to change ttm to include_dims and exclude_dims to resolve -0 ambiguity
dims = np.arange(0, tensor.ndims)
dims = dims[dims != n]
Utilde = tensor.ttm(U, dims, True)
print(f"Utilde[{n}] = {Utilde}")
# Maximize norm(Utilde x_n W') wrt W and
# maintain orthonormality of W
U[n] = Utilde.nvecs(n, rank[n])
Expand All @@ -140,13 +137,11 @@ def tucker_als(
core = Utilde.ttm(U, n, True)

# Compute fit
# TODO this abs is missing from MATLAB, but I get negative numbers for trivial examples
normresidual = np.sqrt(abs(normX**2 - core.norm() ** 2))
fit = 1 - (normresidual / normX) # fraction explained by model
fitchange = abs(fitold - fit)

if iter % printitn == 0:
print(f" NormX: {normX} Core norm: {core.norm()}")
print(f" Iter {iter}: fit = {fit:e} fitdelta = {fitchange:7.1e}\n")

# Check for convergence
Expand Down
10 changes: 8 additions & 2 deletions tests/test_ttensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,15 @@ def test_ttensor_ttm(random_ttensor):

# Negative Tests
big_wrong_size = 123
matrices[0] = np.random.random((big_wrong_size, big_wrong_size))
bad_matrices = matrices.copy()
bad_matrices[0] = np.random.random((big_wrong_size, big_wrong_size))
with pytest.raises(ValueError):
_ = ttensorInstance.ttm(matrices, np.arange(len(matrices)))
_ = ttensorInstance.ttm(bad_matrices, np.arange(len(bad_matrices)))

with pytest.raises(ValueError):
# Negative dims currently broken, ensure we catch early and
# remove once resolved
ttensorInstance.ttm(matrices, -1)


@pytest.mark.indevelopment
Expand Down
1 change: 1 addition & 0 deletions tests/test_tucker_als.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_tucker_als_tensor_default_init(capsys, sample_tensor):
(Solution, Uinit, output) = ttb.tucker_als(T, 2)
capsys.readouterr()
assert pytest.approx(output["fit"], 1) == 0
assert np.all(np.isclose(Solution.double(), T.double()))

(Solution, Uinit, output) = ttb.tucker_als(T, 2, init=Uinit)
capsys.readouterr()
Expand Down