Skip to content

RuntimeError: 0 <= device.index() && device.index() < static_cast<c10::DeviceIndex>(device_ready_queues_.size()) INTERNAL ASSERT FAILED at "/build/pytorch/torch/csrc/autograd/engine.cpp":1418 #571

Open
@SoldierWz

Description

@SoldierWz

Describe the bug

When this problem occurred, I tried to disable the CPU core, and then I could run normally, but the running results were very poor, the accuracy dropped sharply and the training time became longer. I have submitted this issue #565. Then when I restored the CPU core, the above error occurred.
Here is the part where the problem occurs.
device = 'xpu'
for train_idx, test_idx in kf.split(X_tensor):
X_train, X_test = X_tensor[train_idx], X_tensor[test_idx]
y_train, y_test = y_tensor[train_idx], y_tensor[test_idx]

train_dataset = CustomDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

model = MLP(X_train.shape[1]) 
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

model = model.to("xpu")
criterion = criterion.to("xpu")
model, optimizer = ipex.optimize(model, optimizer=optimizer)
for epoch in range(1000):
    model.train() 
    for features, labels in train_loader:
        features, labels = features.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(features)
        loss = criterion(outputs, labels)
        **loss.backward()**
        optimizer.step()

Versions

wget https://github.com/raw/intel/intel-extension-for-pytorch/master/scripts/collect_env.py

For security purposes, please check the contents of collect_env.py before running it.

python collect_env.py

Activity

jgong5

jgong5 commented on Mar 26, 2024

@jgong5

May I know what you mean by "disable CPU core"? It sounds like no GPU was found according to the error message. But we should report more meaningful error messages. cc @gujinghui

SoldierWz

SoldierWz commented on Mar 27, 2024

@SoldierWz
Author

May I know what you mean by "disable CPU core"? It sounds like no GPU was found according to the error message. But we should report more meaningful error messages. cc @gujinghui

I edited the GRUB configuration file
Change GRUB_CMDLINE_LINUX_DEFAULT="quiet splash"
Changed to GRUB_CMDLINE_LINUX_DEFAULT="nohz=off"
There is another line which is GRUB_CMDLINE_LINUX="i915.enable_hangcheck=0" which I did not change.
After editing like this, the GPU can be used
But I just tried and a new problem occurred. The error is reported below.

ImportError Traceback (most recent call last)
Cell In[2], line 8
6 import modin.pandas as pd
7 import numpy as np
----> 8 import torch
9 import intel_extension_for_pytorch as ipex
10 import torch.nn as nn

File ~/mambaforge/envs/pytorch-arc/lib/python3.11/site-packages/torch/init.py:235
233 if USE_GLOBAL_DEPS:
234 _load_global_deps()
--> 235 from torch._C import * # noqa: F403
237 # Appease the type checker; ordinarily this binding is inserted by the
238 # torch._C module initialization code in C
239 if TYPE_CHECKING:

ImportError: /home/wangzhen/mambaforge/envs/pytorch-arc/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so: undefined symbol: iJIT_NotifyEvent

amontse

amontse commented on Jul 11, 2024

@amontse

I have an observation: import intel_extension_for_pytorch for xpu after import torch before any invocation of torch methods, even you are not going to do anything with xpu yet.

I encountered the same INTERNAL ASSERTION at loss.backward() if I did the following:

import torch
# INTERNAL ASSERTION at loss.backward() if the below import if not invoked
# import intel_extension_for_pytorch as ipex
…
data = …
model = …
loss_fn = …
optimizer = …
# training using cpu
for epoch in range(epochs):
    …
    loss = loss.backward()
    …

# repeat the work above but in xpu
import intel_extension_for_pytorch as ipex
data = data.to(‘xpu’)
model = model.to(‘xpu’)
loss_fn = loss_fn.to(‘xpu’)
model, optimizer =ipex.optimize(model, optimizer=optimizer)
# training using  xpu
for epoch in range(epochs):
    …
    # INTERNAL ASSERTION at loss.backward() if import intel_extension_for_pytorch is not invoked immediately after import torch
    loss = loss.backward()
    …

The problem is solved if the import statement in the second line of above code is uncommented.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

Labels

ARCARC GPU

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

    Development

    No branches or pull requests

      Participants

      @jgong5@gujinghui@amontse@ZhaoqiongZ@SoldierWz

      Issue actions

        RuntimeError: 0 <= device.index() && device.index() < static_cast<c10::DeviceIndex>(device_ready_queues_.size()) INTERNAL ASSERT FAILED at "/build/pytorch/torch/csrc/autograd/engine.cpp":1418 · Issue #571 · intel/intel-extension-for-pytorch