Skip to content

Commit 44ff896

Browse files
authored
API server code health pass - misc. refactors (#1059)
1 parent 1a0563b commit 44ff896

24 files changed

+72
-119
lines changed

src/server/_common.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
import time
33

44
from flask import Flask, g, request
5-
from sqlalchemy import event
6-
from sqlalchemy.engine import Connection
5+
from sqlalchemy import create_engine, event
6+
from sqlalchemy.engine import Connection, Engine
77
from werkzeug.local import LocalProxy
88

99
from .utils.logger import get_structured_logger
10-
from ._config import SECRET
11-
from ._db import engine
10+
from ._config import SECRET, SQLALCHEMY_DATABASE_URI, SQLALCHEMY_ENGINE_OPTIONS
1211
from ._exceptions import DatabaseErrorException, EpiDataException
1312

13+
engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS)
14+
1415
app = Flask("EpiData", static_url_path="")
1516
app.config["SECRET"] = SECRET
1617

src/server/_db.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

src/server/_params.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
from ._exceptions import ValidationFailedException
10-
from .utils import days_in_range, weeks_in_range, guess_time_value_is_day, guess_time_value_is_week, TimeValues, days_to_ranges, weeks_to_ranges
10+
from .utils import days_in_range, weeks_in_range, guess_time_value_is_day, guess_time_value_is_week, IntRange, TimeValues, days_to_ranges, weeks_to_ranges
1111

1212

1313
def _parse_common_multi_arg(key: str) -> List[Tuple[str, Union[bool, Sequence[str]]]]:
@@ -140,7 +140,7 @@ def to_ranges(self):
140140
return TimePair(self.time_type, days_to_ranges(self.time_values))
141141

142142

143-
def _verify_range(start: int, end: int) -> Union[int, Tuple[int, int]]:
143+
def _verify_range(start: int, end: int) -> IntRange:
144144
if start == end:
145145
# the first and last numbers are the same, just treat it as a singe value
146146
return start
@@ -151,7 +151,7 @@ def _verify_range(start: int, end: int) -> Union[int, Tuple[int, int]]:
151151
raise ValidationFailedException(f"the given range {start}-{end} is inverted")
152152

153153

154-
def parse_week_value(time_value: str) -> Union[int, Tuple[int, int]]:
154+
def parse_week_value(time_value: str) -> IntRange:
155155
count_dashes = time_value.count("-")
156156
msg = f"{time_value} does not match a known format YYYYWW or YYYYWW-YYYYWW"
157157

@@ -171,7 +171,7 @@ def parse_week_value(time_value: str) -> Union[int, Tuple[int, int]]:
171171
raise ValidationFailedException(msg)
172172

173173

174-
def parse_day_value(time_value: str) -> Union[int, Tuple[int, int]]:
174+
def parse_day_value(time_value: str) -> IntRange:
175175
count_dashes = time_value.count("-")
176176
msg = f"{time_value} does not match a known format YYYYMMDD, YYYY-MM-DD, YYYYMMDD-YYYYMMDD, or YYYY-MM-DD--YYYY-MM-DD"
177177

src/server/_query.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ._exceptions import DatabaseErrorException
2121
from ._validate import extract_strings
2222
from ._params import GeoPair, SourceSignalPair, TimePair
23-
from .utils import time_values_to_ranges, TimeValues
23+
from .utils import time_values_to_ranges, IntRange, TimeValues
2424

2525

2626
def date_string(value: int) -> str:
@@ -34,7 +34,7 @@ def date_string(value: int) -> str:
3434

3535
def to_condition(
3636
field: str,
37-
value: Union[str, Tuple[int, int], int],
37+
value: Union[str, IntRange],
3838
param_key: str,
3939
params: Dict[str, Any],
4040
formatter=lambda x: x,
@@ -50,7 +50,7 @@ def to_condition(
5050

5151
def filter_values(
5252
field: str,
53-
values: Optional[Sequence[Union[str, Tuple[int, int], int]]],
53+
values: Optional[Sequence[Union[str, IntRange]]],
5454
param_key: str,
5555
params: Dict[str, Any],
5656
formatter=lambda x: x,
@@ -75,7 +75,7 @@ def filter_strings(
7575

7676
def filter_integers(
7777
field: str,
78-
values: Optional[Sequence[Union[Tuple[int, int], int]]],
78+
values: Optional[Sequence[IntRange]],
7979
param_key: str,
8080
params: Dict[str, Any],
8181
):
@@ -399,7 +399,7 @@ def _fq_field(self, field: str) -> str:
399399
def where_integers(
400400
self,
401401
field: str,
402-
values: Optional[Sequence[Union[Tuple[int, int], int]]],
402+
values: Optional[Sequence[IntRange]],
403403
param_key: Optional[str] = None,
404404
) -> "QueryBuilder":
405405
fq_field = self._fq_field(field)
@@ -466,25 +466,41 @@ def where_time_pair(
466466
)
467467
return self
468468

469+
def apply_lag_filter(self, history_table: str, lag: Optional[int]):
470+
if lag is not None:
471+
self.retable(history_table)
472+
# history_table has full spectrum of lag values to search from whereas the latest_table does not
473+
self.where(lag=lag)
474+
return self
475+
476+
def apply_issues_filter(self, history_table: str, issues: Optional[TimeValues]):
477+
if issues:
478+
self.retable(history_table)
479+
self.where_integers("issue", issues)
480+
return self
481+
482+
def apply_as_of_filter(self, history_table: str, as_of: Optional[int]):
483+
if as_of is not None:
484+
self.retable(history_table)
485+
sub_condition_asof = "(issue <= :as_of)"
486+
self.params["as_of"] = as_of
487+
sub_fields = "max(issue) max_issue, time_type, time_value, `source`, `signal`, geo_type, geo_value"
488+
sub_group = "time_type, time_value, `source`, `signal`, geo_type, geo_value"
489+
alias = self.alias
490+
sub_condition = f"x.max_issue = {alias}.issue AND x.time_type = {alias}.time_type AND x.time_value = {alias}.time_value AND x.source = {alias}.source AND x.signal = {alias}.signal AND x.geo_type = {alias}.geo_type AND x.geo_value = {alias}.geo_value"
491+
self.subquery = f"JOIN (SELECT {sub_fields} FROM {self.table} WHERE {self.conditions_clause} AND {sub_condition_asof} GROUP BY {sub_group}) x ON {sub_condition}"
492+
return self
493+
469494
def set_fields(self, *fields: Iterable[str]) -> "QueryBuilder":
470495
self.fields = [f"{self.alias}.{field}" for field_list in fields for field in field_list]
471496
return self
472497

473-
def set_order(self, *args: str, **kwargs: Union[str, bool]) -> "QueryBuilder":
498+
def set_sort_order(self, *args: str):
474499
"""
475500
sets the order for the given fields (as key word arguments), True = ASC, False = DESC
476501
"""
477502

478-
def to_asc(v: Union[str, bool]) -> str:
479-
if v is True:
480-
return "ASC"
481-
elif v is False:
482-
return "DESC"
483-
return cast(str, v)
484-
485-
args_order = [f"{self.alias}.{k} ASC" for k in args]
486-
kw_order = [f"{self.alias}.{k} {to_asc(v)}" for k, v in kwargs.items()]
487-
self.order = args_order + kw_order
503+
self.order = [f"{self.alias}.{k} ASC" for k in args]
488504
return self
489505

490506
def with_max_issue(self, *args: str) -> "QueryBuilder":

src/server/_validate.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from flask import request
44

55
from ._exceptions import UnAuthenticatedException, ValidationFailedException
6-
from .utils import TimeValues
6+
from .utils import IntRange, TimeValues
77

88

99
def resolve_auth_token() -> Optional[str]:
@@ -84,9 +84,6 @@ def extract_strings(key: Union[str, Sequence[str]]) -> Optional[List[str]]:
8484
return [v for vs in s for v in vs.split(",")]
8585

8686

87-
IntRange = Union[Tuple[int, int], int]
88-
89-
9087
def extract_integer(key: Union[str, Sequence[str]]) -> Optional[int]:
9188
s = _extract_value(key)
9289
if not s:

src/server/endpoints/covid_hosp_facility.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def handle():
139139
q.set_fields(fields_string, fields_int, fields_float)
140140

141141
# basic query info
142-
q.set_order("collection_week", "hospital_pk", "publication_date")
142+
q.set_sort_order("collection_week", "hospital_pk", "publication_date")
143143

144144
# build the filter
145145
q.where_integers("collection_week", collection_weeks)

src/server/endpoints/covid_hosp_facility_lookup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def handle():
3333
]
3434
)
3535
# basic query info
36-
q.set_order("hospital_pk")
36+
q.set_sort_order("hospital_pk")
3737
# build the filter
3838
# these are all fast because the table has indexes on each of these fields
3939
if state:

src/server/endpoints/covid_hosp_state_timeseries.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def handle():
145145
]
146146

147147
q.set_fields(fields_string, fields_int, fields_float)
148-
q.set_order("date", "state", "issue")
148+
q.set_sort_order("date", "state", "issue")
149149

150150
# build the filter
151151
q.where_integers("date", dates)

src/server/endpoints/covidcast.py

Lines changed: 12 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -91,28 +91,6 @@ def parse_time_pairs() -> TimePair:
9191
return parse_time_arg()
9292

9393

94-
def _handle_lag_issues_as_of(q: QueryBuilder, issues: Optional[TimeValues] = None, lag: Optional[int] = None, as_of: Optional[int] = None):
95-
if issues:
96-
q.retable(history_table)
97-
q.where_integers("issue", issues)
98-
elif lag is not None:
99-
q.retable(history_table)
100-
# history_table has full spectrum of lag values to search from whereas the latest_table does not
101-
q.where(lag=lag)
102-
elif as_of is not None:
103-
# fetch the most recent issue as of a certain date (not to be confused w/ plain-old "most recent issue"
104-
q.retable(history_table)
105-
sub_condition_asof = "(issue <= :as_of)"
106-
q.params["as_of"] = as_of
107-
sub_fields = "max(issue) max_issue, time_type, time_value, `source`, `signal`, geo_type, geo_value"
108-
sub_group = "time_type, time_value, `source`, `signal`, geo_type, geo_value"
109-
sub_condition = f"x.max_issue = {q.alias}.issue AND x.time_type = {q.alias}.time_type AND x.time_value = {q.alias}.time_value AND x.source = {q.alias}.source AND x.signal = {q.alias}.signal AND x.geo_type = {q.alias}.geo_type AND x.geo_value = {q.alias}.geo_value"
110-
q.subquery = f"JOIN (SELECT {sub_fields} FROM {q.table} WHERE {q.conditions_clause} AND {sub_condition_asof} GROUP BY {sub_group}) x ON {sub_condition}"
111-
else:
112-
# else we are using the (standard/default) `latest_table`, to fetch the most recent issue quickly
113-
pass
114-
115-
11694
@bp.route("/", methods=("GET", "POST"))
11795
def handle():
11896
source_signal_pairs = parse_source_signal_pairs()
@@ -132,11 +110,11 @@ def handle():
132110
fields_float = ["value", "stderr", "sample_size"]
133111
is_compatibility = is_compatibility_mode()
134112
if is_compatibility:
135-
q.set_order("signal", "time_value", "geo_value", "issue")
113+
q.set_sort_order("signal", "time_value", "geo_value", "issue")
136114
else:
137115
# transfer also the new detail columns
138116
fields_string.extend(["source", "geo_type", "time_type"])
139-
q.set_order("source", "signal", "time_type", "time_value", "geo_type", "geo_value", "issue")
117+
q.set_sort_order("source", "signal", "time_type", "time_value", "geo_type", "geo_value", "issue")
140118
q.set_fields(fields_string, fields_int, fields_float)
141119

142120
# basic query info
@@ -147,7 +125,9 @@ def handle():
147125
q.where_geo_pairs("geo_type", "geo_value", geo_pairs)
148126
q.where_time_pair("time_type", "time_value", time_pair)
149127

150-
_handle_lag_issues_as_of(q, issues, lag, as_of)
128+
q.apply_issues_filter(history_table, issues)
129+
q.apply_lag_filter(history_table, lag)
130+
q.apply_as_of_filter(history_table, as_of)
151131

152132
def transform_row(row, proxy):
153133
if is_compatibility or not alias_mapper or "source" not in row:
@@ -195,15 +175,12 @@ def handle_trend():
195175
fields_int = ["time_value"]
196176
fields_float = ["value"]
197177
q.set_fields(fields_string, fields_int, fields_float)
198-
q.set_order("geo_type", "geo_value", "source", "signal", "time_value")
178+
q.set_sort_order("geo_type", "geo_value", "source", "signal", "time_value")
199179

200180
q.where_source_signal_pairs("source", "signal", source_signal_pairs)
201181
q.where_geo_pairs("geo_type", "geo_value", geo_pairs)
202182
q.where_time_pair("time_type", "time_value", time_window)
203183

204-
# fetch most recent issue fast
205-
_handle_lag_issues_as_of(q, None, None, None)
206-
207184
p = create_printer()
208185

209186
def gen(rows):
@@ -246,15 +223,12 @@ def handle_trendseries():
246223
fields_int = ["time_value"]
247224
fields_float = ["value"]
248225
q.set_fields(fields_string, fields_int, fields_float)
249-
q.set_order("geo_type", "geo_value", "source", "signal", "time_value")
226+
q.set_sort_order("geo_type", "geo_value", "source", "signal", "time_value")
250227

251228
q.where_source_signal_pairs("source", "signal", source_signal_pairs)
252229
q.where_geo_pairs("geo_type", "geo_value", geo_pairs)
253230
q.where_time_pair("time_type", "time_value", time_window)
254231

255-
# fetch most recent issue fast
256-
_handle_lag_issues_as_of(q, None, None, None)
257-
258232
p = create_printer()
259233

260234
shifter = lambda x: shift_day_value(x, -basis_shift)
@@ -303,7 +277,7 @@ def handle_correlation():
303277
fields_int = ["time_value"]
304278
fields_float = ["value"]
305279
q.set_fields(fields_string, fields_int, fields_float)
306-
q.set_order("geo_type", "geo_value", "source", "signal", "time_value")
280+
q.set_sort_order("geo_type", "geo_value", "source", "signal", "time_value")
307281

308282
q.where_source_signal_pairs(
309283
"source",
@@ -381,12 +355,12 @@ def handle_export():
381355
q = QueryBuilder(latest_table, "t")
382356

383357
q.set_fields(["geo_value", "signal", "time_value", "issue", "lag", "value", "stderr", "sample_size", "geo_type", "source"], [], [])
384-
q.set_order("time_value", "geo_value")
358+
q.set_sort_order("time_value", "geo_value")
385359
q.where_source_signal_pairs("source", "signal", source_signal_pairs)
386360
q.where_time_pair("time_type", "time_value", TimePair("day" if is_day else "week", [(start_day, end_day)]))
387361
q.where_geo_pairs("geo_type", "geo_value", [GeoPair(geo_type, True if geo_values == "*" else geo_values)])
388362

389-
_handle_lag_issues_as_of(q, None, None, as_of)
363+
q.apply_as_of_filter(history_table, as_of)
390364

391365
format_date = time_value_to_iso if is_day else lambda x: time_value_to_week(x).cdcformat()
392366
# tag as_of in filename, if it was specified
@@ -459,16 +433,13 @@ def handle_backfill():
459433
fields_int = ["time_value", "issue"]
460434
fields_float = ["value", "sample_size"]
461435
# sort by time value and issue asc
462-
q.set_order(time_value=True, issue=True)
436+
q.set_sort_order("time_value", "issue")
463437
q.set_fields(fields_string, fields_int, fields_float, ["is_latest_issue"])
464438

465439
q.where_source_signal_pairs("source", "signal", source_signal_pairs)
466440
q.where_geo_pairs("geo_type", "geo_value", [geo_pair])
467441
q.where_time_pair("time_type", "time_value", time_pair)
468442

469-
# no restriction of issues or dates since we want all issues
470-
# _handle_lag_issues_as_of(q, issues, lag, as_of)
471-
472443
p = create_printer()
473444

474445
def find_anchor_row(rows: List[Dict[str, Any]], issue: int) -> Optional[Dict[str, Any]]:
@@ -642,9 +613,7 @@ def handle_coverage():
642613
q.where_source_signal_pairs("source", "signal", source_signal_pairs)
643614
q.where_time_pair("time_type", "time_value", time_window)
644615
q.group_by = "c.source, c.signal, c.time_value"
645-
q.set_order("source", "signal", "time_value")
646-
647-
_handle_lag_issues_as_of(q, None, None, None)
616+
q.set_sort_order("source", "signal", "time_value")
648617

649618
def transform_row(row, proxy):
650619
if not alias_mapper or "source" not in row:

src/server/endpoints/covidcast_utils/model.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,11 +236,6 @@ def _load_data_signals(sources: List[DataSource]):
236236
data_signals_by_key[(source.db_source, d.signal)] = d
237237

238238

239-
240-
def get_related_signals(signal: DataSignal) -> List[DataSignal]:
241-
return [s for s in data_signals if s != signal and s.signal_basename == signal.signal_basename]
242-
243-
244239
def count_signal_time_types(source_signals: List[SourceSignalPair]) -> Tuple[int, int]:
245240
"""
246241
count the number of signals in this query for each time type

src/server/endpoints/dengue_nowcast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def handle():
2222
fields_float = ["value", "std"]
2323
q.set_fields(fields_string, fields_int, fields_float)
2424

25-
q.set_order("epiweek", "location")
25+
q.set_sort_order("epiweek", "location")
2626

2727
# build the filter
2828
q.where_strings("location", locations)

src/server/endpoints/dengue_sensors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def handle():
2626
fields_float = ["value"]
2727
q.set_fields(fields_string, fields_int, fields_float)
2828

29-
q.set_order('epiweek', 'name', 'location')
29+
q.set_sort_order('epiweek', 'name', 'location')
3030

3131
q.where_strings('name', names)
3232
q.where_strings('location', locations)

0 commit comments

Comments
 (0)