|
1 | 1 | from collections import namedtuple
|
2 |
| -from dataclasses import dataclass |
3 |
| -from datetime import date, datetime, timedelta |
| 2 | +from datetime import date, datetime |
4 | 3 | from itertools import chain
|
5 |
| -from typing import Any, Dict, List, Union |
6 | 4 | import pandas as pd
|
7 |
| -from pandas.util.testing import assert_frame_equal |
8 | 5 | import numpy as np
|
9 | 6 | import pytest
|
10 | 7 | from unittest.mock import patch, Mock
|
11 | 8 |
|
12 | 9 | from delphi_utils.geomap import GeoMapper
|
13 | 10 |
|
14 |
| -from delphi_dsew_community_profile.pull import ( |
15 |
| - DatasetTimes, Dataset, |
| 11 | +from delphi_dsew_community_profile.pull import (DatasetTimes, Dataset, |
16 | 12 | fetch_listing, nation_from_state, generate_prop_signal,
|
17 |
| - std_err, add_max_ts_col, unify_testing_sigs, interpolate_missing_values, |
18 |
| - extend_listing_for_interp |
19 |
| -) |
| 13 | + std_err, add_max_ts_col, unify_testing_sigs) |
20 | 14 |
|
21 | 15 |
|
22 | 16 | example = namedtuple("example", "given expected")
|
23 |
| - |
24 |
| -def _assert_frame_equal(df1, df2, index_cols: List[str] = None): |
25 |
| - # Ensure same columns present. |
26 |
| - assert set(df1.columns) == set(df2.columns) |
27 |
| - # Ensure same column order. |
28 |
| - df1 = df1[df1.columns] |
29 |
| - df2 = df2[df1.columns] |
30 |
| - # Ensure same row order by using a common index and sorting. |
31 |
| - df1 = df1.set_index(index_cols).sort_index() |
32 |
| - df2 = df2.set_index(index_cols).sort_index() |
33 |
| - return assert_frame_equal(df1, df2) |
34 |
| - |
35 |
| -def _set_df_dtypes(df: pd.DataFrame, dtypes: Dict[str, Any]) -> pd.DataFrame: |
36 |
| - df = df.copy() |
37 |
| - for k, v in dtypes.items(): |
38 |
| - if k in df.columns: |
39 |
| - df[k] = df[k].astype(v) |
40 |
| - return df |
41 |
| - |
| 17 | + |
42 | 18 | class TestPull:
|
43 | 19 | def test_DatasetTimes(self):
|
44 | 20 | examples = [
|
@@ -477,77 +453,3 @@ def test_std_err(self):
|
477 | 453 | "sample_size": [2, 2, 5, 10, 20, 0]
|
478 | 454 | })
|
479 | 455 | )
|
480 |
| - |
481 |
| - def test_interpolation(self): |
482 |
| - DTYPES = {"geo_id": str, "timestamp": "datetime64[ns]", "val": float, "se": float, "sample_size": float, "publish_date": "datetime64[ns]"} |
483 |
| - line = lambda x: 3 * x + 5 |
484 |
| - |
485 |
| - sig1 = _set_df_dtypes(pd.DataFrame({ |
486 |
| - "geo_id": "1", |
487 |
| - "timestamp": pd.date_range("2022-01-01", "2022-01-10"), |
488 |
| - "val": [line(i) for i in range(2, 12)], |
489 |
| - "se": [line(i) for i in range(1, 11)], |
490 |
| - "sample_size": [line(i) for i in range(0, 10)], |
491 |
| - "publish_date": pd.to_datetime("2022-01-10") |
492 |
| - }), dtypes=DTYPES) |
493 |
| - # A linear signal missing two days which should be filled exactly by the linear interpolation. |
494 |
| - missing_sig1 = sig1[(sig1.timestamp <= "2022-01-05") | (sig1.timestamp >= "2022-01-08")] |
495 |
| - |
496 |
| - sig2 = sig1.copy() |
497 |
| - sig2["geo_id"] = "2" |
498 |
| - # A linear signal missing everything but the end points, should be filled exactly by linear interpolation. |
499 |
| - missing_sig2 = sig2[(sig2.timestamp == "2022-01-01") | (sig2.timestamp == "2022-01-10")] |
500 |
| - |
501 |
| - sig3 = _set_df_dtypes(pd.DataFrame({ |
502 |
| - "geo_id": "3", |
503 |
| - "timestamp": pd.date_range("2022-01-01", "2022-01-10"), |
504 |
| - "val": None, |
505 |
| - "se": [line(i) for i in range(1, 11)], |
506 |
| - "sample_size": [line(i) for i in range(0, 10)], |
507 |
| - "publish_date": pd.to_datetime("2022-01-10") |
508 |
| - }), dtypes=DTYPES) |
509 |
| - # A signal missing everything, should be left alone. |
510 |
| - missing_sig3 = sig3[(sig3.timestamp <= "2022-01-05") | (sig3.timestamp >= "2022-01-08")] |
511 |
| - |
512 |
| - sig4 = _set_df_dtypes(pd.DataFrame({ |
513 |
| - "geo_id": "4", |
514 |
| - "timestamp": pd.date_range("2022-01-01", "2022-01-10"), |
515 |
| - "val": [None] * 9 + [10.0], |
516 |
| - "se": [line(i) for i in range(1, 11)], |
517 |
| - "sample_size": [line(i) for i in range(0, 10)], |
518 |
| - "publish_date": pd.to_datetime("2022-01-10") |
519 |
| - }), dtypes=DTYPES) |
520 |
| - # A signal missing everything except for one point, should be left alone. |
521 |
| - missing_sig4 = sig4[(sig4.timestamp <= "2022-01-05") | (sig4.timestamp >= "2022-01-08")] |
522 |
| - |
523 |
| - missing_dfs = [missing_sig1, missing_sig2, missing_sig3, missing_sig4] |
524 |
| - interpolated_dfs1 = interpolate_missing_values({("src", "sig", False): pd.concat(missing_dfs)}) |
525 |
| - expected_dfs = pd.concat([sig1, sig2, sig3, sig4]) |
526 |
| - _assert_frame_equal(interpolated_dfs1[("src", "sig", False)], expected_dfs, index_cols=["geo_id", "timestamp"]) |
527 |
| - |
528 |
| - @patch("delphi_dsew_community_profile.pull.INTERP_LENGTH", 2) |
529 |
| - def test_extend_listing(self): |
530 |
| - listing = [ |
531 |
| - {"publish_date": date(2020, 1, 20) - timedelta(days=i)} |
532 |
| - for i in range(20) |
533 |
| - ] |
534 |
| - examples = [ |
535 |
| - # single range |
536 |
| - example( |
537 |
| - [{"publish_date": date(2020, 1, 20)}], |
538 |
| - [{"publish_date": date(2020, 1, 20)}, {"publish_date": date(2020, 1, 19)}] |
539 |
| - ), |
540 |
| - # disjoint ranges |
541 |
| - example( |
542 |
| - [{"publish_date": date(2020, 1, 20)}, {"publish_date": date(2020, 1, 10)}], |
543 |
| - [{"publish_date": date(2020, 1, 20)}, {"publish_date": date(2020, 1, 19)}, |
544 |
| - {"publish_date": date(2020, 1, 10)}, {"publish_date": date(2020, 1, 9)}] |
545 |
| - ), |
546 |
| - # conjoined ranges |
547 |
| - example( |
548 |
| - [{"publish_date": date(2020, 1, 20)}, {"publish_date": date(2020, 1, 19)}], |
549 |
| - [{"publish_date": date(2020, 1, 20)}, {"publish_date": date(2020, 1, 19)}, {"publish_date": date(2020, 1, 18)}] |
550 |
| - ), |
551 |
| - ] |
552 |
| - for ex in examples: |
553 |
| - assert extend_listing_for_interp(ex.given, listing) == ex.expected, ex.given |
0 commit comments