Skip to content

[Bug]: vLLM engine crashes then restarts and loads the model on sleep if a chat request is made #15483

@erdaltoprak

Description

@erdaltoprak

Your current environment

The output of `python collect_env.py`
INFO 03-25 16:14:33 [__init__.py:256] Automatically detected platform cuda.
Collecting environment information...
PyTorch version: 2.6.0+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.6
Libc version: glibc-2.35

Python version: 3.12.9 (main, Mar 17 2025, 21:01:58) [Clang 20.1.0 ] (64-bit runtime)
Python platform: Linux-6.1.0-30-amd64-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4090
Nvidia driver version: 570.86.10
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        48 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               20
On-line CPU(s) list:                  0-19
Vendor ID:                            AuthenticAMD
Model name:                           AMD Ryzen 9 5950X 16-Core Processor
CPU family:                           25
Model:                                33
Thread(s) per core:                   1
Core(s) per socket:                   20
Socket(s):                            1
Stepping:                             2
BogoMIPS:                             6800.38
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves clzero xsaveerptr wbnoinvd arat npt lbrv nrip_save tsc_scale vmcb_clean flushbyasid pausefilter pfthreshold v_vmsave_vmload vgif umip pku ospke vaes vpclmulqdq rdpid fsrm arch_capabilities
Virtualization:                       AMD-V
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            1.3 MiB (20 instances)
L1i cache:                            1.3 MiB (20 instances)
L2 cache:                             10 MiB (20 instances)
L3 cache:                             320 MiB (20 instances)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-19
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Mitigation; safe RET, no microcode
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] flashinfer-python==0.2.1.post2+cu124torch2.6
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pyzmq==26.3.0
[pip3] torch==2.6.0
[pip3] torchaudio==2.6.0
[pip3] torchvision==0.21.0
[pip3] transformers==4.49.0
[pip3] triton==3.2.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.8.1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      0-19    0               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

LD_LIBRARY_PATH=/opt/venv/lib/python3.12/site-packages/cv2/../../lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
CUDA_VERSION=12.4.0
NVIDIA_REQUIRE_CUDA=cuda>=12.4 brand=tesla,driver>=470,driver<471 brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 brand=geforce,driver>=470,driver<471 brand=geforcertx,driver>=470,driver<471 brand=quadro,driver>=470,driver<471 brand=quadrortx,driver>=470,driver<471 brand=titan,driver>=470,driver<471 brand=titanrtx,driver>=470,driver<471 brand=tesla,driver>=525,driver<526 brand=unknown,driver>=525,driver<526 brand=nvidia,driver>=525,driver<526 brand=nvidiartx,driver>=525,driver<526 brand=geforce,driver>=525,driver<526 brand=geforcertx,driver>=525,driver<526 brand=quadro,driver>=525,driver<526 brand=quadrortx,driver>=525,driver<526 brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526 brand=tesla,driver>=535,driver<536 brand=unknown,driver>=535,driver<536 brand=nvidia,driver>=535,driver<536 brand=nvidiartx,driver>=535,driver<536 brand=geforce,driver>=535,driver<536 brand=geforcertx,driver>=535,driver<536 brand=quadro,driver>=535,driver<536 brand=quadrortx,driver>=535,driver<536 brand=titan,driver>=535,driver<536 brand=titanrtx,driver>=535,driver<536
NVIDIA_DRIVER_CAPABILITIES=compute,utility
NVIDIA_PRODUCT_NAME=CUDA
VLLM_USAGE_SOURCE=production-docker-image
VLLM_USE_V1=0
NVIDIA_DISABLE_REQUIRE=1
NVIDIA_VISIBLE_DEVICES=all
NCCL_VERSION=2.20.5-1
VLLM_SERVER_DEV_MODE=1
NCCL_CUMEM_ENABLE=0
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY

🐛 Describe the bug

🐛 Bug: Server crashes when making completion requests to a sleeping model

When using vLLM's sleep mode functionality through the /sleep, /wake_up, and /is_sleeping endpoints, the server crashes with a CUDA error if a completion request is made while the model is in sleep mode.

Steps to Reproduce

  1. Start a vLLM server
  2. Put the model to sleep with a POST request to /sleep
  3. Attempt to make a completion request with endpoints like /v1/chat/completions or /v1/completions

The server immediately crashes with a CUDA memory access error, requiring a restart.

Expected Behavior

When making a completion request to a sleeping model, the server should either:

  • Automatically and gracefully wake up the model first
  • Return a proper error response indicating the model is sleeping and needs to be woken up

Actual Behavior

The server crashes with an error like:

CUDA error: an illegal memory access was encountered
RuntimeError: CUDA error: an illegal memory access was encountered

Full crash log

