diff --git a/tests/lora/test_lora_functions.py b/tests/lora/test_lora_functions.py index 204624a0540a..7ae33a848a0a 100644 --- a/tests/lora/test_lora_functions.py +++ b/tests/lora/test_lora_functions.py @@ -69,7 +69,7 @@ def run_check(fn, args, expected: list): run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11]) run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11]) - # Remove all LoRAs + # Remove all LoRAs. run_check(llm.remove_lora, 13, [12, 10, 11]) run_check(llm.remove_lora, 12, [10, 11]) run_check(llm.remove_lora, 11, [10]) diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py index a8a713d446b7..220f05c7ff1c 100644 --- a/tests/v1/sample/test_topk_topp_sampler.py +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -16,31 +16,40 @@ FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available +@pytest.fixture(autouse=True) +def reset_default_device(): + """ + Explicitly set the default device, which can affect subsequent tests. + Adding this fixture helps avoid this problem. + """ + original_device = torch.get_default_device() + yield + torch.set_default_device(original_device) + + def test_topk_impl_equivalance(): - with torch.device(DEVICE): - generator = Generator(device=DEVICE).manual_seed(33) + torch.set_default_device(DEVICE) + generator = Generator(device=DEVICE).manual_seed(33) - logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) + logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) - # Random top-k values between 1 and 9. - k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator) + # Random top-k values between 1 and 9. + k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator) - # Set k=vocab_size for ~50% of requests in the batch (top-k disabled). - k.masked_fill_( - torch.randint(0, - 2, (BATCH_SIZE, ), - generator=generator, - dtype=bool), VOCAB_SIZE) + # Set k=vocab_size for ~50% of requests in the batch (top-k disabled). + k.masked_fill_( + torch.randint(0, 2, (BATCH_SIZE, ), generator=generator, dtype=bool), + VOCAB_SIZE) - # Top-k only implementation - result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None) + # Top-k only implementation + result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None) - # Top-p + top-k - no_op_top_p = torch.tensor([1.0]) - result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p) + # Top-p + top-k + no_op_top_p = torch.tensor([1.0]) + result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p) - assert torch.allclose(result1, result2) + assert torch.allclose(result1, result2) def test_flashinfer_sampler(): @@ -58,50 +67,49 @@ def test_flashinfer_sampler(): pytest.skip( "FlashInfer not installed or not available on this platform.") - with torch.device(DEVICE): - generator = Generator(device=DEVICE).manual_seed(42) - - # Generate random logits - logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) - - # Generate various top-k and top-p values - k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator) - p_values = torch.rand( - (BATCH_SIZE, ), - generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0] - - # Sometimes disable top-k (k=vocab_size) - k_values.masked_fill_( - torch.randint(0, - 2, (BATCH_SIZE, ), - generator=generator, - dtype=torch.bool), VOCAB_SIZE) - - # Sometimes disable top-p (p=1.0) - p_values.masked_fill_( - torch.randint(0, - 2, (BATCH_SIZE, ), - generator=generator, - dtype=torch.bool), 1.0) - - python_logits = apply_top_k_top_p( - logits=logits.clone(), - k=k_values, - p=p_values, - ) - python_probs = torch.softmax(python_logits, dim=-1) - - # FlashInfer only exposed renorm interfaces for probs so convert first - flashinfer_probs = torch.softmax(logits.clone(), dim=-1) - flashinfer_probs = top_k_renorm_probs( - probs=flashinfer_probs, - top_k=k_values, - ) - flashinfer_probs = top_p_renorm_probs( - probs=flashinfer_probs, - top_p=p_values, - ) - - # Compare the results - assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \ - "FlashInfer and Python sampling implementations do not match!" + torch.set_default_device(DEVICE) + generator = Generator(device=DEVICE).manual_seed(42) + + # Generate random logits + logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) + + # Generate various top-k and top-p values + k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator) + p_values = torch.rand( + (BATCH_SIZE, ), generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0] + + # Sometimes disable top-k (k=vocab_size) + k_values.masked_fill_( + torch.randint(0, + 2, (BATCH_SIZE, ), + generator=generator, + dtype=torch.bool), VOCAB_SIZE) + + # Sometimes disable top-p (p=1.0) + p_values.masked_fill_( + torch.randint(0, + 2, (BATCH_SIZE, ), + generator=generator, + dtype=torch.bool), 1.0) + + python_logits = apply_top_k_top_p( + logits=logits.clone(), + k=k_values, + p=p_values, + ) + python_probs = torch.softmax(python_logits, dim=-1) + + # FlashInfer only exposed renorm interfaces for probs so convert first + flashinfer_probs = torch.softmax(logits.clone(), dim=-1) + flashinfer_probs = top_k_renorm_probs( + probs=flashinfer_probs, + top_k=k_values, + ) + flashinfer_probs = top_p_renorm_probs( + probs=flashinfer_probs, + top_p=p_values, + ) + + # Compare the results + assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \ + "FlashInfer and Python sampling implementations do not match!"