Skip to content

[Serve] Use safe_cursor for serve state #3299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 117 additions & 107 deletions sky/serve/serve_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,15 @@ def add_service(name: str, controller_job_id: int, policy: str, version: int,
exists.
"""
try:
_DB.cursor.execute(
"""\
INSERT INTO services
(name, controller_job_id, status, policy,
requested_resources_str, current_version)
VALUES (?, ?, ?, ?, ?, ?)""",
(name, controller_job_id, status.value, policy,
requested_resources_str, version))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
INSERT INTO services
(name, controller_job_id, status, policy,
requested_resources_str, current_version)
VALUES (?, ?, ?, ?, ?, ?)""",
(name, controller_job_id, status.value, policy,
requested_resources_str, version))
except sqlite3.IntegrityError as e:
if str(e) != _UNIQUE_CONSTRAINT_FAILED_ERROR_MSG:
raise RuntimeError('Unexpected database error') from e
Expand All @@ -232,48 +232,49 @@ def add_service(name: str, controller_job_id: int, policy: str, version: int,

def remove_service(service_name: str) -> None:
"""Removes a service from the database."""
_DB.cursor.execute("""\
DELETE FROM services WHERE name=(?)""", (service_name,))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute("""\
DELETE FROM services WHERE name=(?)""", (service_name,))


def set_service_uptime(service_name: str, uptime: int) -> None:
"""Sets the uptime of a service."""
_DB.cursor.execute(
"""\
UPDATE services SET
uptime=(?) WHERE name=(?)""", (uptime, service_name))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
UPDATE services SET
uptime=(?) WHERE name=(?)""", (uptime, service_name))


def set_service_status(service_name: str, status: ServiceStatus) -> None:
"""Sets the service status."""
_DB.cursor.execute(
"""\
UPDATE services SET
status=(?) WHERE name=(?)""", (status.value, service_name))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
UPDATE services SET
status=(?) WHERE name=(?)""", (status.value, service_name))


def set_service_controller_port(service_name: str,
controller_port: int) -> None:
"""Sets the controller port of a service."""
_DB.cursor.execute(
"""\
UPDATE services SET
controller_port=(?) WHERE name=(?)""", (controller_port, service_name))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
UPDATE services SET
controller_port=(?) WHERE name=(?)""",
(controller_port, service_name))


def set_service_load_balancer_port(service_name: str,
load_balancer_port: int) -> None:
"""Sets the load balancer port of a service."""
_DB.cursor.execute(
"""\
UPDATE services SET
load_balancer_port=(?) WHERE name=(?)""",
(load_balancer_port, service_name))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
UPDATE services SET
load_balancer_port=(?) WHERE name=(?)""",
(load_balancer_port, service_name))


def _get_service_from_row(row) -> Dict[str, Any]:
Expand All @@ -299,7 +300,8 @@ def _get_service_from_row(row) -> Dict[str, Any]:

def get_services() -> List[Dict[str, Any]]:
"""Get all existing service records."""
rows = _DB.cursor.execute('SELECT * FROM services').fetchall()
with db_utils.safe_cursor(_DB_PATH) as cursor:
rows = cursor.execute('SELECT * FROM services').fetchall()
records = []
for row in rows:
records.append(_get_service_from_row(row))
Expand All @@ -308,7 +310,8 @@ def get_services() -> List[Dict[str, Any]]:

