Skip to content

CI fails on few test_training_gradient_checkpointing tests for LLAMA #34722

@dvrogozh

Description

@dvrogozh

With:

On:

  • Nvidia A10
  • Intel Max Series

The following 3 tests are failing:

  • tests/models/llama/test_modeling_llama.py::LlamaModelTest::test_training_gradient_checkpointing
  • tests/models/llama/test_modeling_llama.py::LlamaModelTest::test_training_gradient_checkpointing_use_reentrant
  • tests/models/llama/test_modeling_llama.py::LlamaModelTest::test_training_gradient_checkpointing_use_reentrant_false

See log of the failure below. That is a regressions after this commit: 19d58d3, PR:

$ git bisect log
git bisect start
# bad: [33eef992503689ba1af98090e26d3e98865b2a9b] Agents: Small fixes in streaming to gradio + add tests (#34549)
git bisect bad 33eef992503689ba1af98090e26d3e98865b2a9b
# good: [984bc11b0882ff1e5b34ba717ea357e069ceced9] Revert "fixes to properly shard FSDP across cpu and meta for cpu_effcient_loading for prequantized 4bit (#32276)" (#32477)
git bisect good 984bc11b0882ff1e5b34ba717ea357e069ceced9
# good: [80b90e7b2f7466ffb1d9036986e0699880d34284] Add codestral mamba2 (#32080)
git bisect good 80b90e7b2f7466ffb1d9036986e0699880d34284
# bad: [1ec7a70fef4158ab1ed660cba5126c8cde08c7e8] fix trainer tr_loss add error (#33651)
git bisect bad 1ec7a70fef4158ab1ed660cba5126c8cde08c7e8
# good: [7ed9789e210d8eca797fc21b9c783b1ce718ecb5] Fix: `num_logits_to_keep` in composite models (#33168)
git bisect good 7ed9789e210d8eca797fc21b9c783b1ce718ecb5
# good: [bcf8946f0acb578c534b1d33d534450d1fc88507] Fix number of patch check for different vision feature select strategy (#32494)
git bisect good bcf8946f0acb578c534b1d33d534450d1fc88507
# good: [dc8b6eaeeeb59dd3089b478cc09b577f2c62a297] Fix contrastive search to correctly handle input with padding (#33507)
git bisect good dc8b6eaeeeb59dd3089b478cc09b577f2c62a297
# good: [fa0bb0fe762c757203565a940c6e59a8d27537c4] Fix ByteLevel alphabet missing when Sequence pretokenizer is used (#33556)
git bisect good fa0bb0fe762c757203565a940c6e59a8d27537c4
# bad: [19d58d31f19049e8280ccb62a5b098d89909bf5a] Add MLLama (#33703)
git bisect bad 19d58d31f19049e8280ccb62a5b098d89909bf5a
# good: [7e638ef2b8650aaa3e3a8e575bb63af262a43d95] fix code quality after merge
git bisect good 7e638ef2b8650aaa3e3a8e575bb63af262a43d95
# good: [61e98cb957862d679c4a338319a386da197b8073] Add SDPA support for M2M100 (#33309)
git bisect good 61e98cb957862d679c4a338319a386da197b8073
# good: [ade9e0fe41a414c6a24a03a79c15798db609a6c9] Corrected max number for bf16 in transformer/docs (#33658)
git bisect good ade9e0fe41a414c6a24a03a79c15798db609a6c9
# good: [94f18cf23c128055a984ffbe9c57df133c1f6cc7] Add OmDet-Turbo (#31843)
git bisect good 94f18cf23c128055a984ffbe9c57df133c1f6cc7
# first bad commit: [19d58d31f19049e8280ccb62a5b098d89909bf5a] Add MLLama (#33703)

Log for one of the failures (others are similar):

$ python3 -m pytest --pspec tests/models/llama/test_modeling_llama.py::LlamaModelTest::test_training_gradient_checkpointing
============================================================================================ test session starts ============================================================================================
platform linux -- Python 3.10.12, pytest-7.4.4, pluggy-1.5.0
rootdir: /home/dvrogozh/git/huggingface/transformers
configfile: pyproject.toml
plugins: hypothesis-6.111.1, subtests-0.13.1, rich-0.1.1, dash-2.17.1, xdist-3.6.1, pspec-0.0.4, timeout-2.3.1
collected 1 item

tests/models/llama/test_modeling_llama.py                                                                                                                                                                    
Llama Model Test
 » training gradient checkpointing
                                                                                                                                                                                                      [100%]

================================================================================================== ERRORS ===================================================================================================
_________________________________________________________________ ERROR at teardown of LlamaModelTest.test_training_gradient_checkpointing __________________________________________________________________

self = <tests.models.llama.test_modeling_llama.LlamaModelTest testMethod=test_training_gradient_checkpointing>

    def test_training_gradient_checkpointing(self):
        # Scenario - 1 default behaviour
>       self.check_training_gradient_checkpointing()

tests/test_modeling_common.py:899:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <tests.models.llama.test_modeling_llama.LlamaModelTest testMethod=test_training_gradient_checkpointing>, gradient_checkpointing_kwargs = None

    def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
        if not self.model_tester.is_training:
            self.skipTest(reason="ModelTester is not configured to run training tests")

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
                if (
                    model_class.__name__
                    in [
                        *get_values(MODEL_MAPPING_NAMES),
                        *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
                    ]
                    or not model_class.supports_gradient_checkpointing
                ):
                    self.skipTest(reason=f"`supports_gradient_checkpointing` is False for {model_class.__name__}.")

            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
            config.use_cache = False
            config.return_dict = True
            model = model_class(config)

            model.to(torch_device)
            model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
            model.train()

            # unfreeze additional layers
            for p in model.parameters():
                p.requires_grad_(True)

            optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

            inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
>           loss = model(**inputs).loss
E           AttributeError: 'BaseModelOutputWithPast' object has no attribute 'loss'

tests/test_modeling_common.py:868: AttributeError
============================================================================================= warnings summary ==============================================================================================
../../../../../usr/lib/python3.10/distutils/command/build_ext.py:13
  /usr/lib/python3.10/distutils/command/build_ext.py:13: DeprecationWarning: The distutils.sysconfig module is deprecated, use sysconfig instead
    from distutils.sysconfig import customize_compiler, get_python_version

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================================================== short test summary info ==========================================================================================
ERROR tests/models/llama/test_modeling_llama.py::LlamaModelTest::test_training_gradient_checkpointing - AttributeError: 'BaseModelOutputWithPast' object has no attribute 'loss'
================================================================================== 1 skipped, 1 warning, 1 error in 3.91s ===================================================================================
(pytorch.cuda) dvrogozh@cg-cyp-03:~/git/huggingface/transformers$

CC: @amyeroberts, @ArthurZucker

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions