diff --git a/examples/models/moshi/mimi/test_mimi.py b/examples/models/moshi/mimi/test_mimi.py index 350595e9cf7..69859fa39bc 100644 --- a/examples/models/moshi/mimi/test_mimi.py +++ b/examples/models/moshi/mimi/test_mimi.py @@ -59,7 +59,7 @@ def setUpClass(cls): """Setup once for all tests: Load model and prepare test data.""" # Get environment variables (if set), otherwise use default values - mimi_weight = os.getenv("MIMI_WEIGHT", None) + cls.mimi_weight = os.getenv("MIMI_WEIGHT", None) hf_repo = os.getenv("HF_REPO", loaders.DEFAULT_REPO) device = "cuda" if torch.cuda.device_count() else "cpu" @@ -75,15 +75,15 @@ def seed_all(seed): seed_all(42424242) - if mimi_weight is None: + if cls.mimi_weight is None: try: - mimi_weight = hf_hub_download(hf_repo, loaders.MIMI_NAME) + cls.mimi_weight = hf_hub_download(hf_repo, loaders.MIMI_NAME) except: - mimi_weight = hf_hub_download( + cls.mimi_weight = hf_hub_download( hf_repo, loaders.MIMI_NAME, proxies=proxies ) - cls.mimi = loaders.get_mimi(mimi_weight, device) + cls.mimi = loaders.get_mimi(cls.mimi_weight, device) cls.device = device cls.sample_pcm, cls.sample_sr = read_mp3_from_url( "https://huggingface.co/lmz/moshi-swift/resolve/main/bria-24khz.mp3" @@ -182,8 +182,8 @@ def forward(self, x): return out emb_input = torch.rand(1, 1, 512, device="cpu") - - mimi_decode = MimiDecode(self.mimi) + mimi_cpu = loaders.get_mimi(self.mimi_weight, "cpu") + mimi_decode = MimiDecode(mimi_cpu) mimi_decode.eval() mimi_decode(emb_input) @@ -225,7 +225,9 @@ def forward(self, x): # Compare results sqnr = compute_sqnr(eager_res, res[0]) print(f"SQNR: {sqnr}") - torch.testing.assert_close(eager_res, res[0], atol=4e-3, rtol=1e-3) + # Don't check for exact equality, but check that the SQNR is high enough + # torch.testing.assert_close(eager_res, res[0], atol=4e-3, rtol=1e-3) + self.assertGreater(sqnr, 25.0) if __name__ == "__main__":