Description
Bug description
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
import torch
from torch.utils.data import DataLoader, Dataset
import lightning as pl
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import os
os.environ["WANDB_API_KEY"] = "652be9a335ccff9372ec8e5b16946c34163f0ff5"
os.environ["HF_TOKEN"] = "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"
torch.set_float32_matmul_precision('high')
class ChatDataset(Dataset):
def __init__(self, dataset, tokenizer, max_length=1024):
self.dataset = dataset
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
data = self.dataset[idx]
chat = ("<|im_start|>user\n" + data["input"] + "<|im_end|>\n" +
"<|im_start|>assistant\n<think>\n \n</think>\n" + data["output"] + "<|im_end|>\n")
encoding = self.tokenizer(
chat,
padding='max_length',
truncation=True,
max_length=self.max_length,
return_tensors='pt'
)
input_ids = encoding['input_ids'].squeeze()
attention_mask = encoding['attention_mask'].squeeze()
labels = input_ids.clone()
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels
}
class LanguageModelLightning(pl.LightningModule):
def __init__(self, model_name, learning_rate=2e-5, weight_decay=0.01):
super().__init__()
self.save_hyperparameters()
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto"
)
self.model.train()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.learning_rate = learning_rate
self.weight_decay = weight_decay
def forward(self, input_ids, attention_mask, labels=None):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
return outputs
def training_step(self, batch, batch_idx):
outputs = self.forward(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
labels=batch['labels']
)
loss = outputs.loss
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
outputs = self.forward(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
labels=batch['labels']
)
loss = outputs.loss
self.log('val_loss', loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay
)
return optimizer
dataset = load_dataset('intexcp/russian-llm-training-dataset')
model = LanguageModelLightning("Qwen/Qwen3-0.6B")
#model = torch.compile(model)
train_dataset = ChatDataset(dataset["train"], model.tokenizer)
val_dataset = ChatDataset(dataset["test"], model.tokenizer)
train_loader = DataLoader(
train_dataset,
batch_size=8,
shuffle=True,
num_workers=16,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=8,
shuffle=False,
num_workers=16,
pin_memory=True
)
wandb_logger = WandbLogger(
project="IGen",
name="IGen"
)
checkpoint_callback = ModelCheckpoint(
dirpath="IGen/checkpoints",
filename='{epoch}-{val_loss:.2f}',
monitor='val_loss',
mode='min',
save_top_k=1,
save_last=True
)
trainer = Trainer(
max_epochs=2,
precision="bf16-true",
accelerator="auto",
strategy="auto",
devices="auto",
callbacks=[checkpoint_callback],
check_val_every_n_epoch=1,
log_every_n_steps=50,
enable_model_summary=True,
enable_progress_bar=True,
)
trainer.fit(model, train_loader, val_loader)
model.model.save_pretrained("IGen/final_model")
model.tokenizer.save_pretrained("IGen/final_model")
Error messages and logs
# Error messages and logs here please
```WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.
### Environment
<details>
<summary>Current environment</summary>
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(conda
, pip
, source):
</details>
### More info
_No response_