From 925ce273e5830929ce21aaae47bc4a23d191c8bc Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 6 Apr 2023 10:05:54 -0400 Subject: [PATCH] Fix path normalization for StreamWriter-based save operation Follow up of #3243. Save compat module had different semantics than info and load, which requires different way of performing path normalization. --- torchaudio/_backend/utils.py | 2 +- torchaudio/io/_compat.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/torchaudio/_backend/utils.py b/torchaudio/_backend/utils.py index f790130056..70361b8aa4 100644 --- a/torchaudio/_backend/utils.py +++ b/torchaudio/_backend/utils.py @@ -122,7 +122,7 @@ def save( buffer_size: int = 4096, ) -> None: save_audio( - os.path.normpath(uri), + uri, src, sample_rate, channels_first, diff --git a/torchaudio/io/_compat.py b/torchaudio/io/_compat.py index 7471bfdb89..7b122cbcc8 100644 --- a/torchaudio/io/_compat.py +++ b/torchaudio/io/_compat.py @@ -204,8 +204,11 @@ def save_audio( bits_per_sample: Optional[int] = None, buffer_size: int = 4096, ) -> None: - if hasattr(uri, "write") and format is None: - raise RuntimeError("'format' is required when saving to file object.") + if hasattr(uri, "write"): + if format is None: + raise RuntimeError("'format' is required when saving to file object.") + else: + uri = os.path.normpath(uri) s = StreamWriter(uri, format=format, buffer_size=buffer_size) if format is None: tokens = str(uri).split(".")