def get_service_from_name(service_name: str) -> Optional[Dict[str, Any]]:
"""Get all existing service records."""
rows = _DB.cursor.execute('SELECT * FROM services WHERE name=(?)',
with db_utils.safe_cursor(_DB_PATH) as cursor:
rows = cursor.execute('SELECT * FROM services WHERE name=(?)',
(service_name,)).fetchall()
for row in rows:
return _get_service_from_row(row)
Expand All @@ -317,10 +320,11 @@ def get_service_from_name(service_name: str) -> Optional[Dict[str, Any]]:

def get_service_versions(service_name: str) -> List[int]:
"""Gets all versions of a service."""
rows = _DB.cursor.execute(
"""\
SELECT DISTINCT version FROM version_specs
WHERE service_name=(?)""", (service_name,)).fetchall()
with db_utils.safe_cursor(_DB_PATH) as cursor:
rows = cursor.execute(
"""\
SELECT DISTINCT version FROM version_specs
WHERE service_name=(?)""", (service_name,)).fetchall()
return [row[0] for row in rows]


Expand All @@ -335,50 +339,52 @@ def get_glob_service_names(
Returns:
A list of non-duplicated service names.
"""
if service_names is None:
rows = _DB.cursor.execute('SELECT name FROM services').fetchall()
else:
rows = []
for service_name in service_names:
rows.extend(
_DB.cursor.execute(
'SELECT name FROM services WHERE name GLOB (?)',
(service_name,)).fetchall())
with db_utils.safe_cursor(_DB_PATH) as cursor:
if service_names is None:
rows = cursor.execute('SELECT name FROM services').fetchall()
else:
rows = []
for service_name in service_names:
rows.extend(
cursor.execute(
'SELECT name FROM services WHERE name GLOB (?)',
(service_name,)).fetchall())
return list({row[0] for row in rows})


# === Replica functions ===
def add_or_update_replica(service_name: str, replica_id: int,
replica_info: 'replica_managers.ReplicaInfo') -> None:
"""Adds a replica to the database."""
_DB.cursor.execute(
"""\
INSERT OR REPLACE INTO replicas
(service_name, replica_id, replica_info)
VALUES (?, ?, ?)""",
(service_name, replica_id, pickle.dumps(replica_info)))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
INSERT OR REPLACE INTO replicas
(service_name, replica_id, replica_info)
VALUES (?, ?, ?)""",
(service_name, replica_id, pickle.dumps(replica_info)))


def remove_replica(service_name: str, replica_id: int) -> None:
"""Removes a replica from the database."""
_DB.cursor.execute(
"""\
DELETE FROM replicas
WHERE service_name=(?)
AND replica_id=(?)""", (service_name, replica_id))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
DELETE FROM replicas
WHERE service_name=(?)
AND replica_id=(?)""", (service_name, replica_id))


def get_replica_info_from_id(
service_name: str,
replica_id: int) -> Optional['replica_managers.ReplicaInfo']:
"""Gets a replica info from the database."""
rows = _DB.cursor.execute(
"""\
SELECT replica_info FROM replicas
WHERE service_name=(?)
AND replica_id=(?)""", (service_name, replica_id)).fetchall()
with db_utils.safe_cursor(_DB_PATH) as cursor:
rows = cursor.execute(
"""\
SELECT replica_info FROM replicas
WHERE service_name=(?)
AND replica_id=(?)""", (service_name, replica_id)).fetchall()
for row in rows:
return pickle.loads(row[0])
return None
Expand All @@ -387,16 +393,18 @@ def get_replica_info_from_id(
def get_replica_infos(
service_name: str) -> List['replica_managers.ReplicaInfo']:
"""Gets all replica infos of a service."""
rows = _DB.cursor.execute(
"""\
SELECT replica_info FROM replicas
WHERE service_name=(?)""", (service_name,)).fetchall()
with db_utils.safe_cursor(_DB_PATH) as cursor:
rows = cursor.execute(
"""\
SELECT replica_info FROM replicas
WHERE service_name=(?)""", (service_name,)).fetchall()
return [pickle.loads(row[0]) for row in rows]


def total_number_provisioning_replicas() -> int:
"""Returns the total number of provisioning replicas."""
rows = _DB.cursor.execute('SELECT replica_info FROM replicas').fetchall()
with db_utils.safe_cursor(_DB_PATH) as cursor:
rows = cursor.execute('SELECT replica_info FROM replicas').fetchall()
provisioning_count = 0
for row in rows:
replica_info: 'replica_managers.ReplicaInfo' = pickle.loads(row[0])
Expand All @@ -409,62 +417,64 @@ def total_number_provisioning_replicas() -> int:
def add_version(service_name: str) -> int:
"""Adds a version to the database."""

_DB.cursor.execute(
"""\
INSERT INTO version_specs
(version, service_name, spec)
VALUES (
(SELECT COALESCE(MAX(version), 0) + 1 FROM
version_specs WHERE service_name = ?), ?, ?)
RETURNING version""", (service_name, service_name, pickle.dumps(None)))
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
INSERT INTO version_specs
(version, service_name, spec)
VALUES (
(SELECT COALESCE(MAX(version), 0) + 1 FROM
version_specs WHERE service_name = ?), ?, ?)
RETURNING version""",
(service_name, service_name, pickle.dumps(None)))

inserted_version = _DB.cursor.fetchone()[0]
_DB.conn.commit()
inserted_version = cursor.fetchone()[0]

return inserted_version


def add_or_update_version(service_name: str, version: int,
spec: 'service_spec.SkyServiceSpec') -> None:
_DB.cursor.execute(
"""\
INSERT or REPLACE INTO version_specs
(service_name, version, spec)
VALUES (?, ?, ?)""", (service_name, version, pickle.dumps(spec)))
_DB.cursor.execute(
"""\
UPDATE services SET
current_version=(?) WHERE name=(?)""", (version, service_name))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
INSERT or REPLACE INTO version_specs
(service_name, version, spec)
VALUES (?, ?, ?)""", (service_name, version, pickle.dumps(spec)))
cursor.execute(
"""\
UPDATE services SET
current_version=(?) WHERE name=(?)""", (version, service_name))


def remove_service_versions(service_name: str) -> None:
"""Removes a replica from the database."""
_DB.cursor.execute(
"""\
DELETE FROM version_specs
WHERE service_name=(?)""", (service_name,))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
DELETE FROM version_specs
WHERE service_name=(?)""", (service_name,))


def get_spec(service_name: str,
version: int) -> Optional['service_spec.SkyServiceSpec']:
"""Gets spec from the database."""
rows = _DB.cursor.execute(
"""\
SELECT spec FROM version_specs
WHERE service_name=(?)
AND version=(?)""", (service_name, version)).fetchall()
with db_utils.safe_cursor(_DB_PATH) as cursor:
rows = cursor.execute(
"""\
SELECT spec FROM version_specs
WHERE service_name=(?)
AND version=(?)""", (service_name, version)).fetchall()
for row in rows:
return pickle.loads(row[0])
return None


def delete_version(service_name: str, version: int) -> None:
"""Deletes a version from the database."""
_DB.cursor.execute(
"""\
DELETE FROM version_specs
WHERE service_name=(?)
AND version=(?)""", (service_name, version))
_DB.conn.commit()
with db_utils.safe_cursor(_DB_PATH) as cursor:
cursor.execute(
"""\
DELETE FROM version_specs
WHERE service_name=(?)
AND version=(?)""", (service_name, version))