diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e082443a5..9ed1cfde9 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -62,7 +62,7 @@ jobs: - name: Start services run: | docker network create --driver bridge delphi-net - docker run --rm -d -p 13306:3306 --network delphi-net --name delphi_database_epidata delphi_database_epidata + docker run --rm -d -p 13306:3306 --network delphi-net --name delphi_database_epidata --cap-add=sys_nice delphi_database_epidata docker run --rm -d -p 10080:80 --env "SQLALCHEMY_DATABASE_URI=mysql+mysqldb://user:pass@delphi_database_epidata:3306/epidata" --env "FLASK_SECRET=abc" --env "FLASK_PREFIX=/epidata" --network delphi-net --name delphi_web_epidata delphi_web_epidata docker ps diff --git a/dev/docker/database/epidata/Dockerfile b/dev/docker/database/epidata/Dockerfile index 07346229f..bd0ac37b5 100644 --- a/dev/docker/database/epidata/Dockerfile +++ b/dev/docker/database/epidata/Dockerfile @@ -1,5 +1,14 @@ -# start with the `delphi_database` image -FROM delphi_database +# start with a standard percona mysql image +FROM percona:ps-8 + +# percona exits with the mysql user but we need root for additional setup +USER root + +# use delphi's timezome +RUN ln -s -f /usr/share/zoneinfo/America/New_York /etc/localtime + +# specify a development-only password for the database user "root" +ENV MYSQL_ROOT_PASSWORD pass # create the `epidata` database ENV MYSQL_DATABASE epidata @@ -8,8 +17,17 @@ ENV MYSQL_DATABASE epidata ENV MYSQL_USER user ENV MYSQL_PASSWORD pass +# provide DDL which will configure dev environment at container startup +COPY repos/delphi/delphi-epidata/dev/docker/database/epidata/_init.sql /docker-entrypoint-initdb.d/ + # provide DDL which will create empty tables at container startup COPY repos/delphi/delphi-epidata/src/ddl/*.sql /docker-entrypoint-initdb.d/ +# provide additional configuration needed for percona +COPY repos/delphi/delphi-epidata/dev/docker/database/mysql.d/*.cnf /etc/my.cnf.d/ + # grant access to SQL scripts RUN chmod o+r /docker-entrypoint-initdb.d/*.sql + +# restore mysql user for percona +USER mysql diff --git a/dev/docker/database/epidata/_init.sql b/dev/docker/database/epidata/_init.sql new file mode 100644 index 000000000..5ebdcfb08 --- /dev/null +++ b/dev/docker/database/epidata/_init.sql @@ -0,0 +1,2 @@ +CREATE DATABASE covid; +GRANT ALL ON covid.* TO 'user'; diff --git a/dev/docker/database/mysql.d/my.cnf b/dev/docker/database/mysql.d/my.cnf new file mode 100644 index 000000000..0c952a7a7 --- /dev/null +++ b/dev/docker/database/mysql.d/my.cnf @@ -0,0 +1,2 @@ +[mysqld] +default_authentication_plugin=mysql_native_password \ No newline at end of file diff --git a/dev/local/Makefile b/dev/local/Makefile index 52c9e98f0..55d0d8ea3 100644 --- a/dev/local/Makefile +++ b/dev/local/Makefile @@ -10,19 +10,19 @@ # Creates all prereq images (delphi_database, delphi_python) only if they don't # exist. If you need to rebuild a prereq, you're probably doing something # complicated, and can figure out the rebuild command on your own. -# -# +# +# # Commands: -# +# # web: Stops currently-running delphi_web_epidata instances, if any. # Rebuilds delphi_web_epidata image. # Runs image in the background and pipes stdout to a log file. -# +# # db: Stops currently-running delphi_database_epidata instances, if any. # Rebuilds delphi_database_epidata image. # Runs image in the background and pipes stdout to a log file. # Blocks until database is ready to receive connections. -# +# # python: Rebuilds delphi_web_python image. You shouldn't need to do this # often; only if you are installing a new environment, or have # made changes to delphi-epidata/dev/docker/python/Dockerfile. @@ -35,7 +35,7 @@ # # clean: Cleans up dangling Docker images. # -# +# # Optional arguments: # pdb=1 Drops you into debug mode upon test failure, if running tests. # test= Only runs tests in the directories provided here, e.g. @@ -105,11 +105,12 @@ db: @# Run the database @docker run --rm -p 127.0.0.1:13306:3306 \ --network delphi-net --name delphi_database_epidata \ + --cap-add=sys_nice \ delphi_database_epidata >$(LOG_DB) 2>&1 & @# Block until DB is ready @while true; do \ - sed -n '/Temporary server stopped/,/mysqld: ready for connections/p' $(LOG_DB) | grep "ready for connections" && break; \ + sed -n '/mysqld: ready for connections/p' $(LOG_DB) | grep "ready for connections" && break; \ tail -1 $(LOG_DB); \ sleep 1; \ done @@ -127,7 +128,7 @@ py: all: web db py .PHONY=test -test: +test: @docker run -i --rm --network delphi-net \ --mount type=bind,source=$(CWD)repos/delphi/delphi-epidata,target=/usr/src/app/repos/delphi/delphi-epidata,readonly \ --mount type=bind,source=$(CWD)repos/delphi/delphi-epidata/src,target=/usr/src/app/delphi/epidata,readonly \ diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock index a74174c2a..97a81de75 100644 --- a/docs/Gemfile.lock +++ b/docs/Gemfile.lock @@ -246,7 +246,7 @@ GEM thread_safe (0.3.6) typhoeus (1.4.0) ethon (>= 0.9.0) - tzinfo (1.2.9) + tzinfo (1.2.10) thread_safe (~> 0.1) tzinfo-data (1.2021.1) tzinfo (>= 1.0.0) diff --git a/integrations/acquisition/covidcast/test_covidcast_meta_caching.py b/integrations/acquisition/covidcast/test_covidcast_meta_caching.py index c18363c03..b435b2b7c 100644 --- a/integrations/acquisition/covidcast/test_covidcast_meta_caching.py +++ b/integrations/acquisition/covidcast/test_covidcast_meta_caching.py @@ -10,9 +10,9 @@ # first party from delphi_utils import Nans -from delphi.epidata.client.delphi_epidata import Epidata import delphi.operations.secrets as secrets -import delphi.epidata.acquisition.covidcast.database as live +from delphi.epidata.client.delphi_epidata import Epidata +from delphi.epidata.acquisition.covidcast.database_meta import DatabaseMeta from delphi.epidata.acquisition.covidcast.covidcast_meta_cache_updater import main # py3tester coverage target (equivalent to `import *`) @@ -40,9 +40,9 @@ def setUp(self): cur = cnx.cursor() # clear all tables - cur.execute("truncate table signal_load") - cur.execute("truncate table signal_history") - cur.execute("truncate table signal_latest") + cur.execute("truncate table epimetric_load") + cur.execute("truncate table epimetric_full") + cur.execute("truncate table epimetric_latest") cur.execute("truncate table geo_dim") cur.execute("truncate table signal_dim") # reset the `covidcast_meta_cache` table (it should always have one row) @@ -71,14 +71,19 @@ def test_caching(self): # insert dummy data self.cur.execute(f''' - INSERT INTO `signal_dim` (`signal_key_id`, `source`, `signal`) VALUES (42, 'src', 'sig'); + INSERT INTO `signal_dim` (`signal_key_id`, `source`, `signal`) + VALUES + (42, 'src', 'sig'); ''') self.cur.execute(f''' - INSERT INTO `geo_dim` (`geo_key_id`, `geo_type`, `geo_value`) VALUES (96, 'state', 'pa'), (97, 'state', 'wa'); + INSERT INTO `geo_dim` (`geo_key_id`, `geo_type`, `geo_value`) + VALUES + (96, 'state', 'pa'), + (97, 'state', 'wa'); ''') self.cur.execute(f''' INSERT INTO - `signal_latest` (`signal_data_id`, `signal_key_id`, `geo_key_id`, `time_type`, + `epimetric_latest` (`epimetric_id`, `signal_key_id`, `geo_key_id`, `time_type`, `time_value`, `value_updated_timestamp`, `value`, `stderr`, `sample_size`, `issue`, `lag`, `missing_value`, @@ -92,7 +97,7 @@ def test_caching(self): self.cnx.commit() # make sure the live utility is serving something sensible - cvc_database = live.Database() + cvc_database = DatabaseMeta() cvc_database.connect() epidata1 = cvc_database.compute_covidcast_meta() cvc_database.disconnect(False) diff --git a/integrations/acquisition/covidcast/test_csv_uploading.py b/integrations/acquisition/covidcast/test_csv_uploading.py index 29f74f46d..de3eb5f13 100644 --- a/integrations/acquisition/covidcast/test_csv_uploading.py +++ b/integrations/acquisition/covidcast/test_csv_uploading.py @@ -15,7 +15,6 @@ from delphi_utils import Nans from delphi.epidata.client.delphi_epidata import Epidata from delphi.epidata.acquisition.covidcast.csv_to_database import main -from delphi.epidata.acquisition.covidcast.dbjobs_runner import main as dbjobs_main import delphi.operations.secrets as secrets # py3tester coverage target (equivalent to `import *`) @@ -37,9 +36,9 @@ def setUp(self): cur = cnx.cursor() # clear all tables - cur.execute("truncate table signal_load") - cur.execute("truncate table signal_history") - cur.execute("truncate table signal_latest") + cur.execute("truncate table epimetric_load") + cur.execute("truncate table epimetric_full") + cur.execute("truncate table epimetric_latest") cur.execute("truncate table geo_dim") cur.execute("truncate table signal_dim") # reset the `covidcast_meta_cache` table (it should always have one row) @@ -79,9 +78,9 @@ def apply_lag(expected_epidata): def verify_timestamps_and_defaults(self): self.cur.execute(''' -select value_updated_timestamp from signal_history +select value_updated_timestamp from epimetric_full UNION ALL -select value_updated_timestamp from signal_latest''') +select value_updated_timestamp from epimetric_latest''') for (value_updated_timestamp,) in self.cur: self.assertGreater(value_updated_timestamp, 0) @@ -102,8 +101,6 @@ def test_uploading(self): log_file=log_file_directory + "output.log", data_dir=data_dir, - is_wip_override=False, - not_wip_override=False, specific_issue_date=False) uploader_column_rename = {"geo_id": "geo_value", "val": "value", "se": "stderr", "missing_val": "missing_value", "missing_se": "missing_stderr"} @@ -123,7 +120,6 @@ def test_uploading(self): # upload CSVs main(args) - dbjobs_main() response = Epidata.covidcast('src-name', signal_name, 'day', 'state', 20200419, '*') expected_values = pd.concat([values, pd.DataFrame({ "time_value": [20200419] * 3, "signal": [signal_name] * 3, "direction": [None] * 3})], axis=1).rename(columns=uploader_column_rename).to_dict(orient="records") @@ -152,7 +148,6 @@ def test_uploading(self): # upload CSVs main(args) - dbjobs_main() response = Epidata.covidcast('src-name', signal_name, 'day', 'state', 20200419, '*') expected_values = pd.concat([values, pd.DataFrame({ @@ -187,7 +182,6 @@ def test_uploading(self): # upload CSVs main(args) - dbjobs_main() response = Epidata.covidcast('src-name', signal_name, 'day', 'state', 20200419, '*') expected_response = {'result': -2, 'message': 'no results'} @@ -213,7 +207,6 @@ def test_uploading(self): # upload CSVs main(args) - dbjobs_main() response = Epidata.covidcast('src-name', signal_name, 'day', 'state', 20200419, '*') expected_values_df = pd.concat([values, pd.DataFrame({ @@ -232,42 +225,6 @@ def test_uploading(self): self.setUp() - with self.subTest("Valid wip"): - values = pd.DataFrame({ - "geo_id": ["me", "nd", "wa"], - "val": [10.0, 20.0, 30.0], - "se": [0.01, 0.02, 0.03], - "sample_size": [100.0, 200.0, 300.0], - "missing_val": [Nans.NOT_MISSING] * 3, - "missing_se": [Nans.NOT_MISSING] * 3, - "missing_sample_size": [Nans.NOT_MISSING] * 3 - }) - signal_name = "wip_prototype" - values.to_csv(source_receiving_dir + f'/20200419_state_{signal_name}.csv', index=False) - - # upload CSVs - main(args) - dbjobs_main() - response = Epidata.covidcast('src-name', signal_name, 'day', 'state', 20200419, '*') - - expected_values = pd.concat([values, pd.DataFrame({ - "time_value": [20200419] * 3, - "signal": [signal_name] * 3, - "direction": [None] * 3 - })], axis=1).rename(columns=uploader_column_rename).to_dict(orient="records") - expected_response = {'result': 1, 'epidata': self.apply_lag(expected_values), 'message': 'success'} - - self.assertEqual(response, expected_response) - self.verify_timestamps_and_defaults() - - # Verify that files were archived - path = data_dir + f'/archive/successful/src-name/20200419_state_wip_prototype.csv.gz' - self.assertIsNotNone(os.stat(path)) - - self.tearDown() - self.setUp() - - with self.subTest("Valid signal with name length 32 pd.DataFrame: df.geo_value = df.geo_value.str.zfill(5) return df -class CovidcastEndpointTests(unittest.TestCase): - """Tests the `covidcast/*` endpoint.""" - - def setUp(self): - """Perform per-test setup.""" - - # connect to the database and clear the tables - cnx = mysql.connector.connect(user="user", password="pass", host="delphi_database_epidata", database="covid") - cur = cnx.cursor() +def _diff_rows(rows: Sequence[float]): + return [float(x - y) if x is not None and y is not None else None for x, y in zip(rows[1:], rows[:-1])] - # clear all tables - cur.execute("truncate table signal_load") - cur.execute("truncate table signal_history") - cur.execute("truncate table signal_latest") - cur.execute("truncate table geo_dim") - cur.execute("truncate table signal_dim") - # reset the `covidcast_meta_cache` table (it should always have one row) - cur.execute('update covidcast_meta_cache set timestamp = 0, epidata = "[]"') +def _smooth_rows(rows: Sequence[float]): + return [sum(e)/len(e) if None not in e else None for e in windowed(rows, 7)] - cnx.commit() - cur.close() - # make connection and cursor available to the Database object - self._db = Database() - self._db._connection = cnx - self._db._cursor = cnx.cursor() +class CovidcastEndpointTests(CovidcastBase): - def tearDown(self): - """Perform per-test teardown.""" - self._db._cursor.close() - self._db._connection.close() + """Tests the `covidcast/*` endpoint.""" - def _insert_rows(self, rows: Iterable[CovidcastRow]): - self._db.insert_or_update_bulk(rows) - self._db.run_dbjobs() - self._db._connection.commit() - return rows + def localSetUp(self): + """Perform per-test setup.""" + # reset the `covidcast_meta_cache` table (it should always have one row) + self._db._cursor.execute('update covidcast_meta_cache set timestamp = 0, epidata = "[]"') def _fetch(self, endpoint="/", is_compatibility=False, **params): # make the request @@ -81,14 +57,8 @@ def _fetch(self, endpoint="/", is_compatibility=False, **params): response.raise_for_status() return response.json() - def _diff_rows(self, rows: Sequence[float]): - return [float(x - y) if x is not None and y is not None else None for x, y in zip(rows[1:], rows[:-1])] - - def _smooth_rows(self, rows: Sequence[float]): - return [sum(e)/len(e) if None not in e else None for e in windowed(rows, 7)] - def test_basic(self): - """Request a signal the / endpoint.""" + """Request a signal from the / endpoint.""" rows = [CovidcastRow(time_value=20200401 + i, value=i) for i in range(10)] first = rows[0] @@ -112,6 +82,26 @@ def test_basic(self): expected_values = [float(row.value) for row in rows] self.assertEqual(out_values, expected_values) + def test_compatibility(self): + """Request at the /api.php endpoint.""" + rows = [CovidcastRow(source="src", signal="sig", time_value=20200401 + i, value=i) for i in range(10)] + first = rows[0] + self._insert_rows(rows) + + with self.subTest("simple"): + out = self._fetch("/", signal=first.signal_pair, geo=first.geo_pair, time="day:*") + self.assertEqual(len(out["epidata"]), len(rows)) + + with self.subTest("unknown signal"): + rows = [CovidcastRow(source="jhu-csse", signal="confirmed_unknown", time_value=20200401 + i, value=i) for i in range(10)] + first = rows[0] + self._insert_rows(rows) + + out = self._fetch("/", signal="jhu-csse:confirmed_unknown", geo=first.geo_pair, time="day:*") + out_values = [row["value"] for row in out["epidata"]] + expected_values = [float(row.value) for row in rows] + self.assertEqual(out_values, expected_values) + def test_derived_signals(self): time_value_pairs = [(20200401 + i, i ** 2) for i in range(10)] rows01 = [CovidcastRow(source="jhu-csse", signal="confirmed_cumulative_num", time_value=time_value, value=value, geo_value="01") for time_value, value in time_value_pairs] @@ -125,7 +115,7 @@ def test_derived_signals(self): out = self._fetch("/", signal="jhu-csse:confirmed_incidence_num", geo=first.geo_pair, time="day:20200401-20200410") out_values = [row["value"] for row in out["epidata"]] values = [value for _, value in time_value_pairs] - expected_values = self._diff_rows(values) + expected_values = _diff_rows(values) self.assertAlmostEqual(out_values, expected_values) with self.subTest("diffed signal, multiple geos"): @@ -133,29 +123,29 @@ def test_derived_signals(self): out_values = [row["value"] for row in out["epidata"]] values1 = [value for _, value in time_value_pairs] values2 = [2 * value for _, value in time_value_pairs] - expected_values = self._diff_rows(values1) + self._diff_rows(values2) + expected_values = _diff_rows(values1) + _diff_rows(values2) self.assertAlmostEqual(out_values, expected_values) with self.subTest("diffed signal, multiple geos using geo:*"): out = self._fetch("/", signal="jhu-csse:confirmed_incidence_num", geo="county:*", time="day:20200401-20200410") values1 = [value for _, value in time_value_pairs] values2 = [2 * value for _, value in time_value_pairs] - expected_values = self._diff_rows(values1) + self._diff_rows(values2) + expected_values = _diff_rows(values1) + _diff_rows(values2) self.assertAlmostEqual(out_values, expected_values) with self.subTest("smooth diffed signal"): out = self._fetch("/", signal="jhu-csse:confirmed_7dav_incidence_num", geo=first.geo_pair, time="day:20200401-20200410") out_values = [row["value"] for row in out["epidata"]] values = [value for _, value in time_value_pairs] - expected_values = self._smooth_rows(self._diff_rows(values)) + expected_values = _smooth_rows(_diff_rows(values)) self.assertAlmostEqual(out_values, expected_values) with self.subTest("diffed signal and smoothed signal in one request"): out = self._fetch("/", signal="jhu-csse:confirmed_incidence_num;jhu-csse:confirmed_7dav_incidence_num", geo=first.geo_pair, time="day:20200401-20200410") out_values = [row["value"] for row in out["epidata"]] values = [value for _, value in time_value_pairs] - expected_diff = self._diff_rows(values) - expected_smoothed = self._smooth_rows(expected_diff) + expected_diff = _diff_rows(values) + expected_smoothed = _smooth_rows(expected_diff) expected_values = list(interleave_longest(expected_smoothed, expected_diff)) self.assertAlmostEqual(out_values, expected_values) @@ -169,7 +159,7 @@ def test_derived_signals(self): out = self._fetch("/", signal="jhu-csse:confirmed_incidence_num", geo=first.geo_pair, time="day:20200401-20200420") out_values = [row["value"] for row in out["epidata"]] values = [value for _, value in time_value_pairs][:10] + [None] * 5 + [value for _, value in time_value_pairs][10:] - expected_values = self._diff_rows(values) + expected_values = _diff_rows(values) self.assertAlmostEqual(out_values, expected_values) with self.subTest("smoothing and diffing with a time gap"): @@ -177,7 +167,7 @@ def test_derived_signals(self): out = self._fetch("/", signal="jhu-csse:confirmed_7dav_incidence_num", geo=first.geo_pair, time="day:20200401-20200420") out_values = [row["value"] for row in out["epidata"]] values = [value for _, value in time_value_pairs][:10] + [None] * 5 + [value for _, value in time_value_pairs][10:] - expected_values = self._smooth_rows(self._diff_rows(values)) + expected_values = _smooth_rows(_diff_rows(values)) self.assertAlmostEqual(out_values, expected_values) def test_compatibility(self): @@ -199,7 +189,7 @@ def _diff_covidcast_rows(self, rows: List[CovidcastRow]) -> List[CovidcastRow]: return new_rows def test_trend(self): - """Request a signal the /trend endpoint.""" + """Request a signal from the /trend endpoint.""" num_rows = 30 rows = [CovidcastRow(time_value=20200401 + i, value=i) for i in range(num_rows)] @@ -270,7 +260,7 @@ def test_trend(self): def test_trendseries(self): - """Request a signal the /trendseries endpoint.""" + """Request a signal from the /trendseries endpoint.""" num_rows = 3 rows = [CovidcastRow(time_value=20200401 + i, value=num_rows - i) for i in range(num_rows)] @@ -395,7 +385,7 @@ def match_row(trend, row): def test_correlation(self): - """Request a signal the /correlation endpoint.""" + """Request a signal from the /correlation endpoint.""" num_rows = 30 reference_rows = [CovidcastRow(signal="ref", time_value=20200401 + i, value=i) for i in range(num_rows)] @@ -422,7 +412,7 @@ def test_correlation(self): self.assertEqual(df["samples"].tolist(), [num_rows - abs(l) for l in range(-max_lag, max_lag + 1)]) def test_csv(self): - """Request a signal the /csv endpoint.""" + """Request a signal from the /csv endpoint.""" expected_columns = ["geo_value", "signal", "time_value", "issue", "lag", "value", "stderr", "sample_size", "geo_type", "data_source"] data = CovidcastRows.from_args( @@ -483,7 +473,7 @@ def test_csv(self): pd.testing.assert_frame_equal(df_diffed, expected_df) def test_backfill(self): - """Request a signal the /backfill endpoint.""" + """Request a signal from the /backfill endpoint.""" num_rows = 10 issue_0 = [CovidcastRow(time_value=20200401 + i, value=i, sample_size=1, lag=0, issue=20200401 + i) for i in range(num_rows)] @@ -511,7 +501,7 @@ def test_backfill(self): self.assertEqual(df_t0["sample_size_completeness"].tolist(), [1 / 3, 2 / 3, 3 / 3]) # total 2, given 0,1,2 def test_meta(self): - """Request a signal the /meta endpoint.""" + """Request a signal from the /meta endpoint.""" num_rows = 10 rows = [CovidcastRow(time_value=20200401 + i, value=i, source="fb-survey", signal="smoothed_cli") for i in range(num_rows)] @@ -551,7 +541,7 @@ def test_meta(self): self.assertEqual(len(out), 0) def test_coverage(self): - """Request a signal the /coverage endpoint.""" + """Request a signal from the /coverage endpoint.""" num_geos_per_date = [10, 20, 30, 40, 44] dates = [20200401 + i for i in range(len(num_geos_per_date))] @@ -560,17 +550,17 @@ def test_coverage(self): first = rows[0] with self.subTest("default"): - out = self._fetch("/coverage", signal=first.signal_pair, latest=dates[-1], format="json") + out = self._fetch("/coverage", signal=first.signal_pair, geo_type=first.geo_type, latest=dates[-1], format="json") self.assertEqual(len(out), len(num_geos_per_date)) self.assertEqual([o["time_value"] for o in out], dates) self.assertEqual([o["count"] for o in out], num_geos_per_date) with self.subTest("specify window"): - out = self._fetch("/coverage", signal=first.signal_pair, window=f"{dates[0]}-{dates[1]}", format="json") + out = self._fetch("/coverage", signal=first.signal_pair, geo_type=first.geo_type, window=f"{dates[0]}-{dates[1]}", format="json") self.assertEqual(len(out), 2) self.assertEqual([o["time_value"] for o in out], dates[:2]) self.assertEqual([o["count"] for o in out], num_geos_per_date[:2]) with self.subTest("invalid geo_type"): - out = self._fetch("/coverage", signal=first.signal_pair, geo_type="state", format="json") + out = self._fetch("/coverage", signal=first.signal_pair, geo_type="doesnt_exist", format="json") self.assertEqual(len(out), 0) diff --git a/integrations/server/test_covidcast_meta.py b/integrations/server/test_covidcast_meta.py index 6c52f56f1..ef034e8ea 100644 --- a/integrations/server/test_covidcast_meta.py +++ b/integrations/server/test_covidcast_meta.py @@ -1,22 +1,43 @@ """Integration tests for the `covidcast_meta` endpoint.""" # standard library +from datetime import date +from numbers import Number +from typing import Iterable, Optional, Union import unittest # third party -import mysql.connector +import numpy as np +import pandas as pd import requests -#first party +# first party from delphi_utils import Nans from delphi.epidata.acquisition.covidcast.covidcast_meta_cache_updater import main as update_cache +from delphi.epidata.acquisition.covidcast.database import CovidcastRow +from delphi.epidata.acquisition.covidcast.database_meta import DatabaseMeta import delphi.operations.secrets as secrets # use the local instance of the Epidata API BASE_URL = 'http://delphi_web_epidata/epidata/api.php' -class CovidcastMetaTests(unittest.TestCase): +def _almost_equal(v1: Optional[Union[Number, str]], v2: Optional[Union[Number, str]], atol: float = 1e-08) -> bool: + if v1 is None and v2 is None: + return True + elif (v1 is None and v2 is not None) or (v1 is not None and v2 is None): + return False + else: + return np.allclose(v1, v2, atol=atol) if isinstance(v1, Number) and isinstance(v2, Number) else v1 == v2 + + +def _dicts_equal(d1: dict, d2: dict, ignore_keys: Optional[list] = None, atol: float = 1e-08) -> bool: + """Compare dictionary values using floating point comparison for numeric values.""" + assert set(d1.keys()) == set(d2.keys()) + return all(_almost_equal(d1.get(key), d2.get(key), atol=atol) for key in d1.keys() if (ignore_keys and key not in ignore_keys)) + + +class TestCovidcastMeta(unittest.TestCase): """Tests the `covidcast_meta` endpoint.""" src_sig_lookups = { @@ -33,75 +54,71 @@ class CovidcastMetaTests(unittest.TestCase): } template = ''' - INSERT INTO - `signal_latest` (`signal_data_id`, `signal_key_id`, `geo_key_id`, - `time_type`, `time_value`, `value_updated_timestamp`, + INSERT INTO `epimetric_latest` ( + `epimetric_id`, `signal_key_id`, `geo_key_id`, + `time_type`, `time_value`, `value_updated_timestamp`, `value`, `stderr`, `sample_size`, `issue`, `lag`, `missing_value`, `missing_stderr`,`missing_sample_size`) VALUES - (%d, %d, %d, "%s", %d, 123, - %d, 0, 0, %d, 0, %d, %d, %d) + (%d, %d, %d, + "%s", %d, 123, + %d, 0, 0, + %d, 0, %d, + %d, %d) ''' def setUp(self): """Perform per-test setup.""" - # connect to the `epidata` database and clear the `covidcast` table - cnx = mysql.connector.connect( - user='user', - password='pass', - host='delphi_database_epidata', - database='covid') - cur = cnx.cursor() + # connect to the `epidata` database + self.db = DatabaseMeta(base_url="http://delphi_web_epidata/epidata") + self.db.connect(user="user", password="pass", host="delphi_database_epidata", database="covid") + + # TODO: Switch when delphi_epidata client is released. + self.db.delphi_epidata = False # clear all tables - cur.execute("truncate table signal_load") - cur.execute("truncate table signal_history") - cur.execute("truncate table signal_latest") - cur.execute("truncate table geo_dim") - cur.execute("truncate table signal_dim") + self.db._cursor.execute("truncate table epimetric_load") + self.db._cursor.execute("truncate table epimetric_full") + self.db._cursor.execute("truncate table epimetric_latest") + self.db._cursor.execute("truncate table geo_dim") + self.db._cursor.execute("truncate table signal_dim") + self.db._connection.commit() # reset the `covidcast_meta_cache` table (it should always have one row) - cur.execute('update covidcast_meta_cache set timestamp = 0, epidata = "[]"') + self.db._cursor.execute('update covidcast_meta_cache set timestamp = 0, epidata = "[]"') - # populate dimension tables for convenience + # populate dimension tables for (src,sig) in self.src_sig_lookups: - cur.execute(''' + self.db._cursor.execute(''' INSERT INTO `signal_dim` (`signal_key_id`, `source`, `signal`) VALUES (%d, '%s', '%s'); ''' % ( self.src_sig_lookups[(src,sig)], src, sig )) for (gt,gv) in self.geo_lookups: - cur.execute(''' + self.db._cursor.execute(''' INSERT INTO `geo_dim` (`geo_key_id`, `geo_type`, `geo_value`) VALUES (%d, '%s', '%s'); ''' % ( self.geo_lookups[(gt,gv)], gt, gv )) - cnx.commit() - cur.close() + self.db._connection.commit() # initialize counter for tables without non-autoincrement id self.id_counter = 666 - # make connection and cursor available to test cases - self.cnx = cnx - self.cur = cnx.cursor() - # use the local instance of the epidata database secrets.db.host = 'delphi_database_epidata' secrets.db.epi = ('user', 'pass') - def tearDown(self): """Perform per-test teardown.""" - self.cur.close() - self.cnx.close() + self.db._cursor.close() + self.db._connection.close() - def _get_id(self): - self.id_counter += 1 - return self.id_counter - - def test_round_trip(self): - """Make a simple round-trip with some sample data.""" + def _insert_rows(self, rows: Iterable[CovidcastRow]): + self.db.insert_or_update_bulk(list(rows)) + self.db.run_dbjobs() + self.db._connection.commit() + return rows - # insert dummy data and accumulate expected results (in sort order) + def insert_placeholder_data(self): expected = [] for src in ('src1', 'src2'): for sig in ('sig1', 'sig2'): @@ -126,13 +143,25 @@ def test_round_trip(self): }) for tv in (1, 2): for gv, v in zip(('geo1', 'geo2'), (10, 20)): - self.cur.execute(self.template % ( + self.db._cursor.execute(self.template % ( self._get_id(), - self.src_sig_lookups[(src,sig)], self.geo_lookups[(gt,gv)], tt, tv, v, tv, + self.src_sig_lookups[(src,sig)], self.geo_lookups[(gt,gv)], + tt, tv, v, tv, # re-use time value for issue Nans.NOT_MISSING, Nans.NOT_MISSING, Nans.NOT_MISSING )) - self.cnx.commit() + self.db._connection.commit() update_cache(args=None) + return expected + + def _get_id(self): + self.id_counter += 1 + return self.id_counter + + def test_round_trip(self): + """Make a simple round-trip with some sample data.""" + + # insert placeholder data and accumulate expected results (in sort order) + expected = self.insert_placeholder_data() # make the request response = requests.get(BASE_URL, params={'endpoint': 'covidcast_meta'}) @@ -146,42 +175,11 @@ def test_round_trip(self): 'message': 'success', }) - def test_filter(self): """Test filtering options some sample data.""" - # insert dummy data and accumulate expected results (in sort order) - expected = [] - for src in ('src1', 'src2'): - for sig in ('sig1', 'sig2'): - for tt in ('day', 'week'): - for gt in ('hrr', 'msa'): - expected.append({ - 'data_source': src, - 'signal': sig, - 'time_type': tt, - 'geo_type': gt, - 'min_time': 1, - 'max_time': 2, - 'num_locations': 2, - 'min_value': 10, - 'max_value': 20, - 'mean_value': 15, - 'stdev_value': 5, - 'last_update': 123, - 'max_issue': 2, - 'min_lag': 0, - 'max_lag': 0, - }) - for tv in (1, 2): - for gv, v in zip(('geo1', 'geo2'), (10, 20)): - self.cur.execute(self.template % ( - self._get_id(), - self.src_sig_lookups[(src,sig)], self.geo_lookups[(gt,gv)], tt, tv, v, tv, - Nans.NOT_MISSING, Nans.NOT_MISSING, Nans.NOT_MISSING - )) - self.cnx.commit() - update_cache(args=None) + # insert placeholder data and accumulate expected results (in sort order) + expected = self.insert_placeholder_data() def fetch(**kwargs): # make the request @@ -260,3 +258,45 @@ def fetch(**kwargs): self.assertEqual(len(res['epidata']), len(expected)) self.assertEqual(res['epidata'][0], {}) + def test_meta_values2(self): + """This is an A/B test between the old meta compute approach and the new one which relies on an API call for JIT signals. + + It relies on synthetic data that attempts to be as realistic and as general as possible. + """ + + def get_rows_gen(df: pd.DataFrame, filter_nans: bool = False) -> Iterable[CovidcastRow]: + for args in df.itertuples(index=False): + if not filter_nans or (filter_nans and not any(map(pd.isna, args._asdict().values()))): + yield CovidcastRow(**args._asdict()) + + start_date = date(2022, 4, 1) + end_date = date(2022, 6, 1) + n = (end_date - start_date).days + 1 + + # TODO: Build a more complex synthetic dataset here. + cumulative_df = pd.DataFrame( + { + "source": ["jhu-csse"] * n, + "signal": ["confirmed_cumulative_num"] * n, + "time_value": pd.date_range(start_date, end_date), + "issue": pd.date_range(start_date, end_date), + "value": list(range(n)), + } + ) + incidence_df = cumulative_df.assign( + signal="confirmed_incidence_num", value=cumulative_df.value.diff(), issue=[max(window) if window.size >= 2 else np.nan for window in cumulative_df.issue.rolling(2)] + ) + smoothed_incidence_df = incidence_df.assign( + signal="confirmed_7dav_incidence_num", value=incidence_df.value.rolling(7).mean(), issue=[max(window) if window.size >= 7 else np.nan for window in incidence_df.issue.rolling(7)] + ) + + self._insert_rows(get_rows_gen(cumulative_df, filter_nans=True)) + self._insert_rows(get_rows_gen(incidence_df, filter_nans=True)) + self._insert_rows(get_rows_gen(smoothed_incidence_df, filter_nans=True)) + + meta_values = self.db.compute_covidcast_meta(jit=False) + meta_values2 = self.db.compute_covidcast_meta(jit=True, parallel=False) + + out = [_dicts_equal(x, y, ignore_keys=["max_lag"]) for x, y in zip(meta_values, meta_values2)] + + assert all(out) diff --git a/integrations/server/test_covidcast_meta_ab.py b/integrations/server/test_covidcast_meta_ab.py new file mode 100644 index 000000000..8b6c14873 --- /dev/null +++ b/integrations/server/test_covidcast_meta_ab.py @@ -0,0 +1,285 @@ +from datetime import datetime, timedelta +from functools import reduce +from math import inf +from numbers import Number +from pathlib import Path +from typing import Iterable, Iterator, List, Literal, Optional, Tuple, Union + +# third party +import numpy as np +import pandas as pd +import pytest +import pytest_check as check +import requests + +# first party +from delphi_utils.geomap import GeoMapper +import delphi.operations.secrets as secrets +from delphi.epidata.acquisition.covidcast.database import CovidcastRow +from delphi.epidata.acquisition.covidcast.database_meta import DatabaseMeta + +# use the local instance of the Epidata API +BASE_URL = "http://delphi_web_epidata/epidata/api.php" +TEST_DATA_DIR = Path("repos/delphi/delphi-epidata/testdata/acquisition/covidcast/") + + +def _df_to_covidcastrows(df: pd.DataFrame) -> Iterable[CovidcastRow]: + """Iterates over the rows of a dataframe. + + The dataframe is expected to have many columns, see below for which. + """ + for _, row in df.iterrows(): + yield CovidcastRow( + source=row.data_source if "data_source" in df.columns else row.source, + signal=row.signal, + time_type=row.time_type, + geo_type=row.geo_type, + time_value=datetime.strptime(row.time_value, "%Y-%m-%d"), + geo_value=row.geo_value, + value=row.value, + stderr=row.stderr if not np.isnan(row.stderr) else None, + sample_size=row.sample_size if not np.isnan(row.sample_size) else None, + missing_value=row.missing_value, + missing_stderr=row.missing_stderr, + missing_sample_size=row.missing_sample_size, + issue=datetime.strptime(row.issue, "%Y-%m-%d"), + lag=row.lag, + ) + + +def _almost_equal(v1: Optional[Union[Number, str]], v2: Optional[Union[Number, str]], atol: float = 1e-08) -> bool: + if v1 is None and v2 is None: + return True + elif (v1 is None and v2 is not None) or (v1 is not None and v2 is None): + return False + else: + return np.allclose(v1, v2, atol=atol) if isinstance(v1, Number) and isinstance(v2, Number) else v1 == v2 + + +def _dicts_equal(d1: dict, d2: dict, ignore_keys: Optional[list] = None, atol: float = 1e-08) -> bool: + """Compare dictionary values using floating point comparison for numeric values.""" + assert set(d1.keys()) == set(d2.keys()) + return all(_almost_equal(d1.get(key), d2.get(key), atol=atol) for key in d1.keys() if (ignore_keys and key not in ignore_keys)) + + +class TestCovidcastMeta: + def setup_method(self): + """Perform per-test setup.""" + + # connect to the `epidata` database + self.db = DatabaseMeta(base_url="http://delphi_web_epidata/epidata") + self.db.connect(user="user", password="pass", host="delphi_database_epidata", database="covid") + + # TODO: Switch when delphi_epidata client is released. + self.db.delphi_epidata = False + + # clear all tables + self.db._cursor.execute("truncate table epimetric_load") + self.db._cursor.execute("truncate table epimetric_full") + self.db._cursor.execute("truncate table epimetric_latest") + self.db._cursor.execute("truncate table geo_dim") + self.db._cursor.execute("truncate table signal_dim") + self.db._connection.commit() + # reset the `covidcast_meta_cache` table (it should always have one row) + self.db._cursor.execute('update covidcast_meta_cache set timestamp = 0, epidata = "[]"') + + self.db._connection.commit() + + # initialize counter for tables without non-autoincrement id + self.id_counter = 666 + + # use the local instance of the epidata database + secrets.db.host = "delphi_database_epidata" + secrets.db.epi = ("user", "pass") + + def teardown_method(self): + """Perform per-test teardown.""" + self.db._cursor.close() + self.db._connection.close() + + def _insert_rows(self, rows: Iterable[CovidcastRow]): + self.db.insert_or_update_bulk(list(rows)) + self.db.run_dbjobs() + self.db._connection.commit() + return rows + + def get_source_signal_from_db(self, source: str, signal: str) -> pd.DataFrame: + """Get the source signal data from the database.""" + sql = f"""SELECT c.signal, c.geo_type, c.geo_value, c.time_value, c.value FROM epimetric_latest_v c WHERE 1 = 1 AND c.`source` = '{source}' AND c.`signal` = '{signal}'""" + self.db._cursor.execute(sql) + df = ( + pd.DataFrame.from_records(self.db._cursor.fetchall(), columns=["signal", "geo_type", "geo_value", "time_value", "value"]) + .assign(time_value=lambda x: pd.to_datetime(x["time_value"], format="%Y%m%d")) + .set_index(["signal", "geo_value", "time_value"]) + .sort_index() + ) + return df + + def get_source_signal_from_api(self, source: str, signal: str) -> pd.DataFrame: + """Query the source signal data from the local API.""" + base_url = "http://delphi_web_epidata/epidata/covidcast/" + + def get_api_df(**params) -> pd.DataFrame: + return pd.DataFrame.from_records(requests.get(base_url, params=params).json()["epidata"]) + + ALLTIME = "19000101-20500101" + params = {"signal": f"{source}:{signal}", "geo": "state:*;county:*", "time": f"day:{ALLTIME}"} + df = get_api_df(**params).assign(geo_type="day", time_value=lambda x: pd.to_datetime(x["time_value"], format="%Y%m%d")).set_index(["signal", "geo_value", "time_value"]).sort_index() + return df + + def _insert_csv(self, filename: str): + with pd.read_csv(filename, chunksize=10_000) as reader: + for chunk_df in reader: + self._insert_rows(_df_to_covidcastrows(chunk_df)) + + @pytest.mark.skip("Too slow.") + @pytest.mark.parametrize("test_data_filepath", TEST_DATA_DIR.glob("*.csv")) + def test_incidence(self, test_data_filepath): + """This is large-scale A/B test of the JIT system for the incidence signal. + + Uses live API data and compares: + - the results of the new JIT system to the API data + - the results of the new JIT system to the Pandas-derived data + """ + source = "usa-facts" if "usa-facts" in str(test_data_filepath) else "jhu-csse" + print(test_data_filepath) + self._insert_csv(test_data_filepath) + + # Here we load: + # test_data_full_df - the original CSV file with our test data + # db_pandas_incidence_df - the incidence data pulled from the database as cumulative, placed on a contiguous index (live data has gaps), and then diffed via Pandas + # api_incidence_df - the incidence data as returned by the API from JIT + test_data_full_df = ( + pd.read_csv(test_data_filepath) + .assign(time_value=lambda x: pd.to_datetime(x["time_value"]), geo_value=lambda x: x["geo_value"].astype(str)) + .set_index(["signal", "geo_value", "time_value"]) + .sort_index() + ) + db_pandas_incidence_df = ( + self.get_source_signal_from_db(source, "confirmed_cumulative_num") + # Place on a contiguous index + .groupby(["signal", "geo_value"]) + .apply(lambda x: x.reset_index().drop(columns=["signal", "geo_value"]).set_index("time_value").reindex(pd.date_range("2020-01-25", "2022-09-10"))) + .reset_index() + .rename(columns={"level_2": "time_value"}) + .set_index(["signal", "geo_value", "time_value"]) + # Diff + .groupby(["signal", "geo_value"]) + .apply(lambda x: x["value"].reset_index().drop(columns=["signal", "geo_value"]).set_index("time_value").diff()) + .reset_index() + .assign(signal="confirmed_incidence_num") + .set_index(["signal", "geo_value", "time_value"]) + ) + api_incidence_df = self.get_source_signal_from_api(source, "confirmed_incidence_num") + + # Join into one dataframe for easy comparison + test_data_full_df = test_data_full_df.join(db_pandas_incidence_df.value, rsuffix="_db_pandas") + test_data_full_df = test_data_full_df.join(api_incidence_df.value, rsuffix="_api_jit") + test_data_cumulative_df: pd.DataFrame = test_data_full_df.loc["confirmed_cumulative_num"] + test_data_incidence_df: pd.DataFrame = test_data_full_df.loc["confirmed_incidence_num"] + + # Test 1: show that Pandas-recomputed incidence (from cumulative) is identical to JIT incidence (up to 7 decimal places). + pandas_ne_jit = test_data_full_df[["value_db_pandas", "value_api_jit"]].dropna(how="any", axis=0) + pandas_ne_jit = pandas_ne_jit[pandas_ne_jit.value_db_pandas.sub(pandas_ne_jit.value_api_jit, fill_value=inf).abs().ge(1e-7)] + check.is_true(pandas_ne_jit.empty, "Check Pandas-JIT incidence match.") + if not pandas_ne_jit.empty: + print("Pandas-JIT incidence mismatch:") + print(pandas_ne_jit.to_string()) + + # Test 2: show that some JIT incidence values do not match live data. These are errors in the live data. + live_ne_jit = test_data_full_df[["value", "value_api_jit"]].dropna(how="any", axis=0) + live_ne_jit = live_ne_jit[live_ne_jit.value.sub(live_ne_jit.value_api_jit, fill_value=inf).abs().ge(1e-7)] + check.is_true(live_ne_jit.empty, "Check JIT-live match.") + if not live_ne_jit.empty: + print("JIT-live mismatch:") + print(live_ne_jit.to_string()) + + # Test 3: show that when JIT has a NAN, it is reasonable: the cumulative signal is either missing today or yesterday. + jit_nan_df = test_data_incidence_df[["value", "value_api_jit"]].query("value_api_jit.isna()") + jit_nan_df = reduce( + lambda x, y: pd.merge(x, y, how="outer", left_index=True, right_index=True), + ( + test_data_cumulative_df.filter(items=jit_nan_df.index.map(lambda x: (x[0], x[1] - timedelta(days=i))), axis=0)["value"].rename(f"value_{i}_days_past") + for i in range(2) + ), + ) + jit_nan_df = jit_nan_df[jit_nan_df.notna().all(axis=1)] + check.is_true(jit_nan_df.empty, "Check JIT NANs are reasonable.") + if not jit_nan_df.empty: + print("JIT NANs are unreasonable:") + print(jit_nan_df.to_string()) + + @pytest.mark.skip("Too slow.") + @pytest.mark.parametrize("test_data_filepath", TEST_DATA_DIR.glob("*.csv")) + def test_7dav_incidence(self, test_data_filepath): + """This is large-scale A/B test of the JIT system for the 7dav incidence signal. + + Uses live API data and compares: + - the results of the new JIT system to the API data + - the results of the new JIT system to the Pandas-derived data + """ + source = "usa-facts" if "usa-facts" in str(test_data_filepath) else "jhu-csse" + print(test_data_filepath) + self._insert_csv(test_data_filepath) + + # Here we load: + # test_data_full_df - the original CSV file with our test data + # db_pandas_incidence_df - the incidence data pulled from the database as cumulative, placed on a contiguous index (live data has gaps), and then diffed via Pandas + # api_incidence_df - the incidence data as returned by the API from JIT + test_data_full_df = ( + pd.read_csv(test_data_filepath) + .assign(time_value=lambda x: pd.to_datetime(x["time_value"]), geo_value=lambda x: x["geo_value"].astype(str)) + .set_index(["signal", "geo_value", "time_value"]) + .sort_index() + ) + db_pandas_7dav_incidence_df = ( + self.get_source_signal_from_db(source, "confirmed_cumulative_num") + .groupby(["signal", "geo_value"]) + .apply(lambda x: x.reset_index().drop(columns=["signal", "geo_value"]).set_index("time_value").reindex(pd.date_range("2020-01-25", "2022-09-10"))) + .reset_index() + .rename(columns={"level_2": "time_value"}) + .set_index(["signal", "geo_value", "time_value"]) + .groupby(["signal", "geo_value"]) + .apply(lambda x: x["value"].reset_index().drop(columns=["signal", "geo_value"]).set_index("time_value").diff().rolling(7).mean()) + .reset_index() + .assign(signal="confirmed_7dav_incidence_num") + .set_index(["signal", "geo_value", "time_value"]) + ) + api_7dav_incidence_df = self.get_source_signal_from_api(source, "confirmed_7dav_incidence_num") + + # Join into one dataframe for easy comparison + test_data_full_df = test_data_full_df.join(db_pandas_7dav_incidence_df.value, rsuffix="_db_pandas") + test_data_full_df = test_data_full_df.join(api_7dav_incidence_df.value, rsuffix="_api_jit") + test_data_cumulative_df: pd.DataFrame = test_data_full_df.loc["confirmed_cumulative_num"] + test_data_7dav_incidence_df: pd.DataFrame = test_data_full_df.loc["confirmed_7dav_incidence_num"] + + # Test 1: show that Pandas-recomputed incidence (from cumulative) is identical to JIT incidence (up to 7 decimal places). + pandas_ne_jit = test_data_full_df[["value_db_pandas", "value_api_jit"]].dropna(how="any", axis=0) + pandas_ne_jit = pandas_ne_jit[pandas_ne_jit.value_db_pandas.sub(pandas_ne_jit.value_api_jit, fill_value=inf).abs().ge(1e-7)] + check.is_true(pandas_ne_jit.empty, "Check Pandas-JIT incidence match.") + if not pandas_ne_jit.empty: + print("Pandas-JIT incidence mismatch:") + print(pandas_ne_jit.to_string()) + + # Test 2: show that some JIT incidence values do not match live data. These are errors in the live data. + live_ne_jit = test_data_7dav_incidence_df[["value", "value_api_jit"]].dropna(how="any", axis=0) + live_ne_jit = live_ne_jit[live_ne_jit.value.sub(live_ne_jit.value_api_jit, fill_value=inf).abs().ge(1e-7)] + check.is_true(live_ne_jit.empty, "Check JIT-live match.") + if not live_ne_jit.empty: + print("JIT-live mismatch:") + print(live_ne_jit.to_string()) + + # Test 3: show that when JIT has a NAN, it is reasonable: the cumulative signal is either missing today or yesterday. + jit_nan_df = test_data_7dav_incidence_df[["value", "value_api_jit"]].query("value_api_jit.isna()") + jit_nan_df = reduce( + lambda x, y: pd.merge(x, y, how="outer", left_index=True, right_index=True), + ( + test_data_cumulative_df.filter(items=jit_nan_df.index.map(lambda x: (x[0], x[1] - timedelta(days=i))), axis=0)["value"].rename(f"value_{i}_days_past") + for i in range(8) + ), + ) + jit_nan_df = jit_nan_df.dropna(how="any", axis=0) + check.is_true(jit_nan_df.empty, "Check JIT NANs are reasonable.") + if not jit_nan_df.empty: + print("JIT NANs are not reasonable:") + print(jit_nan_df.to_string()) diff --git a/src/acquisition/covidcast/config.py b/src/acquisition/covidcast/config.py new file mode 100644 index 000000000..729408227 --- /dev/null +++ b/src/acquisition/covidcast/config.py @@ -0,0 +1,3 @@ +# TODO: Fill these in. +GEO_TYPES = ["county", "state", "hhs", "msa", "nation", "hrr"] +ALL_TIME = "19000101-20500101" diff --git a/src/acquisition/covidcast/covidcast_meta_cache_updater.py b/src/acquisition/covidcast/covidcast_meta_cache_updater.py index 624a1919b..d5550312a 100644 --- a/src/acquisition/covidcast/covidcast_meta_cache_updater.py +++ b/src/acquisition/covidcast/covidcast_meta_cache_updater.py @@ -6,7 +6,7 @@ import time # first party -from delphi.epidata.acquisition.covidcast.database import Database +from .database_meta import DatabaseMeta from delphi.epidata.acquisition.covidcast.logger import get_structured_logger from delphi.epidata.client.delphi_epidata import Epidata @@ -15,17 +15,20 @@ def get_argument_parser(): parser = argparse.ArgumentParser() parser.add_argument("--log_file", help="filename for log output") + parser.add_argument("--num_threads", type=int, help="number of worker threads to spawn for processing source/signal pairs") return parser -def main(args, epidata_impl=Epidata, database_impl=Database): +def main(args, epidata_impl: Epidata = Epidata, database_impl: DatabaseMeta = DatabaseMeta): """Update the covidcast metadata cache. `args`: parsed command-line arguments """ log_file = None + num_threads = None if (args): log_file = args.log_file + num_threads = args.num_threads logger = get_structured_logger( "metadata_cache_updater", @@ -37,7 +40,7 @@ def main(args, epidata_impl=Epidata, database_impl=Database): # fetch metadata try: metadata_calculation_start_time = time.time() - metadata = database.compute_covidcast_meta() + metadata = database.compute_covidcast_meta(n_threads=num_threads) metadata_calculation_interval_in_seconds = time.time() - metadata_calculation_start_time except: # clean up before failing diff --git a/src/acquisition/covidcast/covidcast_row.py b/src/acquisition/covidcast/covidcast_row.py index 8d9df70c6..af57b0b28 100644 --- a/src/acquisition/covidcast/covidcast_row.py +++ b/src/acquisition/covidcast/covidcast_row.py @@ -19,7 +19,9 @@ class CovidcastRow: Used for: - inserting rows into the database - - quickly creating rows with default fields for testing + - creating test rows with default fields for testing + - created from many formats (dict, csv, df, kwargs) + - can be viewed in many formats (dict, csv, df) The rows are specified in 'v4_schema.sql'. """ @@ -57,7 +59,7 @@ def __post_init__(self): # - 3. If this row was returned by the database. self._db_row_ignore_fields = [] - def _sanity_check_fields(self, test_mode: bool = True): + def _sanity_check_fields(self, extra_checks: bool = True): if self.issue and self.issue < self.time_value: self.issue = self.time_value @@ -68,13 +70,13 @@ def _sanity_check_fields(self, test_mode: bool = True): # This sanity checking is already done in CsvImporter, but it's here so the testing class gets it too. if _is_none(self.value) and self.missing_value == Nans.NOT_MISSING: - self.missing_value = Nans.NOT_APPLICABLE.value if test_mode else Nans.OTHER.value + self.missing_value = Nans.NOT_APPLICABLE.value if extra_checks else Nans.OTHER.value if _is_none(self.stderr) and self.missing_stderr == Nans.NOT_MISSING: - self.missing_stderr = Nans.NOT_APPLICABLE.value if test_mode else Nans.OTHER.value + self.missing_stderr = Nans.NOT_APPLICABLE.value if extra_checks else Nans.OTHER.value if _is_none(self.sample_size) and self.missing_sample_size == Nans.NOT_MISSING: - self.missing_sample_size = Nans.NOT_APPLICABLE.value if test_mode else Nans.OTHER.value + self.missing_sample_size = Nans.NOT_APPLICABLE.value if extra_checks else Nans.OTHER.value return self @@ -159,7 +161,8 @@ def geo_pair(self): def time_pair(self): return f"{self.time_type}:{self.time_value}" - +# TODO: Deprecate this class in favor of functions over the List[CovidcastRow] datatype. +# All the inner variables of this class are derived from the CovidcastRow class. @dataclass class CovidcastRows: rows: List[CovidcastRow] = field(default_factory=list) @@ -198,7 +201,7 @@ def from_args(sanity_check: bool = True, test_mode: bool = True, **kwargs: Dict[ # All the arg values must be lists of the same length. assert len(set(len(lst) for lst in kwargs.values())) == 1 - return CovidcastRows(rows=[CovidcastRow(**_kwargs)._sanity_check_fields(test_mode=test_mode) if sanity_check else CovidcastRow(**_kwargs) for _kwargs in transpose_dict(kwargs)]) + return CovidcastRows(rows=[CovidcastRow(**_kwargs)._sanity_check_fields(extra_checks=test_mode) if sanity_check else CovidcastRow(**_kwargs) for _kwargs in transpose_dict(kwargs)]) @staticmethod def from_records(records: Iterable[dict], sanity_check: bool = False): diff --git a/src/acquisition/covidcast/csv_to_database.py b/src/acquisition/covidcast/csv_to_database.py index 9828564d0..34cbad663 100644 --- a/src/acquisition/covidcast/csv_to_database.py +++ b/src/acquisition/covidcast/csv_to_database.py @@ -7,7 +7,7 @@ # first party from delphi.epidata.acquisition.covidcast.csv_importer import CsvImporter -from delphi.epidata.acquisition.covidcast.database import Database, CovidcastRow +from delphi.epidata.acquisition.covidcast.database import Database, CovidcastRow, DBLoadStateException from delphi.epidata.acquisition.covidcast.file_archiver import FileArchiver from delphi.epidata.acquisition.covidcast.logger import get_structured_logger @@ -77,12 +77,12 @@ def upload_archive( csv_importer_impl=CsvImporter): """Upload CSVs to the database and archive them using the specified handlers. - :path_details: output from CsvImporter.find*_csv_files - + :path_details: output from CsvImporter.find*_csv_files + :database: an open connection to the epidata database :handlers: functions for archiving (successful, failed) files - + :return: the number of modified rows """ archive_as_successful, archive_as_failed = handlers @@ -120,9 +120,13 @@ def upload_archive( if modified_row_count is None or modified_row_count: # else would indicate zero rows inserted total_modified_row_count += (modified_row_count if modified_row_count else 0) database.commit() + except DBLoadStateException as e: + # if the db is in a state that is not fit for loading new data, + # then we should stop processing any more files + raise e except Exception as e: all_rows_valid = False - logger.exception('exception while inserting rows:', e) + logger.exception('exception while inserting rows', exc_info=e) database.rollback() # archive the current file based on validation results @@ -130,7 +134,7 @@ def upload_archive( archive_as_successful(path_src, filename, source, logger) else: archive_as_failed(path_src, filename, source,logger) - + return total_modified_row_count @@ -149,7 +153,7 @@ def main( if not path_details: logger.info('nothing to do; exiting...') return - + logger.info("Ingesting CSVs", csv_count = len(path_details)) database = database_impl() @@ -161,13 +165,12 @@ def main( database, make_handlers(args.data_dir, args.specific_issue_date), logger) - logger.info("Finished inserting database rows", row_count = modified_row_count) - # the following print statement serves the same function as the logger.info call above - # print('inserted/updated %d rows' % modified_row_count) + logger.info("Finished inserting/updating database rows", row_count = modified_row_count) finally: + database.do_analyze() # unconditionally commit database changes since CSVs have been archived database.disconnect(True) - + logger.info( "Ingested CSVs into database", total_runtime_in_seconds=round(time.time() - start_time, 2)) diff --git a/src/acquisition/covidcast/database.py b/src/acquisition/covidcast/database.py index 4eec82194..5b0e54972 100644 --- a/src/acquisition/covidcast/database.py +++ b/src/acquisition/covidcast/database.py @@ -4,61 +4,14 @@ """ # third party -import json +from typing import Iterable, Sequence import mysql.connector -import numpy as np from math import ceil -from queue import Queue, Empty -import threading -from multiprocessing import cpu_count - -# first party import delphi.operations.secrets as secrets -from delphi.epidata.acquisition.covidcast.logger import get_structured_logger - -class CovidcastRow(): - """A container for all the values of a single covidcast row.""" - - @staticmethod - def fromCsvRowValue(row_value, source, signal, time_type, geo_type, time_value, issue, lag): - if row_value is None: return None - return CovidcastRow(source, signal, time_type, geo_type, time_value, - row_value.geo_value, - row_value.value, - row_value.stderr, - row_value.sample_size, - row_value.missing_value, - row_value.missing_stderr, - row_value.missing_sample_size, - issue, lag) - - @staticmethod - def fromCsvRows(row_values, source, signal, time_type, geo_type, time_value, issue, lag): - # NOTE: returns a generator, as row_values is expected to be a generator - return (CovidcastRow.fromCsvRowValue(row_value, source, signal, time_type, geo_type, time_value, issue, lag) - for row_value in row_values) - - def __init__(self, source, signal, time_type, geo_type, time_value, geo_value, value, stderr, - sample_size, missing_value, missing_stderr, missing_sample_size, issue, lag): - self.id = None - self.source = source - self.signal = signal - self.time_type = time_type - self.geo_type = geo_type - self.time_value = time_value - self.geo_value = geo_value # from CSV row - self.value = value # ... - self.stderr = stderr # ... - self.sample_size = sample_size # ... - self.missing_value = missing_value # ... - self.missing_stderr = missing_stderr # ... - self.missing_sample_size = missing_sample_size # from CSV row - self.direction_updated_timestamp = 0 - self.direction = None - self.issue = issue - self.lag = lag +from .logger import get_structured_logger +from .covidcast_row import CovidcastRow # constants for the codes used in the `process_status` column of `signal_load` @@ -68,29 +21,47 @@ class _PROCESS_STATUS(object): BATCHING = 'b' PROCESS_STATUS = _PROCESS_STATUS() +class DBLoadStateException(Exception): + pass class Database: """A collection of covidcast database operations.""" - DATABASE_NAME = 'covid' - - load_table = "signal_load" - latest_table = "signal_latest" # NOTE: careful! probably want to use variable `latest_view` instead for semantics purposes - latest_view = latest_table + "_v" - history_table = "signal_history" # NOTE: careful! probably want to use variable `history_view` instead for semantics purposes - history_view = history_table + "_v" - - - def connect(self, connector_impl=mysql.connector): + def __init__(self): + self.load_table = "epimetric_load" + # if you want to deal with foreign key ids: use table + # if you want to deal with source/signal names, geo type/values, etc: use view + self.latest_table = "epimetric_latest" + self.latest_view = self.latest_table + "_v" + self.history_table = "epimetric_full" + self.history_view = self.history_table + "_v" + # TODO: consider using class variables like this for dimension table names too + # TODO: also consider that for composite key tuples, like short_comp_key and long_comp_key as used in delete_batch() + + self._connector_impl = mysql.connector + self._db_credential_user, self._db_credential_password = secrets.db.epi + self._db_host = secrets.db.host + self._db_database = 'covid' + + def connect(self, connector_impl=None, host=None, user=None, password=None, database=None): """Establish a connection to the database.""" + if connector_impl: + self._connector_impl = connector_impl + if host: + self._db_host = host + if user: + self._db_credential_user = user + if password: + self._db_credential_password = password + if database: + self._db_database = database - u, p = secrets.db.epi - self._connector_impl = connector_impl self._connection = self._connector_impl.connect( - host=secrets.db.host, - user=u, - password=p, - database=Database.DATABASE_NAME) + host=self._db_host, + user=self._db_credential_user, + password=self._db_credential_password, + database=self._db_database + ) self._cursor = self._connection.cursor() def commit(self): @@ -111,52 +82,66 @@ def disconnect(self, commit): self._connection.close() - def count_all_rows(self, tablename=None): - """Return the total number of rows in table `covidcast`.""" - - if tablename is None: - tablename = self.history_view - - self._cursor.execute(f'SELECT count(1) FROM `{tablename}`') + def count_all_load_rows(self): + self._cursor.execute(f'SELECT count(1) FROM `{self.load_table}`') for (num,) in self._cursor: return num - def count_all_history_rows(self): - return self.count_all_rows(self.history_view) + def _reset_load_table_ai_counter(self): + """Corrects the AUTO_INCREMENT counter in the load table. - def count_all_latest_rows(self): - return self.count_all_rows(self.latest_view) - - def count_insertstatus_rows(self): - self._cursor.execute(f"SELECT count(1) from `{self.load_table}` where `process_status`='{PROCESS_STATUS.INSERTING}'") - - for (num,) in self._cursor: - return num + To be used in emergencies only, if the load table was accidentally TRUNCATEd. + This ensures any `epimetric_id`s generated by the load table will not collide with the history or latest tables. + This is also destructive to any data in the load table. + """ + self._cursor.execute(f'DELETE FROM epimetric_load') + # NOTE: 'ones' are used as filler here for the (required) NOT NULL columns. + self._cursor.execute(f""" + INSERT INTO epimetric_load + (epimetric_id, + source, `signal`, geo_type, geo_value, time_type, time_value, issue, `lag`, value_updated_timestamp) + VALUES + ((SELECT 1+MAX(epimetric_id) FROM epimetric_full), + '1', '1', '1', '1', '1', 1, 1, 1, 1);""") + self._cursor.execute(f'DELETE FROM epimetric_load') + + def do_analyze(self): + """performs and stores key distribution analyses, used for join order and index selection""" + # TODO: consider expanding this to update columns' histograms + # https://dev.mysql.com/doc/refman/8.0/en/analyze-table.html#analyze-table-histogram-statistics-analysis + self._cursor.execute( + f'''ANALYZE TABLE + signal_dim, geo_dim, + {self.load_table}, {self.history_table}, {self.latest_table}''') + output = [self._cursor.column_names] + self._cursor.fetchall() + get_structured_logger('do_analyze').info("ANALYZE results", results=str(output)) def insert_or_update_bulk(self, cc_rows): return self.insert_or_update_batch(cc_rows) - def insert_or_update_batch(self, cc_rows, batch_size=2**20, commit_partial=False): + def insert_or_update_batch(self, cc_rows: Sequence[CovidcastRow], batch_size: int = 2**20, commit_partial: bool = False, suppress_jobs: bool = False): """ - Insert new rows (or update existing) into the load table. - Data inserted this way will not be available to clients until the appropriate steps from src/dbjobs/ have run + Insert new rows into the load table and dispatch into dimension and fact tables. """ + if 0 != self.count_all_load_rows(): + err_msg = "Non-zero count in the load table!!! This indicates a previous acquisition run may have failed, another acquisition is in progress, or this process does not otherwise have exclusive access to the db!" + get_structured_logger("insert_or_update_batch").fatal(err_msg) + raise DBLoadStateException(err_msg) + # NOTE: `value_update_timestamp` is hardcoded to "NOW" (which is appropriate) and # `is_latest_issue` is hardcoded to 1 (which is temporary and addressed later in this method) insert_into_loader_sql = f''' INSERT INTO `{self.load_table}` (`source`, `signal`, `time_type`, `geo_type`, `time_value`, `geo_value`, `value_updated_timestamp`, `value`, `stderr`, `sample_size`, `issue`, `lag`, - `is_latest_issue`, `missing_value`, `missing_stderr`, `missing_sample_size`, - `process_status`) + `is_latest_issue`, `missing_value`, `missing_stderr`, `missing_sample_size`) VALUES (%s, %s, %s, %s, %s, %s, UNIX_TIMESTAMP(NOW()), %s, %s, %s, %s, %s, - 1, %s, %s, %s, - '{PROCESS_STATUS.INSERTING}') + 1, %s, %s, %s) ''' # all load table entries are already marked "is_latest_issue". @@ -164,24 +149,12 @@ def insert_or_update_batch(self, cc_rows, batch_size=2**20, commit_partial=False # if an entry *IS* in both load and latest tables, but latest table issue is newer, unmark is_latest_issue in load. fix_is_latest_issue_sql = f''' UPDATE - `{self.load_table}` JOIN `{self.latest_view}` + `{self.load_table}` JOIN `{self.latest_view}` USING (`source`, `signal`, `geo_type`, `geo_value`, `time_type`, `time_value`) SET `{self.load_table}`.`is_latest_issue`=0 - WHERE `{self.load_table}`.`issue` < `{self.latest_view}`.`issue` - AND `process_status` = '{PROCESS_STATUS.INSERTING}' + WHERE `{self.load_table}`.`issue` < `{self.latest_view}`.`issue` ''' - update_status_sql = f''' - UPDATE `{self.load_table}` - SET `process_status` = '{PROCESS_STATUS.LOADED}' - WHERE `process_status` = '{PROCESS_STATUS.INSERTING}' - ''' - - if 0 != self.count_insertstatus_rows(): - # TODO: determine if this should be fatal?! - logger = get_structured_logger("insert_or_update_batch") - logger.warn("Non-zero count in the load table!!! This indicates scheduling of acqusition and dbjobs may be out of sync.") - # TODO: consider handling cc_rows as a generator instead of a list try: @@ -216,7 +189,8 @@ def insert_or_update_batch(self, cc_rows, batch_size=2**20, commit_partial=False self._cursor.executemany(insert_into_loader_sql, args) modified_row_count = self._cursor.rowcount self._cursor.execute(fix_is_latest_issue_sql) - self._cursor.execute(update_status_sql) + if not suppress_jobs: + self.run_dbjobs() # TODO: incorporate the logic of dbjobs() into this method [once calls to dbjobs() are no longer needed for migrations] if modified_row_count is None or modified_row_count == -1: # the SQL connector does not support returning number of rows affected (see PEP 249) @@ -226,60 +200,45 @@ def insert_or_update_batch(self, cc_rows, batch_size=2**20, commit_partial=False if commit_partial: self._connection.commit() except Exception as e: - # TODO: rollback??? something??? + # rollback is handled in csv_to_database; if you're calling this yourself, handle your own rollback raise e return total def run_dbjobs(self): - signal_load_set_comp_keys = f''' - UPDATE `{self.load_table}` - SET compressed_signal_key = md5(CONCAT(`source`,`signal`)), - compressed_geo_key = md5(CONCAT(`geo_type`,`geo_value`)) - ''' - - signal_load_mark_batch = f''' - UPDATE `{self.load_table}` - SET process_status = '{PROCESS_STATUS.BATCHING}' - ''' - + # we do this LEFT JOIN trick because mysql cant do set difference (aka EXCEPT or MINUS) + # (as in " select distinct source, signal from signal_dim minus select distinct source, signal from epimetric_load ") signal_dim_add_new_load = f''' - INSERT INTO signal_dim (`source`, `signal`, `compressed_signal_key`) - SELECT DISTINCT `source`, `signal`, compressed_signal_key - FROM `{self.load_table}` - WHERE compressed_signal_key NOT IN - (SELECT DISTINCT compressed_signal_key - FROM signal_dim) + INSERT INTO signal_dim (`source`, `signal`) + SELECT DISTINCT sl.source, sl.signal + FROM {self.load_table} AS sl LEFT JOIN signal_dim AS sd + USING (`source`, `signal`) + WHERE sd.source IS NULL ''' + # again, same trick to get around lack of EXCEPT/MINUS geo_dim_add_new_load = f''' - INSERT INTO geo_dim (`geo_type`, `geo_value`, `compressed_geo_key`) - SELECT DISTINCT `geo_type`, `geo_value`, compressed_geo_key - FROM `{self.load_table}` - WHERE compressed_geo_key NOT IN - (SELECT DISTINCT compressed_geo_key - FROM geo_dim) + INSERT INTO geo_dim (`geo_type`, `geo_value`) + SELECT DISTINCT sl.geo_type, sl.geo_value + FROM {self.load_table} AS sl LEFT JOIN geo_dim AS gd + USING (`geo_type`, `geo_value`) + WHERE gd.geo_type IS NULL ''' - signal_history_load = f''' - INSERT INTO signal_history - (signal_data_id, signal_key_id, geo_key_id, demog_key_id, issue, data_as_of_dt, - time_type, time_value, `value`, stderr, sample_size, `lag`, value_updated_timestamp, - computation_as_of_dt, is_latest_issue, missing_value, missing_stderr, missing_sample_size, `legacy_id`) + epimetric_full_load = f''' + INSERT INTO {self.history_table} + (epimetric_id, signal_key_id, geo_key_id, issue, data_as_of_dt, + time_type, time_value, `value`, stderr, sample_size, `lag`, value_updated_timestamp, + computation_as_of_dt, missing_value, missing_stderr, missing_sample_size) SELECT - signal_data_id, sd.signal_key_id, gd.geo_key_id, 0, issue, data_as_of_dt, - time_type, time_value, `value`, stderr, sample_size, `lag`, value_updated_timestamp, - computation_as_of_dt, is_latest_issue, missing_value, missing_stderr, missing_sample_size, `legacy_id` + epimetric_id, sd.signal_key_id, gd.geo_key_id, issue, data_as_of_dt, + time_type, time_value, `value`, stderr, sample_size, `lag`, value_updated_timestamp, + computation_as_of_dt, missing_value, missing_stderr, missing_sample_size FROM `{self.load_table}` sl - INNER JOIN signal_dim sd - USE INDEX(`compressed_signal_key_ind`) - ON sd.compressed_signal_key = sl.compressed_signal_key - INNER JOIN geo_dim gd - USE INDEX(`compressed_geo_key_ind`) - ON gd.compressed_geo_key = sl.compressed_geo_key - WHERE process_status = '{PROCESS_STATUS.BATCHING}' + INNER JOIN signal_dim sd USING (source, `signal`) + INNER JOIN geo_dim gd USING (geo_type, geo_value) ON DUPLICATE KEY UPDATE - `signal_data_id` = sl.`signal_data_id`, + `epimetric_id` = sl.`epimetric_id`, `value_updated_timestamp` = sl.`value_updated_timestamp`, `value` = sl.`value`, `stderr` = sl.`stderr`, @@ -290,26 +249,21 @@ def run_dbjobs(self): `missing_sample_size` = sl.`missing_sample_size` ''' - signal_latest_load = f''' - INSERT INTO signal_latest - (signal_data_id, signal_key_id, geo_key_id, demog_key_id, issue, data_as_of_dt, - time_type, time_value, `value`, stderr, sample_size, `lag`, value_updated_timestamp, + epimetric_latest_load = f''' + INSERT INTO {self.latest_table} + (epimetric_id, signal_key_id, geo_key_id, issue, data_as_of_dt, + time_type, time_value, `value`, stderr, sample_size, `lag`, value_updated_timestamp, computation_as_of_dt, missing_value, missing_stderr, missing_sample_size) SELECT - signal_data_id, sd.signal_key_id, gd.geo_key_id, 0, issue, data_as_of_dt, - time_type, time_value, `value`, stderr, sample_size, `lag`, value_updated_timestamp, + epimetric_id, sd.signal_key_id, gd.geo_key_id, issue, data_as_of_dt, + time_type, time_value, `value`, stderr, sample_size, `lag`, value_updated_timestamp, computation_as_of_dt, missing_value, missing_stderr, missing_sample_size - FROM `{self.load_table}` sl - INNER JOIN signal_dim sd - USE INDEX(`compressed_signal_key_ind`) - ON sd.compressed_signal_key = sl.compressed_signal_key - INNER JOIN geo_dim gd - USE INDEX(`compressed_geo_key_ind`) - ON gd.compressed_geo_key = sl.compressed_geo_key - WHERE process_status = '{PROCESS_STATUS.BATCHING}' - AND is_latest_issue = 1 + FROM `{self.load_table}` sl + INNER JOIN signal_dim sd USING (source, `signal`) + INNER JOIN geo_dim gd USING (geo_type, geo_value) + WHERE is_latest_issue = 1 ON DUPLICATE KEY UPDATE - `signal_data_id` = sl.`signal_data_id`, + `epimetric_id` = sl.`epimetric_id`, `value_updated_timestamp` = sl.`value_updated_timestamp`, `value` = sl.`value`, `stderr` = sl.`stderr`, @@ -318,32 +272,44 @@ def run_dbjobs(self): `lag` = sl.`lag`, `missing_value` = sl.`missing_value`, `missing_stderr` = sl.`missing_stderr`, - `missing_sample_size` = sl.`missing_sample_size` + `missing_sample_size` = sl.`missing_sample_size` ''' - signal_load_delete_processed = f''' - DELETE FROM `{self.load_table}` - WHERE process_status <> '{PROCESS_STATUS.LOADED}' + # NOTE: DO NOT `TRUNCATE` THIS TABLE! doing so will ruin the AUTO_INCREMENT counter that the history and latest tables depend on... + epimetric_load_delete_processed = f''' + DELETE FROM `{self.load_table}` ''' - print('signal_load_set_comp_keys:') - self._cursor.execute(signal_load_set_comp_keys) - print('signal_load_mark_batch:') - self._cursor.execute(signal_load_mark_batch) - print('signal_dim_add_new_load:') - self._cursor.execute(signal_dim_add_new_load) - print('geo_dim_add_new_load:') - self._cursor.execute(geo_dim_add_new_load) - print('signal_history_load:') - self._cursor.execute(signal_history_load) - print('signal_latest_load:') - self._cursor.execute(signal_latest_load) - print('signal_load_delete_processed:') - self._cursor.execute(signal_load_delete_processed) - print("done.") + logger = get_structured_logger("run_dbjobs") + import time + time_q = [time.time()] + + try: + self._cursor.execute(signal_dim_add_new_load) + time_q.append(time.time()) + logger.debug('signal_dim_add_new_load', rows=self._cursor.rowcount, elapsed=time_q[-1]-time_q[-2]) + + self._cursor.execute(geo_dim_add_new_load) + time_q.append(time.time()) + logger.debug('geo_dim_add_new_load', rows=self._cursor.rowcount, elapsed=time_q[-1]-time_q[-2]) + + self._cursor.execute(epimetric_full_load) + time_q.append(time.time()) + logger.debug('epimetric_full_load', rows=self._cursor.rowcount, elapsed=time_q[-1]-time_q[-2]) + + self._cursor.execute(epimetric_latest_load) + time_q.append(time.time()) + logger.debug('epimetric_latest_load', rows=self._cursor.rowcount, elapsed=time_q[-1]-time_q[-2]) + + self._cursor.execute(epimetric_load_delete_processed) + time_q.append(time.time()) + logger.debug('epimetric_load_delete_processed', rows=self._cursor.rowcount, elapsed=time_q[-1]-time_q[-2]) + except Exception as e: + raise e return self + def delete_batch(self, cc_deletions): """ Remove rows specified by a csv file or list of tuples. @@ -376,9 +342,11 @@ def delete_batch(self, cc_deletions): # composite keys: short_comp_key = "`source`, `signal`, `time_type`, `geo_type`, `time_value`, `geo_value`" long_comp_key = short_comp_key + ", `issue`" + short_comp_ref_key = "`signal_key_id`, `geo_key_id`, `time_type`, `time_value`" + long_comp_ref_key = short_comp_ref_key + ", `issue`" create_tmp_table_sql = f''' -CREATE OR REPLACE TABLE {tmp_table_name} LIKE {self.load_table}; +CREATE TABLE {tmp_table_name} LIKE {self.load_table}; ''' amend_tmp_table_sql = f''' @@ -407,46 +375,50 @@ def delete_batch(self, cc_deletions): add_history_id_sql = f''' UPDATE {tmp_table_name} d INNER JOIN {self.history_view} h USING ({long_comp_key}) -SET d.delete_history_id=h.signal_data_id; +SET d.delete_history_id=h.epimetric_id; ''' # if a row we are deleting also appears in the 'latest' table (with a matching 'issue')... mark_for_update_latest_sql = f''' UPDATE {tmp_table_name} d INNER JOIN {self.latest_view} ell USING ({long_comp_key}) -SET d.update_latest=1, d.delete_latest_id=ell.signal_data_id; +SET d.update_latest=1, d.delete_latest_id=ell.epimetric_id; ''' delete_history_sql = f''' -DELETE h FROM {tmp_table_name} d INNER JOIN {self.history_table} h ON d.delete_history_id=h.signal_data_id; +DELETE h FROM {tmp_table_name} d INNER JOIN {self.history_table} h ON d.delete_history_id=h.epimetric_id; ''' # ...remove it from 'latest'... delete_latest_sql = f''' -DELETE ell FROM {tmp_table_name} d INNER JOIN {self.latest_table} ell ON d.delete_latest_id=ell.signal_data_id; +DELETE ell FROM {tmp_table_name} d INNER JOIN {self.latest_table} ell ON d.delete_latest_id=ell.epimetric_id; ''' # ...and re-write that record with its next-latest issue (from 'history') instead. # NOTE: this must be executed *AFTER* `delete_history_sql` to ensure we get the correct `issue` # AND also after `delete_latest_sql` so that we dont get a key collision on insert. update_latest_sql = f''' -INSERT INTO signal_latest - (issue, - signal_data_id, signal_key_id, geo_key_id, time_type, time_value, +INSERT INTO {self.latest_table} + (epimetric_id, + signal_key_id, geo_key_id, time_type, time_value, issue, value, stderr, sample_size, `lag`, value_updated_timestamp, missing_value, missing_stderr, missing_sample_size) SELECT - MAX(h.issue), - h.signal_data_id, h.signal_key_id, h.geo_key_id, h.time_type, h.time_value, + h.epimetric_id, + h.signal_key_id, h.geo_key_id, h.time_type, h.time_value, h.issue, h.value, h.stderr, h.sample_size, h.`lag`, h.value_updated_timestamp, h.missing_value, h.missing_stderr, h.missing_sample_size -FROM {self.history_view} h JOIN {tmp_table_name} d USING ({short_comp_key}) -WHERE d.update_latest=1 GROUP BY {short_comp_key}; +FROM {self.history_view} h JOIN ( + SELECT {short_comp_key}, MAX(hh.issue) AS issue + FROM {self.history_view} hh JOIN {tmp_table_name} dd USING ({short_comp_key}) + WHERE dd.update_latest=1 GROUP BY {short_comp_key} + ) d USING ({long_comp_key}); ''' - drop_tmp_table_sql = f'DROP TABLE {tmp_table_name}' + drop_tmp_table_sql = f'DROP TABLE IF EXISTS {tmp_table_name}' total = None try: + self._cursor.execute(drop_tmp_table_sql) self._cursor.execute(create_tmp_table_sql) self._cursor.execute(amend_tmp_table_sql) if isinstance(cc_deletions, str): @@ -457,15 +429,21 @@ def split_list(lst, n): yield lst[i:(i+n)] for deletions_batch in split_list(cc_deletions, 100000): self._cursor.executemany(load_tmp_table_insert_sql, deletions_batch) + print(f"load_tmp_table_insert_sql:{self._cursor.rowcount}") else: raise Exception(f"Bad deletions argument: need a filename or a list of tuples; got a {type(cc_deletions)}") self._cursor.execute(add_history_id_sql) + print(f"add_history_id_sql:{self._cursor.rowcount}") self._cursor.execute(mark_for_update_latest_sql) + print(f"mark_for_update_latest_sql:{self._cursor.rowcount}") self._cursor.execute(delete_history_sql) + print(f"delete_history_sql:{self._cursor.rowcount}") total = self._cursor.rowcount # TODO: consider reporting rows removed and/or replaced in latest table as well self._cursor.execute(delete_latest_sql) + print(f"delete_latest_sql:{self._cursor.rowcount}") self._cursor.execute(update_latest_sql) + print(f"update_latest_sql:{self._cursor.rowcount}") self._connection.commit() if total == -1: @@ -476,131 +454,3 @@ def split_list(lst, n): finally: self._cursor.execute(drop_tmp_table_sql) return total - - - def compute_covidcast_meta(self, table_name=None): - """Compute and return metadata on all COVIDcast signals.""" - logger = get_structured_logger("compute_covidcast_meta") - - if table_name is None: - table_name = self.latest_view - - n_threads = max(1, cpu_count()*9//10) # aka number of concurrent db connections, which [sh|c]ould be ~<= 90% of the #cores available to SQL server - # NOTE: this may present a small problem if this job runs on different hardware than the db, - # but we should not run into that issue in prod. - logger.info(f"using {n_threads} workers") - - srcsigs = Queue() # multi-consumer threadsafe! - sql = f'SELECT `source`, `signal` FROM `{table_name}` GROUP BY `source`, `signal` ORDER BY `source` ASC, `signal` ASC;' - self._cursor.execute(sql) - for source, signal in self._cursor: - srcsigs.put((source, signal)) - - inner_sql = f''' - SELECT - `source` AS `data_source`, - `signal`, - `time_type`, - `geo_type`, - MIN(`time_value`) AS `min_time`, - MAX(`time_value`) AS `max_time`, - COUNT(DISTINCT `geo_value`) AS `num_locations`, - MIN(`value`) AS `min_value`, - MAX(`value`) AS `max_value`, - ROUND(AVG(`value`),7) AS `mean_value`, - ROUND(STD(`value`),7) AS `stdev_value`, - MAX(`value_updated_timestamp`) AS `last_update`, - MAX(`issue`) as `max_issue`, - MIN(`lag`) as `min_lag`, - MAX(`lag`) as `max_lag` - FROM - `{table_name}` - WHERE - `source` = %s AND - `signal` = %s - GROUP BY - `time_type`, - `geo_type` - ORDER BY - `time_type` ASC, - `geo_type` ASC - ''' - - meta = [] - meta_lock = threading.Lock() - - def worker(): - name = threading.current_thread().name - logger.info("starting thread", thread=name) - # set up new db connection for thread - worker_dbc = Database() - worker_dbc.connect(connector_impl=self._connector_impl) - w_cursor = worker_dbc._cursor - try: - while True: - (source, signal) = srcsigs.get_nowait() # this will throw the Empty caught below - logger.info("starting pair", thread=name, pair=f"({source}, {signal})") - w_cursor.execute(inner_sql, (source, signal)) - with meta_lock: - meta.extend(list( - dict(zip(w_cursor.column_names, x)) for x in w_cursor - )) - srcsigs.task_done() - except Empty: - logger.info("no jobs left, thread terminating", thread=name) - finally: - worker_dbc.disconnect(False) # cleanup - - threads = [] - for n in range(n_threads): - t = threading.Thread(target=worker, name='MetacacheThread-'+str(n)) - t.start() - threads.append(t) - - srcsigs.join() - logger.info("jobs complete") - for t in threads: - t.join() - logger.info("all threads terminated") - - # sort the metadata because threaded workers dgaf - sorting_fields = "data_source signal time_type geo_type".split() - sortable_fields_fn = lambda x: [(field, x[field]) for field in sorting_fields] - prepended_sortables_fn = lambda x: sortable_fields_fn(x) + list(x.items()) - tuple_representation = list(map(prepended_sortables_fn, meta)) - tuple_representation.sort() - meta = list(map(dict, tuple_representation)) # back to dict form - - return meta - - - def update_covidcast_meta_cache(self, metadata): - """Updates the `covidcast_meta_cache` table.""" - - sql = ''' - UPDATE - `covidcast_meta_cache` - SET - `timestamp` = UNIX_TIMESTAMP(NOW()), - `epidata` = %s - ''' - epidata_json = json.dumps(metadata) - - self._cursor.execute(sql, (epidata_json,)) - - def retrieve_covidcast_meta_cache(self): - """Useful for viewing cache entries (was used in debugging)""" - - sql = ''' - SELECT `epidata` - FROM `covidcast_meta_cache` - ORDER BY `timestamp` DESC - LIMIT 1; - ''' - self._cursor.execute(sql) - cache_json = self._cursor.fetchone()[0] - cache = json.loads(cache_json) - cache_hash = {} - for entry in cache: - cache_hash[(entry['data_source'], entry['signal'], entry['time_type'], entry['geo_type'])] = entry - return cache_hash diff --git a/src/acquisition/covidcast/database_meta.py b/src/acquisition/covidcast/database_meta.py new file mode 100644 index 000000000..005af02de --- /dev/null +++ b/src/acquisition/covidcast/database_meta.py @@ -0,0 +1,461 @@ +from dataclasses import asdict, dataclass, fields +from datetime import datetime +import json +from multiprocessing import cpu_count +from queue import Queue, Empty +import threading +from typing import Dict, List, Tuple + +import pandas as pd +from requests import get + +# TODO: Switch to epidatpy when we release it https://github.com/cmu-delphi/delphi-epidata/issues/942. +# from epidatpy.request import Epidata, EpiRange + +from .logger import get_structured_logger +from .covidcast_row import CovidcastRow, set_df_dtypes +from .database import Database +from .config import GEO_TYPES, ALL_TIME +from ...server.endpoints.covidcast_utils.model import DataSignal, data_signals_by_key + + +@dataclass +class MetaTableRow: + data_source: str + signal: str + time_type: str + geo_type: str + min_time: int + max_time: int + num_locations: int + min_value: float + max_value: float + mean_value: float + stdev_value: float + last_update: int + max_issue: int + min_lag: int + max_lag: int + + def as_df(self): + df = pd.DataFrame( + { + "data_source": self.data_source, + "signal": self.signal, + "time_type": self.time_type, + "geo_type": self.geo_type, + "min_time": self.min_time, + "max_time": self.max_time, + "num_locations": self.num_locations, + "min_value": self.min_value, + "max_value": self.max_value, + "mean_value": self.mean_value, + "stdev_value": self.stdev_value, + "last_update": self.last_update, + "max_issue": self.max_issue, + "min_lag": self.min_lag, + "max_lag": self.max_lag, + }, + index=[0], + ) + set_df_dtypes( + df, + dtypes={ + "data_source": str, + "signal": str, + "time_type": str, + "geo_type": str, + "min_time": int, + "max_time": int, + "num_locations": int, + "min_value": float, + "max_value": float, + "mean_value": float, + "stdev_value": float, + "last_update": int, + "max_issue": int, + "min_lag": int, + "max_lag": int, + }, + ) + return df + + def as_dict(self): + return asdict(self) + + @staticmethod + def _extract_fields(group_df): + if "source" in group_df.columns: + assert group_df["source"].unique().size == 1 + source = group_df["source"].iloc[0] + elif "data_source" in group_df.columns: + assert group_df["data_source"].unique().size == 1 + source = group_df["data_source"].iloc[0] + else: + raise ValueError("Source name not found in group_df.") + + if "signal" in group_df.columns: + assert group_df["signal"].unique().size == 1 + signal = group_df["signal"].iloc[0] + else: + raise ValueError("Signal name not found in group_df.") + + if "time_type" in group_df.columns: + assert group_df["time_type"].unique().size == 1 + time_type = group_df["time_type"].iloc[0] + else: + raise ValueError("Time type not found in group_df.") + + if "geo_type" in group_df.columns: + assert group_df["geo_type"].unique().size == 1 + geo_type = group_df["geo_type"].iloc[0] + else: + raise ValueError("Geo type not found in group_df.") + + if "value_updated_timestamp" in group_df.columns: + last_updated = max(group_df["value_updated_timestamp"]) + else: + last_updated = int(datetime.now().timestamp()) + + return source, signal, time_type, geo_type, last_updated + + @staticmethod + def from_group_df(group_df): + if group_df is None or group_df.empty: + raise ValueError("Empty group_df given.") + + source, signal, time_type, geo_type, last_updated = MetaTableRow._extract_fields(group_df) + + return MetaTableRow( + data_source=source, + signal=signal, + time_type=time_type, + geo_type=geo_type, + min_time=min(group_df["time_value"]), + max_time=max(group_df["time_value"]), + num_locations=len(group_df["geo_value"].unique()), + min_value=min(group_df["value"]), + max_value=max(group_df["value"]), + mean_value=group_df["value"].mean().round(7), + stdev_value=group_df["value"].std(ddof=0).round(7), + last_update=last_updated, + max_issue=max(group_df["issue"]), + min_lag=min(group_df["lag"]), + max_lag=max(group_df["lag"]), + ) + +class DatabaseMeta(Database): + # TODO: Verify the correct base_url for a local API server. + def __init__(self, base_url: str = "http://localhost/epidata") -> "DatabaseMeta": + Database.__init__(self) + self.epidata_base_url = base_url + # TODO: Switch to epidatpy when we release it https://github.com/cmu-delphi/delphi-epidata/issues/942. + self.delphi_epidata = False + + def compute_covidcast_meta(self, table_name=None, jit=False, parallel=False, n_threads=None): + """This wrapper is here for A/B testing the JIT and non-JIT metadata computation. + + TODO: Remove after code review. + """ + return self.compute_covidcast_meta_new(table_name, parallel) if jit else self.compute_covidcast_meta_old(table_name) + + def compute_covidcast_meta_old(self, table_name=None): + """This is the old method (not using JIT) to compute and return metadata on all COVIDcast signals. + + TODO: This is here for A/B testing. Remove this after code review. + """ + logger = get_structured_logger("compute_covidcast_meta") + + if table_name is None: + table_name = self.latest_view + + n_threads = max(1, cpu_count() * 9 // 10) # aka number of concurrent db connections, which [sh|c]ould be ~<= 90% of the #cores available to SQL server + # NOTE: this may present a small problem if this job runs on different hardware than the db, + # but we should not run into that issue in prod. + logger.info(f"using {n_threads} workers") + + srcsigs = Queue() # multi-consumer threadsafe! + sql = f"SELECT `source`, `signal` FROM `{table_name}` GROUP BY `source`, `signal` ORDER BY `source` ASC, `signal` ASC;" + self._cursor.execute(sql) + for source, signal in self._cursor: + srcsigs.put((source, signal)) + + inner_sql = f""" + SELECT + `source` AS `data_source`, + `signal`, + `time_type`, + `geo_type`, + MIN(`time_value`) AS `min_time`, + MAX(`time_value`) AS `max_time`, + COUNT(DISTINCT `geo_value`) AS `num_locations`, + MIN(`value`) AS `min_value`, + MAX(`value`) AS `max_value`, + ROUND(AVG(`value`),7) AS `mean_value`, + ROUND(STD(`value`),7) AS `stdev_value`, + MAX(`value_updated_timestamp`) AS `last_update`, + MAX(`issue`) as `max_issue`, + MIN(`lag`) as `min_lag`, + MAX(`lag`) as `max_lag` + FROM + `{table_name}` + WHERE + `source` = %s AND + `signal` = %s + GROUP BY + `time_type`, + `geo_type` + ORDER BY + `time_type` ASC, + `geo_type` ASC + """ + + meta = [] + meta_lock = threading.Lock() + + def worker(): + name = threading.current_thread().name + logger.info("starting thread", thread=name) + # set up new db connection for thread + worker_dbc = Database() + worker_dbc.connect(connector_impl=self._connector_impl, host=self._db_host, user=self._db_credential_user, password=self._db_credential_password, database=self._db_database) + w_cursor = worker_dbc._cursor + try: + while True: + (source, signal) = srcsigs.get_nowait() # this will throw the Empty caught below + logger.info("starting pair", thread=name, pair=f"({source}, {signal})") + w_cursor.execute(inner_sql, (source, signal)) + with meta_lock: + meta.extend(list(dict(zip(w_cursor.column_names, x)) for x in w_cursor)) + srcsigs.task_done() + except Empty: + logger.info("no jobs left, thread terminating", thread=name) + finally: + worker_dbc.disconnect(False) # cleanup + + threads = [] + for n in range(n_threads): + t = threading.Thread(target=worker, name="MetacacheThread-" + str(n)) + t.start() + threads.append(t) + + srcsigs.join() + logger.info("jobs complete") + for t in threads: + t.join() + logger.info("all threads terminated") + + # sort the metadata because threaded workers dgaf + sorting_fields = "data_source signal time_type geo_type".split() + sortable_fields_fn = lambda x: [(field, x[field]) for field in sorting_fields] + prepended_sortables_fn = lambda x: sortable_fields_fn(x) + list(x.items()) + tuple_representation = list(map(prepended_sortables_fn, meta)) + tuple_representation.sort() + meta = list(map(dict, tuple_representation)) # back to dict form + + return meta + + def get_source_sig_list(self, data_signal_table: Dict[Tuple[str, str], DataSignal] = data_signals_by_key, derived: bool = False) -> List[Tuple[str]]: + """Return the source-signal pair names from the database. + + The derived flag determines whether the signals returned are derived or base signals. + """ + return [(data_signal.source, data_signal.signal) for data_signal in data_signal_table.values() if data_signal.compute_from_base == derived] + + def compute_base_signal_meta(self, source: str, signal: str, table_name: str = None) -> pd.DataFrame: + """Compute the meta information for base signals. + + A base signal is a signal whose values do not depend on another signal. A derived signal is one whose values are obtained + through a transformation of a base signal, e.g. a 7 day average signal or an incidence (compared to cumulative) signal. + """ + if table_name is None: + table_name = self.latest_view + + inner_sql = f""" + SELECT + `source` AS `data_source`, + `signal`, + `time_type`, + `geo_type`, + MIN(`time_value`) AS `min_time`, + MAX(`time_value`) AS `max_time`, + COUNT(DISTINCT `geo_value`) AS `num_locations`, + MIN(`value`) AS `min_value`, + MAX(`value`) AS `max_value`, + ROUND(AVG(`value`),7) AS `mean_value`, + ROUND(STD(`value`),7) AS `stdev_value`, + MAX(`value_updated_timestamp`) AS `last_update`, + MAX(`issue`) as `max_issue`, + MIN(`lag`) as `min_lag`, + MAX(`lag`) as `max_lag` + FROM + `{table_name}` + WHERE + `source` = %s AND + `signal` = %s + GROUP BY + `time_type`, + `geo_type` + ORDER BY + `time_type` ASC, + `geo_type` ASC + """ + + # TODO: Consider whether we really need this new object. Maybe for parallel or maybe not. + db = Database() + db.connect(connector_impl=self._connector_impl, host=self._db_host, user=self._db_credential_user, password=self._db_credential_password, database=self._db_database) + db._cursor.execute(inner_sql, (source, signal)) + base_signal_meta = pd.DataFrame(db._cursor.fetchall(), columns=["data_source", "signal", "time_type", "geo_type", "min_time", "max_time", "num_locations", "min_value", "max_value", "mean_value", "stdev_value", "last_update", "max_issue", "min_lag", "max_lag"]) + + return base_signal_meta + + def compute_derived_signal_meta(self, source: str, signal: str, base_signal_meta: pd.DataFrame, data_signal_table: Dict[Tuple[str, str], DataSignal] = data_signals_by_key) -> pd.DataFrame: + """Compute the meta information for a derived signal. + + A derived signal is a transformation of a base signal. Since derived signals are not stored in the database, but are computed + on the fly by the API, we call the API here. It is assumed that we have already computed the meta information for the base + signals and passed that in base_signal_meta. The latter is needed to set the `last_updated` field. + """ + logger = get_structured_logger("get_derived_signal_meta") + + meta_table_columns = [field.name for field in fields(MetaTableRow)] + covidcast_response_columns = [field.name for field in fields(CovidcastRow) if field.name not in CovidcastRow()._api_row_ignore_fields] + + # We should be able to find the signal in our table. + data_signal = data_signal_table.get((source, signal)) + if not data_signal: + logger.warn(f"Could not find the requested derived signal {source}:{signal} in the data signal table. Returning no meta results.") + return pd.DataFrame(columns=meta_table_columns) + + # Request all the data for the derived signal. + # TODO: Use when epidatpy is released https://github.com/cmu-delphi/delphi-epidata/issues/942. + if self.delphi_epidata: + raise NotImplemented("Use the old epidata client for now.") + # TODO: Consider refactoring to combine multiple signal requests in one call. + all_time = EpiRange(19000101, 20500101) + epidata = Epidata.with_base_url(self.epidata_base_url) + api_response_df = pd.concat([epidata.covidcast(data_source=source, signals=signal, time_type=data_signal.time_type, geo_type=geo_type, time_values=all_time, geo_values="*").df() for geo_type in GEO_TYPES]) + else: + base_url = f"{self.epidata_base_url}/covidcast/" + params = {"data_source": source, "signals": signal, "time_type": data_signal.time_type, "time_values": ALL_TIME, "geo_values": "*"} + signal_data_dfs = [] + for geo_type in GEO_TYPES: + params.update({"geo_type": geo_type}) + response = get(base_url, params) + if response.status_code in [200]: + signal_data_dfs.append(pd.DataFrame.from_records(response.json()['epidata'], columns=covidcast_response_columns)) + else: + raise Exception(f"The API responded with an error when attempting to get data for the derived signal's {source}:{signal} meta computation. There may be an issue with the API server.") + + # Group the data by time_type and geo_type and find the statistical summaries for their values. + meta_rows = [MetaTableRow.from_group_df(group_df).as_df() for signal_data_df in signal_data_dfs for _, group_df in signal_data_df.groupby("time_type")] + if meta_rows: + meta_df = pd.concat(meta_rows) + else: + logger.warn(f"The meta computation for {source}:{signal} returned no summary statistics. There may be an issue with the API server or the database.") + return pd.DataFrame(columns=meta_table_columns) + + # Copy the value of 'last_updated' column from the base signal meta to the derived signal meta. + # TODO: Remove if/when we remove the 'last_updated' column. + meta_df = pd.merge( + meta_df.assign(parent_signal = data_signal.signal_basename), + base_signal_meta[["data_source", "signal", "time_type", "geo_type", "last_update"]], + left_on = ["data_source", "parent_signal", "time_type", "geo_type"], + right_on = ["data_source", "signal", "time_type", "geo_type"] + ) + meta_df = meta_df.assign(signal = meta_df["signal_x"], last_update = meta_df["last_update_y"]) + + return meta_df[meta_table_columns] + + def compute_covidcast_meta_new(self, table_name=None, parallel=True, data_signal_table: Dict[Tuple[str, str], DataSignal] = data_signals_by_key) -> Dict: + """Compute and return metadata on all non-WIP COVIDcast signals.""" + logger = get_structured_logger("compute_covidcast_meta") + + if table_name is None: + table_name = self.latest_view + + if parallel: + n_threads = max(1, cpu_count() * 9 // 10) # aka number of concurrent db connections, which [sh|c]ould be ~<= 90% of the #cores available to SQL server + # NOTE: this may present a small problem if this job runs on different hardware than the db, + # but we should not run into that issue in prod. + logger.info(f"using {n_threads} workers") + + srcsigs = Queue() # multi-consumer threadsafe! + for source, signal in self.get_source_sig_list(table_name): + srcsigs.put((source, signal)) + + meta_dfs = [] + meta_lock = threading.Lock() + + def worker(): + name = threading.current_thread().name + logger.info("starting thread", thread=name) + # set up new db connection for thread + worker_dbc = DatabaseMeta() + worker_dbc.connect(connector_impl=self._connector_impl, host=self._db_host, user=self._db_credential_user, password=self._db_credential_password, database=self._db_database) + try: + while True: + (source, signal) = srcsigs.get_nowait() # this will throw the Empty caught below + logger.info("starting pair", thread=name, pair=f"({source}, {signal})") + + df = worker_dbc.covidcast_meta_job(table_name, source, signal) + with meta_lock: + meta_dfs.append(df) + + srcsigs.task_done() + except Empty: + logger.info("no jobs left, thread terminating", thread=name) + finally: + worker_dbc.disconnect(False) # cleanup + + threads = [] + for n in range(n_threads): + t = threading.Thread(target=worker, name="MetacacheThread-" + str(n)) + t.start() + threads.append(t) + + srcsigs.join() + logger.info("jobs complete") + + for t in threads: + t.join() + logger.info("all threads terminated") + else: + # Here to illustrate the simple logic behind meta computations without the parallel boilerplate + base_meta_dfs = pd.concat([self.compute_base_signal_meta(source, signal, table_name) for source, signal in self.get_source_sig_list(data_signal_table=data_signal_table, derived=False)]) + derived_meta_dfs = pd.concat([self.compute_derived_signal_meta(source, signal, base_meta_dfs, data_signal_table=data_signal_table) for source, signal in self.get_source_sig_list(data_signal_table=data_signal_table, derived=True)]) + + # combine and sort the metadata results + meta_df = pd.concat([base_meta_dfs, derived_meta_dfs]).sort_values(by="data_source signal time_type geo_type".split()) + return meta_df.to_dict(orient="records") + + def update_covidcast_meta_cache(self, metadata): + """Updates the `covidcast_meta_cache` table.""" + + sql = """ + UPDATE + `covidcast_meta_cache` + SET + `timestamp` = UNIX_TIMESTAMP(NOW()), + `epidata` = %s + """ + epidata_json = json.dumps(metadata) + + self._cursor.execute(sql, (epidata_json,)) + + def retrieve_covidcast_meta_cache(self): + """Useful for viewing cache entries (was used in debugging)""" + + sql = """ + SELECT `epidata` + FROM `covidcast_meta_cache` + ORDER BY `timestamp` DESC + LIMIT 1; + """ + self._cursor.execute(sql) + cache_json = self._cursor.fetchone()[0] + cache = json.loads(cache_json) + cache_hash = {} + for entry in cache: + cache_hash[(entry["data_source"], entry["signal"], entry["time_type"], entry["geo_type"])] = entry + return cache_hash diff --git a/src/acquisition/covidcast/dbjobs_runner.py b/src/acquisition/covidcast/dbjobs_runner.py deleted file mode 100644 index a8f8e1c80..000000000 --- a/src/acquisition/covidcast/dbjobs_runner.py +++ /dev/null @@ -1,15 +0,0 @@ - -from delphi.epidata.acquisition.covidcast.database import Database - -# simple helper to easily run dbjobs from the command line, such as after an acquisition cycle is complete - -def main(): - database = Database() - database.connect() - try: - database.run_dbjobs() - finally: - database.disconnect(True) - -if __name__ == '__main__': - main() diff --git a/src/acquisition/covidcast/migrate_epidata_to_v4.py b/src/acquisition/covidcast/migrate_epidata_to_v4.py new file mode 100644 index 000000000..a4afafc11 --- /dev/null +++ b/src/acquisition/covidcast/migrate_epidata_to_v4.py @@ -0,0 +1,188 @@ +# run as: +# python3 -u -m delphi.epidata.acquisition.covidcast.migrate_epidata_to_v4 +# ("-u" allows unbuffered print statements so we can watch timing in closer-to-real-time) + + +#####import delphi.operations.secrets as secrets +#####secrets.db.host = '172.30.n.n' # aka 'epidata-db-qa-01' +#####secrets.db.epi = ('delphi', 'xxxxxxxx') +# ^ these are already set appropriately on qa-automation in/by the operations module ^ + + +# TODO: make cli flags for these two variables: +use_transaction_wrappers = False +use_autocommit = False + +# TODO: maybe output: was autocommit enabled? was table locking used? what isolation type was used? were indexes enabled? were uniqueness checks enabled? + +# TODO: consider dropping indexes before moving data and recreating them afterward + +''' + +mysql> select count(id) from epidata.covidcast; ++------------+ +| count(id) | ++------------+ +| 2647381579 | ++------------+ +1 row in set (13 min 49.32 sec) + +mysql> select max(id) from epidata.covidcast; ++------------+ +| max(id) | ++------------+ +| 3740757041 | ++------------+ +1 row in set (0.00 sec) + +-- so ~71% coverage of actual rows per allocated ids ( 2647381579 / 3740757041 = .70771278379851347314 ) + +mysql> select time_value, issue from epidata.covidcast where id=3740757041; ++------------+----------+ +| time_value | issue | ++------------+----------+ +| 20210927 | 20210930 | ++------------+----------+ +1 row in set (0.01 sec) + +mysql> select now(); ++---------------------+ +| now() | ++---------------------+ +| 2022-05-16 16:45:34 | ++---------------------+ +1 row in set (0.00 sec) + +''' + + +from delphi.epidata.acquisition.covidcast.database import Database +import time +import argparse + +def start_tx(cursor): + cursor.execute('SET SESSION TRANSACTION ISOLATION LEVEL READ UNCOMMITTED;') + cursor.execute('SET autocommit=0;') # starts a transaction as suggested in https://dev.mysql.com/doc/refman/8.0/en/lock-tables.html + # NOTE: locks must be specified for any aliases of table names that are used + cursor.execute('''LOCK TABLES epidata.covidcast AS cc READ, + epimetric_load WRITE, epimetric_load AS sl WRITE, + epimetric_full WRITE, + epimetric_latest WRITE, + signal_dim WRITE, signal_dim AS sd READ, + geo_dim WRITE, geo_dim AS gd READ;''') + cursor.execute('SET unique_checks=0;') + +def finish_tx(cursor): + cursor.execute('SET unique_checks=1;') + cursor.execute('COMMIT;') + cursor.execute('UNLOCK TABLES;') + + +def do_batches(db, start, upper_lim, batch_size): + # NOTE: upper_lim is not actually selected for ; make sure it exceeds any ids you want to include + batch_lower = start + + while batch_lower < upper_lim: + batch_upper = min(batch_lower + batch_size, upper_lim) + + # NOTE: first rows of column names are identical, second rows are for specifying a rename and a literal + batch_sql = f""" + INSERT INTO epimetric_load ( + `issue`, `source`, `signal`, geo_type, geo_value, time_type, time_value, `value`, stderr, sample_size, `lag`, value_updated_timestamp, is_latest_issue, missing_value, missing_stderr, missing_sample_size + ) SELECT + `issue`, `source`, `signal`, geo_type, geo_value, time_type, time_value, `value`, stderr, sample_size, `lag`, value_updated_timestamp, is_latest_issue, missing_value, missing_stderr, missing_sample_size + FROM epidata.covidcast AS cc + USE INDEX(`PRIMARY`) + WHERE {batch_lower} <= cc.id AND cc.id < {batch_upper}; """ + # TODO: use LIMIT instead of id range?? + # TODO: might it be worth adding "ORDER BY id ASC" ? + + if use_transaction_wrappers: + start_tx(db._cursor) + + print(f"-=-=-=-=-=-=-=- RUNNING BATCH STARTING AT {batch_lower} -=-=-=-=-=-=-=-") + print(f"-=-=-=-=-=-=-=- RUNNING ''INSERT INTO SELECT FROM''... ", end="") + t = time.time() + db._cursor.execute(batch_sql) + print(f"elapsed: {time.time()-t} sec, rows: {db._cursor.rowcount} -=-=-=-=-=-=-=-") + + t = time.time() + db.run_dbjobs() + print(f"-=-=-=-=-=-=-=- RAN db_jobs()... elapsed: {time.time()-t} sec -=-=-=-=-=-=-=-") + + print("-=-=-=-=-=-=-=- RUNNING commit()... ", end="") + t = time.time() + db.commit() + if use_transaction_wrappers: + finish_tx(db._cursor) + print(f"elapsed: {time.time()-t} sec -=-=-=-=-=-=-=-") + + print("\n\n") + # move pointer for next batch + batch_lower = batch_upper + + +def main(destination_schema, batch_size, start_id, upper_lim_override): + Database.DATABASE_NAME = destination_schema + db = Database() + db.connect() + if use_autocommit: + db._connection.autocommit = True + + if upper_lim_override: + upper_lim = upper_lim_override + else: + # find upper limit for data to be imported + db._cursor.execute("SELECT MAX(id) FROM epidata.covidcast;") + for (max_id,) in db._cursor: + upper_lim = 1 + max_id + + print(f"migrating data to schema '{destination_schema}', with batch size {batch_size} and {start_id} <= ids < {upper_lim}") + if start_id==0: + print("this WILL truncate any existing v4 tables") + print() + if input("type 'yes' to continue: ") != 'yes': + import sys + sys.exit('operation cancelled!') + + print(f"starting run at: {time.strftime('%c')}") + + if start_id==0: + # clear tables in the v4 schema + print("truncating tables...") + for table in "epimetric_load epimetric_latest epimetric_full geo_dim signal_dim".split(): + db._cursor.execute(f"TRUNCATE TABLE {table}") + db.commit() + start_id = 1 + + # run batch loop + do_batches(db, start_id, upper_lim, batch_size) + + # get table counts [the quick and dirty way] + print("-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-") + db._cursor.execute(f"SELECT MAX(epimetric_id) FROM epimetric_full;") + for (max_id,) in db._cursor: + print(f"epimetric_full: {max_id}") + db._cursor.execute(f"SELECT MAX(epimetric_id) FROM epimetric_latest;") + for (max_id,) in db._cursor: + print(f"epimetric_latest: {max_id} (this should be <= the number above)") + db._cursor.execute(f"SELECT COUNT(signal_key_id), MAX(signal_key_id) FROM signal_dim;") + for (count_id, max_id) in db._cursor: + print(f"signal_dim: count {count_id} / max {max_id}") + db._cursor.execute(f"SELECT COUNT(geo_key_id), MAX(geo_key_id) FROM geo_dim;") + for (count_id, max_id) in db._cursor: + print(f"geo_dim: count {count_id} / max {max_id}") + + return upper_lim + + +if __name__ == '__main__': + argparser = argparse.ArgumentParser() + argparser.add_argument('--destination_schema', type=str, default='covid') + argparser.add_argument('--batch_size', type=int, default=20_000_000) + argparser.add_argument('--start_id', type=int, default=0) + argparser.add_argument('--upper_lim_override', type=int) # should default to None + args = argparser.parse_args() + + upper_lim = main(args.destination_schema, args.batch_size, args.start_id, args.upper_lim_override) + print(f"the next execution of this program should include argument: --start_id={upper_lim}") diff --git a/src/acquisition/covidcast/test_utils.py b/src/acquisition/covidcast/test_utils.py new file mode 100644 index 000000000..33556eca2 --- /dev/null +++ b/src/acquisition/covidcast/test_utils.py @@ -0,0 +1,53 @@ +from typing import Sequence +import unittest + +from delphi_utils import Nans +from delphi.epidata.acquisition.covidcast.covidcast_row import CovidcastRow +from delphi.epidata.acquisition.covidcast.database import Database +import delphi.operations.secrets as secrets + +# all the Nans we use here are just one value, so this is a shortcut to it: +nmv = Nans.NOT_MISSING.value + +class CovidcastBase(unittest.TestCase): + def setUp(self): + # use the local test instance of the database + secrets.db.host = 'delphi_database_epidata' + secrets.db.epi = ('user', 'pass') + + self._db = Database() + self._db.connect() + + # empty all of the data tables + for table in "epimetric_load epimetric_latest epimetric_full geo_dim signal_dim".split(): + self._db._cursor.execute(f"TRUNCATE TABLE {table};") + self.localSetUp() + self._db._connection.commit() + + def localSetUp(self): + # stub; override in subclasses to perform custom setup. + # runs after tables have been truncated but before database changes have been committed + pass + + def tearDown(self): + # close and destroy conenction to the database + self._db.disconnect(False) + del self._db + + def _insert_rows(self, rows: Sequence[CovidcastRow]): + # inserts rows into the database using the full acquisition process, including 'dbjobs' load into history & latest tables + n = self._db.insert_or_update_bulk(rows) + print(f"{n} rows added to load table & dispatched to v4 schema") + self._db._connection.commit() # NOTE: this isnt expressly needed for our test cases, but would be if using external access (like through client lib) to ensure changes are visible outside of this db session + + def params_from_row(self, row: CovidcastRow, **kwargs): + ret = { + 'data_source': row.source, + 'signals': row.signal, + 'time_type': row.time_type, + 'geo_type': row.geo_type, + 'time_values': row.time_value, + 'geo_value': row.geo_value, + } + ret.update(kwargs) + return ret diff --git a/src/acquisition/covidcast_nowcast/load_sensors.py b/src/acquisition/covidcast_nowcast/load_sensors.py index ab9a6b33e..079b2f27c 100644 --- a/src/acquisition/covidcast_nowcast/load_sensors.py +++ b/src/acquisition/covidcast_nowcast/load_sensors.py @@ -92,7 +92,12 @@ def _move_after_processing(filepath, success): def _create_upsert_method(meta): def method(table, conn, keys, data_iter): - sql_table = sqlalchemy.Table(table.name, meta, autoload=True) + sql_table = sqlalchemy.Table( + table.name, + meta, + # specify lag column explicitly; lag is a reserved word sqlalchemy doesn't know about + sqlalchemy.Column("lag", sqlalchemy.Integer, quote=True), + autoload=True) insert_stmt = sqlalchemy.dialects.mysql.insert(sql_table).values([dict(zip(keys, data)) for data in data_iter]) upsert_stmt = insert_stmt.on_duplicate_key_update({x.name: x for x in insert_stmt.inserted}) conn.execute(upsert_stmt) diff --git a/src/ddl/api_analytics.sql b/src/ddl/api_analytics.sql index a0a06fe8b..7b8aa0279 100644 --- a/src/ddl/api_analytics.sql +++ b/src/ddl/api_analytics.sql @@ -1,3 +1,4 @@ +USE epidata; /* `api_analytics` logs API usage, which Delphi uses to improve the API. diff --git a/src/ddl/cdc.sql b/src/ddl/cdc.sql index 006c33754..06e445acf 100644 --- a/src/ddl/cdc.sql +++ b/src/ddl/cdc.sql @@ -1,3 +1,4 @@ +USE epidata; /* TODO: document */ diff --git a/src/ddl/covid_hosp.sql b/src/ddl/covid_hosp.sql index be3ee9e86..bd47e20f1 100644 --- a/src/ddl/covid_hosp.sql +++ b/src/ddl/covid_hosp.sql @@ -1,3 +1,4 @@ +USE epidata; /* These tables store the collection of datasets relating to COVID-19 patient impact and hospital capacity. Data is provided by the US Department of Health & diff --git a/src/ddl/covidcast_nowcast.sql b/src/ddl/covidcast_nowcast.sql index 0b9671a6e..8c3944f6e 100644 --- a/src/ddl/covidcast_nowcast.sql +++ b/src/ddl/covidcast_nowcast.sql @@ -1,3 +1,4 @@ +USE epidata; /* This table stores various sensors of Delphi's COVID-19 surveillance streams for nowcasting. */ diff --git a/src/ddl/ecdc_ili.sql b/src/ddl/ecdc_ili.sql index 201cba805..1d57876c5 100644 --- a/src/ddl/ecdc_ili.sql +++ b/src/ddl/ecdc_ili.sql @@ -1,3 +1,4 @@ +USE epidata; /* TODO: document */ diff --git a/src/ddl/fluview.sql b/src/ddl/fluview.sql index b3ece8b6c..9da1589ce 100644 --- a/src/ddl/fluview.sql +++ b/src/ddl/fluview.sql @@ -1,3 +1,4 @@ +USE epidata; /* These tables are generally a mirror of what CDC publishes through the interactive FluView web app at: diff --git a/src/ddl/forecasts.sql b/src/ddl/forecasts.sql index 806e46dbc..362f65cfa 100644 --- a/src/ddl/forecasts.sql +++ b/src/ddl/forecasts.sql @@ -1,3 +1,4 @@ +USE epidata; /* TODO: document */ diff --git a/src/ddl/gft.sql b/src/ddl/gft.sql index aa2f8e6e4..d6d6a64a0 100644 --- a/src/ddl/gft.sql +++ b/src/ddl/gft.sql @@ -1,3 +1,4 @@ +USE epidata; /* TODO: document */ diff --git a/src/ddl/ght.sql b/src/ddl/ght.sql index 928d0f63c..15db5667f 100644 --- a/src/ddl/ght.sql +++ b/src/ddl/ght.sql @@ -1,3 +1,4 @@ +USE epidata; /* TODO: document */ diff --git a/src/ddl/kcdc_ili.sql b/src/ddl/kcdc_ili.sql index 1d5a71dbe..64f3f576c 100644 --- a/src/ddl/kcdc_ili.sql +++ b/src/ddl/kcdc_ili.sql @@ -1,3 +1,4 @@ +USE epidata; /* TODO: document */ diff --git a/src/ddl/migrations/v4_renaming.sql b/src/ddl/migrations/v4_renaming.sql new file mode 100644 index 000000000..c0edf8e96 --- /dev/null +++ b/src/ddl/migrations/v4_renaming.sql @@ -0,0 +1,93 @@ +-- drop VIEWs in `epidata` that act as aliases to (ie, they reference) VIEWs in `covid` +USE epidata; +DROP VIEW + signal_history_v, + signal_latest_v; + +-- return to v4 schema namespace +USE covid; + +-- drop VIEWs that reference main TABLEs +DROP VIEW + signal_history_v, + signal_latest_v; + +-- rename main TABLEs +RENAME TABLE + signal_history TO epimetric_full, + signal_latest TO epimetric_latest, + signal_load TO epimetric_load; + +-- rename id COLUMNs in main TABLEs +ALTER TABLE epimetric_full RENAME COLUMN signal_data_id TO epimetric_id; +ALTER TABLE epimetric_latest RENAME COLUMN signal_data_id TO epimetric_id; +ALTER TABLE epimetric_load RENAME COLUMN signal_data_id TO epimetric_id; + +-- -- -- TODO: rename `value_key_*` INDEXes in `epimetric_*` TABLEs to `???_idx_*`? + +-- re-create VIEWs that reference newly renamed TABLEs (this is a straight copy of the VIEW definitions from ../v4_schema.sql +CREATE OR REPLACE VIEW epimetric_full_v AS + SELECT + 0 AS `is_latest_issue`, -- provides column-compatibility to match `covidcast` table + -- ^ this value is essentially undefined in this view, the notion of a 'latest' issue is not encoded here and must be drawn from the 'latest' table or view or otherwise computed... + NULL AS `direction`, -- provides column-compatibility to match `covidcast` table + `t2`.`source` AS `source`, + `t2`.`signal` AS `signal`, + `t3`.`geo_type` AS `geo_type`, + `t3`.`geo_value` AS `geo_value`, + `t1`.`epimetric_id` AS `epimetric_id`, + `t1`.`strat_key_id` AS `strat_key_id`, -- TODO: for future use + `t1`.`issue` AS `issue`, + `t1`.`data_as_of_dt` AS `data_as_of_dt`, -- TODO: for future use ; also "as_of" is problematic and should be renamed + `t1`.`time_type` AS `time_type`, + `t1`.`time_value` AS `time_value`, + `t1`.`reference_dt` AS `reference_dt`, -- TODO: for future use + `t1`.`value` AS `value`, + `t1`.`stderr` AS `stderr`, + `t1`.`sample_size` AS `sample_size`, + `t1`.`lag` AS `lag`, + `t1`.`value_updated_timestamp` AS `value_updated_timestamp`, + `t1`.`computation_as_of_dt` AS `computation_as_of_dt`, -- TODO: for future use ; also "as_of" is problematic and should be renamed + `t1`.`missing_value` AS `missing_value`, + `t1`.`missing_stderr` AS `missing_stderr`, + `t1`.`missing_sample_size` AS `missing_sample_size`, + `t1`.`signal_key_id` AS `signal_key_id`, + `t1`.`geo_key_id` AS `geo_key_id` + FROM `epimetric_full` `t1` + JOIN `signal_dim` `t2` USING (`signal_key_id`) + JOIN `geo_dim` `t3` USING (`geo_key_id`); +CREATE OR REPLACE VIEW epimetric_latest_v AS + SELECT + 1 AS `is_latest_issue`, -- provides column-compatibility to match `covidcast` table + NULL AS `direction`, -- provides column-compatibility to match `covidcast` table + `t2`.`source` AS `source`, + `t2`.`signal` AS `signal`, + `t3`.`geo_type` AS `geo_type`, + `t3`.`geo_value` AS `geo_value`, + `t1`.`epimetric_id` AS `epimetric_id`, + `t1`.`strat_key_id` AS `strat_key_id`, -- TODO: for future use + `t1`.`issue` AS `issue`, + `t1`.`data_as_of_dt` AS `data_as_of_dt`, -- TODO: for future use ; also "as_of" is problematic and should be renamed + `t1`.`time_type` AS `time_type`, + `t1`.`time_value` AS `time_value`, + `t1`.`reference_dt` AS `reference_dt`, -- TODO: for future use + `t1`.`value` AS `value`, + `t1`.`stderr` AS `stderr`, + `t1`.`sample_size` AS `sample_size`, + `t1`.`lag` AS `lag`, + `t1`.`value_updated_timestamp` AS `value_updated_timestamp`, + `t1`.`computation_as_of_dt` AS `computation_as_of_dt`, -- TODO: for future use ; also "as_of" is problematic and should be renamed + `t1`.`missing_value` AS `missing_value`, + `t1`.`missing_stderr` AS `missing_stderr`, + `t1`.`missing_sample_size` AS `missing_sample_size`, + `t1`.`signal_key_id` AS `signal_key_id`, + `t1`.`geo_key_id` AS `geo_key_id` + FROM `epimetric_latest` `t1` + JOIN `signal_dim` `t2` USING (`signal_key_id`) + JOIN `geo_dim` `t3` USING (`geo_key_id`); + + +-- re-create `epidata` alias VIEWs +USE epidata; +CREATE VIEW epidata.epimetric_full_v AS SELECT * FROM covid.epimetric_full_v; +CREATE VIEW epidata.epimetric_latest_v AS SELECT * FROM covid.epimetric_latest_v; diff --git a/src/ddl/nidss.sql b/src/ddl/nidss.sql index 51f5b60f6..936de64a2 100644 --- a/src/ddl/nidss.sql +++ b/src/ddl/nidss.sql @@ -1,3 +1,4 @@ +USE epidata; /* TODO: document */ diff --git a/src/ddl/nowcasts.sql b/src/ddl/nowcasts.sql index 9ed07cb6d..ecffb9deb 100644 --- a/src/ddl/nowcasts.sql +++ b/src/ddl/nowcasts.sql @@ -1,3 +1,4 @@ +USE epidata; /* TODO: document */ diff --git a/src/ddl/paho_dengue.sql b/src/ddl/paho_dengue.sql index 4d51ed58f..e2fd98c72 100644 --- a/src/ddl/paho_dengue.sql +++ b/src/ddl/paho_dengue.sql @@ -1,3 +1,4 @@ +USE epidata; /* TODO: document */ diff --git a/src/ddl/quidel.sql b/src/ddl/quidel.sql index eb454d281..5e3d35820 100644 --- a/src/ddl/quidel.sql +++ b/src/ddl/quidel.sql @@ -1,3 +1,4 @@ +USE epidata; /* TODO: document */ diff --git a/src/ddl/sensors.sql b/src/ddl/sensors.sql index 19740e8f5..3854865fb 100644 --- a/src/ddl/sensors.sql +++ b/src/ddl/sensors.sql @@ -1,3 +1,4 @@ +USE epidata; /* TODO: document */ diff --git a/src/ddl/signal_dashboard.sql b/src/ddl/signal_dashboard.sql index f5655dc57..cad7f0a5f 100644 --- a/src/ddl/signal_dashboard.sql +++ b/src/ddl/signal_dashboard.sql @@ -1,3 +1,4 @@ +USE epidata; /* This table stores the signals used in the public signal dashboard. diff --git a/src/ddl/twitter.sql b/src/ddl/twitter.sql index 9790e6529..c7e610121 100644 --- a/src/ddl/twitter.sql +++ b/src/ddl/twitter.sql @@ -1,3 +1,4 @@ +USE epidata; /* TODO: document */ diff --git a/src/ddl/v4_schema.sql b/src/ddl/v4_schema.sql index 0d266a4ca..7551707f6 100644 --- a/src/ddl/v4_schema.sql +++ b/src/ddl/v4_schema.sql @@ -1,106 +1,77 @@ --- -------------------------------- --- TODO: REMOVE THESE HACKS!!! (find a better way to do this --- --- the database schema `epidata` is created by ENV variables specified in the docker image definition found at: --- ../../dev/docker/database/epidata/Dockerfile --- and the user 'user' is created with permissions on that database. --- here we create the `covid` schema and extend permissions to the same user, --- as the ENV options do not appear to be expressive enough to do this as well. --- this is incredibly permissive and easily guessable, but is reqd for testing our environment. --- -CREATE DATABASE covid; USE covid; -GRANT ALL ON covid.* TO 'user'; --- END TODO --- -------------------------------- CREATE TABLE geo_dim ( - `geo_key_id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, - `geo_type` VARCHAR(12), - `geo_value` VARCHAR(12), - `compressed_geo_key` VARCHAR(100), - - PRIMARY KEY (`geo_key_id`) USING BTREE, - UNIQUE INDEX `compressed_geo_key_ind` (`compressed_geo_key`) USING BTREE -); - + `geo_key_id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, + `geo_type` VARCHAR(12) NOT NULL, + `geo_value` VARCHAR(12) NOT NULL, -CREATE TABLE signal_dim ( - `signal_key_id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, - `source` VARCHAR(32), - `signal` VARCHAR(64), - `compressed_signal_key` VARCHAR(100), - - PRIMARY KEY (`signal_key_id`) USING BTREE, - UNIQUE INDEX `compressed_signal_key_ind` (`compressed_signal_key`) USING BTREE + UNIQUE INDEX `geo_dim_index` (`geo_type`, `geo_value`) ) ENGINE=InnoDB; +CREATE TABLE signal_dim ( + `signal_key_id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, + `source` VARCHAR(32) NOT NULL, + `signal` VARCHAR(64) NOT NULL, -CREATE TABLE signal_history ( - `signal_data_id` BIGINT(20) UNSIGNED NOT NULL, - `signal_key_id` BIGINT(20) UNSIGNED, - `geo_key_id` BIGINT(20) UNSIGNED, - `demog_key_id` BIGINT(20) UNSIGNED, -- TODO: for future use ; also rename s/demog/stratification/ - `issue` INT(11), - `data_as_of_dt` DATETIME(0), -- TODO: for future use ; also "as_of" is problematic and should be renamed - `time_type` VARCHAR(12) NOT NULL, - `time_value` INT(11) NOT NULL, - `reference_dt` DATETIME(0), -- TODO: for future use - `value` DOUBLE NULL DEFAULT NULL, - `stderr` DOUBLE NULL DEFAULT NULL, - `sample_size` DOUBLE NULL DEFAULT NULL, - `lag` INT(11) NOT NULL, - `value_updated_timestamp` INT(11) NOT NULL, - `computation_as_of_dt` DATETIME(0), -- TODO: for future use ; also "as_of" is problematic and should be renamed - `is_latest_issue` BINARY(1) NOT NULL DEFAULT '0', -- TODO: delete this, its hard to keep updated and its not currently used - `missing_value` INT(1) NULL DEFAULT '0', - `missing_stderr` INT(1) NULL DEFAULT '0', - `missing_sample_size` INT(1) NULL DEFAULT '0', - `legacy_id` BIGINT(20) UNSIGNED NULL DEFAULT NULL, -- not used beyond import of previous data into the v4 schema - - PRIMARY KEY (`signal_data_id`) USING BTREE, - UNIQUE INDEX `value_key` (`signal_key_id`,`geo_key_id`,`issue`,`time_type`,`time_value`) USING BTREE + UNIQUE INDEX `signal_dim_index` (`source`, `signal`) ) ENGINE=InnoDB; - -CREATE TABLE signal_latest ( - `signal_data_id` BIGINT(20) UNSIGNED NOT NULL, - `signal_key_id` BIGINT(20) UNSIGNED, - `geo_key_id` BIGINT(20) UNSIGNED, - `demog_key_id` BIGINT(20) UNSIGNED, -- TODO: for future use ; also rename s/demog/stratification/ - `issue` INT(11), +CREATE TABLE strat_dim ( + `strat_key_id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, + `stratification_name` VARCHAR(64) NOT NULL UNIQUE, + `stratification_descr` VARCHAR(64) NOT NULL +) ENGINE=InnoDB; +INSERT INTO strat_dim VALUES (1, 'NO_STRATIFICATION', ''); + +CREATE TABLE epimetric_full ( + `epimetric_id` BIGINT(20) UNSIGNED NOT NULL PRIMARY KEY, + `signal_key_id` BIGINT(20) UNSIGNED NOT NULL, + `geo_key_id` BIGINT(20) UNSIGNED NOT NULL, + `strat_key_id` BIGINT(20) UNSIGNED NOT NULL DEFAULT 1, -- TODO: for future use + `issue` INT(11) NOT NULL, `data_as_of_dt` DATETIME(0), -- TODO: for future use ; also "as_of" is problematic and should be renamed `time_type` VARCHAR(12) NOT NULL, `time_value` INT(11) NOT NULL, `reference_dt` DATETIME(0), -- TODO: for future use - `value` DOUBLE NULL DEFAULT NULL, - `stderr` DOUBLE NULL DEFAULT NULL, - `sample_size` DOUBLE NULL DEFAULT NULL, + `value` DOUBLE, + `stderr` DOUBLE, + `sample_size` DOUBLE, `lag` INT(11) NOT NULL, `value_updated_timestamp` INT(11) NOT NULL, `computation_as_of_dt` DATETIME(0), -- TODO: for future use ; also "as_of" is problematic and should be renamed - `missing_value` INT(1) NULL DEFAULT '0', - `missing_stderr` INT(1) NULL DEFAULT '0', - `missing_sample_size` INT(1) NULL DEFAULT '0', - - PRIMARY KEY (`signal_data_id`) USING BTREE, - UNIQUE INDEX `value_key` (`signal_key_id`,`geo_key_id`,`time_type`,`time_value`) USING BTREE + `missing_value` INT(1) DEFAULT '0', + `missing_stderr` INT(1) DEFAULT '0', + `missing_sample_size` INT(1) DEFAULT '0', + + UNIQUE INDEX `value_key_tig` (`signal_key_id`, `time_type`, `time_value`, `issue`, `geo_key_id`), + UNIQUE INDEX `value_key_tgi` (`signal_key_id`, `time_type`, `time_value`, `geo_key_id`, `issue`), + UNIQUE INDEX `value_key_itg` (`signal_key_id`, `issue`, `time_type`, `time_value`, `geo_key_id`), + UNIQUE INDEX `value_key_igt` (`signal_key_id`, `issue`, `geo_key_id`, `time_type`, `time_value`), + UNIQUE INDEX `value_key_git` (`signal_key_id`, `geo_key_id`, `issue`, `time_type`, `time_value`), + UNIQUE INDEX `value_key_gti` (`signal_key_id`, `geo_key_id`, `time_type`, `time_value`, `issue`) ) ENGINE=InnoDB; +CREATE TABLE epimetric_latest ( + PRIMARY KEY (`epimetric_id`), + UNIQUE INDEX `value_key_tg` (`signal_key_id`, `time_type`, `time_value`, `geo_key_id`), + UNIQUE INDEX `value_key_gt` (`signal_key_id`, `geo_key_id`, `time_type`, `time_value`) +) ENGINE=InnoDB +SELECT * FROM epimetric_full; + -- NOTE: In production or any non-testing system that should maintain consistency, -- **DO NOT** 'TRUNCATE' this table. --- Doing so will function as a DROP/CREATE and reset the AUTO_INCREMENT counter for the `signal_data_id` field. --- This field is used to populate the non-AUTO_INCREMENT fields of the same name in `signal_latest` and `signal_history`, +-- Doing so will function as a DROP/CREATE and reset the AUTO_INCREMENT counter for the `epimetric_id` field. +-- This field is used to populate the non-AUTO_INCREMENT fields of the same name in `epimetric_latest` and `epimetric_full`, -- and resetting it will ultimately cause PK collisions. --- To restore the counter, a row must be written with a `signal_data_id` value greater than the maximum +-- To restore the counter, a row must be written with a `epimetric_id` value greater than the maximum -- of its values in the other tables. -CREATE TABLE signal_load ( - `signal_data_id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, +CREATE TABLE epimetric_load ( + `epimetric_id` BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, `signal_key_id` BIGINT(20) UNSIGNED, `geo_key_id` BIGINT(20) UNSIGNED, - `demog_key_id` BIGINT(20) UNSIGNED, -- TODO: for future use ; also rename s/demog/stratification/ - `issue` INT(11), + `strat_key_id` BIGINT(20) UNSIGNED NOT NULL DEFAULT 1, -- TODO: for future use + `issue` INT(11) NOT NULL, `data_as_of_dt` DATETIME(0), -- TODO: for future use ; also "as_of" is problematic and should be renamed `source` VARCHAR(32) NOT NULL, `signal` VARCHAR(64) NOT NULL, @@ -109,98 +80,81 @@ CREATE TABLE signal_load ( `time_type` VARCHAR(12) NOT NULL, `time_value` INT(11) NOT NULL, `reference_dt` DATETIME(0), -- TODO: for future use - `value` DOUBLE NULL DEFAULT NULL, - `stderr` DOUBLE NULL DEFAULT NULL, - `sample_size` DOUBLE NULL DEFAULT NULL, + `value` DOUBLE, + `stderr` DOUBLE, + `sample_size` DOUBLE, `lag` INT(11) NOT NULL, `value_updated_timestamp` INT(11) NOT NULL, `computation_as_of_dt` DATETIME(0), -- TODO: for future use ; also "as_of" is problematic and should be renamed `is_latest_issue` BINARY(1) NOT NULL DEFAULT '0', - `missing_value` INT(1) NULL DEFAULT '0', - `missing_stderr` INT(1) NULL DEFAULT '0', - `missing_sample_size` INT(1) NULL DEFAULT '0', - `legacy_id` BIGINT(20) UNSIGNED, -- not used beyond import of previous data into the v4 schema - `compressed_signal_key` VARCHAR(100), - `compressed_geo_key` VARCHAR(100), - `compressed_demog_key` VARCHAR(100), -- TODO: for future use ; also rename s/demog/stratification/ - `process_status` VARCHAR(2) DEFAULT 'l', -- using codes: 'i' (I) for "inserting", 'l' (L) for "loaded", and 'b' for "batching" - -- TODO: change `process_status` default to 'i' (I) "inserting" or even 'x'/'u' "undefined" ? - - PRIMARY KEY (`signal_data_id`) USING BTREE, - INDEX `comp_signal_key` (`compressed_signal_key`) USING BTREE, - INDEX `comp_geo_key` (`compressed_geo_key`) USING BTREE -) ENGINE=InnoDB AUTO_INCREMENT=4000000001; - - -CREATE OR REPLACE VIEW signal_history_v AS + `missing_value` INT(1) DEFAULT '0', + `missing_stderr` INT(1) DEFAULT '0', + `missing_sample_size` INT(1) DEFAULT '0', + + UNIQUE INDEX (`source`, `signal`, `time_type`, `geo_type`, `time_value`, `geo_value`, `issue`) +) ENGINE=InnoDB; + + +CREATE OR REPLACE VIEW epimetric_full_v AS SELECT - 0 AS is_latest_issue, -- provides column-compatibility to match `covidcast` table + 0 AS `is_latest_issue`, -- provides column-compatibility to match `covidcast` table -- ^ this value is essentially undefined in this view, the notion of a 'latest' issue is not encoded here and must be drawn from the 'latest' table or view or otherwise computed... - NULL AS direction, -- provides column-compatibility to match `covidcast` table + NULL AS `direction`, -- provides column-compatibility to match `covidcast` table `t2`.`source` AS `source`, `t2`.`signal` AS `signal`, `t3`.`geo_type` AS `geo_type`, `t3`.`geo_value` AS `geo_value`, - `t1`.`signal_data_id` AS `signal_data_id`, -- TODO: unnecessary ...remove? - `t1`.`demog_key_id` AS `demog_key_id`, -- TODO: for future use ; also rename s/demog/stratification/ ...remove? + `t1`.`epimetric_id` AS `epimetric_id`, + `t1`.`strat_key_id` AS `strat_key_id`, -- TODO: for future use `t1`.`issue` AS `issue`, - `t1`.`data_as_of_dt` AS `data_as_of_dt`, -- TODO: for future use ; also "as_of" is problematic and should be renamed ...remove? + `t1`.`data_as_of_dt` AS `data_as_of_dt`, -- TODO: for future use ; also "as_of" is problematic and should be renamed `t1`.`time_type` AS `time_type`, `t1`.`time_value` AS `time_value`, - `t1`.`reference_dt` AS `reference_dt`, -- TODO: for future use ...remove? + `t1`.`reference_dt` AS `reference_dt`, -- TODO: for future use `t1`.`value` AS `value`, `t1`.`stderr` AS `stderr`, `t1`.`sample_size` AS `sample_size`, `t1`.`lag` AS `lag`, `t1`.`value_updated_timestamp` AS `value_updated_timestamp`, - `t1`.`computation_as_of_dt` AS `computation_as_of_dt`, -- TODO: for future use ; also "as_of" is problematic and should be renamed ...remove? + `t1`.`computation_as_of_dt` AS `computation_as_of_dt`, -- TODO: for future use ; also "as_of" is problematic and should be renamed `t1`.`missing_value` AS `missing_value`, `t1`.`missing_stderr` AS `missing_stderr`, `t1`.`missing_sample_size` AS `missing_sample_size`, - `t1`.`signal_key_id` AS `signal_key_id`, -- TODO: unnecessary ...remove? - `t1`.`geo_key_id` AS `geo_key_id` -- TODO: unnecessary ...remove? - FROM ((`signal_history` `t1` - JOIN `signal_dim` `t2` - USE INDEX (PRIMARY) - ON `t1`.`signal_key_id` = `t2`.`signal_key_id`) - JOIN `geo_dim` `t3` - USE INDEX (PRIMARY) - ON `t1`.`geo_key_id` = `t3`.`geo_key_id`); - - -CREATE OR REPLACE VIEW signal_latest_v AS + `t1`.`signal_key_id` AS `signal_key_id`, + `t1`.`geo_key_id` AS `geo_key_id` + FROM `epimetric_full` `t1` + JOIN `signal_dim` `t2` USING (`signal_key_id`) + JOIN `geo_dim` `t3` USING (`geo_key_id`); + +CREATE OR REPLACE VIEW epimetric_latest_v AS SELECT - 1 AS is_latest_issue, -- provides column-compatibility to match `covidcast` table - NULL AS direction, -- provides column-compatibility to match `covidcast` table + 1 AS `is_latest_issue`, -- provides column-compatibility to match `covidcast` table + NULL AS `direction`, -- provides column-compatibility to match `covidcast` table `t2`.`source` AS `source`, `t2`.`signal` AS `signal`, `t3`.`geo_type` AS `geo_type`, `t3`.`geo_value` AS `geo_value`, - `t1`.`signal_data_id` AS `signal_data_id`, -- TODO: unnecessary ...remove? - `t1`.`demog_key_id` AS `demog_key_id`, -- TODO: for future use ; also rename s/demog/stratification/ ...remove? + `t1`.`epimetric_id` AS `epimetric_id`, + `t1`.`strat_key_id` AS `strat_key_id`, -- TODO: for future use `t1`.`issue` AS `issue`, - `t1`.`data_as_of_dt` AS `data_as_of_dt`, -- TODO: for future use ; also "as_of" is problematic and should be renamed ...remove? + `t1`.`data_as_of_dt` AS `data_as_of_dt`, -- TODO: for future use ; also "as_of" is problematic and should be renamed `t1`.`time_type` AS `time_type`, `t1`.`time_value` AS `time_value`, - `t1`.`reference_dt` AS `reference_dt`, -- TODO: for future use ...remove? + `t1`.`reference_dt` AS `reference_dt`, -- TODO: for future use `t1`.`value` AS `value`, `t1`.`stderr` AS `stderr`, `t1`.`sample_size` AS `sample_size`, `t1`.`lag` AS `lag`, `t1`.`value_updated_timestamp` AS `value_updated_timestamp`, - `t1`.`computation_as_of_dt` AS `computation_as_of_dt`, -- TODO: for future use ; also "as_of" is problematic and should be renamed ...remove? + `t1`.`computation_as_of_dt` AS `computation_as_of_dt`, -- TODO: for future use ; also "as_of" is problematic and should be renamed `t1`.`missing_value` AS `missing_value`, `t1`.`missing_stderr` AS `missing_stderr`, `t1`.`missing_sample_size` AS `missing_sample_size`, - `t1`.`signal_key_id` AS `signal_key_id`, -- TODO: unnecessary ...remove? - `t1`.`geo_key_id` AS `geo_key_id` -- TODO: unnecessary ...remove? - FROM ((`signal_latest` `t1` - JOIN `signal_dim` `t2` - USE INDEX (PRIMARY) - ON `t1`.`signal_key_id` = `t2`.`signal_key_id`) - JOIN `geo_dim` `t3` - USE INDEX (PRIMARY) - ON `t1`.`geo_key_id` = `t3`.`geo_key_id`); + `t1`.`signal_key_id` AS `signal_key_id`, + `t1`.`geo_key_id` AS `geo_key_id` + FROM `epimetric_latest` `t1` + JOIN `signal_dim` `t2` USING (`signal_key_id`) + JOIN `geo_dim` `t3` USING (`geo_key_id`); CREATE TABLE `covidcast_meta_cache` ( @@ -208,5 +162,5 @@ CREATE TABLE `covidcast_meta_cache` ( `epidata` LONGTEXT NOT NULL, PRIMARY KEY (`timestamp`) -) ENGINE=InnoDB DEFAULT CHARSET=utf8; +) ENGINE=InnoDB; INSERT INTO covidcast_meta_cache VALUES (0, '[]'); diff --git a/src/ddl/v4_schema_aliases.sql b/src/ddl/v4_schema_aliases.sql index 838facc53..f5c6340e9 100644 --- a/src/ddl/v4_schema_aliases.sql +++ b/src/ddl/v4_schema_aliases.sql @@ -5,6 +5,6 @@ -- frontend api code still uses `epidata` but has these relevant tables/views "aliased" to use covid.blah when referred to as epidata.blah in context. -- ---------------------------------- -CREATE VIEW `epidata`.`signal_history_v` AS SELECT * FROM `covid`.`signal_history_v`; -CREATE VIEW `epidata`.`signal_latest_v` AS SELECT * FROM `covid`.`signal_latest_v`; +CREATE VIEW `epidata`.`epimetric_full_v` AS SELECT * FROM `covid`.`epimetric_full_v`; +CREATE VIEW `epidata`.`epimetric_latest_v` AS SELECT * FROM `covid`.`epimetric_latest_v`; CREATE VIEW `epidata`.`covidcast_meta_cache` AS SELECT * FROM `covid`.`covidcast_meta_cache`; diff --git a/src/ddl/wiki.sql b/src/ddl/wiki.sql index f743fb0e0..0bb2a79c1 100644 --- a/src/ddl/wiki.sql +++ b/src/ddl/wiki.sql @@ -1,3 +1,4 @@ +USE epidata; /* These tables coordinate scraping of Wikipedia page visit stats and store page visit counts for pages of interest (i.e. those which are epidemiologically diff --git a/src/server/_params.py b/src/server/_params.py index 33ca168df..8b780a89c 100644 --- a/src/server/_params.py +++ b/src/server/_params.py @@ -93,6 +93,13 @@ def count(self) -> float: return inf if self.signal else 0 return len(self.signal) + def add_signal(self, signal: str) -> None: + if not isinstance(self.signal, bool): + self.signal.append(signal) + + def __hash__(self) -> int: + return hash((self.source, self.signal if self.signal is isinstance(self.signal, bool) else tuple(self.signal))) + def _combine_source_signal_pairs(source_signal_pairs: List[SourceSignalPair]) -> List[SourceSignalPair]: """Combine SourceSignalPairs with the same source into a single SourceSignalPair object. @@ -112,6 +119,7 @@ def _combine_source_signal_pairs(source_signal_pairs: List[SourceSignalPair]) -> source_signal_pairs_combined.append(SourceSignalPair(source, combined_signals)) return source_signal_pairs_combined + def parse_source_signal_arg(key: str = "signal") -> List[SourceSignalPair]: return _combine_source_signal_pairs([SourceSignalPair(source, signals) for [source, signals] in _parse_common_multi_arg(key)]) diff --git a/src/server/_query.py b/src/server/_query.py index 66c9e5d43..bd1b37270 100644 --- a/src/server/_query.py +++ b/src/server/_query.py @@ -350,7 +350,7 @@ def retable(self, new_table: str): updates this QueryBuilder to point to another table. useful for switching to a different view of the data... """ - # TODO: consider creating a copy of the QueryBuilder, modifying that, and returning the new one. + # WARNING: if we ever switch to re-using QueryBuilder, we should change this to return a copy. self.table: str = f"{new_table} {self.alias}" return self diff --git a/src/server/endpoints/covid_hosp_state_timeseries.py b/src/server/endpoints/covid_hosp_state_timeseries.py index c8684ddba..5da4d4e16 100644 --- a/src/server/endpoints/covid_hosp_state_timeseries.py +++ b/src/server/endpoints/covid_hosp_state_timeseries.py @@ -154,16 +154,16 @@ def handle(): if issues is not None: q.where_integers("issue", issues) # final query using specific issues - query = f"WITH c as (SELECT {q.fields_clause}, ROW_NUMBER() OVER (PARTITION BY date, state, issue ORDER BY record_type) row FROM {q.table} WHERE {q.conditions_clause}) SELECT {q.fields_clause} FROM {q.alias} WHERE row = 1 ORDER BY {q.order_clause}" + query = f"WITH c as (SELECT {q.fields_clause}, ROW_NUMBER() OVER (PARTITION BY date, state, issue ORDER BY record_type) `row` FROM {q.table} WHERE {q.conditions_clause}) SELECT {q.fields_clause} FROM {q.alias} WHERE `row` = 1 ORDER BY {q.order_clause}" elif as_of is not None: sub_condition_asof = "(issue <= :as_of)" q.params["as_of"] = as_of - query = f"WITH c as (SELECT {q.fields_clause}, ROW_NUMBER() OVER (PARTITION BY date, state ORDER BY issue DESC, record_type) row FROM {q.table} WHERE {q.conditions_clause} AND {sub_condition_asof}) SELECT {q.fields_clause} FROM {q.alias} WHERE row = 1 ORDER BY {q.order_clause}" + query = f"WITH c as (SELECT {q.fields_clause}, ROW_NUMBER() OVER (PARTITION BY date, state ORDER BY issue DESC, record_type) `row` FROM {q.table} WHERE {q.conditions_clause} AND {sub_condition_asof}) SELECT {q.fields_clause} FROM {q.alias} WHERE `row` = 1 ORDER BY {q.order_clause}" else: # final query using most recent issues subquery = f"(SELECT max(`issue`) `max_issue`, `date`, `state` FROM {q.table} WHERE {q.conditions_clause} GROUP BY `date`, `state`) x" condition = f"x.`max_issue` = {q.alias}.`issue` AND x.`date` = {q.alias}.`date` AND x.`state` = {q.alias}.`state`" - query = f"WITH c as (SELECT {q.fields_clause}, ROW_NUMBER() OVER (PARTITION BY date, state, issue ORDER BY record_type) row FROM {q.table} JOIN {subquery} ON {condition}) select {q.fields_clause} FROM {q.alias} WHERE row = 1 ORDER BY {q.order_clause}" + query = f"WITH c as (SELECT {q.fields_clause}, ROW_NUMBER() OVER (PARTITION BY date, state, issue ORDER BY record_type) `row` FROM {q.table} JOIN {subquery} ON {condition}) select {q.fields_clause} FROM {q.alias} WHERE `row` = 1 ORDER BY {q.order_clause}" # send query return execute_query(query, q.params, fields_string, fields_int, fields_float) diff --git a/src/server/endpoints/covidcast.py b/src/server/endpoints/covidcast.py index afd575274..eb2dff165 100644 --- a/src/server/endpoints/covidcast.py +++ b/src/server/endpoints/covidcast.py @@ -48,8 +48,8 @@ alias = None JIT_COMPUTE = True -latest_table = "signal_latest_v" -history_table = "signal_history_v" +latest_table = "epimetric_latest_v" +history_table = "epimetric_full_v" def parse_source_signal_pairs() -> List[SourceSignalPair]: ds = request.values.get("data_source") @@ -96,9 +96,7 @@ def parse_time_pairs() -> List[TimePair]: raise ValidationFailedException("missing parameter: time or (time_type and time_values)") time_pairs = parse_time_arg() - # TODO: Put a bound on the number of time_values? - # if sum(len(time_pair.time_values) for time_pair in time_pairs if not isinstance(time_pair.time_values, bool)) > 30: - # raise ValidationFailedException("parameter value exceed: too many time pairs requested, consider using a timerange instead YYYYMMDD-YYYYMMDD") + # TODO: Put a bound on the number of time_values? (see above) return time_pairs @@ -124,27 +122,6 @@ def _handle_lag_issues_as_of(q: QueryBuilder, issues: Optional[List[Union[Tuple[ pass -def guess_index_to_use(time: List[TimePair], geo: List[GeoPair], issues: Optional[List[Union[Tuple[int, int], int]]] = None, lag: Optional[int] = None, as_of: Optional[int] = None) -> Optional[str]: - #TODO: remove this method? - return None - - time_values_to_retrieve = sum((t.count() for t in time)) - geo_values_to_retrieve = sum((g.count() for g in geo)) - - if geo_values_to_retrieve > 5 or time_values_to_retrieve < 30: - # no optimization known - return None - - if issues: - return "by_issue" - elif lag is not None: - return "by_lag" - elif as_of is None: - # latest - return "by_issue" - return None - - def parse_transform_args(): # The length of the window to smooth over. smoother_window_length = request.values.get("smoother_window_length", 7) @@ -200,6 +177,7 @@ def alias_row(row): fields_int = ["time_value", "direction", "issue", "lag", "missing_value", "missing_stderr", "missing_sample_size"] fields_float = ["value", "stderr", "sample_size"] + # TODO: JIT computations don't support time_value = *; there may be a clever way to implement this. use_server_side_compute = not any((issues, lag, is_time_type_week, is_time_value_true)) and JIT_COMPUTE and not jit_bypass if use_server_side_compute: transform_args = parse_transform_args() @@ -228,8 +206,6 @@ def gen_transform(rows): q.where_geo_pairs("geo_type", "geo_value", geo_pairs) q.where_time_pairs("time_type", "time_value", time_pairs) - q.index = guess_index_to_use(time_pairs, geo_pairs, issues, lag, as_of) - _handle_lag_issues_as_of(q, issues, lag, as_of) p = create_printer() @@ -423,7 +399,7 @@ def handle_correlation(): if lag is None: lag = 28 - # build query -- TODO: should this be using most recent issue but also specifying a lag? + # `lag` above is used in post-processing, not in the database query, so we can use latest here q = QueryBuilder(latest_table, "t") fields_string = ["geo_type", "geo_value", "source", "signal"] diff --git a/src/server/endpoints/covidcast_nowcast.py b/src/server/endpoints/covidcast_nowcast.py index 9b2e79848..9a773f572 100644 --- a/src/server/endpoints/covidcast_nowcast.py +++ b/src/server/endpoints/covidcast_nowcast.py @@ -94,7 +94,7 @@ def handle(): query = f"SELECT {fields} FROM {table} {subquery} WHERE {conditions} AND ({condition_version}) ORDER BY {order}" else: # fetch most recent issue fast - query = f"WITH t as (SELECT {fields}, ROW_NUMBER() OVER (PARTITION BY t.`time_type`, t.`time_value`, t.`source`, t.`signal`, t.`geo_type`, t.`geo_value` ORDER BY t.`issue` DESC) row FROM {table} {subquery} WHERE {conditions}) SELECT {fields} FROM t where row = 1 ORDER BY {order}" + query = f"WITH t as (SELECT {fields}, ROW_NUMBER() OVER (PARTITION BY t.`time_type`, t.`time_value`, t.`source`, t.`signal`, t.`geo_type`, t.`geo_value` ORDER BY t.`issue` DESC) `row` FROM {table} {subquery} WHERE {conditions}) SELECT {fields} FROM t where `row` = 1 ORDER BY {order}" fields_string = ["geo_value", "signal"] fields_int = ["time_value", "issue", "lag"] diff --git a/src/server/endpoints/covidcast_utils/model.py b/src/server/endpoints/covidcast_utils/model.py index 2fa1d90e3..4b5977f28 100644 --- a/src/server/endpoints/covidcast_utils/model.py +++ b/src/server/endpoints/covidcast_utils/model.py @@ -23,6 +23,7 @@ SMOOTH: Callable = lambda rows, **kwargs: generate_smoothed_rows(rows, **kwargs) DIFF_SMOOTH: Callable = lambda rows, **kwargs: generate_smoothed_rows(generate_diffed_rows(rows, **kwargs), **kwargs) +SignalTransforms = Dict[SourceSignalPair, SourceSignalPair] class HighValuesAre(str, Enum): bad = "bad" @@ -366,30 +367,30 @@ def _reindex_iterable(iterator: Iterator[Dict], time_pairs: Optional[List[TimePa # Non-trivial operations otherwise. min_time_value = first_item.get("time_value") - for day in get_day_range(time_pairs): - if day < min_time_value: + for expected_time_value in get_day_range(time_pairs): + if expected_time_value < min_time_value: continue try: - # This will stay the same until the iterator is iterated. - # When _iterator is exhausted, it will raise StopIteration, ending this loop. + # This will stay the same until the peeked element is consumed. new_item = _iterator.peek() - if day == new_item.get("time_value"): - # Get the value we just peeked. - yield next(_iterator) - else: - # Return a default row instead. - # Copy to avoid Python by-reference memory issues. - default_item = _default_item.copy() - default_item.update({ - "time_value": day, - "value": fill_value, - "missing_value": Nans.NOT_MISSING if fill_value and not np.isnan(fill_value) else Nans.NOT_APPLICABLE - }) - yield default_item except StopIteration: return + if expected_time_value == new_item.get("time_value"): + # Get the value we just peeked. + yield next(_iterator) + else: + # Return a default row instead. + # Copy to avoid Python by-reference memory issues. + default_item = _default_item.copy() + default_item.update({ + "time_value": expected_time_value, + "value": fill_value, + "missing_value": Nans.NOT_MISSING if fill_value and not np.isnan(fill_value) else Nans.NOT_APPLICABLE + }) + yield default_item + def _get_base_signal_transform(signal: Union[DataSignal, Tuple[str, str]], data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key) -> Callable: """Given a DataSignal, return the transformation that needs to be applied to its base signal to derive the signal.""" @@ -543,7 +544,7 @@ def get_day_range(time_pairs: List[TimePair]) -> Iterator[int]: def _generate_transformed_rows( - parsed_rows: Iterator[Dict], time_pairs: Optional[List[TimePair]] = None, transform_dict: Optional[Dict[Tuple[str, str], List[Tuple[str, str]]]] = None, transform_args: Optional[Dict] = None, group_keyfunc: Optional[Callable] = None, data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key, + parsed_rows: Iterator[Dict], time_pairs: Optional[List[TimePair]] = None, transform_dict: Optional[SignalTransforms] = None, transform_args: Optional[Dict] = None, group_keyfunc: Optional[Callable] = None, data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key, ) -> Iterator[Dict]: """Applies time-series transformations to streamed rows from a database. @@ -553,7 +554,7 @@ def _generate_transformed_rows( time_pairs: Optional[List[TimePair]], default None A list of TimePairs, which can be used to create a continguous time index for time-series operations. The min and max dates in the TimePairs list is used. - transform_dict: Optional[Dict[Tuple[str, str], List[Tuple[str, str]]]], default None + transform_dict: Optional[SignalTransforms], default None A dictionary mapping base sources to a list of their derived signals that the user wishes to query. For example, transform_dict may be {("jhu-csse", "confirmed_cumulative_num): [("jhu-csse", "confirmed_incidence_num"), ("jhu-csse", "confirmed_7dav_incidence_num")]}. transform_args: Optional[Dict], default None @@ -576,54 +577,52 @@ def _generate_transformed_rows( group_keyfunc = lambda row: (row["geo_type"], row["geo_value"], row["source"], row["signal"]) for key, group in groupby(parsed_rows, group_keyfunc): - group = [x for x in group] _, _, source_name, signal_name = key - # Extract the list of derived signals. - derived_signals: List[Tuple[str, str]] = transform_dict.get((source_name, signal_name), [(source_name, signal_name)]) + # Extract the list of derived signals; if a signal is not in the dictionary, then use the identity map. + derived_signals: SourceSignalPair = transform_dict.get(SourceSignalPair(source_name, [signal_name]), SourceSignalPair(source_name, [signal_name])) # Create a list of source-signal pairs along with the transformation required for the signal. - source_signal_pairs_and_group_transforms: List[Tuple[Tuple[str, str], Callable]] = [((derived_source, derived_signal), _get_base_signal_transform((derived_source, derived_signal), data_signals_by_key)) for (derived_source, derived_signal) in derived_signals] + signal_and_group_transforms: List[Tuple[Tuple[str, str], Callable]] = [(signal, _get_base_signal_transform((source_name, signal), data_signals_by_key)) for signal in derived_signals.signal] # Put the current time series on a contiguous time index. group_contiguous_time = _reindex_iterable(group, time_pairs, fill_value=transform_args.get("pad_fill_value")) # Create copies of the iterable, with smart memory usage. - group_iter_copies: Iterator[Iterator[Dict]] = tee(group_contiguous_time, len(source_signal_pairs_and_group_transforms)) + group_iter_copies: Iterator[Iterator[Dict]] = tee(group_contiguous_time, len(signal_and_group_transforms)) # Create a list of transformed group iterables, remembering their derived name as needed. - transformed_group_rows: Iterator[Iterator[Dict]] = (zip(transform(rows, **transform_args), repeat(key)) for (key, transform), rows in zip(source_signal_pairs_and_group_transforms, group_iter_copies)) + transformed_group_rows: Iterator[Iterator[Dict]] = (zip(transform(rows, **transform_args), repeat(derived_signal)) for (derived_signal, transform), rows in zip(signal_and_group_transforms, group_iter_copies)) # Traverse through the transformed iterables in an interleaved fashion, which makes sure that only a small window # of the original iterable (group) is stored in memory. - for row, (_, derived_signal) in interleave_longest(*transformed_group_rows): + for row, derived_signal in interleave_longest(*transformed_group_rows): row["signal"] = derived_signal yield row - def get_basename_signal_and_jit_generator(source_signal_pairs: List[SourceSignalPair], transform_args: Optional[Dict[str, Union[str, int]]] = None, data_sources_by_id: Dict[str, DataSource] = data_sources_by_id, data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key) -> Tuple[List[SourceSignalPair], Generator]: """From a list of SourceSignalPairs, return the base signals required to derive them and a transformation function to take a stream of the base signals and return the transformed signals. Example: SourceSignalPair("src", signal=["sig_base", "sig_smoothed"]) would return SourceSignalPair("src", signal=["sig_base"]) and a transformation function - that will take the returned database query for "sig_base" and return both the base time series and the smoothed time series. + that will take the returned database query for "sig_base" and return both the base time series and the smoothed time series. transform_dict in this case + would be {("src", "sig_base"): [("src", "sig_base"), ("src", "sig_smooth")]}. """ source_signal_pairs = _resolve_bool_source_signals(source_signal_pairs, data_sources_by_id) base_signal_pairs: List[SourceSignalPair] = [] - transform_dict: Dict[Tuple[str, str], List[Tuple[str, str]]] = dict() + transform_dict: SignalTransforms = dict() for pair in source_signal_pairs: + # Should only occur when the SourceSignalPair was unrecognized by _resolve_bool_source_signals. Useful for testing with fake signal names. if isinstance(pair.signal, bool): base_signal_pairs.append(pair) continue - source_name = pair.source - signal_names = pair.signal signals = [] - for signal_name in signal_names: - signal = data_signals_by_key.get((source_name, signal_name)) + for signal_name in pair.signal: + signal = data_signals_by_key.get((pair.source, signal_name)) if not signal or not signal.compute_from_base: + transform_dict.setdefault(SourceSignalPair(source=pair.source, signal=[signal_name]), SourceSignalPair(source=pair.source, signal=[])).add_signal(signal_name) signals.append(signal_name) - transform_dict.setdefault((source_name, signal_name), []).append((source_name, signal_name)) else: + transform_dict.setdefault(SourceSignalPair(source=pair.source, signal=[signal.signal_basename]), SourceSignalPair(source=pair.source, signal=[])).add_signal(signal_name) signals.append(signal.signal_basename) - transform_dict.setdefault((source_name, signal.signal_basename), []).append((source_name, signal_name)) base_signal_pairs.append(SourceSignalPair(pair.source, signals)) row_transform_generator = partial(_generate_transformed_rows, transform_dict=transform_dict, transform_args=transform_args, data_signals_by_key=data_signals_by_key) diff --git a/tests/acquisition/covidcast/test_covidcast_meta_cache_updater.py b/tests/acquisition/covidcast/test_covidcast_meta_cache_updater.py index 40a242e22..90c9cb5ad 100644 --- a/tests/acquisition/covidcast/test_covidcast_meta_cache_updater.py +++ b/tests/acquisition/covidcast/test_covidcast_meta_cache_updater.py @@ -2,73 +2,62 @@ # standard library import argparse + import unittest from unittest.mock import MagicMock # third party -import pandas -from delphi.epidata.acquisition.covidcast.covidcast_meta_cache_updater import get_argument_parser, \ - main -# py3tester coverage target -__test_target__ = ( - 'delphi.epidata.acquisition.covidcast.' - 'covidcast_meta_cache_updater' -) +from delphi.epidata.acquisition.covidcast.covidcast_meta_cache_updater import get_argument_parser, main class UnitTests(unittest.TestCase): - """Basic unit tests.""" - - def test_get_argument_parser(self): - """Return a parser for command-line arguments.""" - - self.assertIsInstance(get_argument_parser(), argparse.ArgumentParser) + """Basic unit tests.""" - def test_main_successful(self): - """Run the main program successfully.""" + def test_get_argument_parser(self): + """Return a parser for command-line arguments.""" + self.assertIsInstance(get_argument_parser(), argparse.ArgumentParser) - api_response = { - 'result': 1, - 'message': 'yes', - 'epidata': [{'foo': 'bar'}], - } + def test_main_successful(self): + """Run the main program successfully.""" - args = MagicMock(log_file="log") - mock_epidata_impl = MagicMock() - mock_epidata_impl.covidcast_meta.return_value = api_response - mock_database = MagicMock() - mock_database.compute_covidcast_meta.return_value=api_response['epidata'] - fake_database_impl = lambda: mock_database + api_response = { + "result": 1, + "message": "yes", + "epidata": [{"foo": "bar"}], + } - main( - args, - epidata_impl=mock_epidata_impl, - database_impl=fake_database_impl) + args = MagicMock(log_file="log") + mock_epidata_impl = MagicMock() + mock_epidata_impl.covidcast_meta.return_value = api_response + mock_database = MagicMock() + mock_database.compute_covidcast_meta.return_value = api_response["epidata"] + fake_database_impl = lambda: mock_database - self.assertTrue(mock_database.connect.called) + main(args, epidata_impl=mock_epidata_impl, database_impl=fake_database_impl) - self.assertTrue(mock_database.update_covidcast_meta_cache.called) - actual_args = mock_database.update_covidcast_meta_cache.call_args[0] - expected_args = (api_response['epidata'],) - self.assertEqual(actual_args, expected_args) + self.assertTrue(mock_database.connect.called) - self.assertTrue(mock_database.disconnect.called) - self.assertTrue(mock_database.disconnect.call_args[0][0]) + self.assertTrue(mock_database.update_covidcast_meta_cache.called) + actual_args = mock_database.update_covidcast_meta_cache.call_args[0] + expected_args = (api_response["epidata"],) + self.assertEqual(actual_args, expected_args) - def test_main_failure(self): - """Run the main program with a query failure.""" + self.assertTrue(mock_database.disconnect.called) + self.assertTrue(mock_database.disconnect.call_args[0][0]) - api_response = { - 'result': -123, - 'message': 'no', - } + def test_main_failure(self): + """Run the main program with a query failure.""" + api_response = { + "result": -123, + "message": "no", + } - args = MagicMock(log_file="log") - mock_database = MagicMock() - mock_database.compute_covidcast_meta.return_value = list() - fake_database_impl = lambda: mock_database + args = MagicMock(log_file="log") + mock_database = MagicMock() + mock_database.compute_covidcast_meta.return_value = list() + fake_database_impl = lambda: mock_database - main(args, epidata_impl=None, database_impl=fake_database_impl) + main(args, epidata_impl=None, database_impl=fake_database_impl) - self.assertTrue(mock_database.compute_covidcast_meta.called) + self.assertTrue(mock_database.compute_covidcast_meta.called) diff --git a/tests/acquisition/covidcast/test_csv_to_database.py b/tests/acquisition/covidcast/test_csv_to_database.py index 2444c5262..0b91815fb 100644 --- a/tests/acquisition/covidcast/test_csv_to_database.py +++ b/tests/acquisition/covidcast/test_csv_to_database.py @@ -2,6 +2,7 @@ # standard library import argparse +from typing import Iterable import unittest from unittest.mock import MagicMock @@ -27,9 +28,7 @@ def _path_details(self): # a file with a data error ('path/b.csv', ('src_b', 'sig_b', 'week', 'msa', 202016, 202017, 1)), # emulate a file that's named incorrectly - ('path/c.csv', None), - # another good file w/ wip - ('path/d.csv', ('src_d', 'wip_sig_d', 'week', 'msa', 202016, 202017, 1)), + ('path/c.csv', None) ] def test_collect_files(self): @@ -65,15 +64,16 @@ def load_csv_impl(path, *args): yield make_row('b1') yield None yield make_row('b3') - elif path == 'path/d.csv': - yield make_row('d1') else: # fail the test for any other path raise Exception('unexpected path') + def iter_len(l: Iterable) -> int: + return len(list(l)) + data_dir = 'data_dir' mock_database = MagicMock() - mock_database.insert_or_update_bulk.return_value = 2 + mock_database.insert_or_update_bulk = MagicMock(wraps=iter_len) mock_csv_importer = MagicMock() mock_csv_importer.load_csv = load_csv_impl mock_file_archiver = MagicMock() @@ -87,9 +87,9 @@ def load_csv_impl(path, *args): mock_logger, csv_importer_impl=mock_csv_importer) - self.assertEqual(modified_row_count, 4) + self.assertEqual(modified_row_count, 3) # verify that appropriate rows were added to the database - self.assertEqual(mock_database.insert_or_update_bulk.call_count, 2) + self.assertEqual(mock_database.insert_or_update_bulk.call_count, 1) call_args_list = mock_database.insert_or_update_bulk.call_args_list actual_args = [[(a.source, a.signal, a.time_type, a.geo_type, a.time_value, a.geo_value, a.value, a.stderr, a.sample_size, a.issue, a.lag) @@ -97,20 +97,18 @@ def load_csv_impl(path, *args): expected_args = [ [('src_a', 'sig_a', 'day', 'hrr', 20200419, 'a1', 'a1', 'a1', 'a1', 20200420, 1), ('src_a', 'sig_a', 'day', 'hrr', 20200419, 'a2', 'a2', 'a2', 'a2', 20200420, 1), - ('src_a', 'sig_a', 'day', 'hrr', 20200419, 'a3', 'a3', 'a3', 'a3', 20200420, 1)], - [('src_d', 'wip_sig_d', 'week', 'msa', 202016, 'd1', 'd1', 'd1', 'd1', 202017, 1)] + ('src_a', 'sig_a', 'day', 'hrr', 20200419, 'a3', 'a3', 'a3', 'a3', 20200420, 1)] ] self.assertEqual(actual_args, expected_args) # verify that two files were successful (a, d) and two failed (b, c) - self.assertEqual(mock_file_archiver.archive_file.call_count, 4) + self.assertEqual(mock_file_archiver.archive_file.call_count, 3) call_args_list = mock_file_archiver.archive_file.call_args_list actual_args = [args for (args, kwargs) in call_args_list] expected_args = [ ('path', 'data_dir/archive/successful/src_a', 'a.csv', True), ('path', 'data_dir/archive/failed/src_b', 'b.csv', False), ('path', 'data_dir/archive/failed/unknown', 'c.csv', False), - ('path', 'data_dir/archive/successful/src_d', 'd.csv', True), ] self.assertEqual(actual_args, expected_args) diff --git a/tests/acquisition/covidcast/test_database.py b/tests/acquisition/covidcast/test_database.py index 71fd429b9..c75e1bd8e 100644 --- a/tests/acquisition/covidcast/test_database.py +++ b/tests/acquisition/covidcast/test_database.py @@ -51,55 +51,6 @@ def test_disconnect_with_commit(self): self.assertTrue(connection.commit.called) self.assertTrue(connection.close.called) - def test_count_all_rows_query(self): - """Query to count all rows looks sensible. - - NOTE: Actual behavior is tested by integration test. - """ - - mock_connector = MagicMock() - database = Database() - database.connect(connector_impl=mock_connector) - connection = mock_connector.connect() - cursor = connection.cursor() - cursor.__iter__.return_value = [(123,)] - - num = database.count_all_rows() - - self.assertEqual(num, 123) - self.assertTrue(cursor.execute.called) - - sql = cursor.execute.call_args[0][0].lower() - self.assertIn('select count(1)', sql) - self.assertIn('from `signal_', sql) # note that this table name is incomplete - - def test_update_covidcast_meta_cache_query(self): - """Query to update the metadata cache looks sensible. - - NOTE: Actual behavior is tested by integration test. - """ - - args = ('epidata_json_str',) - mock_connector = MagicMock() - database = Database() - database.connect(connector_impl=mock_connector) - - database.update_covidcast_meta_cache(*args) - - connection = mock_connector.connect() - cursor = connection.cursor() - self.assertTrue(cursor.execute.called) - - sql, args = cursor.execute.call_args[0] - expected_args = ('"epidata_json_str"',) - self.assertEqual(args, expected_args) - - sql = sql.lower() - self.assertIn('update', sql) - self.assertIn('`covidcast_meta_cache`', sql) - self.assertIn('timestamp', sql) - self.assertIn('epidata', sql) - def test_insert_or_update_batch_exception_reraised(self): """Test that an exception is reraised""" mock_connector = MagicMock() @@ -116,6 +67,7 @@ def test_insert_or_update_batch_row_count_returned(self): """Test that the row count is returned""" mock_connector = MagicMock() database = Database() + database.count_all_load_rows = lambda:0 # simulate an empty load table database.connect(connector_impl=mock_connector) connection = mock_connector.connect() cursor = connection.cursor() @@ -129,6 +81,7 @@ def test_insert_or_update_batch_none_returned(self): """Test that None is returned when row count cannot be returned""" mock_connector = MagicMock() database = Database() + database.count_all_load_rows = lambda:0 # simulate an empty load table database.connect(connector_impl=mock_connector) connection = mock_connector.connect() cursor = connection.cursor() diff --git a/tests/acquisition/covidcast/test_database_meta.py b/tests/acquisition/covidcast/test_database_meta.py new file mode 100644 index 000000000..78566d86a --- /dev/null +++ b/tests/acquisition/covidcast/test_database_meta.py @@ -0,0 +1,34 @@ +import unittest +from unittest.mock import MagicMock + +from delphi.epidata.acquisition.covidcast.database_meta import DatabaseMeta + +class UnitTests(unittest.TestCase): + """Basic unit tests.""" + + def test_update_covidcast_meta_cache_query(self): + """Query to update the metadata cache looks sensible. + + NOTE: Actual behavior is tested by integration test. + """ + + args = ('epidata_json_str',) + mock_connector = MagicMock() + database = DatabaseMeta() + database.connect(connector_impl=mock_connector) + + database.update_covidcast_meta_cache(*args) + + connection = mock_connector.connect() + cursor = connection.cursor() + self.assertTrue(cursor.execute.called) + + sql, args = cursor.execute.call_args[0] + expected_args = ('"epidata_json_str"',) + self.assertEqual(args, expected_args) + + sql = sql.lower() + self.assertIn('update', sql) + self.assertIn('`covidcast_meta_cache`', sql) + self.assertIn('timestamp', sql) + self.assertIn('epidata', sql) diff --git a/tests/server/endpoints/covidcast_utils/test_model.py b/tests/server/endpoints/covidcast_utils/test_model.py index 26994e005..667b4bd4a 100644 --- a/tests/server/endpoints/covidcast_utils/test_model.py +++ b/tests/server/endpoints/covidcast_utils/test_model.py @@ -1,35 +1,35 @@ -from dataclasses import fields -from numbers import Number -from typing import Iterable, List, Optional, Union, get_args, get_origin import unittest from itertools import chain +from numbers import Number +from typing import Iterable, List, Optional + +import pandas as pd from more_itertools import interleave_longest, windowed -from pandas import DataFrame, date_range from pandas.testing import assert_frame_equal -from delphi_utils.nancodes import Nans - -from delphi.epidata.acquisition.covidcast.covidcast_row import CovidcastRow, CovidcastRows, set_df_dtypes +from delphi.epidata.acquisition.covidcast.covidcast_row import CovidcastRows from delphi.epidata.server._params import SourceSignalPair, TimePair from delphi.epidata.server.endpoints.covidcast_utils.model import ( - IDENTITY, DIFF, - SMOOTH, DIFF_SMOOTH, - DataSource, + IDENTITY, + SMOOTH, DataSignal, - _resolve_bool_source_signals, - _reindex_iterable, + DataSource, + _generate_transformed_rows, _get_base_signal_transform, - get_transform_types, - get_pad_length, - pad_time_pairs, - get_day_range, _iterate_over_ints_and_ranges, - _generate_transformed_rows, + _reindex_iterable, + _resolve_bool_source_signals, get_basename_signal_and_jit_generator, + get_day_range, + get_pad_length, + get_transform_types, + pad_time_pairs, ) +from delphi_utils.nancodes import Nans +# fmt: off DATA_SIGNALS_BY_KEY = { ("src", "sig_diff"): DataSignal( source="src", @@ -73,10 +73,28 @@ compute_from_base=True, ), ("src", "sig_base"): DataSignal( - source="src", signal="sig_base", signal_basename="sig_base", name="src", active=True, short_description="", description="", time_label="", value_label="", is_cumulative=True, + source="src", + signal="sig_base", + signal_basename="sig_base", + name="src", + active=True, + short_description="", + description="", + time_label="", + value_label="", + is_cumulative=True, ), ("src2", "sig_base"): DataSignal( - source="src2", signal="sig_base", signal_basename="sig_base", name="sig_base", active=True, short_description="", description="", time_label="", value_label="", is_cumulative=True, + source="src2", + signal="sig_base", + signal_basename="sig_base", + name="sig_base", + active=True, + short_description="", + description="", + time_label="", + value_label="", + is_cumulative=True, ), ("src2", "sig_diff_smooth"): DataSignal( source="src2", @@ -112,16 +130,19 @@ signals=[DATA_SIGNALS_BY_KEY[key] for key in DATA_SIGNALS_BY_KEY if key[0] == "src2"], ), } +# fmt: on def _diff_rows(rows: Iterable[Number]) -> List[Number]: return [round(float(y - x), 8) if not (x is None or y is None) else None for x, y in windowed(rows, 2)] + def _smooth_rows(rows: Iterable[Number], window_length: int = 7, kernel: Optional[List[Number]] = None): if not kernel: - kernel = [1. / window_length] * window_length + kernel = [1.0 / window_length] * window_length return [round(sum(x * y for x, y in zip(window, kernel)), 8) if None not in window else None for window in windowed(rows, len(kernel))] + def _reindex_windowed(lst: list, window_length: int) -> list: return [max(window) if None not in window else None for window in windowed(lst, window_length)] @@ -141,37 +162,35 @@ def test__reindex_iterable(self): time_pairs = [(20210503, 20210508)] assert list(_reindex_iterable([], time_pairs)) == [] - data = CovidcastRows.from_args(time_value=date_range("2021-05-03", "2021-05-08").to_list()).api_row_df + data = CovidcastRows.from_args(time_value=pd.date_range("2021-05-03", "2021-05-08").to_list()).api_row_df for time_pairs in [[TimePair("day", [(20210503, 20210508)])], [], None]: with self.subTest(f"Identity operations: {time_pairs}"): - df = CovidcastRows.from_records(_reindex_iterable(data.to_dict(orient='records'), time_pairs)).api_row_df + df = CovidcastRows.from_records(_reindex_iterable(data.to_dict(orient="records"), time_pairs)).api_row_df assert_frame_equal(df, data) - data = CovidcastRows.from_args( - time_value=date_range("2021-05-03", "2021-05-08").to_list() + date_range("2021-05-11", "2021-05-14").to_list() - ).api_row_df + data = CovidcastRows.from_args(time_value=pd.date_range("2021-05-03", "2021-05-08").to_list() + pd.date_range("2021-05-11", "2021-05-14").to_list()).api_row_df with self.subTest("Non-trivial operations"): time_pairs = [TimePair("day", [(20210501, 20210513)])] - df = CovidcastRows.from_records(_reindex_iterable(data.to_dict(orient='records'), time_pairs)).api_row_df + df = CovidcastRows.from_records(_reindex_iterable(data.to_dict(orient="records"), time_pairs)).api_row_df expected_df = CovidcastRows.from_args( - time_value=date_range("2021-05-03", "2021-05-13"), - issue=date_range("2021-05-03", "2021-05-08").to_list() + [None] * 2 + date_range("2021-05-11", "2021-05-13").to_list(), + time_value=pd.date_range("2021-05-03", "2021-05-13"), + issue=pd.date_range("2021-05-03", "2021-05-08").to_list() + [None] * 2 + pd.date_range("2021-05-11", "2021-05-13").to_list(), lag=[0] * 6 + [None] * 2 + [0] * 3, - value=chain([10.] * 6, [None] * 2, [10.] * 3), - stderr=chain([10.] * 6, [None] * 2, [10.] * 3), - sample_size=chain([10.] * 6, [None] * 2, [10.] * 3) + value=chain([10.0] * 6, [None] * 2, [10.0] * 3), + stderr=chain([10.0] * 6, [None] * 2, [10.0] * 3), + sample_size=chain([10.0] * 6, [None] * 2, [10.0] * 3), ).api_row_df assert_frame_equal(df, expected_df) - df = CovidcastRows.from_records(_reindex_iterable(data.to_dict(orient='records'), time_pairs, fill_value=2.)).api_row_df + df = CovidcastRows.from_records(_reindex_iterable(data.to_dict(orient="records"), time_pairs, fill_value=2.0)).api_row_df expected_df = CovidcastRows.from_args( - time_value=date_range("2021-05-03", "2021-05-13"), - issue=date_range("2021-05-03", "2021-05-08").to_list() + [None] * 2 + date_range("2021-05-11", "2021-05-13").to_list(), + time_value=pd.date_range("2021-05-03", "2021-05-13"), + issue=pd.date_range("2021-05-03", "2021-05-08").to_list() + [None] * 2 + pd.date_range("2021-05-11", "2021-05-13").to_list(), lag=[0] * 6 + [None] * 2 + [0] * 3, - value=chain([10.] * 6, [2.] * 2, [10.] * 3), - stderr=chain([10.] * 6, [None] * 2, [10.] * 3), - sample_size=chain([10.] * 6, [None] * 2, [10.] * 3), + value=chain([10.0] * 6, [2.0] * 2, [10.0] * 3), + stderr=chain([10.0] * 6, [None] * 2, [10.0] * 3), + sample_size=chain([10.0] * 6, [None] * 2, [10.0] * 3), ).api_row_df assert_frame_equal(df, expected_df) @@ -188,14 +207,17 @@ def test_get_transform_types(self): transform_types = get_transform_types(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) expected_transform_types = {IDENTITY, DIFF, SMOOTH, DIFF_SMOOTH} assert transform_types == expected_transform_types + source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_diff"])] transform_types = get_transform_types(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) expected_transform_types = {DIFF} assert transform_types == expected_transform_types + source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_smooth"])] transform_types = get_transform_types(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) expected_transform_types = {SMOOTH} assert transform_types == expected_transform_types + source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_diff_smooth"])] transform_types = get_transform_types(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) expected_transform_types = {DIFF_SMOOTH} @@ -205,29 +227,57 @@ def test_get_pad_length(self): source_signal_pairs = [SourceSignalPair(source="src", signal=True)] pad_length = get_pad_length(source_signal_pairs, smoother_window_length=7, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) assert pad_length == 7 + source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_diff"])] pad_length = get_pad_length(source_signal_pairs, smoother_window_length=7, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) assert pad_length == 1 + source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_smooth"])] pad_length = get_pad_length(source_signal_pairs, smoother_window_length=5, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) assert pad_length == 4 + source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_diff_smooth"])] pad_length = get_pad_length(source_signal_pairs, smoother_window_length=10, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) assert pad_length == 10 def test_pad_time_pairs(self): - time_pairs = [TimePair("day", [20210817, (20210810, 20210815)]), TimePair("day", True), TimePair("day", [20210816])] - padded_time_pairs = pad_time_pairs(time_pairs, pad_length=7) - expected_padded_time_pairs = [TimePair("day", [20210817, (20210810, 20210815)]), TimePair("day", True), TimePair("day", [20210816]), TimePair("day", [(20210803, 20210810)])] - assert padded_time_pairs == expected_padded_time_pairs - time_pairs = [TimePair("day", [20210817, (20210810, 20210815)]), TimePair("day", True), TimePair("day", [20210816]), TimePair("day", [20210809])] - padded_time_pairs = pad_time_pairs(time_pairs, pad_length=8) - expected_padded_time_pairs = [TimePair("day", [20210817, (20210810, 20210815)]), TimePair("day", True), TimePair("day", [20210816]), TimePair("day", [20210809]), TimePair("day", [(20210801, 20210809)])] - assert padded_time_pairs == expected_padded_time_pairs - time_pairs = [TimePair("day", [20210817, (20210810, 20210815)])] - padded_time_pairs = pad_time_pairs(time_pairs, pad_length=0) - expected_padded_time_pairs = [TimePair("day", [20210817, (20210810, 20210815)])] - assert padded_time_pairs == expected_padded_time_pairs + # fmt: off + time_pairs = [ + TimePair("day", [20210817, (20210810, 20210815)]), + TimePair("day", True), + TimePair("day", [20210816]) + ] + expected_padded_time_pairs = [ + TimePair("day", [20210817, (20210810, 20210815)]), + TimePair("day", True), + TimePair("day", [20210816]), + TimePair("day", [(20210803, 20210810)]) + ] + assert pad_time_pairs(time_pairs, pad_length=7) == expected_padded_time_pairs + + time_pairs = [ + TimePair("day", [20210817, (20210810, 20210815)]), + TimePair("day", True), + TimePair("day", [20210816]), + TimePair("day", [20210809]) + ] + expected_padded_time_pairs = [ + TimePair("day", [20210817, (20210810, 20210815)]), + TimePair("day", True), + TimePair("day", [20210816]), + TimePair("day", [20210809]), + TimePair("day", [(20210801, 20210809)]), + ] + assert pad_time_pairs(time_pairs, pad_length=8) == expected_padded_time_pairs + + time_pairs = [ + TimePair("day", [20210817, (20210810, 20210815)]) + ] + expected_padded_time_pairs = [ + TimePair("day", [20210817, (20210810, 20210815)]) + ] + assert pad_time_pairs(time_pairs, pad_length=0) == expected_padded_time_pairs + # fmt: on def test_get_day_range(self): assert list(_iterate_over_ints_and_ranges([0, (5, 8)], use_dates=False)) == [0, 5, 6, 7, 8] @@ -241,14 +291,19 @@ def test_get_day_range(self): assert list(get_day_range([TimePair("day", [(20210801, 20210805)]), TimePair("day", [(20210803, 20210807)])])) == [20210801, 20210802, 20210803, 20210804, 20210805, 20210806, 20210807] def test__generate_transformed_rows(self): + # fmt: off with self.subTest("diffed signal test"): - data = CovidcastRows.from_args(signal=["sig_base"] * 5, time_value=date_range("2021-05-01", "2021-05-05"), value=range(5)).api_row_df - transform_dict = {("src", "sig_base"): [("src", "sig_diff")]} + data = CovidcastRows.from_args( + signal=["sig_base"] * 5, + time_value=pd.date_range("2021-05-01", "2021-05-05"), + value=range(5) + ).api_row_df + transform_dict = {SourceSignalPair("src", ["sig_base"]): SourceSignalPair("src", ["sig_diff"])} df = CovidcastRows.from_records(_generate_transformed_rows(data.to_dict(orient="records"), transform_dict=transform_dict, data_signals_by_key=DATA_SIGNALS_BY_KEY)).api_row_df expected_df = CovidcastRows.from_args( signal=["sig_diff"] * 4, - time_value=date_range("2021-05-02", "2021-05-05"), + time_value=pd.date_range("2021-05-02", "2021-05-05"), value=[1.0] * 4, stderr=[None] * 4, sample_size=[None] * 4, @@ -260,14 +315,18 @@ def test__generate_transformed_rows(self): with self.subTest("smoothed and diffed signals on one base test"): data = CovidcastRows.from_args( - signal=["sig_base"] * 10, time_value=date_range("2021-05-01", "2021-05-10"), value=range(10), stderr=range(10), sample_size=range(10) + signal=["sig_base"] * 10, + time_value=pd.date_range("2021-05-01", "2021-05-10"), + value=range(10), + stderr=range(10), + sample_size=range(10) ).api_row_df - transform_dict = {("src", "sig_base"): [("src", "sig_diff"), ("src", "sig_smooth")]} + transform_dict = {SourceSignalPair("src", ["sig_base"]): SourceSignalPair("src", ["sig_diff", "sig_smooth"])} df = CovidcastRows.from_records(_generate_transformed_rows(data.to_dict(orient="records"), transform_dict=transform_dict, data_signals_by_key=DATA_SIGNALS_BY_KEY)).api_row_df expected_df = CovidcastRows.from_args( signal=interleave_longest(["sig_diff"] * 9, ["sig_smooth"] * 4), - time_value=interleave_longest(date_range("2021-05-02", "2021-05-10"), date_range("2021-05-07", "2021-05-10")), + time_value=interleave_longest(pd.date_range("2021-05-02", "2021-05-10"), pd.date_range("2021-05-07", "2021-05-10")), value=interleave_longest(_diff_rows(data.value.to_list()), _smooth_rows(data.value.to_list())), stderr=[None] * 13, sample_size=[None] * 13, @@ -281,18 +340,24 @@ def test__generate_transformed_rows(self): with self.subTest("smoothed and diffed signal on two non-continguous regions"): data = CovidcastRows.from_args( - signal=["sig_base"] * 15, time_value=chain(date_range("2021-05-01", "2021-05-10"), date_range("2021-05-16", "2021-05-20")), value=range(15), stderr=range(15), sample_size=range(15) + signal=["sig_base"] * 15, + time_value=chain(pd.date_range("2021-05-01", "2021-05-10"), pd.date_range("2021-05-16", "2021-05-20")), + value=range(15), + stderr=range(15), + sample_size=range(15), ).api_row_df - transform_dict = {("src", "sig_base"): [("src", "sig_diff"), ("src", "sig_smooth")]} + transform_dict = {SourceSignalPair("src", ["sig_base"]): SourceSignalPair("src", ["sig_diff", "sig_smooth"])} time_pairs = [TimePair("day", [(20210501, 20210520)])] - df = CovidcastRows.from_records(_generate_transformed_rows(data.to_dict(orient="records"), time_pairs=time_pairs, transform_dict=transform_dict, data_signals_by_key=DATA_SIGNALS_BY_KEY)).api_row_df + df = CovidcastRows.from_records( + _generate_transformed_rows(data.to_dict(orient="records"), time_pairs=time_pairs, transform_dict=transform_dict, data_signals_by_key=DATA_SIGNALS_BY_KEY) + ).api_row_df filled_values = data.value.to_list()[:10] + [None] * 5 + data.value.to_list()[10:] - filled_time_values = list(chain(date_range("2021-05-01", "2021-05-10"), [None] * 5, date_range("2021-05-16", "2021-05-20"))) + filled_time_values = list(chain(pd.date_range("2021-05-01", "2021-05-10"), [None] * 5, pd.date_range("2021-05-16", "2021-05-20"))) expected_df = CovidcastRows.from_args( signal=interleave_longest(["sig_diff"] * 19, ["sig_smooth"] * 14), - time_value=interleave_longest(date_range("2021-05-02", "2021-05-20"), date_range("2021-05-07", "2021-05-20")), + time_value=interleave_longest(pd.date_range("2021-05-02", "2021-05-20"), pd.date_range("2021-05-07", "2021-05-20")), value=interleave_longest(_diff_rows(filled_values), _smooth_rows(filled_values)), stderr=[None] * 33, sample_size=[None] * 33, @@ -303,6 +368,7 @@ def test__generate_transformed_rows(self): assert_frame_equal(df.set_index(idx).sort_index(), expected_df.set_index(idx).sort_index()) # Test order. assert_frame_equal(df, expected_df) + # fmt: on def test_get_basename_signals(self): with self.subTest("none to transform"): @@ -336,12 +402,13 @@ def test_get_basename_signals(self): assert basename_pairs == expected_basename_pairs with self.subTest("test base, diff, smooth"): + # fmt: off data = CovidcastRows.from_args( signal=["sig_base"] * 20 + ["sig_other"] * 5, - time_value=chain(date_range("2021-05-01", "2021-05-10"), date_range("2021-05-21", "2021-05-30"), date_range("2021-05-01", "2021-05-05")), + time_value=chain(pd.date_range("2021-05-01", "2021-05-10"), pd.date_range("2021-05-21", "2021-05-30"), pd.date_range("2021-05-01", "2021-05-05")), value=chain(range(20), range(5)), stderr=chain(range(20), range(5)), - sample_size=chain(range(20), range(5)) + sample_size=chain(range(20), range(5)), ).api_row_df source_signal_pairs = [SourceSignalPair("src", ["sig_base", "sig_diff", "sig_other", "sig_smooth"])] _, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) @@ -349,15 +416,20 @@ def test_get_basename_signals(self): df = CovidcastRows.from_records(row_transform_generator(data.to_dict(orient="records"), time_pairs=time_pairs)).api_row_df filled_values = list(chain(range(10), [None] * 10, range(10, 20))) - filled_time_values = list(chain(date_range("2021-05-01", "2021-05-10"), [None] * 10, date_range("2021-05-21", "2021-05-30"))) + filled_time_values = list(chain(pd.date_range("2021-05-01", "2021-05-10"), [None] * 10, pd.date_range("2021-05-21", "2021-05-30"))) expected_df = CovidcastRows.from_args( signal=["sig_base"] * 30 + ["sig_diff"] * 29 + ["sig_other"] * 5 + ["sig_smooth"] * 24, - time_value=chain(date_range("2021-05-01", "2021-05-30"), date_range("2021-05-02", "2021-05-30"), date_range("2021-05-01", "2021-05-05"), date_range("2021-05-07", "2021-05-30")), + time_value=chain( + pd.date_range("2021-05-01", "2021-05-30"), + pd.date_range("2021-05-02", "2021-05-30"), + pd.date_range("2021-05-01", "2021-05-05"), + pd.date_range("2021-05-07", "2021-05-30") + ), value=chain( - filled_values, - _diff_rows(filled_values), - range(5), + filled_values, + _diff_rows(filled_values), + range(5), _smooth_rows(filled_values) ), stderr=chain( @@ -372,25 +444,22 @@ def test_get_basename_signals(self): range(5), chain([None] * 24), ), - issue=chain( - filled_time_values, - _reindex_windowed(filled_time_values, 2), - date_range("2021-05-01", "2021-05-05"), - _reindex_windowed(filled_time_values, 7) - ), + issue=chain(filled_time_values, _reindex_windowed(filled_time_values, 2), pd.date_range("2021-05-01", "2021-05-05"), _reindex_windowed(filled_time_values, 7)), ).api_row_df + # fmt: on # Test no order. idx = ["source", "signal", "time_value"] assert_frame_equal(df.set_index(idx).sort_index(), expected_df.set_index(idx).sort_index()) with self.subTest("test base, diff, smooth; multiple geos"): + # fmt: off data = CovidcastRows.from_args( signal=["sig_base"] * 40, geo_value=["ak"] * 20 + ["ca"] * 20, - time_value=chain(date_range("2021-05-01", "2021-05-20"), date_range("2021-05-01", "2021-05-20")), + time_value=chain(pd.date_range("2021-05-01", "2021-05-20"), pd.date_range("2021-05-01", "2021-05-20")), value=chain(range(20), range(0, 40, 2)), stderr=chain(range(20), range(0, 40, 2)), - sample_size=chain(range(20), range(0, 40, 2)) + sample_size=chain(range(20), range(0, 40, 2)), ).api_row_df source_signal_pairs = [SourceSignalPair("src", ["sig_base", "sig_diff", "sig_other", "sig_smooth"])] _, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) @@ -400,17 +469,18 @@ def test_get_basename_signals(self): signal=["sig_base"] * 40 + ["sig_diff"] * 38 + ["sig_smooth"] * 28, geo_value=["ak"] * 20 + ["ca"] * 20 + ["ak"] * 19 + ["ca"] * 19 + ["ak"] * 14 + ["ca"] * 14, time_value=chain( - date_range("2021-05-01", "2021-05-20"), - date_range("2021-05-01", "2021-05-20"), - date_range("2021-05-02", "2021-05-20"), - date_range("2021-05-02", "2021-05-20"), - date_range("2021-05-07", "2021-05-20"), - date_range("2021-05-07", "2021-05-20") + pd.date_range("2021-05-01", "2021-05-20"), + pd.date_range("2021-05-01", "2021-05-20"), + pd.date_range("2021-05-02", "2021-05-20"), + pd.date_range("2021-05-02", "2021-05-20"), + pd.date_range("2021-05-07", "2021-05-20"), + pd.date_range("2021-05-07", "2021-05-20"), ), value=chain( - chain(range(20), range(0, 40, 2)), - chain([1] * 19, [2] * 19), - chain([sum(x) / len(x) for x in windowed(range(20), 7)], [sum(x) / len(x) for x in windowed(range(0, 40, 2), 7)]) + chain(range(20), range(0, 40, 2)), + chain([1] * 19, [2] * 19), + chain([sum(x) / len(x) for x in windowed(range(20), 7)], + [sum(x) / len(x) for x in windowed(range(0, 40, 2), 7)]) ), stderr=chain( chain(range(20), range(0, 40, 2)), @@ -421,8 +491,9 @@ def test_get_basename_signals(self): chain(range(20), range(0, 40, 2)), chain([None] * 38), chain([None] * 28), - ) + ), ).api_row_df + # fmt: on # Test no order. idx = ["source", "signal", "time_value"] assert_frame_equal(df.set_index(idx).sort_index(), expected_df.set_index(idx).sort_index()) @@ -430,10 +501,10 @@ def test_get_basename_signals(self): with self.subTest("resolve signals called"): data = CovidcastRows.from_args( signal=["sig_base"] * 20 + ["sig_other"] * 5, - time_value=chain(date_range("2021-05-01", "2021-05-10"), date_range("2021-05-21", "2021-05-30"), date_range("2021-05-01", "2021-05-05")), + time_value=chain(pd.date_range("2021-05-01", "2021-05-10"), pd.date_range("2021-05-21", "2021-05-30"), pd.date_range("2021-05-01", "2021-05-05")), value=chain(range(20), range(5)), stderr=chain(range(20), range(5)), - sample_size=chain(range(20), range(5)) + sample_size=chain(range(20), range(5)), ).api_row_df source_signal_pairs = [SourceSignalPair("src", True)] _, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) @@ -441,19 +512,20 @@ def test_get_basename_signals(self): df = CovidcastRows.from_records(row_transform_generator(data.to_dict(orient="records"), time_pairs=time_pairs)).api_row_df filled_values = list(chain(range(10), [None] * 10, range(10, 20))) - filled_time_values = list(chain(date_range("2021-05-01", "2021-05-10"), [None] * 10, date_range("2021-05-21", "2021-05-30"))) + filled_time_values = list(chain(pd.date_range("2021-05-01", "2021-05-10"), [None] * 10, pd.date_range("2021-05-21", "2021-05-30"))) + # fmt: off expected_df = CovidcastRows.from_args( signal=["sig_base"] * 30 + ["sig_diff"] * 29 + ["sig_diff_smooth"] * 23 + ["sig_other"] * 5 + ["sig_smooth"] * 24, time_value=chain( - date_range("2021-05-01", "2021-05-30"), - date_range("2021-05-02", "2021-05-30"), - date_range("2021-05-08", "2021-05-30"), - date_range("2021-05-01", "2021-05-05"), - date_range("2021-05-07", "2021-05-30") + pd.date_range("2021-05-01", "2021-05-30"), + pd.date_range("2021-05-02", "2021-05-30"), + pd.date_range("2021-05-08", "2021-05-30"), + pd.date_range("2021-05-01", "2021-05-05"), + pd.date_range("2021-05-07", "2021-05-30"), ), value=chain( - filled_values, + filled_values, _diff_rows(filled_values), _smooth_rows(_diff_rows(filled_values)), range(5), @@ -477,15 +549,15 @@ def test_get_basename_signals(self): filled_time_values, _reindex_windowed(filled_time_values, 2), _reindex_windowed(filled_time_values, 8), - date_range("2021-05-01", "2021-05-05"), - _reindex_windowed(filled_time_values, 7) + pd.date_range("2021-05-01", "2021-05-05"), + _reindex_windowed(filled_time_values, 7), ), ).api_row_df + # fmt: off # Test no order. idx = ["source", "signal", "time_value"] assert_frame_equal(df.set_index(idx).sort_index(), expected_df.set_index(idx).sort_index()) - with self.subTest("empty iterator"): source_signal_pairs = [SourceSignalPair("src", ["sig_base", "sig_diff", "sig_smooth"])] _, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) diff --git a/tests/server/endpoints/test_covidcast.py b/tests/server/endpoints/test_covidcast.py index 80a2191d2..7b183fe16 100644 --- a/tests/server/endpoints/test_covidcast.py +++ b/tests/server/endpoints/test_covidcast.py @@ -5,7 +5,6 @@ from flask import Response from delphi.epidata.server.main import app -from delphi.epidata.server.endpoints.covidcast import guess_index_to_use, parse_transform_args from delphi.epidata.server._params import ( GeoPair, TimePair, @@ -40,25 +39,6 @@ def test_time(self): self.assertEqual(msg["result"], -2) # no result self.assertEqual(msg["message"], "no results") - def test_guess_index_to_use(self): - self.assertFalse(False, "deprecated tests...") - return - # TODO: remove this as we are no longer planning to hint at indexes... - self.assertEqual(guess_index_to_use([TimePair("day", True)], [GeoPair("county", ["a"])], issues=None, lag=None, as_of=None), "by_issue") - self.assertEqual(guess_index_to_use([TimePair("day", True)], [GeoPair("county", ["a", "b"])], issues=None, lag=None, as_of=None), "by_issue") - self.assertEqual(guess_index_to_use([TimePair("day", True)], [GeoPair("county", ["a", "b"])], issues=None, lag=None, as_of=None), "by_issue") - self.assertEqual(guess_index_to_use([TimePair("day", True)], [GeoPair("county", ["a", "b", "c"])], issues=None, lag=None, as_of=None), "by_issue") - - # to many geo - self.assertIsNone(guess_index_to_use([TimePair("day", True)], [GeoPair("county", ["a", "b", "c", "d", "e", "f"])], issues=None, lag=None, as_of=None)) - # to short time frame - self.assertIsNone(guess_index_to_use([TimePair("day", [(20200101, 20200115)])], [GeoPair("county", ["a", "b", "c", "d", "e", "f"])], issues=None, lag=None, as_of=None)) - - self.assertEqual(guess_index_to_use([TimePair("day", True)], [GeoPair("county", ["a"])], issues=None, lag=3, as_of=None), "by_lag") - self.assertEqual(guess_index_to_use([TimePair("day", True)], [GeoPair("county", ["a"])], issues=[20200202], lag=3, as_of=None), "by_issue") - self.assertIsNone(guess_index_to_use([TimePair("day", [20200201])], [GeoPair("county", ["a"])], issues=[20200202], lag=3, as_of=None)) - self.assertIsNone(guess_index_to_use([TimePair("day", True)], [GeoPair("county", True)], issues=None, lag=3, as_of=None)) - # TODO def test_parse_transform_args(self): ...