Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/test/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,5 @@ runs:
CUSTOM_HEADERS: ${{ inputs.custom_headers }}
run: |
mkdir -p ~/.rai
python -m unittest
python -m unittest -v
shell: bash
17 changes: 11 additions & 6 deletions railib/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,27 +331,32 @@ def poll_with_specified_overhead(
max_tries: int = None,
max_delay: int = 120,
):
if overhead_rate < 0:
raise ValueError("overhead_rate must be non-negative")

if start_time is None:
start_time = time.time()

tries = 0
max_time = time.time() + timeout if timeout else None

while True:
if f():
break

current_time = time.time()

if max_tries is not None and tries >= max_tries:
raise Exception(f'max tries {max_tries} exhausted')

if max_time is not None and time.time() >= max_time:
if max_time is not None and current_time >= max_time:
raise Exception(f'timed out after {timeout} seconds')

duration = (current_time - start_time) * overhead_rate
duration = min(duration, max_delay)

time.sleep(duration)
tries += 1
duration = min((time.time() - start_time) * overhead_rate, max_delay)
if tries == 1:
time.sleep(0.5)
else:
time.sleep(duration)


def is_engine_term_state(state: str) -> bool:
Expand Down
43 changes: 43 additions & 0 deletions test/test_unit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import socket
import time
import unittest
from unittest.mock import patch, MagicMock
from urllib.error import URLError
Expand Down Expand Up @@ -27,6 +28,48 @@ def test_validation(self):
api.poll_with_specified_overhead(lambda: True, overhead_rate=0.1, max_tries=1)
api.poll_with_specified_overhead(lambda: True, overhead_rate=0.1, timeout=1, max_tries=1)

@patch('time.sleep', return_value=None)
@patch('time.time')
def test_initial_delay(self, mock_time, mock_sleep):
start_time = 100 # Fixed start time
increment_time = 0.0001
mock_time.side_effect = [start_time, start_time + increment_time] # Simulate time progression

try:
api.poll_with_specified_overhead(lambda: False, overhead_rate=0.1, max_tries=2)
except Exception as e:
pass # Ignore the exception for this test

expected_sleep_time = (start_time + increment_time - start_time) * 0.1
mock_sleep.assert_called_with(expected_sleep_time)

def test_max_delay_cap(self):
with patch('time.sleep') as mock_sleep:
try:
api.poll_with_specified_overhead(lambda: False, overhead_rate=1, max_tries=2, start_time=time.time() - 200)
except Exception:
pass
# Ensure that the maximum delay of 120 seconds is not exceeded
mock_sleep.assert_called_with(120)

def test_polling_success(self):
with patch('time.sleep', return_value=None) as mock_sleep:
api.poll_with_specified_overhead(lambda: True, overhead_rate=0.1)
mock_sleep.assert_not_called()

def test_negative_overhead_rate(self):
with self.assertRaises(ValueError):
api.poll_with_specified_overhead(lambda: True, overhead_rate=-0.1)

def test_realistic_scenario(self):
responses = [False, False, True]
def side_effect():
return responses.pop(0)

with patch('time.sleep') as mock_sleep:
api.poll_with_specified_overhead(side_effect, overhead_rate=0.1, max_tries=3)
# Ensure that `time.sleep` was called twice (for two false returns)
self.assertEqual(mock_sleep.call_count, 2)

@patch('railib.rest.urlopen')
class TestURLOpenWithRetry(unittest.TestCase):
Expand Down