From 3b305f5ef694ae82ebb9dd5fe6e4c7397076eb5c Mon Sep 17 00:00:00 2001 From: Nick Johnson <24689722+ntjohnson1@users.noreply.github.com> Date: Sat, 11 Mar 2023 14:18:48 -0500 Subject: [PATCH] TUCKER_ALS: TTM with negative values is broken in ttensor (#62) * Replace usage in tucker_als * Update test for tucker_als to ensure result matches expectation * Add early error handling in ttensor ttm for negative dims --- pyttb/pyttb_utils.py | 2 +- pyttb/ttensor.py | 4 +++- pyttb/tucker_als.py | 15 +++++---------- tests/test_ttensor.py | 10 ++++++++-- tests/test_tucker_als.py | 1 + 5 files changed, 18 insertions(+), 14 deletions(-) diff --git a/pyttb/pyttb_utils.py b/pyttb/pyttb_utils.py index ab72c479..c043f989 100644 --- a/pyttb/pyttb_utils.py +++ b/pyttb/pyttb_utils.py @@ -198,7 +198,7 @@ def tt_dimscheck( # Fix "minus" case if np.max(dims) < 0: - # Check that all memebers in range + # Check that all members in range if not np.all(np.isin(-dims, np.arange(0, N + 1))): assert False, "Invalid magnitude for negative dims selection" dims = np.setdiff1d(np.arange(1, N + 1), -dims) - 1 diff --git a/pyttb/ttensor.py b/pyttb/ttensor.py index cbf53c76..a346d1f2 100644 --- a/pyttb/ttensor.py +++ b/pyttb/ttensor.py @@ -434,7 +434,9 @@ def ttm(self, matrix, dims=None, transpose=False): dims = np.arange(self.ndims) elif isinstance(dims, list): dims = np.array(dims) - elif np.isscalar(dims) or isinstance(dims, list): + elif np.isscalar(dims): + if dims < 0: + raise ValueError("Negative dims is currently unsupported, see #62") dims = np.array([dims]) if not isinstance(matrix, list): diff --git a/pyttb/tucker_als.py b/pyttb/tucker_als.py index 71227b30..ed824f0a 100644 --- a/pyttb/tucker_als.py +++ b/pyttb/tucker_als.py @@ -124,14 +124,11 @@ def tucker_als( # Iterate over all N modes of the tensor for n in dimorder: - if ( - n == 0 - ): # TODO proposal to change ttm to include_dims and exclude_dims to resolve -0 ambiguity - dims = np.arange(1, tensor.ndims) - Utilde = tensor.ttm(U, dims, True) - else: - Utilde = tensor.ttm(U, -n, True) - + # TODO proposal to change ttm to include_dims and exclude_dims to resolve -0 ambiguity + dims = np.arange(0, tensor.ndims) + dims = dims[dims != n] + Utilde = tensor.ttm(U, dims, True) + print(f"Utilde[{n}] = {Utilde}") # Maximize norm(Utilde x_n W') wrt W and # maintain orthonormality of W U[n] = Utilde.nvecs(n, rank[n]) @@ -140,13 +137,11 @@ def tucker_als( core = Utilde.ttm(U, n, True) # Compute fit - # TODO this abs is missing from MATLAB, but I get negative numbers for trivial examples normresidual = np.sqrt(abs(normX**2 - core.norm() ** 2)) fit = 1 - (normresidual / normX) # fraction explained by model fitchange = abs(fitold - fit) if iter % printitn == 0: - print(f" NormX: {normX} Core norm: {core.norm()}") print(f" Iter {iter}: fit = {fit:e} fitdelta = {fitchange:7.1e}\n") # Check for convergence diff --git a/tests/test_ttensor.py b/tests/test_ttensor.py index 686dfd04..3dd12d82 100644 --- a/tests/test_ttensor.py +++ b/tests/test_ttensor.py @@ -310,9 +310,15 @@ def test_ttensor_ttm(random_ttensor): # Negative Tests big_wrong_size = 123 - matrices[0] = np.random.random((big_wrong_size, big_wrong_size)) + bad_matrices = matrices.copy() + bad_matrices[0] = np.random.random((big_wrong_size, big_wrong_size)) with pytest.raises(ValueError): - _ = ttensorInstance.ttm(matrices, np.arange(len(matrices))) + _ = ttensorInstance.ttm(bad_matrices, np.arange(len(bad_matrices))) + + with pytest.raises(ValueError): + # Negative dims currently broken, ensure we catch early and + # remove once resolved + ttensorInstance.ttm(matrices, -1) @pytest.mark.indevelopment diff --git a/tests/test_tucker_als.py b/tests/test_tucker_als.py index bf599338..b72d1251 100644 --- a/tests/test_tucker_als.py +++ b/tests/test_tucker_als.py @@ -19,6 +19,7 @@ def test_tucker_als_tensor_default_init(capsys, sample_tensor): (Solution, Uinit, output) = ttb.tucker_als(T, 2) capsys.readouterr() assert pytest.approx(output["fit"], 1) == 0 + assert np.all(np.isclose(Solution.double(), T.double())) (Solution, Uinit, output) = ttb.tucker_als(T, 2, init=Uinit) capsys.readouterr()