Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ jobs:
CLIENT_SECRET: ${{ secrets.CLIENT_SECRET }}
CLIENT_CREDENTIALS_URL: ${{ secrets.CLIENT_CREDENTIALS_URL }}
run: |
python tests/integration.py
python -m unittest
15 changes: 2 additions & 13 deletions examples/create_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,16 @@

from argparse import ArgumentParser
import json
import time
from urllib.request import HTTPError
from railib import api, config, show
from railib.api import EngineSize


# Answers if the given state is a terminal state.
def is_term_state(state: str) -> bool:
return state == "PROVISIONED" or ("FAILED" in state)


def run(engine: str, size: str, profile: str):
cfg = config.read(profile=profile)
ctx = api.Context(**cfg)
rsp = api.create_engine(ctx, engine, EngineSize(size))
while True: # wait for request to reach terminal state
time.sleep(3)
rsp = api.get_engine(ctx, engine)
if is_term_state(rsp["state"]):
break
print(json.dumps(rsp, indent=2))
api.create_engine_wait(ctx, engine, EngineSize(size))
print(json.dumps(api.get_engine(ctx, engine), indent=2))


if __name__ == "__main__":
Expand Down
67 changes: 58 additions & 9 deletions railib/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,42 @@ def _parse_arrow_results(files: List[TransactionAsyncFile]):
results.append({"relationId": file.name, "table": table})
return results

# polling with specified overhead
# delay is the overhead % of the time the transaction has been running so far


def poll_with_specified_overhead(
f,
overhead_rate: float,
start_time: int = time.time(),
timeout: int = None,
max_tries: int = None,
max_delay: int = 120,
):
tries = 0
max_time = time.time() + timeout if timeout else None

while True:
if f():
break

if max_tries is not None and tries >= max_tries:
raise Exception(f'max tries {max_tries} exhausted')

if max_time is not None and time.time() >= max_time:
raise Exception(f'timed out after {timeout} seconds')

tries += 1
duration = min((time.time() - start_time) * overhead_rate, max_delay)
if tries == 1:
time.sleep(0.5)
else:
time.sleep(duration)


def is_engine_term_state(state: str) -> bool:
return state == "PROVISIONED" or ("FAILED" in state)


def create_engine(ctx: Context, engine: str, size: EngineSize = EngineSize.XS, **kwargs):
data = {"region": ctx.region, "name": engine, "size": size.value}
Expand All @@ -333,6 +369,16 @@ def create_engine(ctx: Context, engine: str, size: EngineSize = EngineSize.XS, *
return json.loads(rsp.read())


def create_engine_wait(ctx: Context, engine: str, size: EngineSize = EngineSize.XS, **kwargs):
create_engine(ctx, engine, size, **kwargs)
poll_with_specified_overhead(
lambda: is_engine_term_state(get_engine(ctx, engine)["state"]),
overhead_rate=0.2,
timeout=30 * 60,
)
return get_engine(ctx, engine)


def create_user(ctx: Context, email: str, roles: List[Role] = None, **kwargs):
rs = roles or []
data = {"email": email, "roles": [r.value for r in rs]}
Expand Down Expand Up @@ -836,6 +882,7 @@ def exec(
readonly: bool = True,
**kwargs
) -> TransactionAsyncResponse:
start_time = time.time()
txn = exec_async(ctx, database, engine, command, inputs=inputs, readonly=readonly)
# in case of if short-path, return results directly, no need to poll for
# state
Expand All @@ -844,15 +891,17 @@ def exec(

rsp = TransactionAsyncResponse()
txn = get_transaction(ctx, txn.transaction["id"], **kwargs)
while True:
time.sleep(1)
txn = get_transaction(ctx, txn["id"], **kwargs)
if is_txn_term_state(txn["state"]):
rsp.transaction = txn
rsp.metadata = get_transaction_metadata(ctx, txn["id"], **kwargs)
rsp.problems = get_transaction_problems(ctx, txn["id"], **kwargs)
rsp.results = get_transaction_results(ctx, txn["id"], **kwargs)
break

poll_with_specified_overhead(
lambda: is_txn_term_state(get_transaction(ctx, txn["id"], **kwargs)["state"]),
overhead_rate=0.2,
start_time=start_time,
)

rsp.transaction = get_transaction(ctx, txn["id"], **kwargs)
rsp.metadata = get_transaction_metadata(ctx, txn["id"], **kwargs)
rsp.problems = get_transaction_problems(ctx, txn["id"], **kwargs)
rsp.results = get_transaction_results(ctx, txn["id"], **kwargs)

return rsp

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
ed25519==1.5
grpcio-tools==1.47.0
protobuf==3.20.1
protobuf==3.20.2
pyarrow==6.0.1
requests-toolbelt==0.9.1

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"ed25519==1.5",
"pyarrow>=6.0.1",
"requests-toolbelt==0.9.1",
"protobuf==3.20.1"],
"protobuf==3.20.2"],
license="http://www.apache.org/licenses/LICENSE-2.0",
long_description="Enables access to the RelationalAI REST APIs from Python",
long_description_content_type="text/markdown",
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
23 changes: 4 additions & 19 deletions tests/integration.py → test/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
from time import sleep
import unittest
import os
import uuid
Expand All @@ -8,22 +7,6 @@
from pathlib import Path
from railib import api, config

# TODO: create_engine_wait should be added to API
# with exponential backoff


def create_engine_wait(ctx: api.Context, engine: str):
state = api.create_engine(ctx, engine, headers=custom_headers)["compute"]["state"]

count = 0
while not ("PROVISIONED" == state):
if count > 12:
return

count += 1
sleep(30)
state = api.get_engine(ctx, engine)["state"]


# Get creds from env vars if exists
client_id = os.getenv("CLIENT_ID")
Expand Down Expand Up @@ -58,8 +41,10 @@ def create_engine_wait(ctx: api.Context, engine: str):

class TestTransactionAsync(unittest.TestCase):
def setUp(self):
create_engine_wait(ctx, engine)
api.create_database(ctx, dbname)
rsp = api.create_engine_wait(ctx, engine, headers=custom_headers)
self.assertEqual("PROVISIONED", rsp["state"])
rsp = api.create_database(ctx, dbname)
self.assertEqual("CREATED", rsp["database"]["state"])

def test_v2_exec(self):
cmd = "x, x^2, x^3, x^4 from x in {1; 2; 3; 4; 5}"
Expand Down
27 changes: 27 additions & 0 deletions test/test_unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import unittest

from railib import api


class TestPolling(unittest.TestCase):
def test_timeout_exception(self):
try:
api.poll_with_specified_overhead(lambda: False, overhead_rate=0.1, timeout=1)
except Exception as e:
self.assertEqual('timed out after 1 seconds', str(e))

def test_max_tries_exception(self):
try:
api.poll_with_specified_overhead(lambda: False, overhead_rate=0.1, max_tries=1)
except Exception as e:
self.assertEqual('max tries 1 exhausted', str(e))

def test_validation(self):
api.poll_with_specified_overhead(lambda: True, overhead_rate=0.1)
api.poll_with_specified_overhead(lambda: True, overhead_rate=0.1, timeout=1)
api.poll_with_specified_overhead(lambda: True, overhead_rate=0.1, max_tries=1)
api.poll_with_specified_overhead(lambda: True, overhead_rate=0.1, timeout=1, max_tries=1)


if __name__ == '__main__':
unittest.main()