From 92dc1b50743a316735595d022c103d0470495ada Mon Sep 17 00:00:00 2001 From: Manal ML Date: Wed, 21 May 2025 03:47:17 +0000 Subject: [PATCH 01/10] add working x-codec --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/xcodec.md | 82 +++ src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + .../models/auto/feature_extraction_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 1 + src/transformers/models/xcodec/__init__.py | 27 + .../models/xcodec/configuration_xcodec.py | 204 ++++++ .../xcodec/convert_xcodec_weights_to_hf.py | 237 +++++++ .../models/xcodec/modeling_xcodec.py | 591 ++++++++++++++++++ tests/models/xcodec/__init__.py | 0 tests/models/xcodec/test_modeling_xcodec.py | 478 ++++++++++++++ 12 files changed, 1626 insertions(+) create mode 100644 docs/source/en/model_doc/xcodec.md create mode 100644 src/transformers/models/xcodec/__init__.py create mode 100644 src/transformers/models/xcodec/configuration_xcodec.py create mode 100644 src/transformers/models/xcodec/convert_xcodec_weights_to_hf.py create mode 100644 src/transformers/models/xcodec/modeling_xcodec.py create mode 100644 tests/models/xcodec/__init__.py create mode 100644 tests/models/xcodec/test_modeling_xcodec.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 50567ebec463..e4d16783448a 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -669,6 +669,8 @@ title: UL2 - local: model_doc/umt5 title: UMT5 + - local: model_doc/xcodec + title: X-CODEC - local: model_doc/xmod title: X-MOD - local: model_doc/xglm diff --git a/docs/source/en/model_doc/xcodec.md b/docs/source/en/model_doc/xcodec.md new file mode 100644 index 000000000000..ce0287f46f17 --- /dev/null +++ b/docs/source/en/model_doc/xcodec.md @@ -0,0 +1,82 @@ + + +# X-Codec + +
+PyTorch +
+ +## Overview + +The X-Codec model was proposed in [Codec Does Matter: Exploring the Semantic Shortcoming of Codec for Audio Language Model](https://arxiv.org/abs/2408.17175) by Zhen Ye, Peiwen Sun, Jiahe Lei, Hongzhan Lin, Xu Tan, Zheqi Dai, Qiuqiang Kong, Jianyi Chen, Jiahao Pan, Qifeng Liu, Yike Guo, Wei Xue + +The X-Codec model is a neural audio codec that integrates semantic information from self-supervised models (e.g., HuBERT) alongside traditional acoustic information. This enables : + +- **Music continuation** : Better modeling of musical semantics yields more coherent continuations. +- **Text-to-Sound Synthesis** : X-Codec captures semantic alignment between text prompts and generated audio. +- **Semantic aware audio tokenization**: X-Codec is used as an audio tokenizer in the YuE lyrics to song generation model. + +The abstract of the paper states the following: + +*Recent advancements in audio generation have been significantly propelled by the capabilities of Large Language Models (LLMs). The existing research on audio LLM has primarily focused on enhancing the architecture and scale of audio language models, as well as leveraging larger datasets, and generally, acoustic codecs, such as EnCodec, are used for audio tokenization. However, these codecs were originally designed for audio compression, which may lead to suboptimal performance in the context of audio LLM. Our research aims to address the shortcomings of current audio LLM codecs, particularly their challenges in maintaining semantic integrity in generated audio. For instance, existing methods like VALL-E, which condition acoustic token generation on text transcriptions, often suffer from content inaccuracies and elevated word error rates (WER) due to semantic misinterpretations of acoustic tokens, resulting in word skipping and errors. To overcome these issues, we propose a straightforward yet effective approach called X-Codec. X-Codec incorporates semantic features from a pre-trained semantic encoder before the Residual Vector Quantization (RVQ) stage and introduces a semantic reconstruction loss after RVQ. By enhancing the semantic ability of the codec, X-Codec significantly reduces WER in speech synthesis tasks and extends these benefits to non-speech applications, including music and sound generation. Our experiments in text-to-speech, music continuation, and text-to-sound tasks demonstrate that integrating semantic information substantially improves the overall performance of language models in audio generation.* + +Demos can be found in this [post](https://x-codec-audio.github.io/). + + +This model was contributed by [Manal El Aidouni](https://huggingface.co/Manel). The original code can be found [here](https://github.com/zhenye234/xcodec) and original checkpoint [here](https://huggingface.co/ZhenYe234/xcodec/blob/main/xcodec_speech_hubert_librispeech.pth). + + + +## Usage example + +Here is a quick example of how to encode and decode an audio using this model: + +```python +>>> from datasets import load_dataset, Audio +>>> from transformers import XCodecModel, AutoFeatureExtractor +>>> dummy_dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + +>>> # load model and feature extractor +>>> model = XCodecModel.from_pretrained("Manel/X-Codec") +>>> feature_extractor = AutoFeatureExtractor.from_pretrained("Manel/X-Codec") +>>> # load audio sample +>>> dummy_dataset = dummy_dataset.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate)) +>>> audio_sample = dummy_dataset[-1]["audio"]["array"] +>>> inputs = feature_extractor(raw_audio=audio_sample, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") + +>>> encoder_outputs = model.encode(inputs["input_values"]) +>>> audio_codes = encoder_outputs.audio_codes +>>> decoder_outputs = model.decode(audio_codes) +>>> audio_values = decoder_outputs.audio_values + +>>> # or the equivalent with a forward pass +>>> outputs = model(inputs["input_values"]) +>>> audio_codes = outputs.audio_codes +>>> audio_values = outputs.audio_values +``` + +## XcodecConfig + +[[autodoc]] XcodecConfig + + +## XcodecModel + +[[autodoc]] XcodecModel + - decode + - encode + - forward \ No newline at end of file diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 6d2c5affad91..a20a9cbd9439 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -337,6 +337,7 @@ from .wavlm import * from .whisper import * from .x_clip import * + from .xcodec import * from .xglm import * from .xlm import * from .xlm_roberta import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 02eb31a503bd..f434caf013dc 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -379,6 +379,7 @@ ("wavlm", "WavLMConfig"), ("whisper", "WhisperConfig"), ("xclip", "XCLIPConfig"), + ("xcodec", "XcodecConfig"), ("xglm", "XGLMConfig"), ("xlm", "XLMConfig"), ("xlm-prophetnet", "XLMProphetNetConfig"), @@ -773,6 +774,7 @@ ("wavlm", "WavLM"), ("whisper", "Whisper"), ("xclip", "X-CLIP"), + ("xcodec", "X-CODEC"), ("xglm", "XGLM"), ("xlm", "XLM"), ("xlm-prophetnet", "XLM-ProphetNet"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 5754b3bc1bb6..5569dce3d0c7 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -113,6 +113,7 @@ ("wavlm", "Wav2Vec2FeatureExtractor"), ("whisper", "WhisperFeatureExtractor"), ("xclip", "CLIPFeatureExtractor"), + ("xcodec", "EncodecFeatureExtractor"), ("yolos", "YolosFeatureExtractor"), ] ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index f6cb83d1ee51..08f831b23469 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -350,6 +350,7 @@ ("wavlm", "WavLMModel"), ("whisper", "WhisperModel"), ("xclip", "XCLIPModel"), + ("xcodec", "XcodecModel"), ("xglm", "XGLMModel"), ("xlm", "XLMModel"), ("xlm-prophetnet", "XLMProphetNetModel"), diff --git a/src/transformers/models/xcodec/__init__.py b/src/transformers/models/xcodec/__init__.py new file mode 100644 index 000000000000..720f1612e2e8 --- /dev/null +++ b/src/transformers/models/xcodec/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_xcodec import * + from .modeling_xcodec import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) \ No newline at end of file diff --git a/src/transformers/models/xcodec/configuration_xcodec.py b/src/transformers/models/xcodec/configuration_xcodec.py new file mode 100644 index 000000000000..9839e8a1b651 --- /dev/null +++ b/src/transformers/models/xcodec/configuration_xcodec.py @@ -0,0 +1,204 @@ +# coding=utf-8 +# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Xcodec model configuration""" + +import math + +import numpy as np +from typing import Union, Dict, List, Optional + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + +from transformers import DacConfig, HubertConfig + + +logger = logging.get_logger(__name__) + + +class XcodecConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`XcodecModel`]. It is used to instantiate a + Xcodec model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the + [ ](https://huggingface.co/ ) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + target_bandwidths (`List[float]`, *optional*, defaults to [0.5, 1.0, 1.5, 2.0, 4.0]): + The range of different bandwidths (in kbps) the model can encode audio with. + audio_channels (`int`, *optional*, defaults to 1): + Number of channels in the audio data. Either 1 for mono or 2 for stereo. + sample_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the audio waveform should be digitalized, in hertz (Hz). + input_channels (`int`, *optional*, defaults to 768): + Number of channels of the input to the first convolution in the semantic encoder. + kernel_size (`int`, *optional*, defaults to 3): + Kernel size for the initial semantic convolution. + encoder_channels (`int`, *optional*, defaults to 768): + Number of hidden channels in each semantic encoder block. + channel_ratios (`List[float]`, *optional*, defaults to [1.0, 1.0]): + Expansion factors for the number of output channels in each semantic block. + strides (`List[int]`, *optional*, defaults to [1, 1]): + Strides for each semantic encoder block. + block_dilations (`List[int]`, *optional*, defaults to [1, 1]): + Dilation factors for the residual units in semantic blocks. + unit_kernel_size (`int`, *optional*, defaults to 3): + Kernel size inside each ResidualUnit in semantic blocks. + decoder_channels (`int`, *optional*, defaults to 768): + Number of hidden channels in each semantic decoder block. + output_channels (`int`, *optional*, defaults to 768): + Number of output channels in the semantic decoder. + codebook_size (`int`, *optional*, defaults to 1024): + Number of entries in each residual quantizer’s codebook. + num_quantizers (`int`, *optional*, defaults to 8): + Number of sequential quantizers (codebooks) in the RVQ stack. + codebook_dim (`int`, *optional*, defaults to 1024): + Dimensionality of each codebook vector. + initializer_range (`float`, *optional*, defaults to 0.02): + Standard deviation of the truncated normal initializer for all weight matrices. + hidden_dim (`int`, *optional*, defaults to 1024): + Dimensionality of the joint acoustic+semantic FC layer. + intermediate_dim (`int`, *optional*, defaults to 768): + Dimensionality of the next FC layer in the decoder path. + output_dim (`int`, *optional*, defaults to 256): + Dimensionality of the final FC layer before feeding into the acoustic decoder. + acoustic_model_config (`Union[Dict, DacConfig]`, *optional*): + An instance of the configuration for the acoustic (DAC) model. + semantic_model_config (`Union[Dict, HubertConfig]`, *optional*): + An instance of the configuration object for the semantic (HuBERT) model. + + Example: + + ```python + >>> from transformers import XcodecModel, XcodecConfig + + >>> # Initializing a " " style configuration + >>> configuration = XcodecConfig() + + >>> # Initializing a model (with random weights) from the " " style configuration + >>> model = XcodecModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "xcodec" + + sub_configs = { + "acoustic_model_config": DacConfig, + "semantic_model_config": HubertConfig, + } + is_composition = True + + def __init__( + self, + target_bandwidths: List[float] = [0.5, 1, 1.5, 2, 4], + audio_channels=1, + sample_rate: int = 16000, + input_channels: int = 768, + encoder_channels: int = 768, + kernel_size: int = 3, + channel_ratios: List[float] = [1, 1], + strides: List[int] = [1, 1], + block_dilations: List[int] = [1, 1], + unit_kernel_size: int = 3, + decoder_channels: int = 768, + output_channels: int = 768, + codebook_size: int = 1024, + num_quantizers: int = 8, + codebook_dim: int = 1024, + initializer_range: float = 0.02, + hidden_dim: int = 1024, + intermediate_dim: int = 768, + output_dim: int = 256, + acoustic_model_config: Union[dict, DacConfig] = None, + semantic_model_config: Union[dict, HubertConfig] = None, + **kwargs, + ): + + super().__init__(**kwargs) + + if acoustic_model_config is None: + self.acoustic_model_config = DacConfig( + encoder_hidden_size = 64, + downsampling_ratios = [8, 5, 4, 2], + decoder_hidden_size = 1024, + upsampling_ratios = [8, 5, 4, 2], + hidden_size = 256, + ) + elif isinstance(acoustic_model_config, dict): + self.acoustic_model_config = DacConfig(**acoustic_model_config) + elif isinstance(acoustic_model_config, DacConfig): + self.acoustic_model_config = acoustic_model_config + + if semantic_model_config is None: + self.semantic_model_config = HubertConfig() + elif isinstance(semantic_model_config, dict): + self.semantic_model_config = HubertConfig(**semantic_model_config) + elif isinstance(semantic_model_config, HubertConfig): + self.semantic_model_config = semantic_model_config + + self.target_bandwidths = target_bandwidths + self.audio_channels = audio_channels + self.sample_rate = sample_rate + self.input_channels = input_channels + self.encoder_channels = encoder_channels + self.kernel_size = kernel_size + self.channel_ratios = channel_ratios + self.strides = strides + self.block_dilations = block_dilations + self.unit_kernel_size = unit_kernel_size + self.decoder_channels = decoder_channels + self.output_channels = output_channels + self.codebook_size = codebook_size + self.num_quantizers = num_quantizers + self.codebook_dim = codebook_dim + self.initializer_range = initializer_range + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.output_dim = output_dim + + + @classmethod + def from_sub_models_config(cls, acoustic_model_config: DacConfig, semantic_model_config: HubertConfig, **kwargs): + """ + Instantiate a [`XcodecConfig`] from acoustic model and semantic model. + + Returns: + [`XcodecConfig`]: The instantiated configuration. + """ + return cls( + acoustic_model_config=acoustic_model_config.to_dict() if hasattr(acoustic_model_config, "to_dict") else acoustic_model_config, + semantic_model_config=semantic_model_config.to_dict() if hasattr(semantic_model_config, "to_dict") else semantic_model_config, + **kwargs, + ) + + @property + def frame_rate(self) -> int: + return math.ceil(self.sample_rate / np.prod(self.acoustic_model_config.upsampling_ratios)) + + @property + def bits_per_codebook(self) -> int: + return int(math.log2(self.codebook_size)) + + @property + def hop_length(self) -> int: + return int(np.prod(self.acoustic_model_config.downsampling_ratios)) + + +__all__ = ["XcodecConfig"] \ No newline at end of file diff --git a/src/transformers/models/xcodec/convert_xcodec_weights_to_hf.py b/src/transformers/models/xcodec/convert_xcodec_weights_to_hf.py new file mode 100644 index 000000000000..de7b6c35adba --- /dev/null +++ b/src/transformers/models/xcodec/convert_xcodec_weights_to_hf.py @@ -0,0 +1,237 @@ +# coding=utf-8 +# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import re + +import torch +from typing import Dict + + +from transformers import ( + XcodecConfig, + EncodecFeatureExtractor, + XcodecModel, + logging, +) + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + + +MAPPING_ACOUSTIC_ENCODER = { + r"^block\.0": ["conv1"], + r"^block\.(\d+)\.block\.(\d+)\.block\.0": ["block", "res_unit", "snake1"], + r"^block\.(\d+)\.block\.(\d+)\.block\.1": ["block", "res_unit", "conv1"], + r"^block\.(\d+)\.block\.(\d+)\.block\.2": ["block", "res_unit", "snake2"], + r"^block\.(\d+)\.block\.(\d+)\.block\.3": ["block", "res_unit", "conv2"], + r"^block\.(\d+)\.block\.3": ["block", "snake1"], + r"^block\.(\d+)\.block\.4": ["block", "conv1"], + r"^block\.5": ["snake1"], + r"^block\.6": ["conv2"], +} + +MAPPING_ACOUSTIC_DECODER = { + r"^model\.0": ["conv1"], + r"^model\.(\d+)\.block\.0": ["block", "snake1"], + r"^model\.(\d+)\.block\.1": ["block", "conv_t1"], + r"^model\.(\d+)\.block\.(\d+)\.block\.0": ["block", "res_unit", "snake1"], + r"^model\.(\d+)\.block\.(\d+)\.block\.1": ["block", "res_unit", "conv1"], + r"^model\.(\d+)\.block\.(\d+)\.block\.2": ["block", "res_unit", "snake2"], + r"^model\.(\d+)\.block\.(\d+)\.block\.3": ["block", "res_unit", "conv2"], + r"^model\.5": ["snake1"], + r"^model\.6": ["conv2"], +} + +MAPPING_SEMANTIC_ENCODER = { + "conv.conv.": "conv.", + "conv1.conv.": "conv1.", + "conv2.conv.": "conv2.", +} + +MAPPING_SEMANTIC_DECODER = { + "conv1.conv.": "conv1.", + "conv2.conv.": "conv2.", + "conv.conv.": "conv.", +} + +MAPPING_QUANTIZER = { + "quantizer.vq.layers": "quantizer.quantizers", + "._codebook.": ".codebook.", +} + + +def _rewrite_weight_norm(key: str) -> str: + if key.endswith("weight_g"): + return key[:-len("weight_g")] + "parametrizations.weight.original0" + if key.endswith("weight_v"): + return key[:-len("weight_v")] + "parametrizations.weight.original1" + return key + + +def convert_old_keys_to_new_keys(original_state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + converted_checkpoint: Dict[str, torch.Tensor] = {} + + for old_key, value in original_state_dict.items(): + + if old_key.startswith("encoder."): + layer_key = old_key[len("encoder."):] + for pattern, path_parts in MAPPING_ACOUSTIC_ENCODER.items(): + pattern_match = re.match(pattern, layer_key) + if pattern_match is None: + continue + + digit_strings = [g for g in pattern_match.groups() if g is not None] + digit_indices = [int(ds) for ds in digit_strings] + remainder = layer_key[pattern_match.end():] + + if len(path_parts) == 1: + mapped_subkey = f"{path_parts[0]}{remainder}" + elif len(path_parts) == 2: + encoder_layer = digit_indices[0] - 1 + mapped_subkey = f"{path_parts[0]}.{encoder_layer}.{path_parts[1]}{remainder}" + else: + encoder_layer, unit_idx = digit_indices + mapped_subkey = ( + f"{path_parts[0]}.{encoder_layer-1}." + f"{path_parts[1]}{unit_idx+1}." + f"{path_parts[2]}{remainder}" + ) + + new_key = f"acoustic_encoder.{_rewrite_weight_norm(mapped_subkey)}" + converted_checkpoint[new_key] = value + break + + elif old_key.startswith("decoder_2."): + layer_key = old_key[len("decoder_2."):] + + for pattern, path_parts in MAPPING_ACOUSTIC_DECODER.items(): + pattern_match = re.match(pattern, layer_key) + if pattern_match is None: + continue + digit_strings = [g for g in pattern_match.groups() if g is not None] + digit_indices = [int(ds) for ds in digit_strings] + remainder = layer_key[pattern_match.end():] + + if len(path_parts) == 1: + mapped_subkey = f"{path_parts[0]}{remainder}" + elif len(path_parts) == 2: + decoder_layer = digit_indices[0] - 1 + mapped_subkey = f"{path_parts[0]}.{decoder_layer}.{path_parts[1]}{remainder}" + else: + decoder_layer, unit_idx = digit_indices + mapped_subkey = ( + f"{path_parts[0]}.{decoder_layer-1}." + f"{path_parts[1]}{unit_idx-1}." + f"{path_parts[2]}{remainder}") + new_key = f"acoustic_decoder.{_rewrite_weight_norm(mapped_subkey)}" + converted_checkpoint[new_key] = value + break + + elif old_key.startswith("encoder_semantic."): + semantic_key = old_key[len("encoder_semantic."):] + for old, new in MAPPING_SEMANTIC_ENCODER.items(): + semantic_key = semantic_key.replace(old, new) + converted_checkpoint[f"encoder_semantic.{semantic_key}"] = value + + elif old_key.startswith("decoder_semantic."): + semantic_key = old_key[len("decoder_semantic."):] + for old, new in MAPPING_SEMANTIC_DECODER.items(): + semantic_key = semantic_key.replace(old, new) + converted_checkpoint[f"decoder_semantic.{semantic_key}"] = value + + elif old_key.startswith("semantic_model."): + converted_checkpoint[old_key] = value + + elif old_key.startswith("fc_prior."): + converted_checkpoint[f"fc.{old_key[len('fc_prior.'):]}"] = value + + elif old_key.startswith("fc_post1."): + converted_checkpoint[f"fc1.{old_key[len('fc_post1.'):]}"] = value + + elif old_key.startswith("fc_post2."): + converted_checkpoint[f"fc2.{old_key[len('fc_post2.'):]}"] = value + + elif old_key.startswith("quantizer.vq.layers"): + new_key = old_key + for old_sub, new_sub in MAPPING_QUANTIZER.items(): + new_key = new_key.replace(old_sub, new_sub) + converted_checkpoint[new_key] = value + + return converted_checkpoint + + + +@torch.no_grad() +def convert_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None, push_to_hub=None): + if config_path is not None: + config = XcodecConfig.from_pretrained(config_path) + else: + config = XcodecConfig() + + model = XcodecModel(config) + + logger.info(f"Loading original checkpoint ...") + + state_dict = torch.load(checkpoint_path) + + # the original checkpoint has weight norm applied + model.apply_weight_norm() + + logger.info(f"Converting model ...") + + new_state_dict = convert_old_keys_to_new_keys(state_dict) + + missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) + + if len(unexpected_keys) != 0: + raise ValueError(f"Unexpected keys: {unexpected_keys}") + + if len(missing_keys) != 0: + raise ValueError(f"missing keys found: {missing_keys}") + + model.remove_weight_norm() + + model.save_pretrained(pytorch_dump_folder_path) + + feature_extractor = EncodecFeatureExtractor(feature_size=config.audio_channels, sampling_rate=config.sample_rate) + + feature_extractor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + print("Pushing to the hub...") + feature_extractor.push_to_hub(push_to_hub) + model.push_to_hub(push_to_hub) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + parser.add_argument( + "--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub." + ) + + args = parser.parse_args() + convert_checkpoint( + args.checkpoint_path, + args.pytorch_dump_folder_path, + args.config_path, + args.push_to_hub, + ) \ No newline at end of file diff --git a/src/transformers/models/xcodec/modeling_xcodec.py b/src/transformers/models/xcodec/modeling_xcodec.py new file mode 100644 index 000000000000..7f9e81b58300 --- /dev/null +++ b/src/transformers/models/xcodec/modeling_xcodec.py @@ -0,0 +1,591 @@ +# coding=utf-8 +# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformers Xcodec model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Union, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..auto import AutoModel +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from .configuration_xcodec import XcodecConfig + +# General docstring +_CONFIG_FOR_DOC = "XcodecConfig" + + +@dataclass +class XcodecOutput(ModelOutput): + """ + Args: + audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*): + Discrete code indices computed using `model.encode`. + audio_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`, *optional*) + Decoded audio values obtained using the decoder part of Xcodec. + """ + + audio_codes: Optional[torch.LongTensor] = None + audio_values: Optional[torch.FloatTensor] = None + + +@dataclass +class XcodecEncoderOutput(ModelOutput): + """ + Args: + audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*): + Discrete code indices computed using `model.encode`. + """ + audio_codes: Optional[torch.LongTensor] = None + + +@dataclass +class XcodecDecoderOutput(ModelOutput): + """ + Args: + audio_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`, *optional*): + Decoded audio values obtained using the decoder part of Xcodec. + """ + audio_values: Optional[torch.FloatTensor] = None + + + +class ResidualUnit(nn.Module): + """Residual block for SemanticEncoder and SemanticDecoder used in Xcodec.""" + + def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, dilation: int = 1, stride: int = 1, bias: bool = False, padding: int = -1, groups: int = 1): + super().__init__() + self.activation = nn.ELU() + if padding < 0: + padding = ((kernel_size - 1) // 2) * dilation + self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) + self.conv2 = nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=1, bias=bias) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + output_tensor = self.activation(hidden_state) + output_tensor = self.conv1(output_tensor) + output_tensor = self.activation(output_tensor) + output_tensor = self.conv2(output_tensor) + return hidden_state + output_tensor + + +class SemanticEncoderBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int, + dilations: tuple, + unit_kernel_size: int = 3, + bias: bool = False, + ): + super().__init__() + self.res_units = nn.ModuleList([ + ResidualUnit(in_channels, in_channels, unit_kernel_size, dilation=dilation, bias=bias) + for dilation in dilations]) + + # special case: stride=1, do not use kernel=2 + kernel = 3 if stride == 1 else (2 * stride) + padding = (kernel - 1) // 2 + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=padding, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + for unit in self.res_units: + hidden_state = unit(hidden_state) + hidden_state = self.conv(hidden_state) + return hidden_state + + +class SemanticEncoder(nn.Module): + def __init__(self, config): + super().__init__() + if len(config.strides) != len(config.channel_ratios): + raise ValueError("Number of strides must match the number of channel_ratios.") + + self.conv = nn.Conv1d( + config.input_channels, config.encoder_channels, config.kernel_size, 1, config.kernel_size // 2, bias=False) + + in_channels = config.encoder_channels + conv_blocks = [] + for i, stride in enumerate(config.strides): + out_channels = int(config.encoder_channels * config.channel_ratios[i]) + conv_blocks += [SemanticEncoderBlock(in_channels, out_channels, stride, config.block_dilations, config.unit_kernel_size, bias=False)] + in_channels = out_channels + + self.conv_blocks = nn.ModuleList(conv_blocks) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.conv(hidden_state) + for block in self.conv_blocks: + hidden_state = block(hidden_state) + return hidden_state + + +class SemanticDecoderBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int, + dilations: tuple, + unit_kernel_size: int = 3, + bias: bool = False, + ): + super().__init__() + if stride == 1: + self.conv = nn.Conv1d(in_channels, out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ) + else: + kernel_size = 2 * stride + padding = (stride + 1) // 2 + output_padding = 1 if stride % 2 == 1 else 0 + self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=False) + + self.res_units = nn.ModuleList([ + ResidualUnit(in_channels=out_channels, out_channels=out_channels, kernel_size=unit_kernel_size, dilation=dilation, bias=bias) + for dilation in dilations]) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.conv(hidden_state) + for unit in self.res_units: + hidden_state = unit(hidden_state) + return hidden_state + + +class SemanticDecoder(nn.Module): + def __init__(self, config): + super().__init__() + self.conv1 = nn.Conv1d( + in_channels=config.decoder_channels, + out_channels=int(config.decoder_channels * config.channel_ratios[0]), + kernel_size=config.kernel_size, + stride=1, + padding=config.kernel_size // 2, + bias=False, + ) + conv_blocks = [] + for i in range(len(config.strides)): + in_channels = int(config.decoder_channels * config.channel_ratios[i]) + + if i < (len(config.channel_ratios) - 1): + out_channels = int(config.decoder_channels * config.channel_ratios[i + 1]) + else: + out_channels = config.decoder_channels + + conv_blocks += [SemanticDecoderBlock(in_channels, out_channels, config.strides[i], + dilations = config.block_dilations, + unit_kernel_size = config.unit_kernel_size, + bias = False)] + + self.conv_blocks = nn.ModuleList(conv_blocks) + self.conv2 = nn.Conv1d(config.decoder_channels, config.output_channels, config.kernel_size, stride=1, padding=config.kernel_size // 2, bias=False) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.conv1(hidden_state) + for block in self.conv_blocks: + hidden_state = block(hidden_state) + hidden_state = self.conv2(hidden_state) + return hidden_state + + +class XcodecEuclideanCodebook(nn.Module): + """Codebook with Euclidean distance.""" + def __init__(self, config): + super().__init__() + embed = torch.zeros(config.codebook_size, config.codebook_dim) + self.codebook_size = config.codebook_size + self.register_buffer("inited", torch.Tensor([True])) + self.register_buffer("cluster_size", torch.zeros(config.codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + # Copied from transformers.models.mimi.modeling_mimi.MimiEuclideanCodebook.quantize + def quantize(self, hidden_states): + embed = self.embed.t() + dist = -(hidden_states.pow(2).sum(1, keepdim=True) - 2 * hidden_states @ embed + embed.pow(2).sum(0, keepdim=True)) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + # Copied from transformers.models.encodec.modeling_encodec.EncodecEuclideanCodebook.encode + def encode(self, hidden_states): + shape = hidden_states.shape + hidden_states = hidden_states.reshape((-1, shape[-1])) + embed_ind = self.quantize(hidden_states) + embed_ind = embed_ind.view(*shape[:-1]) + return embed_ind + + # Copied from transformers.models.encodec.modeling_encodec.EncodecEuclideanCodebook.decode + def decode(self, embed_ind): + quantized = F.embedding(embed_ind, self.embed) + return quantized + + + +# Copied from transformers.models.encodec.modeling_encodec.EncodecVectorQuantization with Encodec-> Xcodec +class XcodecVectorQuantization(nn.Module): + """Vector quantization implementation. Currently supports only euclidean distance. + """ + def __init__(self, config): + super().__init__() + self.codebook = XcodecEuclideanCodebook(config) + + def encode(self, hidden_states): + hidden_states = hidden_states.permute(0, 2, 1) + embed_in = self.codebook.encode(hidden_states) + return embed_in + + def decode(self, embed_ind): + quantized = self.codebook.decode(embed_ind) + quantized = quantized.permute(0, 2, 1) + return quantized + + +class XcodecResidualVectorQuantization(nn.Module): + """ + Residual vector quantization implementation. Follows Algorithm 1 in https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, config): + super().__init__() + self.quantizers = nn.ModuleList([XcodecVectorQuantization(config) for _ in range(config.num_quantizers)]) + self.frame_rate = config.frame_rate + self.codebook_size = config.codebook_size + self.num_quantizers = config.num_quantizers + + def get_bandwidth_per_quantizer(self): + """Return bandwidth per quantizer.""" + return math.log2(self.codebook_size) * self.frame_rate/ 1000 + + def get_num_quantizers_for_bandwidth(self, bandwidth= None) -> int: + """Return num_quantizers based on specified target bandwidth.""" + bw_per_q = self.get_bandwidth_per_quantizer() + num_quantizers = self.num_quantizers + if bandwidth is not None and bandwidth > 0.0: + num_quantizers = int(max(1, math.floor(bandwidth / bw_per_q))) + return num_quantizers + + def encode(self, embeddings: torch.Tensor, bandwidth = None) -> torch.Tensor: + """ + Encode the input tensor into discrete indices using RVQ, with the number of quantizers selected based on the given bandwidth. + Each quantizer /codebook residually quantizes the input and returns the nearest indices in terms of Euclidian distance. + """ + num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth) + residual = embeddings + all_indices = [] + for quantizer in self.quantizers[:num_quantizers]: + indices = quantizer.encode(residual) + quantized = quantizer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes to their quantized representation.""" + quantized_out = torch.tensor(0.0, device=codes.device) + for i, indices in enumerate(codes): + quantizer = self.quantizers[i] + quantized = quantizer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out + + + + +class XcodecPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + config_class = XcodecConfig + base_model_prefix = "xcodec" + main_input_name = "input_values" + supports_gradient_checkpointing = False + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + + def apply_weight_norm(self): + """Apply weight norm in the acoustic encoder and decoder because the original checkpoint has weight norm applied. + """ + weight_norm = torch.nn.utils.weight_norm + if hasattr(torch.nn.utils.parametrizations, "weight_norm"): + weight_norm = torch.nn.utils.parametrizations.weight_norm + + weight_norm(self.acoustic_encoder.conv1) + weight_norm(self.acoustic_encoder.conv2) + + for block in self.acoustic_encoder.block: + weight_norm(block.conv1) + for res_unit in (block.res_unit1, block.res_unit2, block.res_unit3): + weight_norm(res_unit.conv1) + weight_norm(res_unit.conv2) + + weight_norm(self.acoustic_decoder.conv1, name="weight") + weight_norm(self.acoustic_decoder.conv2, name="weight") + + for block in self.acoustic_decoder.block: + weight_norm(block.conv_t1, name="weight") + for res_unit in (block.res_unit1, block.res_unit2, block.res_unit3): + weight_norm(res_unit.conv1, name="weight") + weight_norm(res_unit.conv2, name="weight") + + def remove_weight_norm(self): + """Remove the weight norm from the acoustic encoder and decoder. """ + for module in (self.acoustic_encoder, self.acoustic_decoder): + for m in module.modules(): + try: + torch.nn.utils.remove_weight_norm(m, name="weight") + except (ValueError, AttributeError): + pass + if hasattr(m, "parametrizations") and "weight" in m.parametrizations: + torch.nn.utils.parametrize.remove_parametrizations(m, "weight", leave_parametrized=True) + + + +XCODEC_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`XcodecConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +XCODEC_INPUTS_DOCSTRING = r""" + args: + input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`): + The raw float values of the input audio waveform. + audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`: + Discrete code indices computed using `model.encode`. + bandwidth (`float`, *optional*): + The target bandwidth in (kbps) supports only values in `config.target_bandwidths`. + Defaults to the highest available bandwidth `4.0` kbps. + return_dict (`bool`, *optional*): + whether to return a `XcodecOutput` or a plain tuple. +""" + + +@add_start_docstrings( + "The Xcodec neural audio codec model.", + XCODEC_START_DOCSTRING, +) + +class XcodecModel(XcodecPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.pad = config.hop_length // 2 + dac = AutoModel.from_config(config.acoustic_model_config) + self.acoustic_encoder = dac.encoder + self.acoustic_decoder = dac.decoder + self._adjust_dac_decoder(self.acoustic_decoder) + self.encoder_semantic = SemanticEncoder(config) + self.decoder_semantic = SemanticDecoder(config) + self.semantic_model = AutoModel.from_config(config.semantic_model_config) + self.fc = nn.Linear(config.hidden_dim, config.hidden_dim) + self.fc1 = nn.Linear(config.hidden_dim, config.intermediate_dim) + self.fc2 = nn.Linear(config.hidden_dim, config.output_dim) + self.quantizer = XcodecResidualVectorQuantization(config) + + @staticmethod + def _adjust_dac_decoder(decoder: nn.Module): + r""" + DAC implemented in Xcodec is slightly different from the HF version. + DAC in Xcodec adjusts the output padding in every ConvTranspose1d in the decoder and removes + the final `nn.Tanh` activation function. + """ + for module in decoder.modules(): + if isinstance(module, nn.ConvTranspose1d): + stride = module.stride[0] if isinstance(module.stride, tuple) else module.stride + module.output_padding = (stride % 2,) + if hasattr(decoder, "tanh") and isinstance(decoder.tanh, nn.Tanh): + decoder.tanh = nn.Identity() + + def _extract_semantic_features(self, input_values: torch.FloatTensor) -> torch.FloatTensor: + input_values = input_values[:,0,:] + input_values = F.pad(input_values, (self.pad, self.pad)) + with torch.no_grad(): + outputs = self.semantic_model(input_values, output_hidden_states=True) + hidden_states = outputs.hidden_states + + stacked = torch.stack(hidden_states, dim=1) + return stacked.mean(dim=1) + + + def encode(self, input_values: torch.Tensor, bandwidth: Optional[float] = None, return_dict: Optional[bool] = None, **kwargs) -> Union[torch.Tensor, XcodecEncoderOutput]: + """ + Encodes the input audio waveform into discrete audio codes. + + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`): + Float values of the input audio waveform. + bandwidth (`float`, *optional*): + The target bandwidth in (kbps) supports only values in `config.target_bandwidths`. + Defaults to the highest available bandwidth `4.0` kbps. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + `torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)` containing the discrete encoded audio codes. + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if input_values.ndim != 3: + raise ValueError(f"Expected input shape (batch_size, channels, num_samples), but got shape {input_values.shape}") + + _, channels, self._input_length = input_values.shape + + if channels not in (1, 2): + raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}") + + if bandwidth is None: + bandwidth = self.config.target_bandwidths[-1] + elif bandwidth not in self.config.target_bandwidths: + raise ValueError( + f"This model doesn't support the bandwidth {bandwidth}. Select one of {self.config.target_bandwidths}.") + + e_semantic_input = self._extract_semantic_features(input_values).detach() + e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2)) + e_acoustic = self.acoustic_encoder(input_values) + + if e_acoustic.shape[2] != e_semantic.shape[2]: + # make sure they line up if frames don't match + e_acoustic = self.acoustic_encoder(F.pad(input_values[:,0,:], (self.pad, self.pad)).unsqueeze(1)) + + embeddings = torch.cat([e_acoustic, e_semantic], dim=1) + embeddings = self.fc(embeddings.transpose(1, 2)).transpose(1, 2) + audio_codes = self.quantizer.encode(embeddings, bandwidth) + audio_codes = audio_codes.transpose(0, 1) + + if not return_dict: + return (audio_codes) + + return XcodecEncoderOutput(audio_codes) + + + def decode(self, audio_codes: torch.Tensor, return_dict: Optional[bool] = None, **kwargs) -> Union[torch.Tensor, XcodecDecoderOutput]: + """ + Decode the given discrete codes into an output audio waveform. + + The produced audio waveform is longer than the audio input, so it's automatically trimmed to match the original input. + + Args: + audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`): + Discrete code indices computed using `model.encode`. + + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + Decoded audio values of shape `(batch_size, channels, num_samples)` obtained using the decoder part of Xcodec. + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + audio_codes = audio_codes.transpose(0, 1) + quantized = self.quantizer.decode(audio_codes) + quantized_acoustic = self.fc2(quantized.transpose(1, 2)).transpose(1, 2) + audio_values = self.acoustic_decoder(quantized_acoustic) + + if getattr(self, "_input_length", None) is not None: + output_length = audio_values.shape[-1] + if self._input_length != output_length: + extra = output_length - self._input_length + start = extra // 2 + audio_values = audio_values[..., start : start + self._input_length] + + if not return_dict: + return (audio_values) + + return XcodecDecoderOutput(audio_values) + + + @add_start_docstrings_to_model_forward(XCODEC_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=XcodecOutput, config_class=_CONFIG_FOR_DOC) + def forward(self, input_values: torch.Tensor, audio_codes: Optional[torch.Tensor] = None, bandwidth: Optional[float] = None, return_dict: Optional[bool] = None, **kwargs) -> Union[Tuple[torch.Tensor, torch.Tensor], XcodecOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from datasets import load_dataset + >>> from transformers import AutoFeatureExtractor, XcodecModel + + >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example") + >>> audio_sample = dataset["train"]["audio"][0]["array"] + + >>> model_id = "Manel/X-Codec" + >>> model = XcodecModel.from_pretrained(model_id) + >>> feature_extractor = AutoFeatureExtractor.from_pretrained(model_id) + + >>> inputs = feature_extractor(raw_audio=audio_sample, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> audio_codes = outputs.audio_codes + >>> audio_values = outputs.audio_values + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if audio_codes is None: + audio_codes = self.encode(input_values, bandwidth, return_dict=False) + + audio_values = self.decode(audio_codes, return_dict=return_dict)[0] + + if not return_dict: + return (audio_codes, audio_values) + + return XcodecOutput(audio_codes=audio_codes, audio_values=audio_values) + + + +__all__ = ["XcodecModel", "XcodecPreTrainedModel"] \ No newline at end of file diff --git a/tests/models/xcodec/__init__.py b/tests/models/xcodec/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/xcodec/test_modeling_xcodec.py b/tests/models/xcodec/test_modeling_xcodec.py new file mode 100644 index 000000000000..ee1e33faac9d --- /dev/null +++ b/tests/models/xcodec/test_modeling_xcodec.py @@ -0,0 +1,478 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Xcodec model.""" + +import inspect +import os +import tempfile +import unittest + +import math +import numpy as np +from datasets import Audio, load_dataset +from pytest import mark + +from transformers import AutoFeatureExtractor, XcodecConfig +from transformers.testing_utils import ( + is_flaky, + is_torch_available, + require_flash_attn, + require_torch, + require_torch_gpu, + slow, + torch_device, +) +from tests.test_configuration_common import ConfigTester +from tests.test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + + from transformers import XcodecModel + + + +@require_torch +class XcodecModelTester: + def __init__( + self, + parent, + batch_size=4, + num_channels=1, + sample_rate=16000, + codebook_size=1024, + num_quantizers=8, + num_samples=400, + is_training=False, + + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.sample_rate = sample_rate + self.codebook_size = codebook_size + self.num_quantizers = num_quantizers + self.is_training = is_training + self.num_samples = num_samples + + def prepare_config_and_inputs(self): + input_values = floats_tensor([self.batch_size, self.num_channels, self.num_samples], scale=1.0) + config = self.get_config() + inputs_dict = {"input_values": input_values} + return config, inputs_dict + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + def prepare_config_and_inputs_for_model_class(self, model_class): + config, inputs_dict = self.prepare_config_and_inputs() + codes_length = math.ceil(self.num_samples / config.hop_length) + inputs_dict["audio_codes"] = ids_tensor( + [self.batch_size, self.num_quantizers, codes_length], config.codebook_size) + + return config, inputs_dict + + def get_config(self): + return XcodecConfig( + sample_rate=self.sample_rate, + audio_channels=self.num_channels, + codebook_size=self.codebook_size, + num_quantizers=self.num_quantizers, + ) + + + def create_and_check_model_forward(self, config, inputs_dict): + model = XcodecModel(config=config).to(torch_device).eval() + input_values = inputs_dict["input_values"] + result = model(input_values) + self.parent.assertEqual( + result.audio_values.shape, (self.batch_size, self.num_channels, self.num_samples) + ) + +@require_torch +class XcodecModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (XcodecModel,) if is_torch_available() else () + is_encoder_decoder = True + test_pruning = False + test_headmasking = False + test_resize_embeddings = False + test_torchscript = False + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + # model does not support returning hidden states + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + if "output_attentions" in inputs_dict: + inputs_dict.pop("output_attentions") + if "output_hidden_states" in inputs_dict: + inputs_dict.pop("output_hidden_states") + return inputs_dict + + def setUp(self): + self.model_tester = XcodecModelTester(self) + self.config_tester = ConfigTester( + self, config_class=XcodecConfig, hidden_size=37, common_properties=[], has_text_modality=False + ) + + def test_config(self): + self.config_tester.run_common_tests() + + + def test_model_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_forward(*config_and_inputs) + + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["input_values", "audio_codes", "bandwidth", "return_dict"] + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + + + def test_gradient_checkpointing_backward_compatibility(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if not model_class.supports_gradient_checkpointing: + continue + + config.text_encoder.gradient_checkpointing = True + config.audio_encoder.gradient_checkpointing = True + config.decoder.gradient_checkpointing = True + model = model_class(config) + self.assertTrue(model.is_gradient_checkpointing) + + + @unittest.skip(reason="We cannot configure to output a smaller model.") + def test_model_is_small(self): + pass + + @unittest.skip(reason="The XcodecModel does not have `inputs_embeds` logics") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="The XcodecModel does not have `inputs_embeds` logics") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="The XcodecModel does not have the usual `attention` logic") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="The XcodecModel does not have the usual `attention` logic") + def test_torchscript_output_attentions(self): + pass + + @unittest.skip(reason="The XcodecModel does not have the usual `hidden_states` logic") + def test_torchscript_output_hidden_state(self): + pass + + # Copied from transformers.tests.encodec.test_modeling_encodec.XcodecModelTest._create_and_check_torchscript + def _create_and_check_torchscript(self, config, inputs_dict): + if not self.test_torchscript: + self.skipTest(reason="test_torchscript is set to False") + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.torchscript = True + configs_no_init.return_dict = False + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + inputs = self._prepare_for_class(inputs_dict, model_class) + + main_input_name = model_class.main_input_name + + try: + main_input = inputs[main_input_name] + model(main_input) + traced_model = torch.jit.trace(model, main_input) + except RuntimeError: + self.fail("Couldn't trace module.") + + with tempfile.TemporaryDirectory() as tmp_dir_name: + pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt") + + try: + torch.jit.save(traced_model, pt_file_name) + except Exception: + self.fail("Couldn't save module.") + + try: + loaded_model = torch.jit.load(pt_file_name) + except Exception: + self.fail("Couldn't load module.") + + model.to(torch_device) + model.eval() + + loaded_model.to(torch_device) + loaded_model.eval() + + model_state_dict = model.state_dict() + loaded_model_state_dict = loaded_model.state_dict() + + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + + models_equal = True + for layer_name, p1 in model_state_dict.items(): + if layer_name in loaded_model_state_dict: + p2 = loaded_model_state_dict[layer_name] + if p1.data.ne(p2.data).sum() > 0: + models_equal = False + + self.assertTrue(models_equal) + + # Avoid memory leak. Without this, each call increase RAM usage by ~20MB. + # (Even with this call, there are still memory leak by ~0.04MB) + self.clear_torch_jit_class_registry() + + @unittest.skip(reason="The XcodecModel does not have the usual `attention` logic") + def test_attention_outputs(self): + pass + + @unittest.skip(reason="The XcodecModel does not have the usual `hidden_states` logic") + def test_hidden_states_output(self): + pass + + # Copied from transformers.tests.encodec.test_modeling_encodecEncodecModelTest.test_determinism + def test_determinism(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def check_determinism(first, second): + # outputs are not tensors but list (since each sequence don't have the same frame_length) + out_1 = first.cpu().numpy() + out_2 = second.cpu().numpy() + out_1 = out_1[~np.isnan(out_1)] + out_2 = out_2[~np.isnan(out_2)] + max_diff = np.amax(np.abs(out_1 - out_2)) + self.assertLessEqual(max_diff, 1e-5) + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + first = model(**self._prepare_for_class(inputs_dict, model_class))[0] + second = model(**self._prepare_for_class(inputs_dict, model_class))[0] + + if isinstance(first, tuple) and isinstance(second, tuple): + for tensor1, tensor2 in zip(first, second): + check_determinism(tensor1, tensor2) + else: + check_determinism(first, second) + + # Copied from transformers.tests.encodec.test_modeling_encodecEncodecModelTest.test_model_outputs_equivalence + def test_model_outputs_equivalence(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def set_nan_tensor_to_zero(t): + t[t != t] = 0 + return t + + def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + with torch.no_grad(): + tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) + dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs) + + self.assertTrue(isinstance(tuple_output, tuple)) + self.assertTrue(isinstance(dict_output, dict)) + + for tuple_value, dict_value in zip(tuple_output, dict_output.values()): + self.assertTrue( + torch.allclose( + set_nan_tensor_to_zero(tuple_value), set_nan_tensor_to_zero(dict_value), atol=1e-5 + ), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {torch.max(torch.abs(tuple_value - dict_value))}. Tuple has `nan`:" + f" {torch.isnan(tuple_value).any()} and `inf`: {torch.isinf(tuple_value)}. Dict has" + f" `nan`: {torch.isnan(dict_value).any()} and `inf`: {torch.isinf(dict_value)}." + ), + ) + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs) + + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + # skipping the parametrizations original0 tensor + if name =="semantic_model.encoder.pos_conv_embed.conv.parametrizations.weight.original0": + continue + + uniform_init_parms = ["conv"] + + if param.requires_grad: + if any(x in name for x in uniform_init_parms): + self.assertTrue( + -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, + msg=f"Parameter {name} of {model_class.__name__} seems not properly initialized", + ) + + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + @is_flaky() + def test_flash_attn_2_inference_equivalence(self): + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + dummy_input = inputs_dict[model.main_input_name][:1] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + outputs = model(dummy_input) + outputs_fa = model_fa(dummy_input) + + logits = outputs[1] + logits_fa = outputs_fa[1] + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + @unittest.skip(reason="The XcodecModel does not support right padding") + def test_flash_attn_2_inference_equivalence_right_padding(self): + pass + + @unittest.skip(reason="The XcodecModel does not have support dynamic compile yet") + def test_sdpa_can_compile_dynamic(self): + pass + + +# Copied from transformers.tests.encodec.test_modeling_encodec.normalize +def normalize(arr): + norm = np.linalg.norm(arr) + normalized_arr = arr / norm + return normalized_arr + + +# Copied from transformers.tests.encodec.test_modeling_encodec.compute_rmse +def compute_rmse(arr1, arr2): + arr1_normalized = normalize(arr1) + arr2_normalized = normalize(arr2) + return np.sqrt(((arr1_normalized - arr2_normalized) ** 2).mean()) + + +#@slow +@require_torch +class XcodecIntegrationTest(unittest.TestCase): + def test_integration(self): + expected_rmse = { + "0.5": 0.0065491, + "4.0": 0.0070978, + } + expected_codesums = { + "0.5": [117262], + "4.0": [926416], + } + + librispeech = load_dataset( + "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" + ) + model_id = "Manel/X-Codec" + model = XcodecModel.from_pretrained(model_id).to(torch_device).eval() + feature_extractor = AutoFeatureExtractor.from_pretrained(model_id) + + librispeech = librispeech.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate)) + audio = librispeech[-1]["audio"]["array"] + + inputs = feature_extractor(raw_audio=audio, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt").to(torch_device) + + for bandwidth, exp_rmse in expected_rmse.items(): + bandwidth = float(bandwidth) + with torch.no_grad(): + audio_codes = model.encode(inputs["input_values"], bandwidth=bandwidth, return_dict=False) + codesum = int(audio_codes.sum().item()) + + expected_codesum = expected_codesums[str(bandwidth)][0] + self.assertEqual(codesum, expected_codesum) + + input_values_dec = model.decode( + audio_codes, return_dict=False + ) + input_values_enc_dec = model( + inputs["input_values"], bandwidth=bandwidth + )[1] + + self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3)) + + self.assertTrue(inputs["input_values"].shape == input_values_enc_dec.shape) + + arr = inputs["input_values"][0].cpu().numpy() + arr_enc_dec = input_values_enc_dec[0].cpu().numpy() + rmse = compute_rmse(arr, arr_enc_dec) + #self.assertTrue(rmse < exp_rmse) + self.assertTrue(np.abs(rmse - exp_rmse) < 1e-5) From e2661b63b2b6d6cbe0f4f5da405cb02c561d5322 Mon Sep 17 00:00:00 2001 From: Manal ML Date: Wed, 21 May 2025 04:35:48 +0000 Subject: [PATCH 02/10] nit --- tests/models/xcodec/test_modeling_xcodec.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/xcodec/test_modeling_xcodec.py b/tests/models/xcodec/test_modeling_xcodec.py index ee1e33faac9d..c9ef98f51434 100644 --- a/tests/models/xcodec/test_modeling_xcodec.py +++ b/tests/models/xcodec/test_modeling_xcodec.py @@ -474,5 +474,4 @@ def test_integration(self): arr = inputs["input_values"][0].cpu().numpy() arr_enc_dec = input_values_enc_dec[0].cpu().numpy() rmse = compute_rmse(arr, arr_enc_dec) - #self.assertTrue(rmse < exp_rmse) self.assertTrue(np.abs(rmse - exp_rmse) < 1e-5) From 2ad36001eea160da5eea92d3a97a325c6d495c89 Mon Sep 17 00:00:00 2001 From: Manal ML Date: Wed, 21 May 2025 05:48:12 +0000 Subject: [PATCH 03/10] fix styling + copies --- src/transformers/models/xcodec/__init__.py | 2 +- .../models/xcodec/configuration_xcodec.py | 42 ++-- .../xcodec/convert_xcodec_weights_to_hf.py | 57 ++--- .../models/xcodec/modeling_xcodec.py | 215 +++++++++++------- tests/models/xcodec/test_modeling_xcodec.py | 69 +++--- 5 files changed, 208 insertions(+), 177 deletions(-) diff --git a/src/transformers/models/xcodec/__init__.py b/src/transformers/models/xcodec/__init__.py index 720f1612e2e8..45e7620f54d7 100644 --- a/src/transformers/models/xcodec/__init__.py +++ b/src/transformers/models/xcodec/__init__.py @@ -24,4 +24,4 @@ import sys _file = globals()["__file__"] - sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) \ No newline at end of file + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/xcodec/configuration_xcodec.py b/src/transformers/models/xcodec/configuration_xcodec.py index 9839e8a1b651..ba363c993fad 100644 --- a/src/transformers/models/xcodec/configuration_xcodec.py +++ b/src/transformers/models/xcodec/configuration_xcodec.py @@ -15,15 +15,15 @@ """Xcodec model configuration""" import math +from typing import List, Union import numpy as np -from typing import Union, Dict, List, Optional + +from transformers import DacConfig, HubertConfig from ...configuration_utils import PretrainedConfig from ...utils import logging -from transformers import DacConfig, HubertConfig - logger = logging.get_logger(__name__) @@ -50,7 +50,7 @@ class XcodecConfig(PretrainedConfig): kernel_size (`int`, *optional*, defaults to 3): Kernel size for the initial semantic convolution. encoder_channels (`int`, *optional*, defaults to 768): - Number of hidden channels in each semantic encoder block. + Number of hidden channels in each semantic encoder block. channel_ratios (`List[float]`, *optional*, defaults to [1.0, 1.0]): Expansion factors for the number of output channels in each semantic block. strides (`List[int]`, *optional*, defaults to [1, 1]): @@ -100,10 +100,10 @@ class XcodecConfig(PretrainedConfig): model_type = "xcodec" sub_configs = { - "acoustic_model_config": DacConfig, + "acoustic_model_config": DacConfig, "semantic_model_config": HubertConfig, } - is_composition = True + is_composition = True def __init__( self, @@ -129,18 +129,17 @@ def __init__( acoustic_model_config: Union[dict, DacConfig] = None, semantic_model_config: Union[dict, HubertConfig] = None, **kwargs, - ): - + ): super().__init__(**kwargs) - + if acoustic_model_config is None: self.acoustic_model_config = DacConfig( - encoder_hidden_size = 64, - downsampling_ratios = [8, 5, 4, 2], - decoder_hidden_size = 1024, - upsampling_ratios = [8, 5, 4, 2], - hidden_size = 256, - ) + encoder_hidden_size=64, + downsampling_ratios=[8, 5, 4, 2], + decoder_hidden_size=1024, + upsampling_ratios=[8, 5, 4, 2], + hidden_size=256, + ) elif isinstance(acoustic_model_config, dict): self.acoustic_model_config = DacConfig(**acoustic_model_config) elif isinstance(acoustic_model_config, DacConfig): @@ -173,7 +172,6 @@ def __init__( self.intermediate_dim = intermediate_dim self.output_dim = output_dim - @classmethod def from_sub_models_config(cls, acoustic_model_config: DacConfig, semantic_model_config: HubertConfig, **kwargs): """ @@ -183,15 +181,19 @@ def from_sub_models_config(cls, acoustic_model_config: DacConfig, semantic_model [`XcodecConfig`]: The instantiated configuration. """ return cls( - acoustic_model_config=acoustic_model_config.to_dict() if hasattr(acoustic_model_config, "to_dict") else acoustic_model_config, - semantic_model_config=semantic_model_config.to_dict() if hasattr(semantic_model_config, "to_dict") else semantic_model_config, + acoustic_model_config=acoustic_model_config.to_dict() + if hasattr(acoustic_model_config, "to_dict") + else acoustic_model_config, + semantic_model_config=semantic_model_config.to_dict() + if hasattr(semantic_model_config, "to_dict") + else semantic_model_config, **kwargs, ) @property def frame_rate(self) -> int: return math.ceil(self.sample_rate / np.prod(self.acoustic_model_config.upsampling_ratios)) - + @property def bits_per_codebook(self) -> int: return int(math.log2(self.codebook_size)) @@ -201,4 +203,4 @@ def hop_length(self) -> int: return int(np.prod(self.acoustic_model_config.downsampling_ratios)) -__all__ = ["XcodecConfig"] \ No newline at end of file +__all__ = ["XcodecConfig"] diff --git a/src/transformers/models/xcodec/convert_xcodec_weights_to_hf.py b/src/transformers/models/xcodec/convert_xcodec_weights_to_hf.py index de7b6c35adba..74a74335cf36 100644 --- a/src/transformers/models/xcodec/convert_xcodec_weights_to_hf.py +++ b/src/transformers/models/xcodec/convert_xcodec_weights_to_hf.py @@ -14,14 +14,13 @@ # limitations under the License. import argparse import re - -import torch from typing import Dict +import torch from transformers import ( - XcodecConfig, EncodecFeatureExtractor, + XcodecConfig, XcodecModel, logging, ) @@ -31,7 +30,6 @@ logger = logging.get_logger(__name__) - MAPPING_ACOUSTIC_ENCODER = { r"^block\.0": ["conv1"], r"^block\.(\d+)\.block\.(\d+)\.block\.0": ["block", "res_unit", "snake1"], @@ -57,7 +55,7 @@ } MAPPING_SEMANTIC_ENCODER = { - "conv.conv.": "conv.", + "conv.conv.": "conv.", "conv1.conv.": "conv1.", "conv2.conv.": "conv2.", } @@ -76,9 +74,9 @@ def _rewrite_weight_norm(key: str) -> str: if key.endswith("weight_g"): - return key[:-len("weight_g")] + "parametrizations.weight.original0" + return key[: -len("weight_g")] + "parametrizations.weight.original0" if key.endswith("weight_v"): - return key[:-len("weight_v")] + "parametrizations.weight.original1" + return key[: -len("weight_v")] + "parametrizations.weight.original1" return key @@ -86,9 +84,8 @@ def convert_old_keys_to_new_keys(original_state_dict: Dict[str, torch.Tensor]) - converted_checkpoint: Dict[str, torch.Tensor] = {} for old_key, value in original_state_dict.items(): - if old_key.startswith("encoder."): - layer_key = old_key[len("encoder."):] + layer_key = old_key[len("encoder.") :] for pattern, path_parts in MAPPING_ACOUSTIC_ENCODER.items(): pattern_match = re.match(pattern, layer_key) if pattern_match is None: @@ -96,7 +93,7 @@ def convert_old_keys_to_new_keys(original_state_dict: Dict[str, torch.Tensor]) - digit_strings = [g for g in pattern_match.groups() if g is not None] digit_indices = [int(ds) for ds in digit_strings] - remainder = layer_key[pattern_match.end():] + remainder = layer_key[pattern_match.end() :] if len(path_parts) == 1: mapped_subkey = f"{path_parts[0]}{remainder}" @@ -106,9 +103,7 @@ def convert_old_keys_to_new_keys(original_state_dict: Dict[str, torch.Tensor]) - else: encoder_layer, unit_idx = digit_indices mapped_subkey = ( - f"{path_parts[0]}.{encoder_layer-1}." - f"{path_parts[1]}{unit_idx+1}." - f"{path_parts[2]}{remainder}" + f"{path_parts[0]}.{encoder_layer - 1}.{path_parts[1]}{unit_idx + 1}.{path_parts[2]}{remainder}" ) new_key = f"acoustic_encoder.{_rewrite_weight_norm(mapped_subkey)}" @@ -116,7 +111,7 @@ def convert_old_keys_to_new_keys(original_state_dict: Dict[str, torch.Tensor]) - break elif old_key.startswith("decoder_2."): - layer_key = old_key[len("decoder_2."):] + layer_key = old_key[len("decoder_2.") :] for pattern, path_parts in MAPPING_ACOUSTIC_DECODER.items(): pattern_match = re.match(pattern, layer_key) @@ -124,7 +119,7 @@ def convert_old_keys_to_new_keys(original_state_dict: Dict[str, torch.Tensor]) - continue digit_strings = [g for g in pattern_match.groups() if g is not None] digit_indices = [int(ds) for ds in digit_strings] - remainder = layer_key[pattern_match.end():] + remainder = layer_key[pattern_match.end() :] if len(path_parts) == 1: mapped_subkey = f"{path_parts[0]}{remainder}" @@ -134,21 +129,20 @@ def convert_old_keys_to_new_keys(original_state_dict: Dict[str, torch.Tensor]) - else: decoder_layer, unit_idx = digit_indices mapped_subkey = ( - f"{path_parts[0]}.{decoder_layer-1}." - f"{path_parts[1]}{unit_idx-1}." - f"{path_parts[2]}{remainder}") + f"{path_parts[0]}.{decoder_layer - 1}.{path_parts[1]}{unit_idx - 1}.{path_parts[2]}{remainder}" + ) new_key = f"acoustic_decoder.{_rewrite_weight_norm(mapped_subkey)}" converted_checkpoint[new_key] = value break elif old_key.startswith("encoder_semantic."): - semantic_key = old_key[len("encoder_semantic."):] + semantic_key = old_key[len("encoder_semantic.") :] for old, new in MAPPING_SEMANTIC_ENCODER.items(): semantic_key = semantic_key.replace(old, new) converted_checkpoint[f"encoder_semantic.{semantic_key}"] = value elif old_key.startswith("decoder_semantic."): - semantic_key = old_key[len("decoder_semantic."):] + semantic_key = old_key[len("decoder_semantic.") :] for old, new in MAPPING_SEMANTIC_DECODER.items(): semantic_key = semantic_key.replace(old, new) converted_checkpoint[f"decoder_semantic.{semantic_key}"] = value @@ -157,13 +151,13 @@ def convert_old_keys_to_new_keys(original_state_dict: Dict[str, torch.Tensor]) - converted_checkpoint[old_key] = value elif old_key.startswith("fc_prior."): - converted_checkpoint[f"fc.{old_key[len('fc_prior.'):]}"] = value + converted_checkpoint[f"fc.{old_key[len('fc_prior.') :]}"] = value elif old_key.startswith("fc_post1."): - converted_checkpoint[f"fc1.{old_key[len('fc_post1.'):]}"] = value + converted_checkpoint[f"fc1.{old_key[len('fc_post1.') :]}"] = value elif old_key.startswith("fc_post2."): - converted_checkpoint[f"fc2.{old_key[len('fc_post2.'):]}"] = value + converted_checkpoint[f"fc2.{old_key[len('fc_post2.') :]}"] = value elif old_key.startswith("quantizer.vq.layers"): new_key = old_key @@ -174,25 +168,24 @@ def convert_old_keys_to_new_keys(original_state_dict: Dict[str, torch.Tensor]) - return converted_checkpoint - @torch.no_grad() -def convert_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None, push_to_hub=None): +def convert_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None, push_to_hub=None): if config_path is not None: config = XcodecConfig.from_pretrained(config_path) else: config = XcodecConfig() model = XcodecModel(config) - - logger.info(f"Loading original checkpoint ...") - - state_dict = torch.load(checkpoint_path) + + logger.info("Loading original checkpoint ...") + + state_dict = torch.load(checkpoint_path) # the original checkpoint has weight norm applied model.apply_weight_norm() - logger.info(f"Converting model ...") - + logger.info("Converting model ...") + new_state_dict = convert_old_keys_to_new_keys(state_dict) missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) @@ -234,4 +227,4 @@ def convert_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=No args.pytorch_dump_folder_path, args.config_path, args.push_to_hub, - ) \ No newline at end of file + ) diff --git a/src/transformers/models/xcodec/modeling_xcodec.py b/src/transformers/models/xcodec/modeling_xcodec.py index 7f9e81b58300..239577d5966c 100644 --- a/src/transformers/models/xcodec/modeling_xcodec.py +++ b/src/transformers/models/xcodec/modeling_xcodec.py @@ -16,14 +16,12 @@ import math from dataclasses import dataclass -from typing import Optional, Union, Tuple +from typing import Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from ..auto import AutoModel from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -31,8 +29,10 @@ add_start_docstrings_to_model_forward, replace_return_docstrings, ) +from ..auto import AutoModel from .configuration_xcodec import XcodecConfig + # General docstring _CONFIG_FOR_DOC = "XcodecConfig" @@ -58,6 +58,7 @@ class XcodecEncoderOutput(ModelOutput): audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*): Discrete code indices computed using `model.encode`. """ + audio_codes: Optional[torch.LongTensor] = None @@ -68,14 +69,24 @@ class XcodecDecoderOutput(ModelOutput): audio_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`, *optional*): Decoded audio values obtained using the decoder part of Xcodec. """ - audio_values: Optional[torch.FloatTensor] = None + audio_values: Optional[torch.FloatTensor] = None class ResidualUnit(nn.Module): """Residual block for SemanticEncoder and SemanticDecoder used in Xcodec.""" - def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, dilation: int = 1, stride: int = 1, bias: bool = False, padding: int = -1, groups: int = 1): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + dilation: int = 1, + stride: int = 1, + bias: bool = False, + padding: int = -1, + groups: int = 1, + ): super().__init__() self.activation = nn.ELU() if padding < 0: @@ -102,9 +113,12 @@ def __init__( bias: bool = False, ): super().__init__() - self.res_units = nn.ModuleList([ - ResidualUnit(in_channels, in_channels, unit_kernel_size, dilation=dilation, bias=bias) - for dilation in dilations]) + self.res_units = nn.ModuleList( + [ + ResidualUnit(in_channels, in_channels, unit_kernel_size, dilation=dilation, bias=bias) + for dilation in dilations + ] + ) # special case: stride=1, do not use kernel=2 kernel = 3 if stride == 1 else (2 * stride) @@ -125,13 +139,18 @@ def __init__(self, config): raise ValueError("Number of strides must match the number of channel_ratios.") self.conv = nn.Conv1d( - config.input_channels, config.encoder_channels, config.kernel_size, 1, config.kernel_size // 2, bias=False) + config.input_channels, config.encoder_channels, config.kernel_size, 1, config.kernel_size // 2, bias=False + ) in_channels = config.encoder_channels conv_blocks = [] for i, stride in enumerate(config.strides): out_channels = int(config.encoder_channels * config.channel_ratios[i]) - conv_blocks += [SemanticEncoderBlock(in_channels, out_channels, stride, config.block_dilations, config.unit_kernel_size, bias=False)] + conv_blocks += [ + SemanticEncoderBlock( + in_channels, out_channels, stride, config.block_dilations, config.unit_kernel_size, bias=False + ) + ] in_channels = out_channels self.conv_blocks = nn.ModuleList(conv_blocks) @@ -155,7 +174,9 @@ def __init__( ): super().__init__() if stride == 1: - self.conv = nn.Conv1d(in_channels, out_channels, + self.conv = nn.Conv1d( + in_channels, + out_channels, kernel_size=3, stride=1, padding=1, @@ -165,11 +186,22 @@ def __init__( kernel_size = 2 * stride padding = (stride + 1) // 2 output_padding = 1 if stride % 2 == 1 else 0 - self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=False) + self.conv = nn.ConvTranspose1d( + in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=False + ) - self.res_units = nn.ModuleList([ - ResidualUnit(in_channels=out_channels, out_channels=out_channels, kernel_size=unit_kernel_size, dilation=dilation, bias=bias) - for dilation in dilations]) + self.res_units = nn.ModuleList( + [ + ResidualUnit( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=unit_kernel_size, + dilation=dilation, + bias=bias, + ) + for dilation in dilations + ] + ) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = self.conv(hidden_state) @@ -182,7 +214,7 @@ class SemanticDecoder(nn.Module): def __init__(self, config): super().__init__() self.conv1 = nn.Conv1d( - in_channels=config.decoder_channels, + in_channels=config.decoder_channels, out_channels=int(config.decoder_channels * config.channel_ratios[0]), kernel_size=config.kernel_size, stride=1, @@ -198,13 +230,26 @@ def __init__(self, config): else: out_channels = config.decoder_channels - conv_blocks += [SemanticDecoderBlock(in_channels, out_channels, config.strides[i], - dilations = config.block_dilations, - unit_kernel_size = config.unit_kernel_size, - bias = False)] + conv_blocks += [ + SemanticDecoderBlock( + in_channels, + out_channels, + config.strides[i], + dilations=config.block_dilations, + unit_kernel_size=config.unit_kernel_size, + bias=False, + ) + ] self.conv_blocks = nn.ModuleList(conv_blocks) - self.conv2 = nn.Conv1d(config.decoder_channels, config.output_channels, config.kernel_size, stride=1, padding=config.kernel_size // 2, bias=False) + self.conv2 = nn.Conv1d( + config.decoder_channels, + config.output_channels, + config.kernel_size, + stride=1, + padding=config.kernel_size // 2, + bias=False, + ) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = self.conv1(hidden_state) @@ -216,7 +261,8 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: class XcodecEuclideanCodebook(nn.Module): """Codebook with Euclidean distance.""" - def __init__(self, config): + + def __init__(self, config: XcodecConfig): super().__init__() embed = torch.zeros(config.codebook_size, config.codebook_dim) self.codebook_size = config.codebook_size @@ -225,14 +271,14 @@ def __init__(self, config): self.register_buffer("embed", embed) self.register_buffer("embed_avg", embed.clone()) - # Copied from transformers.models.mimi.modeling_mimi.MimiEuclideanCodebook.quantize + # Copied from transformers.models.encodec.modeling_encodec.EncodecEuclideanCodebook.quantize def quantize(self, hidden_states): embed = self.embed.t() - dist = -(hidden_states.pow(2).sum(1, keepdim=True) - 2 * hidden_states @ embed + embed.pow(2).sum(0, keepdim=True)) + scaled_states = hidden_states.pow(2).sum(1, keepdim=True) + dist = -(scaled_states - 2 * hidden_states @ embed + embed.pow(2).sum(0, keepdim=True)) embed_ind = dist.max(dim=-1).indices return embed_ind - # Copied from transformers.models.encodec.modeling_encodec.EncodecEuclideanCodebook.encode def encode(self, hidden_states): shape = hidden_states.shape hidden_states = hidden_states.reshape((-1, shape[-1])) @@ -240,49 +286,44 @@ def encode(self, hidden_states): embed_ind = embed_ind.view(*shape[:-1]) return embed_ind - # Copied from transformers.models.encodec.modeling_encodec.EncodecEuclideanCodebook.decode - def decode(self, embed_ind): - quantized = F.embedding(embed_ind, self.embed) - return quantized - - - -# Copied from transformers.models.encodec.modeling_encodec.EncodecVectorQuantization with Encodec-> Xcodec class XcodecVectorQuantization(nn.Module): - """Vector quantization implementation. Currently supports only euclidean distance. """ - def __init__(self, config): + Vector quantization implementation. Currently supports only euclidean distance. + """ + + def __init__(self, config: XcodecConfig): super().__init__() self.codebook = XcodecEuclideanCodebook(config) + # Copied from transformers.models.encodec.modeling_encodec.EncodecVectorQuantization.encode def encode(self, hidden_states): hidden_states = hidden_states.permute(0, 2, 1) embed_in = self.codebook.encode(hidden_states) return embed_in + # Copied from transformers.models.encodec.modeling_encodec.EncodecVectorQuantization.decode def decode(self, embed_ind): - quantized = self.codebook.decode(embed_ind) - quantized = quantized.permute(0, 2, 1) - return quantized + quantize = self.codebook.decode(embed_ind) + quantize = quantize.permute(0, 2, 1) + return quantize class XcodecResidualVectorQuantization(nn.Module): """ Residual vector quantization implementation. Follows Algorithm 1 in https://arxiv.org/pdf/2107.03312.pdf """ - - def __init__(self, config): + def __init__(self, config: XcodecConfig): super().__init__() self.quantizers = nn.ModuleList([XcodecVectorQuantization(config) for _ in range(config.num_quantizers)]) - self.frame_rate = config.frame_rate + self.frame_rate = config.frame_rate self.codebook_size = config.codebook_size self.num_quantizers = config.num_quantizers def get_bandwidth_per_quantizer(self): """Return bandwidth per quantizer.""" - return math.log2(self.codebook_size) * self.frame_rate/ 1000 + return math.log2(self.codebook_size) * self.frame_rate / 1000 - def get_num_quantizers_for_bandwidth(self, bandwidth= None) -> int: + def get_num_quantizers_for_bandwidth(self, bandwidth=None) -> int: """Return num_quantizers based on specified target bandwidth.""" bw_per_q = self.get_bandwidth_per_quantizer() num_quantizers = self.num_quantizers @@ -290,7 +331,7 @@ def get_num_quantizers_for_bandwidth(self, bandwidth= None) -> int: num_quantizers = int(max(1, math.floor(bandwidth / bw_per_q))) return num_quantizers - def encode(self, embeddings: torch.Tensor, bandwidth = None) -> torch.Tensor: + def encode(self, embeddings: torch.Tensor, bandwidth=None) -> torch.Tensor: """ Encode the input tensor into discrete indices using RVQ, with the number of quantizers selected based on the given bandwidth. Each quantizer /codebook residually quantizes the input and returns the nearest indices in terms of Euclidian distance. @@ -306,7 +347,6 @@ def encode(self, embeddings: torch.Tensor, bandwidth = None) -> torch.Tensor: out_indices = torch.stack(all_indices) return out_indices - def decode(self, codes: torch.Tensor) -> torch.Tensor: """Decode the given codes to their quantized representation.""" quantized_out = torch.tensor(0.0, device=codes.device) @@ -317,25 +357,24 @@ def decode(self, codes: torch.Tensor) -> torch.Tensor: return quantized_out - - class XcodecPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ + config_class = XcodecConfig base_model_prefix = "xcodec" main_input_name = "input_values" supports_gradient_checkpointing = False - def _init_weights(self, module): + def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() - + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): module.bias.data.zero_() module.weight.data.fill_(1.0) @@ -344,11 +383,9 @@ def _init_weights(self, module): if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) - def apply_weight_norm(self): - """Apply weight norm in the acoustic encoder and decoder because the original checkpoint has weight norm applied. - """ + """Apply weight norm in the acoustic encoder and decoder because the original checkpoint has weight norm applied.""" weight_norm = torch.nn.utils.weight_norm if hasattr(torch.nn.utils.parametrizations, "weight_norm"): weight_norm = torch.nn.utils.parametrizations.weight_norm @@ -372,7 +409,7 @@ def apply_weight_norm(self): weight_norm(res_unit.conv2, name="weight") def remove_weight_norm(self): - """Remove the weight norm from the acoustic encoder and decoder. """ + """Remove the weight norm from the acoustic encoder and decoder.""" for module in (self.acoustic_encoder, self.acoustic_decoder): for m in module.modules(): try: @@ -383,7 +420,6 @@ def remove_weight_norm(self): torch.nn.utils.parametrize.remove_parametrizations(m, "weight", leave_parametrized=True) - XCODEC_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -418,8 +454,7 @@ def remove_weight_norm(self): "The Xcodec neural audio codec model.", XCODEC_START_DOCSTRING, ) - -class XcodecModel(XcodecPreTrainedModel): +class XcodecModel(XcodecPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config @@ -438,11 +473,11 @@ def __init__(self, config): @staticmethod def _adjust_dac_decoder(decoder: nn.Module): - r""" + r""" DAC implemented in Xcodec is slightly different from the HF version. DAC in Xcodec adjusts the output padding in every ConvTranspose1d in the decoder and removes the final `nn.Tanh` activation function. - """ + """ for module in decoder.modules(): if isinstance(module, nn.ConvTranspose1d): stride = module.stride[0] if isinstance(module.stride, tuple) else module.stride @@ -451,17 +486,22 @@ def _adjust_dac_decoder(decoder: nn.Module): decoder.tanh = nn.Identity() def _extract_semantic_features(self, input_values: torch.FloatTensor) -> torch.FloatTensor: - input_values = input_values[:,0,:] + input_values = input_values[:, 0, :] input_values = F.pad(input_values, (self.pad, self.pad)) with torch.no_grad(): outputs = self.semantic_model(input_values, output_hidden_states=True) - hidden_states = outputs.hidden_states + hidden_states = outputs.hidden_states - stacked = torch.stack(hidden_states, dim=1) + stacked = torch.stack(hidden_states, dim=1) return stacked.mean(dim=1) - - def encode(self, input_values: torch.Tensor, bandwidth: Optional[float] = None, return_dict: Optional[bool] = None, **kwargs) -> Union[torch.Tensor, XcodecEncoderOutput]: + def encode( + self, + input_values: torch.Tensor, + bandwidth: Optional[float] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[torch.Tensor, XcodecEncoderOutput]: """ Encodes the input audio waveform into discrete audio codes. @@ -475,15 +515,17 @@ def encode(self, input_values: torch.Tensor, bandwidth: Optional[float] = None, Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Returns: - `torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)` containing the discrete encoded audio codes. + `torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)` containing the discrete encoded audio codes. """ return_dict = return_dict if return_dict is not None else self.config.return_dict - + if input_values.ndim != 3: - raise ValueError(f"Expected input shape (batch_size, channels, num_samples), but got shape {input_values.shape}") - - _, channels, self._input_length = input_values.shape - + raise ValueError( + f"Expected input shape (batch_size, channels, num_samples), but got shape {input_values.shape}" + ) + + _, channels, self._input_length = input_values.shape + if channels not in (1, 2): raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}") @@ -491,33 +533,35 @@ def encode(self, input_values: torch.Tensor, bandwidth: Optional[float] = None, bandwidth = self.config.target_bandwidths[-1] elif bandwidth not in self.config.target_bandwidths: raise ValueError( - f"This model doesn't support the bandwidth {bandwidth}. Select one of {self.config.target_bandwidths}.") + f"This model doesn't support the bandwidth {bandwidth}. Select one of {self.config.target_bandwidths}." + ) e_semantic_input = self._extract_semantic_features(input_values).detach() e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2)) e_acoustic = self.acoustic_encoder(input_values) - + if e_acoustic.shape[2] != e_semantic.shape[2]: # make sure they line up if frames don't match - e_acoustic = self.acoustic_encoder(F.pad(input_values[:,0,:], (self.pad, self.pad)).unsqueeze(1)) - + e_acoustic = self.acoustic_encoder(F.pad(input_values[:, 0, :], (self.pad, self.pad)).unsqueeze(1)) + embeddings = torch.cat([e_acoustic, e_semantic], dim=1) embeddings = self.fc(embeddings.transpose(1, 2)).transpose(1, 2) audio_codes = self.quantizer.encode(embeddings, bandwidth) audio_codes = audio_codes.transpose(0, 1) if not return_dict: - return (audio_codes) + return audio_codes return XcodecEncoderOutput(audio_codes) - - def decode(self, audio_codes: torch.Tensor, return_dict: Optional[bool] = None, **kwargs) -> Union[torch.Tensor, XcodecDecoderOutput]: + def decode( + self, audio_codes: torch.Tensor, return_dict: Optional[bool] = None, **kwargs + ) -> Union[torch.Tensor, XcodecDecoderOutput]: """ Decode the given discrete codes into an output audio waveform. - + The produced audio waveform is longer than the audio input, so it's automatically trimmed to match the original input. - + Args: audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`): Discrete code indices computed using `model.encode`. @@ -543,14 +587,20 @@ def decode(self, audio_codes: torch.Tensor, return_dict: Optional[bool] = None, audio_values = audio_values[..., start : start + self._input_length] if not return_dict: - return (audio_values) + return audio_values return XcodecDecoderOutput(audio_values) - @add_start_docstrings_to_model_forward(XCODEC_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=XcodecOutput, config_class=_CONFIG_FOR_DOC) - def forward(self, input_values: torch.Tensor, audio_codes: Optional[torch.Tensor] = None, bandwidth: Optional[float] = None, return_dict: Optional[bool] = None, **kwargs) -> Union[Tuple[torch.Tensor, torch.Tensor], XcodecOutput]: + def forward( + self, + input_values: torch.Tensor, + audio_codes: Optional[torch.Tensor] = None, + bandwidth: Optional[float] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], XcodecOutput]: r""" Returns: @@ -587,5 +637,4 @@ def forward(self, input_values: torch.Tensor, audio_codes: Optional[torch.Tensor return XcodecOutput(audio_codes=audio_codes, audio_values=audio_values) - -__all__ = ["XcodecModel", "XcodecPreTrainedModel"] \ No newline at end of file +__all__ = ["XcodecModel", "XcodecPreTrainedModel"] diff --git a/tests/models/xcodec/test_modeling_xcodec.py b/tests/models/xcodec/test_modeling_xcodec.py index c9ef98f51434..d2e464d90837 100644 --- a/tests/models/xcodec/test_modeling_xcodec.py +++ b/tests/models/xcodec/test_modeling_xcodec.py @@ -14,15 +14,17 @@ """Testing suite for the PyTorch Xcodec model.""" import inspect +import math import os import tempfile import unittest -import math import numpy as np from datasets import Audio, load_dataset from pytest import mark +from tests.test_configuration_common import ConfigTester +from tests.test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor from transformers import AutoFeatureExtractor, XcodecConfig from transformers.testing_utils import ( is_flaky, @@ -33,8 +35,6 @@ slow, torch_device, ) -from tests.test_configuration_common import ConfigTester -from tests.test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor if is_torch_available(): @@ -43,7 +43,6 @@ from transformers import XcodecModel - @require_torch class XcodecModelTester: def __init__( @@ -56,7 +55,6 @@ def __init__( num_quantizers=8, num_samples=400, is_training=False, - ): self.parent = parent self.batch_size = batch_size @@ -81,26 +79,25 @@ def prepare_config_and_inputs_for_model_class(self, model_class): config, inputs_dict = self.prepare_config_and_inputs() codes_length = math.ceil(self.num_samples / config.hop_length) inputs_dict["audio_codes"] = ids_tensor( - [self.batch_size, self.num_quantizers, codes_length], config.codebook_size) - + [self.batch_size, self.num_quantizers, codes_length], config.codebook_size + ) + return config, inputs_dict def get_config(self): return XcodecConfig( - sample_rate=self.sample_rate, - audio_channels=self.num_channels, - codebook_size=self.codebook_size, - num_quantizers=self.num_quantizers, + sample_rate=self.sample_rate, + audio_channels=self.num_channels, + codebook_size=self.codebook_size, + num_quantizers=self.num_quantizers, ) - - + def create_and_check_model_forward(self, config, inputs_dict): model = XcodecModel(config=config).to(torch_device).eval() input_values = inputs_dict["input_values"] result = model(input_values) - self.parent.assertEqual( - result.audio_values.shape, (self.batch_size, self.num_channels, self.num_samples) - ) + self.parent.assertEqual(result.audio_values.shape, (self.batch_size, self.num_channels, self.num_samples)) + @require_torch class XcodecModelTest(ModelTesterMixin, unittest.TestCase): @@ -128,12 +125,10 @@ def setUp(self): def test_config(self): self.config_tester.run_common_tests() - def test_model_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model_forward(*config_and_inputs) - def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -147,7 +142,6 @@ def test_forward_signature(self): expected_arg_names = ["input_values", "audio_codes", "bandwidth", "return_dict"] self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) - def test_gradient_checkpointing_backward_compatibility(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -160,7 +154,6 @@ def test_gradient_checkpointing_backward_compatibility(self): config.decoder.gradient_checkpointing = True model = model_class(config) self.assertTrue(model.is_gradient_checkpointing) - @unittest.skip(reason="We cannot configure to output a smaller model.") def test_model_is_small(self): @@ -285,7 +278,7 @@ def test_attention_outputs(self): def test_hidden_states_output(self): pass - # Copied from transformers.tests.encodec.test_modeling_encodecEncodecModelTest.test_determinism + # Copied from transformers.tests.encodec.test_modeling_encodecEncodecModelTest.test_determinism def test_determinism(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -350,7 +343,6 @@ def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): dict_inputs = self._prepare_for_class(inputs_dict, model_class) check_equivalence(model, tuple_inputs, dict_inputs) - def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() configs_no_init = _config_zero_init(config) @@ -358,11 +350,11 @@ def test_initialization(self): model = model_class(config=configs_no_init) for name, param in model.named_parameters(): # skipping the parametrizations original0 tensor - if name =="semantic_model.encoder.pos_conv_embed.conv.parametrizations.weight.original0": + if name == "semantic_model.encoder.pos_conv_embed.conv.parametrizations.weight.original0": continue uniform_init_parms = ["conv"] - + if param.requires_grad: if any(x in name for x in uniform_init_parms): self.assertTrue( @@ -370,7 +362,6 @@ def test_initialization(self): msg=f"Parameter {name} of {model_class.__name__} seems not properly initialized", ) - @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -426,22 +417,20 @@ def compute_rmse(arr1, arr2): return np.sqrt(((arr1_normalized - arr2_normalized) ** 2).mean()) -#@slow +# @slow @require_torch class XcodecIntegrationTest(unittest.TestCase): def test_integration(self): expected_rmse = { - "0.5": 0.0065491, - "4.0": 0.0070978, + "0.5": 0.0065491, + "4.0": 0.0070978, } expected_codesums = { - "0.5": [117262], - "4.0": [926416], + "0.5": [117262], + "4.0": [926416], } - librispeech = load_dataset( - "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" - ) + librispeech = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") model_id = "Manel/X-Codec" model = XcodecModel.from_pretrained(model_id).to(torch_device).eval() feature_extractor = AutoFeatureExtractor.from_pretrained(model_id) @@ -449,7 +438,9 @@ def test_integration(self): librispeech = librispeech.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate)) audio = librispeech[-1]["audio"]["array"] - inputs = feature_extractor(raw_audio=audio, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt").to(torch_device) + inputs = feature_extractor( + raw_audio=audio, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt" + ).to(torch_device) for bandwidth, exp_rmse in expected_rmse.items(): bandwidth = float(bandwidth) @@ -460,17 +451,13 @@ def test_integration(self): expected_codesum = expected_codesums[str(bandwidth)][0] self.assertEqual(codesum, expected_codesum) - input_values_dec = model.decode( - audio_codes, return_dict=False - ) - input_values_enc_dec = model( - inputs["input_values"], bandwidth=bandwidth - )[1] + input_values_dec = model.decode(audio_codes, return_dict=False) + input_values_enc_dec = model(inputs["input_values"], bandwidth=bandwidth)[1] self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3)) self.assertTrue(inputs["input_values"].shape == input_values_enc_dec.shape) - + arr = inputs["input_values"][0].cpu().numpy() arr_enc_dec = input_values_enc_dec[0].cpu().numpy() rmse = compute_rmse(arr, arr_enc_dec) From 62e7d5298f29caabc673c9e405386814a0ded0ea Mon Sep 17 00:00:00 2001 From: Manal ML Date: Wed, 21 May 2025 06:10:45 +0000 Subject: [PATCH 04/10] fix docstring --- src/transformers/models/xcodec/configuration_xcodec.py | 2 +- src/transformers/models/xcodec/modeling_xcodec.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/xcodec/configuration_xcodec.py b/src/transformers/models/xcodec/configuration_xcodec.py index ba363c993fad..320dc15e955d 100644 --- a/src/transformers/models/xcodec/configuration_xcodec.py +++ b/src/transformers/models/xcodec/configuration_xcodec.py @@ -33,7 +33,7 @@ class XcodecConfig(PretrainedConfig): This is the configuration class to store the configuration of an [`XcodecModel`]. It is used to instantiate a Xcodec model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the - [ ](https://huggingface.co/ ) architecture. + [Manel/X-Codec](https://huggingface.co/Manel/X-Codec) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. diff --git a/src/transformers/models/xcodec/modeling_xcodec.py b/src/transformers/models/xcodec/modeling_xcodec.py index 239577d5966c..8d547806bf43 100644 --- a/src/transformers/models/xcodec/modeling_xcodec.py +++ b/src/transformers/models/xcodec/modeling_xcodec.py @@ -262,7 +262,7 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: class XcodecEuclideanCodebook(nn.Module): """Codebook with Euclidean distance.""" - def __init__(self, config: XcodecConfig): + def __init__(self, config): super().__init__() embed = torch.zeros(config.codebook_size, config.codebook_dim) self.codebook_size = config.codebook_size @@ -286,6 +286,11 @@ def encode(self, hidden_states): embed_ind = embed_ind.view(*shape[:-1]) return embed_ind + def decode(self, embed_ind): + quantized = F.embedding(embed_ind, self.embed) + return quantized + + class XcodecVectorQuantization(nn.Module): """ Vector quantization implementation. Currently supports only euclidean distance. @@ -312,6 +317,7 @@ class XcodecResidualVectorQuantization(nn.Module): """ Residual vector quantization implementation. Follows Algorithm 1 in https://arxiv.org/pdf/2107.03312.pdf """ + def __init__(self, config: XcodecConfig): super().__init__() self.quantizers = nn.ModuleList([XcodecVectorQuantization(config) for _ in range(config.num_quantizers)]) From 6e093ec53b3a40eda66b141e9060fa9025dde994 Mon Sep 17 00:00:00 2001 From: Manal ML Date: Wed, 21 May 2025 17:07:10 +0000 Subject: [PATCH 05/10] fix docstring and config attribute --- .../models/xcodec/configuration_xcodec.py | 14 +++++++------- utils/check_config_attributes.py | 2 ++ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/xcodec/configuration_xcodec.py b/src/transformers/models/xcodec/configuration_xcodec.py index 320dc15e955d..b38a6724fc70 100644 --- a/src/transformers/models/xcodec/configuration_xcodec.py +++ b/src/transformers/models/xcodec/configuration_xcodec.py @@ -39,7 +39,7 @@ class XcodecConfig(PretrainedConfig): documentation from [`PretrainedConfig`] for more information. Args: - target_bandwidths (`List[float]`, *optional*, defaults to [0.5, 1.0, 1.5, 2.0, 4.0]): + target_bandwidths (`List[float]`, *optional*, defaults to `[0.5, 1, 1.5, 2, 4]`): The range of different bandwidths (in kbps) the model can encode audio with. audio_channels (`int`, *optional*, defaults to 1): Number of channels in the audio data. Either 1 for mono or 2 for stereo. @@ -47,15 +47,15 @@ class XcodecConfig(PretrainedConfig): The sampling rate at which the audio waveform should be digitalized, in hertz (Hz). input_channels (`int`, *optional*, defaults to 768): Number of channels of the input to the first convolution in the semantic encoder. - kernel_size (`int`, *optional*, defaults to 3): - Kernel size for the initial semantic convolution. encoder_channels (`int`, *optional*, defaults to 768): Number of hidden channels in each semantic encoder block. - channel_ratios (`List[float]`, *optional*, defaults to [1.0, 1.0]): + kernel_size (`int`, *optional*, defaults to 3): + Kernel size for the initial semantic convolution. + channel_ratios (`List[float]`, *optional*, defaults to `[1, 1]`): Expansion factors for the number of output channels in each semantic block. - strides (`List[int]`, *optional*, defaults to [1, 1]): + strides (`List[int]`, *optional*, defaults to `[1, 1]`): Strides for each semantic encoder block. - block_dilations (`List[int]`, *optional*, defaults to [1, 1]): + block_dilations (`List[int]`, *optional*, defaults to `[1, 1]`): Dilation factors for the residual units in semantic blocks. unit_kernel_size (`int`, *optional*, defaults to 3): Kernel size inside each ResidualUnit in semantic blocks. @@ -108,7 +108,7 @@ class XcodecConfig(PretrainedConfig): def __init__( self, target_bandwidths: List[float] = [0.5, 1, 1.5, 2, 4], - audio_channels=1, + audio_channels: int = 1, sample_rate: int = 16000, input_channels: int = 768, encoder_channels: int = 768, diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 6f5d95dfee24..b4fe7fb00209 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -61,6 +61,8 @@ "Phi3Config": ["embd_pdrop"], # used to compute the property `self.chunk_length` "EncodecConfig": ["overlap"], + # used to compute `frame_rate` + "XcodecConfig": ["sample_rate", "audio_channels"], # used to compute the property `self.layers_block_type` "RecurrentGemmaConfig": ["block_types"], # used as in the config to define `intermediate_size` From ad5a62cff50002054e187a6360cc8106b03d06c6 Mon Sep 17 00:00:00 2001 From: Manal ML Date: Thu, 29 May 2025 03:26:06 +0100 Subject: [PATCH 06/10] Update args + config --- docs/source/en/model_doc/xcodec.md | 42 +++++----- .../models/xcodec/configuration_xcodec.py | 25 ++---- .../models/xcodec/modeling_xcodec.py | 83 +++++-------------- 3 files changed, 46 insertions(+), 104 deletions(-) diff --git a/docs/source/en/model_doc/xcodec.md b/docs/source/en/model_doc/xcodec.md index ce0287f46f17..328b48af9230 100644 --- a/docs/source/en/model_doc/xcodec.md +++ b/docs/source/en/model_doc/xcodec.md @@ -46,27 +46,27 @@ This model was contributed by [Manal El Aidouni](https://huggingface.co/Manel). Here is a quick example of how to encode and decode an audio using this model: ```python ->>> from datasets import load_dataset, Audio ->>> from transformers import XCodecModel, AutoFeatureExtractor ->>> dummy_dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - ->>> # load model and feature extractor ->>> model = XCodecModel.from_pretrained("Manel/X-Codec") ->>> feature_extractor = AutoFeatureExtractor.from_pretrained("Manel/X-Codec") ->>> # load audio sample ->>> dummy_dataset = dummy_dataset.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate)) ->>> audio_sample = dummy_dataset[-1]["audio"]["array"] ->>> inputs = feature_extractor(raw_audio=audio_sample, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") - ->>> encoder_outputs = model.encode(inputs["input_values"]) ->>> audio_codes = encoder_outputs.audio_codes ->>> decoder_outputs = model.decode(audio_codes) ->>> audio_values = decoder_outputs.audio_values - ->>> # or the equivalent with a forward pass ->>> outputs = model(inputs["input_values"]) ->>> audio_codes = outputs.audio_codes ->>> audio_values = outputs.audio_values +from datasets import load_dataset, Audio +from transformers import XcodecModel, AutoFeatureExtractor +dummy_dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + +# load model and feature extractor +model = XcodecModel.from_pretrained("Manel/X-Codec") +feature_extractor = AutoFeatureExtractor.from_pretrained("Manel/X-Codec") +# load audio sample +dummy_dataset = dummy_dataset.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate)) +audio_sample = dummy_dataset[-1]["audio"]["array"] +inputs = feature_extractor(raw_audio=audio_sample, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") + +encoder_outputs = model.encode(inputs["input_values"]) +audio_codes = encoder_outputs.audio_codes +decoder_outputs = model.decode(audio_codes) +audio_values = decoder_outputs.audio_values + +# or the equivalent with a forward pass +outputs = model(inputs["input_values"]) +audio_codes = outputs.audio_codes +audio_values = outputs.audio_values ``` ## XcodecConfig diff --git a/src/transformers/models/xcodec/configuration_xcodec.py b/src/transformers/models/xcodec/configuration_xcodec.py index b38a6724fc70..79f2e37cf779 100644 --- a/src/transformers/models/xcodec/configuration_xcodec.py +++ b/src/transformers/models/xcodec/configuration_xcodec.py @@ -15,7 +15,7 @@ """Xcodec model configuration""" import math -from typing import List, Union +from typing import List, Optional, Union import numpy as np @@ -107,7 +107,7 @@ class XcodecConfig(PretrainedConfig): def __init__( self, - target_bandwidths: List[float] = [0.5, 1, 1.5, 2, 4], + target_bandwidths: Optional[List[float]] = None, audio_channels: int = 1, sample_rate: int = 16000, input_channels: int = 768, @@ -152,6 +152,9 @@ def __init__( elif isinstance(semantic_model_config, HubertConfig): self.semantic_model_config = semantic_model_config + if target_bandwidths is None: + target_bandwidths = [0.5, 1, 1.5, 2, 4] + self.target_bandwidths = target_bandwidths self.audio_channels = audio_channels self.sample_rate = sample_rate @@ -172,24 +175,6 @@ def __init__( self.intermediate_dim = intermediate_dim self.output_dim = output_dim - @classmethod - def from_sub_models_config(cls, acoustic_model_config: DacConfig, semantic_model_config: HubertConfig, **kwargs): - """ - Instantiate a [`XcodecConfig`] from acoustic model and semantic model. - - Returns: - [`XcodecConfig`]: The instantiated configuration. - """ - return cls( - acoustic_model_config=acoustic_model_config.to_dict() - if hasattr(acoustic_model_config, "to_dict") - else acoustic_model_config, - semantic_model_config=semantic_model_config.to_dict() - if hasattr(semantic_model_config, "to_dict") - else semantic_model_config, - **kwargs, - ) - @property def frame_rate(self) -> int: return math.ceil(self.sample_rate / np.prod(self.acoustic_model_config.upsampling_ratios)) diff --git a/src/transformers/models/xcodec/modeling_xcodec.py b/src/transformers/models/xcodec/modeling_xcodec.py index 8d547806bf43..07e2b2ecaf44 100644 --- a/src/transformers/models/xcodec/modeling_xcodec.py +++ b/src/transformers/models/xcodec/modeling_xcodec.py @@ -76,23 +76,21 @@ class XcodecDecoderOutput(ModelOutput): class ResidualUnit(nn.Module): """Residual block for SemanticEncoder and SemanticDecoder used in Xcodec.""" - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int = 3, - dilation: int = 1, - stride: int = 1, - bias: bool = False, - padding: int = -1, - groups: int = 1, - ): + def __init__(self, config: XcodecConfig, in_channels: int, out_channels: int, dilation: int): super().__init__() self.activation = nn.ELU() - if padding < 0: - padding = ((kernel_size - 1) // 2) * dilation - self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) - self.conv2 = nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=1, bias=bias) + padding = ((config.unit_kernel_size - 1) // 2) * dilation + self.conv1 = nn.Conv1d( + in_channels, + out_channels, + config.unit_kernel_size, + stride=1, + padding=padding, + dilation=dilation, + groups=1, + bias=False, + ) + self.conv2 = nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=1, bias=False) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: output_tensor = self.activation(hidden_state) @@ -103,21 +101,10 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: class SemanticEncoderBlock(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - stride: int, - dilations: tuple, - unit_kernel_size: int = 3, - bias: bool = False, - ): + def __init__(self, config: XcodecConfig, in_channels: int, out_channels: int, stride: int): super().__init__() self.res_units = nn.ModuleList( - [ - ResidualUnit(in_channels, in_channels, unit_kernel_size, dilation=dilation, bias=bias) - for dilation in dilations - ] + [ResidualUnit(config, in_channels, in_channels, dilation) for dilation in config.block_dilations] ) # special case: stride=1, do not use kernel=2 @@ -146,11 +133,7 @@ def __init__(self, config): conv_blocks = [] for i, stride in enumerate(config.strides): out_channels = int(config.encoder_channels * config.channel_ratios[i]) - conv_blocks += [ - SemanticEncoderBlock( - in_channels, out_channels, stride, config.block_dilations, config.unit_kernel_size, bias=False - ) - ] + conv_blocks += [SemanticEncoderBlock(config, in_channels, out_channels, stride)] in_channels = out_channels self.conv_blocks = nn.ModuleList(conv_blocks) @@ -163,15 +146,7 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: class SemanticDecoderBlock(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - stride: int, - dilations: tuple, - unit_kernel_size: int = 3, - bias: bool = False, - ): + def __init__(self, config: XcodecConfig, in_channels: int, out_channels: int, stride: int): super().__init__() if stride == 1: self.conv = nn.Conv1d( @@ -191,16 +166,7 @@ def __init__( ) self.res_units = nn.ModuleList( - [ - ResidualUnit( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=unit_kernel_size, - dilation=dilation, - bias=bias, - ) - for dilation in dilations - ] + [ResidualUnit(config, out_channels, out_channels, dilation) for dilation in config.block_dilations] ) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: @@ -222,7 +188,7 @@ def __init__(self, config): bias=False, ) conv_blocks = [] - for i in range(len(config.strides)): + for i, stride in enumerate(config.strides): in_channels = int(config.decoder_channels * config.channel_ratios[i]) if i < (len(config.channel_ratios) - 1): @@ -230,16 +196,7 @@ def __init__(self, config): else: out_channels = config.decoder_channels - conv_blocks += [ - SemanticDecoderBlock( - in_channels, - out_channels, - config.strides[i], - dilations=config.block_dilations, - unit_kernel_size=config.unit_kernel_size, - bias=False, - ) - ] + conv_blocks += [SemanticDecoderBlock(config, in_channels, out_channels, stride)] self.conv_blocks = nn.ModuleList(conv_blocks) self.conv2 = nn.Conv1d( From cbb4abdfea9ec8f45abc1629c71477a54c0c967c Mon Sep 17 00:00:00 2001 From: Manal ML Date: Thu, 29 May 2025 22:21:51 +0100 Subject: [PATCH 07/10] update convertion script --- .../xcodec/convert_xcodec_weights_to_hf.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/xcodec/convert_xcodec_weights_to_hf.py b/src/transformers/models/xcodec/convert_xcodec_weights_to_hf.py index 74a74335cf36..ad03d731e00c 100644 --- a/src/transformers/models/xcodec/convert_xcodec_weights_to_hf.py +++ b/src/transformers/models/xcodec/convert_xcodec_weights_to_hf.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import io import re from typing import Dict @@ -30,6 +31,8 @@ logger = logging.get_logger(__name__) +torch.serialization.add_safe_globals([io.BytesIO]) + MAPPING_ACOUSTIC_ENCODER = { r"^block\.0": ["conv1"], r"^block\.(\d+)\.block\.(\d+)\.block\.0": ["block", "res_unit", "snake1"], @@ -72,6 +75,14 @@ } +def safe_load(path: str) -> Dict[str, torch.Tensor]: + """ + Load only the tensor objects from a checkpoint, skipping any BytesIO + """ + shard = torch.load(path, map_location="cpu", weights_only=True) + return {k: v for k, v in shard.items() if not isinstance(v, io.BytesIO)} + + def _rewrite_weight_norm(key: str) -> str: if key.endswith("weight_g"): return key[: -len("weight_g")] + "parametrizations.weight.original0" @@ -175,11 +186,12 @@ def convert_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=No else: config = XcodecConfig() - model = XcodecModel(config) + with torch.device("meta"): + model = XcodecModel(config) logger.info("Loading original checkpoint ...") - state_dict = torch.load(checkpoint_path) + state_dict = safe_load(checkpoint_path) # the original checkpoint has weight norm applied model.apply_weight_norm() @@ -188,7 +200,7 @@ def convert_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=No new_state_dict = convert_old_keys_to_new_keys(state_dict) - missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) + missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=True, assign=True) # strict=False) if len(unexpected_keys) != 0: raise ValueError(f"Unexpected keys: {unexpected_keys}") From 1d1de8031f88d3694a4218455272a3c7436c85cd Mon Sep 17 00:00:00 2001 From: Manal ML Date: Fri, 30 May 2025 03:12:25 +0100 Subject: [PATCH 08/10] update docs + cleanup --- docs/source/en/model_doc/xcodec.md | 21 ++++++++++++++----- .../models/xcodec/configuration_xcodec.py | 4 ---- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/docs/source/en/model_doc/xcodec.md b/docs/source/en/model_doc/xcodec.md index 328b48af9230..959c6b71207a 100644 --- a/docs/source/en/model_doc/xcodec.md +++ b/docs/source/en/model_doc/xcodec.md @@ -59,16 +59,27 @@ audio_sample = dummy_dataset[-1]["audio"]["array"] inputs = feature_extractor(raw_audio=audio_sample, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") encoder_outputs = model.encode(inputs["input_values"]) -audio_codes = encoder_outputs.audio_codes -decoder_outputs = model.decode(audio_codes) +decoder_outputs = model.decode(encoder_outputs.audio_codes) audio_values = decoder_outputs.audio_values # or the equivalent with a forward pass -outputs = model(inputs["input_values"]) -audio_codes = outputs.audio_codes -audio_values = outputs.audio_values +audio_values = model(inputs["input_values"]).audio_values + +``` +To listen to the original and reconstructed audio, run the snippet below and then open the generated `original.wav` and `reconstruction.wav` files in your music player to compare. + +```python +import soundfile as sf + +original = audio_sample +reconstruction = audio_values[0].cpu().detach().numpy() +sampling_rate = feature_extractor.sampling_rate + +sf.write("original.wav", original, sampling_rate) +sf.write("reconstruction.wav", reconstruction.T, sampling_rate) ``` + ## XcodecConfig [[autodoc]] XcodecConfig diff --git a/src/transformers/models/xcodec/configuration_xcodec.py b/src/transformers/models/xcodec/configuration_xcodec.py index 79f2e37cf779..967f423e40e3 100644 --- a/src/transformers/models/xcodec/configuration_xcodec.py +++ b/src/transformers/models/xcodec/configuration_xcodec.py @@ -179,10 +179,6 @@ def __init__( def frame_rate(self) -> int: return math.ceil(self.sample_rate / np.prod(self.acoustic_model_config.upsampling_ratios)) - @property - def bits_per_codebook(self) -> int: - return int(math.log2(self.codebook_size)) - @property def hop_length(self) -> int: return int(np.prod(self.acoustic_model_config.downsampling_ratios)) From 877afecaae36b5d72a51384ed1a2581d99228840 Mon Sep 17 00:00:00 2001 From: Manal ML Date: Wed, 25 Jun 2025 15:59:36 +0100 Subject: [PATCH 09/10] Ruff fix --- src/transformers/models/xcodec/configuration_xcodec.py | 10 +++++----- .../models/xcodec/convert_xcodec_weights_to_hf.py | 7 +++---- src/transformers/models/xcodec/modeling_xcodec.py | 4 ++-- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/xcodec/configuration_xcodec.py b/src/transformers/models/xcodec/configuration_xcodec.py index 967f423e40e3..d93510c7d621 100644 --- a/src/transformers/models/xcodec/configuration_xcodec.py +++ b/src/transformers/models/xcodec/configuration_xcodec.py @@ -15,7 +15,7 @@ """Xcodec model configuration""" import math -from typing import List, Optional, Union +from typing import Optional, Union import numpy as np @@ -107,15 +107,15 @@ class XcodecConfig(PretrainedConfig): def __init__( self, - target_bandwidths: Optional[List[float]] = None, + target_bandwidths: Optional[list[float]] = None, audio_channels: int = 1, sample_rate: int = 16000, input_channels: int = 768, encoder_channels: int = 768, kernel_size: int = 3, - channel_ratios: List[float] = [1, 1], - strides: List[int] = [1, 1], - block_dilations: List[int] = [1, 1], + channel_ratios: list[float] = [1, 1], + strides: list[int] = [1, 1], + block_dilations: list[int] = [1, 1], unit_kernel_size: int = 3, decoder_channels: int = 768, output_channels: int = 768, diff --git a/src/transformers/models/xcodec/convert_xcodec_weights_to_hf.py b/src/transformers/models/xcodec/convert_xcodec_weights_to_hf.py index ad03d731e00c..30783234fc2b 100644 --- a/src/transformers/models/xcodec/convert_xcodec_weights_to_hf.py +++ b/src/transformers/models/xcodec/convert_xcodec_weights_to_hf.py @@ -15,7 +15,6 @@ import argparse import io import re -from typing import Dict import torch @@ -75,7 +74,7 @@ } -def safe_load(path: str) -> Dict[str, torch.Tensor]: +def safe_load(path: str) -> dict[str, torch.Tensor]: """ Load only the tensor objects from a checkpoint, skipping any BytesIO """ @@ -91,8 +90,8 @@ def _rewrite_weight_norm(key: str) -> str: return key -def convert_old_keys_to_new_keys(original_state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - converted_checkpoint: Dict[str, torch.Tensor] = {} +def convert_old_keys_to_new_keys(original_state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + converted_checkpoint: dict[str, torch.Tensor] = {} for old_key, value in original_state_dict.items(): if old_key.startswith("encoder."): diff --git a/src/transformers/models/xcodec/modeling_xcodec.py b/src/transformers/models/xcodec/modeling_xcodec.py index 07e2b2ecaf44..5534ec0d0275 100644 --- a/src/transformers/models/xcodec/modeling_xcodec.py +++ b/src/transformers/models/xcodec/modeling_xcodec.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Optional, Union import torch import torch.nn as nn @@ -563,7 +563,7 @@ def forward( bandwidth: Optional[float] = None, return_dict: Optional[bool] = None, **kwargs, - ) -> Union[Tuple[torch.Tensor, torch.Tensor], XcodecOutput]: + ) -> Union[tuple[torch.Tensor, torch.Tensor], XcodecOutput]: r""" Returns: From 4854e5b3bf97b008265eaaa4387c0b9586335745 Mon Sep 17 00:00:00 2001 From: Manal ML Date: Thu, 17 Jul 2025 17:21:45 +0100 Subject: [PATCH 10/10] fix doctrings --- .../models/xcodec/configuration_xcodec.py | 1 - .../models/xcodec/modeling_xcodec.py | 73 ++++++------------- tests/models/xcodec/test_modeling_xcodec.py | 5 ++ 3 files changed, 29 insertions(+), 50 deletions(-) diff --git a/src/transformers/models/xcodec/configuration_xcodec.py b/src/transformers/models/xcodec/configuration_xcodec.py index d93510c7d621..45bd1077af2e 100644 --- a/src/transformers/models/xcodec/configuration_xcodec.py +++ b/src/transformers/models/xcodec/configuration_xcodec.py @@ -103,7 +103,6 @@ class XcodecConfig(PretrainedConfig): "acoustic_model_config": DacConfig, "semantic_model_config": HubertConfig, } - is_composition = True def __init__( self, diff --git a/src/transformers/models/xcodec/modeling_xcodec.py b/src/transformers/models/xcodec/modeling_xcodec.py index 5534ec0d0275..a62d0d2952bf 100644 --- a/src/transformers/models/xcodec/modeling_xcodec.py +++ b/src/transformers/models/xcodec/modeling_xcodec.py @@ -23,20 +23,11 @@ import torch.nn.functional as F from ...modeling_utils import PreTrainedModel -from ...utils import ( - ModelOutput, - add_start_docstrings, - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) +from ...utils import ModelOutput, auto_docstring from ..auto import AutoModel from .configuration_xcodec import XcodecConfig -# General docstring -_CONFIG_FOR_DOC = "XcodecConfig" - - @dataclass class XcodecOutput(ModelOutput): """ @@ -320,6 +311,7 @@ def decode(self, codes: torch.Tensor) -> torch.Tensor: return quantized_out +@auto_docstring class XcodecPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -383,40 +375,7 @@ def remove_weight_norm(self): torch.nn.utils.parametrize.remove_parametrizations(m, "weight", leave_parametrized=True) -XCODEC_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`XcodecConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -XCODEC_INPUTS_DOCSTRING = r""" - args: - input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`): - The raw float values of the input audio waveform. - audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`: - Discrete code indices computed using `model.encode`. - bandwidth (`float`, *optional*): - The target bandwidth in (kbps) supports only values in `config.target_bandwidths`. - Defaults to the highest available bandwidth `4.0` kbps. - return_dict (`bool`, *optional*): - whether to return a `XcodecOutput` or a plain tuple. -""" - - -@add_start_docstrings( - "The Xcodec neural audio codec model.", - XCODEC_START_DOCSTRING, -) +@auto_docstring(custom_intro="""The Xcodec neural audio codec model.""") class XcodecModel(XcodecPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -458,6 +417,7 @@ def _extract_semantic_features(self, input_values: torch.FloatTensor) -> torch.F stacked = torch.stack(hidden_states, dim=1) return stacked.mean(dim=1) + @auto_docstring def encode( self, input_values: torch.Tensor, @@ -475,7 +435,7 @@ def encode( The target bandwidth in (kbps) supports only values in `config.target_bandwidths`. Defaults to the highest available bandwidth `4.0` kbps. return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + Whether or not to return a [`~utils.ModelOutput`]. Returns: `torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)` containing the discrete encoded audio codes. @@ -517,6 +477,7 @@ def encode( return XcodecEncoderOutput(audio_codes) + @auto_docstring def decode( self, audio_codes: torch.Tensor, return_dict: Optional[bool] = None, **kwargs ) -> Union[torch.Tensor, XcodecDecoderOutput]: @@ -530,7 +491,7 @@ def decode( Discrete code indices computed using `model.encode`. return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + Whether or not to return a [`~utils.ModelOutput`] Returns: Decoded audio values of shape `(batch_size, channels, num_samples)` obtained using the decoder part of Xcodec. @@ -554,8 +515,7 @@ def decode( return XcodecDecoderOutput(audio_values) - @add_start_docstrings_to_model_forward(XCODEC_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=XcodecOutput, config_class=_CONFIG_FOR_DOC) + @auto_docstring def forward( self, input_values: torch.Tensor, @@ -565,9 +525,24 @@ def forward( **kwargs, ) -> Union[tuple[torch.Tensor, torch.Tensor], XcodecOutput]: r""" + Encodes and quantizes the input audio into discrete codes, then decodes those codes back into an audio waveform. + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`): + The raw float values of the input audio waveform. + audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`: + Discrete code indices computed using `model.encode`. + bandwidth (`float`, *optional*): + Target bandwidth in kbps. Must be one of `config.target_bandwidths`. + Defaults to the highest available bandwidth. + return_dict (`bool`, *optional*): + Whether to return a [`XcodecOutput`] instead of a plain tuple. + Returns: + `XcodecOutput` or tuple `(audio_codes, audio_values)`: + - `audio_codes` of shape `(batch_size, num_quantizers, codes_length)`: the quantized discrete codes. + - `audio_values` of shape `(batch_size, channels, num_samples)`: the reconstructed audio waveform given the codes. - Examples: + Example: ```python >>> from datasets import load_dataset diff --git a/tests/models/xcodec/test_modeling_xcodec.py b/tests/models/xcodec/test_modeling_xcodec.py index d2e464d90837..0d2a2aade593 100644 --- a/tests/models/xcodec/test_modeling_xcodec.py +++ b/tests/models/xcodec/test_modeling_xcodec.py @@ -107,6 +107,7 @@ class XcodecModelTest(ModelTesterMixin, unittest.TestCase): test_headmasking = False test_resize_embeddings = False test_torchscript = False + test_can_init_all_missing_weights = False def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): # model does not support returning hidden states @@ -155,6 +156,10 @@ def test_gradient_checkpointing_backward_compatibility(self): model = model_class(config) self.assertTrue(model.is_gradient_checkpointing) + @unittest.skip("XcodecModel cannot be tested with meta device") + def test_can_load_with_meta_device_context_manager(self): + pass + @unittest.skip(reason="We cannot configure to output a smaller model.") def test_model_is_small(self): pass