From c6d3bfa219d435303651c7bc9188af780ed5cc5f Mon Sep 17 00:00:00 2001 From: "helmi.nour" Date: Wed, 26 Oct 2022 19:49:48 +0100 Subject: [PATCH 1/8] add polling to python sdk --- examples/create_engine.py | 15 ++-------- railib/api.py | 62 +++++++++++++++++++++++++++++++++------ requirements.txt | 2 +- setup.py | 2 +- tests/integration.py | 23 +++------------ 5 files changed, 61 insertions(+), 43 deletions(-) diff --git a/examples/create_engine.py b/examples/create_engine.py index f925102..456667c 100644 --- a/examples/create_engine.py +++ b/examples/create_engine.py @@ -16,27 +16,16 @@ from argparse import ArgumentParser import json -import time from urllib.request import HTTPError from railib import api, config, show from railib.api import EngineSize -# Answers if the given state is a terminal state. -def is_term_state(state: str) -> bool: - return state == "PROVISIONED" or ("FAILED" in state) - - def run(engine: str, size: str, profile: str): cfg = config.read(profile=profile) ctx = api.Context(**cfg) - rsp = api.create_engine(ctx, engine, EngineSize(size)) - while True: # wait for request to reach terminal state - time.sleep(3) - rsp = api.get_engine(ctx, engine) - if is_term_state(rsp["state"]): - break - print(json.dumps(rsp, indent=2)) + api.create_engine_wait(ctx, engine, EngineSize(size)) + print(json.dumps(api.get_engine(ctx, engine), indent=2)) if __name__ == "__main__": diff --git a/railib/api.py b/railib/api.py index 5ffcc53..0e594cf 100644 --- a/railib/api.py +++ b/railib/api.py @@ -19,6 +19,7 @@ import time import re import io +from datetime import datetime, timezone from enum import Enum, unique from typing import List, Union from requests_toolbelt import multipart @@ -325,6 +326,38 @@ def _parse_arrow_results(files: List[TransactionAsyncFile]): results.append({"relationId": file.name, "table": table}) return results +# polling with specified overhead +# delay is the overhead % of the time the transaction has been running so far + + +def poll_with_specified_overhead( + f, + overhead_rate: float = 0.1, + start_time: datetime = datetime.now(timezone.utc), + timeout: int = None, + max_retries: int = None +): + retries = 0 + max_time = time.time() + timeout if timeout else None + + while True: + if max_retries is not None and retries >= max_retries: + raise Exception(f'max retries {max_retries} exhausted') + + if max_time is not None and time.time() >= max_time: + raise Exception(f'timed out after {timeout} seconds') + + if f(): + break + + retries += 1 + duration = (datetime.now(timezone.utc) - start_time).total_seconds() * overhead_rate + time.sleep(duration) + + +def is_engine_term_state(state: str) -> bool: + return state == "PROVISIONED" or ("FAILED" in state) + def create_engine(ctx: Context, engine: str, size: EngineSize = EngineSize.XS, **kwargs): data = {"region": ctx.region, "name": engine, "size": size.value} @@ -333,6 +366,15 @@ def create_engine(ctx: Context, engine: str, size: EngineSize = EngineSize.XS, * return json.loads(rsp.read()) +def create_engine_wait(ctx: Context, engine: str, size: EngineSize = EngineSize.XS, **kwargs): + create_engine(ctx, engine, size, **kwargs) + poll_with_specified_overhead( + lambda: is_engine_term_state(get_engine(ctx, engine)["state"]), + timeout=30 * 60, + ) + return get_engine(ctx, engine) + + def create_user(ctx: Context, email: str, roles: List[Role] = None, **kwargs): rs = roles or [] data = {"email": email, "roles": [r.value for r in rs]} @@ -844,15 +886,17 @@ def exec( rsp = TransactionAsyncResponse() txn = get_transaction(ctx, txn.transaction["id"], **kwargs) - while True: - time.sleep(1) - txn = get_transaction(ctx, txn["id"], **kwargs) - if is_txn_term_state(txn["state"]): - rsp.transaction = txn - rsp.metadata = get_transaction_metadata(ctx, txn["id"], **kwargs) - rsp.problems = get_transaction_problems(ctx, txn["id"], **kwargs) - rsp.results = get_transaction_results(ctx, txn["id"], **kwargs) - break + start_time = datetime.fromtimestamp(txn["created_on"] / 1000, tz=timezone.utc) + + poll_with_specified_overhead( + lambda: is_txn_term_state(get_transaction(ctx, txn["id"], **kwargs)["state"]), + start_time=start_time, + ) + + rsp.transaction = get_transaction(ctx, txn["id"], **kwargs) + rsp.metadata = get_transaction_metadata(ctx, txn["id"], **kwargs) + rsp.problems = get_transaction_problems(ctx, txn["id"], **kwargs) + rsp.results = get_transaction_results(ctx, txn["id"], **kwargs) return rsp diff --git a/requirements.txt b/requirements.txt index 744088b..788c556 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ ed25519==1.5 grpcio-tools==1.47.0 -protobuf==3.20.1 +protobuf==3.20.2 pyarrow==6.0.1 requests-toolbelt==0.9.1 diff --git a/setup.py b/setup.py index 0b10bde..e428a84 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ "ed25519==1.5", "pyarrow>=6.0.1", "requests-toolbelt==0.9.1", - "protobuf==3.20.1"], + "protobuf==3.20.2"], license="http://www.apache.org/licenses/LICENSE-2.0", long_description="Enables access to the RelationalAI REST APIs from Python", long_description_content_type="text/markdown", diff --git a/tests/integration.py b/tests/integration.py index a882b38..bce80f1 100644 --- a/tests/integration.py +++ b/tests/integration.py @@ -1,5 +1,4 @@ import json -from time import sleep import unittest import os import uuid @@ -8,22 +7,6 @@ from pathlib import Path from railib import api, config -# TODO: create_engine_wait should be added to API -# with exponential backoff - - -def create_engine_wait(ctx: api.Context, engine: str): - state = api.create_engine(ctx, engine, headers=custom_headers)["compute"]["state"] - - count = 0 - while not ("PROVISIONED" == state): - if count > 12: - return - - count += 1 - sleep(30) - state = api.get_engine(ctx, engine)["state"] - # Get creds from env vars if exists client_id = os.getenv("CLIENT_ID") @@ -58,8 +41,10 @@ def create_engine_wait(ctx: api.Context, engine: str): class TestTransactionAsync(unittest.TestCase): def setUp(self): - create_engine_wait(ctx, engine) - api.create_database(ctx, dbname) + rsp = api.create_engine_wait(ctx, engine, headers=custom_headers) + self.assertEqual("PROVISIONED", rsp["state"]) + rsp = api.create_database(ctx, dbname) + self.assertEqual("CREATED", rsp["database"]["state"]) def test_v2_exec(self): cmd = "x, x^2, x^3, x^4 from x in {1; 2; 3; 4; 5}" From 119af89f744fd660ad82d5104edb76327b138c1d Mon Sep 17 00:00:00 2001 From: "helmi.nour" Date: Thu, 27 Oct 2022 02:20:58 +0100 Subject: [PATCH 2/8] addressing PR comments --- .github/workflows/build.yaml | 2 +- railib/api.py | 10 +++++----- tests/unit_tests.py | 26 ++++++++++++++++++++++++++ 3 files changed, 32 insertions(+), 6 deletions(-) create mode 100644 tests/unit_tests.py diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 6ffd9ed..f094c77 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -35,4 +35,4 @@ jobs: CLIENT_SECRET: ${{ secrets.CLIENT_SECRET }} CLIENT_CREDENTIALS_URL: ${{ secrets.CLIENT_CREDENTIALS_URL }} run: | - python tests/integration.py + python tests/*.py diff --git a/railib/api.py b/railib/api.py index 0e594cf..3cc9e73 100644 --- a/railib/api.py +++ b/railib/api.py @@ -335,14 +335,14 @@ def poll_with_specified_overhead( overhead_rate: float = 0.1, start_time: datetime = datetime.now(timezone.utc), timeout: int = None, - max_retries: int = None + max_tries: int = None ): - retries = 0 + tries = 0 max_time = time.time() + timeout if timeout else None while True: - if max_retries is not None and retries >= max_retries: - raise Exception(f'max retries {max_retries} exhausted') + if max_tries is not None and tries >= max_tries: + raise Exception(f'max tries {max_tries} exhausted') if max_time is not None and time.time() >= max_time: raise Exception(f'timed out after {timeout} seconds') @@ -350,7 +350,7 @@ def poll_with_specified_overhead( if f(): break - retries += 1 + tries += 1 duration = (datetime.now(timezone.utc) - start_time).total_seconds() * overhead_rate time.sleep(duration) diff --git a/tests/unit_tests.py b/tests/unit_tests.py new file mode 100644 index 0000000..c8d20b4 --- /dev/null +++ b/tests/unit_tests.py @@ -0,0 +1,26 @@ +import unittest + +from railib import api + +class TestPolling(unittest.TestCase): + def test_timeout_exception(self): + try: + api.poll_with_specified_overhead(lambda: False, timeout=1) + except Exception as e: + self.assertEqual('timed out after 1 seconds', str(e)) + + def test_max_tries_exception(self): + try: + api.poll_with_specified_overhead(lambda: False, max_tries=1) + except Exception as e: + self.assertEqual('max tries 1 exhausted', str(e)) + + def test_validation(self): + api.poll_with_specified_overhead(lambda: True) + api.poll_with_specified_overhead(lambda: True, timeout=1) + api.poll_with_specified_overhead(lambda: True, max_tries=1) + api.poll_with_specified_overhead(lambda: True, timeout=1, max_tries=1) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From f59a209e2ca37abd4234e426bd679f5572c148d3 Mon Sep 17 00:00:00 2001 From: "helmi.nour" Date: Thu, 27 Oct 2022 02:22:46 +0100 Subject: [PATCH 3/8] fix linter --- tests/unit_tests.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests.py b/tests/unit_tests.py index c8d20b4..3c38e23 100644 --- a/tests/unit_tests.py +++ b/tests/unit_tests.py @@ -2,6 +2,7 @@ from railib import api + class TestPolling(unittest.TestCase): def test_timeout_exception(self): try: @@ -23,4 +24,4 @@ def test_validation(self): if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() From 1dfef991889266c3453a3f2e023d3d647df8494c Mon Sep 17 00:00:00 2001 From: "helmi.nour" Date: Thu, 27 Oct 2022 02:26:39 +0100 Subject: [PATCH 4/8] run all tests --- .github/workflows/build.yaml | 2 +- {tests => test}/__init__.py | 0 {tests => test}/metadata.pb | 0 {tests => test}/metadata.pb.txt | 0 tests/integration.py => test/test_integration.py | 0 tests/unit_tests.py => test/test_unit.py | 0 6 files changed, 1 insertion(+), 1 deletion(-) rename {tests => test}/__init__.py (100%) rename {tests => test}/metadata.pb (100%) rename {tests => test}/metadata.pb.txt (100%) rename tests/integration.py => test/test_integration.py (100%) rename tests/unit_tests.py => test/test_unit.py (100%) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index f094c77..754d91e 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -35,4 +35,4 @@ jobs: CLIENT_SECRET: ${{ secrets.CLIENT_SECRET }} CLIENT_CREDENTIALS_URL: ${{ secrets.CLIENT_CREDENTIALS_URL }} run: | - python tests/*.py + python -m unittest diff --git a/tests/__init__.py b/test/__init__.py similarity index 100% rename from tests/__init__.py rename to test/__init__.py diff --git a/tests/metadata.pb b/test/metadata.pb similarity index 100% rename from tests/metadata.pb rename to test/metadata.pb diff --git a/tests/metadata.pb.txt b/test/metadata.pb.txt similarity index 100% rename from tests/metadata.pb.txt rename to test/metadata.pb.txt diff --git a/tests/integration.py b/test/test_integration.py similarity index 100% rename from tests/integration.py rename to test/test_integration.py diff --git a/tests/unit_tests.py b/test/test_unit.py similarity index 100% rename from tests/unit_tests.py rename to test/test_unit.py From 4b78b1fc8e4e8118eee313592f2e151d1c238613 Mon Sep 17 00:00:00 2001 From: "helmi.nour" Date: Fri, 28 Oct 2022 17:37:38 +0100 Subject: [PATCH 5/8] updated polling policies --- railib/api.py | 21 +++++++++++++-------- test/test_unit.py | 12 ++++++------ 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/railib/api.py b/railib/api.py index 3cc9e73..f89ae89 100644 --- a/railib/api.py +++ b/railib/api.py @@ -332,27 +332,31 @@ def _parse_arrow_results(files: List[TransactionAsyncFile]): def poll_with_specified_overhead( f, - overhead_rate: float = 0.1, - start_time: datetime = datetime.now(timezone.utc), + overhead_rate: float, + start_time: int = time.time(), timeout: int = None, - max_tries: int = None + max_tries: int = None, + max_delay: int = 120, ): tries = 0 max_time = time.time() + timeout if timeout else None while True: + if f(): + break + if max_tries is not None and tries >= max_tries: raise Exception(f'max tries {max_tries} exhausted') if max_time is not None and time.time() >= max_time: raise Exception(f'timed out after {timeout} seconds') - if f(): - break - tries += 1 - duration = (datetime.now(timezone.utc) - start_time).total_seconds() * overhead_rate - time.sleep(duration) + duration = min((time.time() - start_time) * overhead_rate, max_delay) + if tries == 1: + time.sleep(0.5) + else: + time.sleep(duration) def is_engine_term_state(state: str) -> bool: @@ -370,6 +374,7 @@ def create_engine_wait(ctx: Context, engine: str, size: EngineSize = EngineSize. create_engine(ctx, engine, size, **kwargs) poll_with_specified_overhead( lambda: is_engine_term_state(get_engine(ctx, engine)["state"]), + overhead_rate=0.2, timeout=30 * 60, ) return get_engine(ctx, engine) diff --git a/test/test_unit.py b/test/test_unit.py index 3c38e23..cf885d2 100644 --- a/test/test_unit.py +++ b/test/test_unit.py @@ -6,21 +6,21 @@ class TestPolling(unittest.TestCase): def test_timeout_exception(self): try: - api.poll_with_specified_overhead(lambda: False, timeout=1) + api.poll_with_specified_overhead(lambda: False, overhead_rate=0.1, timeout=1) except Exception as e: self.assertEqual('timed out after 1 seconds', str(e)) def test_max_tries_exception(self): try: - api.poll_with_specified_overhead(lambda: False, max_tries=1) + api.poll_with_specified_overhead(lambda: False, overhead_rate=0.1, max_tries=1) except Exception as e: self.assertEqual('max tries 1 exhausted', str(e)) def test_validation(self): - api.poll_with_specified_overhead(lambda: True) - api.poll_with_specified_overhead(lambda: True, timeout=1) - api.poll_with_specified_overhead(lambda: True, max_tries=1) - api.poll_with_specified_overhead(lambda: True, timeout=1, max_tries=1) + api.poll_with_specified_overhead(lambda: True, overhead_rate=0.1) + api.poll_with_specified_overhead(lambda: True, overhead_rate=0.1, timeout=1) + api.poll_with_specified_overhead(lambda: True, overhead_rate=0.1, max_tries=1) + api.poll_with_specified_overhead(lambda: True, overhead_rate=0.1, timeout=1, max_tries=1) if __name__ == '__main__': From 6ffbe74b848a6072d1b8d1f5b07d1171744bff39 Mon Sep 17 00:00:00 2001 From: "helmi.nour" Date: Fri, 28 Oct 2022 17:43:35 +0100 Subject: [PATCH 6/8] fix --- railib/api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/railib/api.py b/railib/api.py index f89ae89..5b9284f 100644 --- a/railib/api.py +++ b/railib/api.py @@ -895,6 +895,7 @@ def exec( poll_with_specified_overhead( lambda: is_txn_term_state(get_transaction(ctx, txn["id"], **kwargs)["state"]), + overhead_rate=0.2, start_time=start_time, ) From 31f30a45cb9e77f15e87024ca2f917cf619a6680 Mon Sep 17 00:00:00 2001 From: "helmi.nour" Date: Fri, 28 Oct 2022 17:46:59 +0100 Subject: [PATCH 7/8] start_time is a timestamp --- railib/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/railib/api.py b/railib/api.py index 5b9284f..dd5a07c 100644 --- a/railib/api.py +++ b/railib/api.py @@ -883,6 +883,7 @@ def exec( readonly: bool = True, **kwargs ) -> TransactionAsyncResponse: + start_time = time.time() txn = exec_async(ctx, database, engine, command, inputs=inputs, readonly=readonly) # in case of if short-path, return results directly, no need to poll for # state @@ -891,7 +892,6 @@ def exec( rsp = TransactionAsyncResponse() txn = get_transaction(ctx, txn.transaction["id"], **kwargs) - start_time = datetime.fromtimestamp(txn["created_on"] / 1000, tz=timezone.utc) poll_with_specified_overhead( lambda: is_txn_term_state(get_transaction(ctx, txn["id"], **kwargs)["state"]), From 818310813fea9055aad1a476aed653ffc4f94734 Mon Sep 17 00:00:00 2001 From: "helmi.nour" Date: Fri, 28 Oct 2022 17:47:21 +0100 Subject: [PATCH 8/8] cleanup --- railib/api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/railib/api.py b/railib/api.py index dd5a07c..378b53d 100644 --- a/railib/api.py +++ b/railib/api.py @@ -19,7 +19,6 @@ import time import re import io -from datetime import datetime, timezone from enum import Enum, unique from typing import List, Union from requests_toolbelt import multipart