vllm  | INFO 03-25 16:01:40 [engine.py:310] Added request cmpl-0750d012ef7140d1b80c8768ae5ef193-0.
vllm  | CRITICAL 03-25 16:01:40 [launcher.py:116] MQLLMEngine is already dead, terminating server process
vllm  | INFO:     172.18.0.1:33248 - "POST /v1/completions HTTP/1.1" 500 Internal Server Error
vllm  | ERROR 03-25 16:01:40 [engine.py:160] RuntimeError('CUDA error: an illegal memory access was encountered\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n')
vllm  | ERROR 03-25 16:01:40 [engine.py:160] Traceback (most recent call last):
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 158, in start
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     self.run_engine_loop()
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 221, in run_engine_loop
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     request_outputs = self.engine_step()
vllm  | ERROR 03-25 16:01:40 [engine.py:160]                       ^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 247, in engine_step
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     raise e
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 230, in engine_step
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     return self.engine.step()
vllm  | ERROR 03-25 16:01:40 [engine.py:160]            ^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 1435, in step
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     outputs = self.model_executor.execute_model(
vllm  | ERROR 03-25 16:01:40 [engine.py:160]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/vllm/executor/executor_base.py", line 139, in execute_model
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     output = self.collective_rpc("execute_model",
vllm  | ERROR 03-25 16:01:40 [engine.py:160]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     answer = run_method(self.driver_worker, method, args, kwargs)
vllm  | ERROR 03-25 16:01:40 [engine.py:160]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/vllm/utils.py", line 2216, in run_method
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     return func(*args, **kwargs)
vllm  | ERROR 03-25 16:01:40 [engine.py:160]            ^^^^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/vllm/worker/worker_base.py", line 420, in execute_model
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     output = self.model_runner.execute_model(
vllm  | ERROR 03-25 16:01:40 [engine.py:160]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     return func(*args, **kwargs)
vllm  | ERROR 03-25 16:01:40 [engine.py:160]            ^^^^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1742, in execute_model
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     hidden_or_intermediate_states = model_executable(
vllm  | ERROR 03-25 16:01:40 [engine.py:160]                                     ^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     return self._call_impl(*args, **kwargs)
vllm  | ERROR 03-25 16:01:40 [engine.py:160]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     return forward_call(*args, **kwargs)
vllm  | ERROR 03-25 16:01:40 [engine.py:160]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/vllm/model_executor/models/qwen2_5_vl.py", line 1071, in forward
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     hidden_states = self.language_model.model(
vllm  | ERROR 03-25 16:01:40 [engine.py:160]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/vllm/compilation/decorators.py", line 172, in __call__
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     return self.forward(*args, **kwargs)
vllm  | ERROR 03-25 16:01:40 [engine.py:160]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/vllm/model_executor/models/qwen2.py", line 338, in forward
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     hidden_states, residual = layer(
vllm  | ERROR 03-25 16:01:40 [engine.py:160]                               ^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     return self._call_impl(*args, **kwargs)
vllm  | ERROR 03-25 16:01:40 [engine.py:160]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     return forward_call(*args, **kwargs)
vllm  | ERROR 03-25 16:01:40 [engine.py:160]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/vllm/model_executor/models/qwen2.py", line 243, in forward
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     hidden_states = self.self_attn(
vllm  | ERROR 03-25 16:01:40 [engine.py:160]                     ^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     return self._call_impl(*args, **kwargs)
vllm  | ERROR 03-25 16:01:40 [engine.py:160]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     return forward_call(*args, **kwargs)
vllm  | ERROR 03-25 16:01:40 [engine.py:160]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/vllm/model_executor/models/qwen2.py", line 177, in forward
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     attn_output = self.attn(q, k, v)
vllm  | ERROR 03-25 16:01:40 [engine.py:160]                   ^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     return self._call_impl(*args, **kwargs)
vllm  | ERROR 03-25 16:01:40 [engine.py:160]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     return forward_call(*args, **kwargs)
vllm  | ERROR 03-25 16:01:40 [engine.py:160]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/vllm/attention/layer.py", line 214, in forward
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     torch.ops.vllm.unified_attention_with_output(
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/torch/_ops.py", line 1123, in __call__
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     return self._op(*args, **(kwargs or {}))
vllm  | ERROR 03-25 16:01:40 [engine.py:160]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/vllm/attention/layer.py", line 363, in unified_attention_with_output
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     self.impl.forward(self,
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/vllm/attention/backends/flash_attn.py", line 756, in forward
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     flash_attn_varlen_func(
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/vllm/vllm_flash_attn/flash_attn_interface.py", line 173, in flash_attn_varlen_func
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd(
vllm  | ERROR 03-25 16:01:40 [engine.py:160]                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160]   File "/opt/venv/lib/python3.12/site-packages/torch/_ops.py", line 1123, in __call__
vllm  | ERROR 03-25 16:01:40 [engine.py:160]     return self._op(*args, **(kwargs or {}))
vllm  | ERROR 03-25 16:01:40 [engine.py:160]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm  | ERROR 03-25 16:01:40 [engine.py:160] RuntimeError: CUDA error: an illegal memory access was encountered
vllm  | ERROR 03-25 16:01:40 [engine.py:160] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
vllm  | ERROR 03-25 16:01:40 [engine.py:160] For debugging consider passing CUDA_LAUNCH_BLOCKING=1
vllm  | ERROR 03-25 16:01:40 [engine.py:160] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
vllm  | ERROR 03-25 16:01:40 [engine.py:160] 
vllm  | INFO:     Shutting down
vllm  | INFO:     Waiting for application shutdown.
vllm  | INFO:     Application shutdown complete.
vllm  | INFO:     Finished server process [1]
vllm exited with code 0

This happens because the completion endpoints try to access GPU memory that has been offloaded during sleep mode

Proposed solution based on my current understanding

@router.post("/v1/chat/completions", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request):
    # First check if the engine is sleeping
    client = engine_client(raw_request)
    try:
        is_sleeping = await client.is_sleeping()
        if is_sleeping:
            # Option 1: Auto wake-up
            # await client.wake_up()
            # logger.info("Model was sleeping; automatically waking up before processing completion request")
            
            # Option 2: Return an error response
            return JSONResponse(
                content={
                    "error": {
                        "message": "Model is currently in sleep mode. Please wake it up first with a POST request to /wake_up",
                        "type": "ModelSleepingError",
                        "code": 503
                    }
                },
                status_code=503
            )
    except Exception as e:
        # If sleep check fails, log it but continue with request processing
        logger.warning(f"Failed to check if model is sleeping: {e}")
    
    # Original handler code continues...
    handler = chat(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
            message="The model does not support Chat Completions API")

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingunstaleRecieved activity after being labelled stale

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions