|
17 | 17 | import os
|
18 | 18 | import time
|
19 | 19 | from contextlib import asynccontextmanager
|
20 |
| -from pathlib import Path |
21 | 20 | from typing import Final
|
22 | 21 |
|
23 |
| -import apsw |
24 |
| -import apsw.bestpractice |
| 22 | +import mariadb |
25 | 23 | import uvicorn
|
26 | 24 | from fastapi import FastAPI
|
27 | 25 | from fastapi.middleware.cors import CORSMiddleware
|
|
80 | 78 | ]
|
81 | 79 |
|
82 | 80 |
|
83 |
| -def _enable_best_practice(connection: apsw.Connection): |
84 |
| - """Enable aspw best practice.""" |
85 |
| - apsw.bestpractice.connection_wal(connection) |
86 |
| - apsw.bestpractice.library_logging() |
| 81 | +def _get_database_connection() -> mariadb.Connection: |
| 82 | + """Get a MriaDB database connection.""" |
| 83 | + connection = mariadb.connect( |
| 84 | + user=os.environ["DB_USER"], |
| 85 | + password=os.environ["DB_PASS"], |
| 86 | + host="127.0.0.1", |
| 87 | + port=3306, |
| 88 | + database=os.environ["DB_DATABASE"], |
| 89 | + autocommit=True, |
| 90 | + ) |
| 91 | + return connection |
87 | 92 |
|
88 | 93 |
|
89 | 94 | @asynccontextmanager
|
90 | 95 | async def lifespan(app: FastAPI):
|
91 | 96 | """Load the database connection for the life of the app.s"""
|
92 |
| - db_path = Path(os.environ["DATABASE_PATH"]) |
93 |
| - logger.info("validator database: %s", db_path) |
94 |
| - app.state.connection = apsw.Connection( |
95 |
| - str(db_path), flags=apsw.SQLITE_OPEN_READONLY |
96 |
| - ) |
97 |
| - _enable_best_practice(app.state.connection) |
| 97 | + app.state.connection = _get_database_connection() |
98 | 98 | app.state.kupo_url = os.environ["KUPO_URL"]
|
99 | 99 | app.state.kupo_port = os.environ["KUPO_PORT"]
|
100 | 100 | yield
|
@@ -141,34 +141,36 @@ def redirect_root_to_docs():
|
141 | 141 | @app.get("/get_active_participants", tags=[TAG_STATISTICS])
|
142 | 142 | async def get_active_participants():
|
143 | 143 | """Return participants in the ITN database."""
|
| 144 | + cursor = app.state.connection.cursor() |
144 | 145 | try:
|
145 |
| - participants = app.state.connection.execute( |
146 |
| - "select distinct address from data_points;" |
147 |
| - ) |
148 |
| - except apsw.SQLError as err: |
| 146 | + cursor.execute("select distinct address from data_points;") |
| 147 | + except mariadb.Error as err: |
149 | 148 | return {"error": f"{err}"}
|
150 |
| - data = [participant[0] for participant in participants] |
| 149 | + data = [participant[0] for participant in cursor] |
| 150 | + cursor.close() |
151 | 151 | return data
|
152 | 152 |
|
153 | 153 |
|
154 | 154 | @app.get("/get_participants_counts_total", tags=[TAG_STATISTICS])
|
155 | 155 | async def get_participants_counts_total():
|
156 | 156 | """Return participants total counts."""
|
| 157 | + cursor = app.state.connection.cursor() |
157 | 158 | try:
|
158 |
| - participants_count_total = app.state.connection.execute( |
| 159 | + cursor.execute( |
159 | 160 | "select count(*) as count, address from data_points group by address order by count desc;"
|
160 | 161 | )
|
161 |
| - except apsw.SQLError as err: |
| 162 | + except mariadb.Error as err: |
162 | 163 | return {"error": f"{err}"}
|
163 |
| - return participants_count_total |
| 164 | + res = list(cursor) |
| 165 | + cursor.close() |
| 166 | + return res |
164 | 167 |
|
165 | 168 |
|
166 | 169 | @app.get("/get_participants_counts_day", tags=[TAG_STATISTICS])
|
167 | 170 | async def get_participants_counts_day(
|
168 | 171 | date_start: str = "1970-01-01", date_end: str = "1970-01-03"
|
169 | 172 | ):
|
170 | 173 | """Return participants in ITN."""
|
171 |
| - |
172 | 174 | report = reports.get_participants_counts_date_range(app, date_start, date_end)
|
173 | 175 | return report
|
174 | 176 |
|
@@ -231,28 +233,33 @@ async def get_itn_participants() -> str:
|
231 | 233 | @app.get("/online_collectors", tags=[TAG_HTMX], response_class=HTMLResponse)
|
232 | 234 | async def get_online_collectors() -> str:
|
233 | 235 | """Return ITN aliases and collector counts."""
|
| 236 | + cursor = app.state.connection.cursor() |
234 | 237 | try:
|
235 |
| - participants_count = app.state.connection.execute( |
| 238 | + cursor.execute( |
236 | 239 | """SELECT address, COUNT(*) AS total_count,
|
237 |
| - SUM(CASE WHEN datetime(date_time) >= datetime('now', '-24 hours') |
| 240 | + SUM(CASE WHEN date_time >= (SELECT DATE_SUB(NOW(), INTERVAL 1 DAY)) |
238 | 241 | THEN 1 ELSE 0 END) AS count_24hr
|
239 | 242 | FROM data_points
|
240 | 243 | GROUP BY address ORDER BY total_count DESC;
|
241 | 244 | """
|
242 | 245 | )
|
243 |
| - except apsw.SQLError: |
| 246 | + except mariadb.Error: |
244 | 247 | return "zero collectors online"
|
245 | 248 |
|
| 249 | + participants_count = list(cursor) |
| 250 | + |
246 | 251 | try:
|
247 |
| - feed_count = app.state.connection.execute( |
| 252 | + cursor.execute( |
248 | 253 | """SELECT distinct feed_id
|
249 | 254 | from data_points
|
250 |
| - where datetime(date_time) >= datetime('now', '-48 hours'); |
| 255 | + where date_time >= (SELECT DATE_SUB(NOW(), INTERVAL 1 DAY)); |
251 | 256 | """
|
252 | 257 | )
|
253 |
| - except apsw.SQLError: |
| 258 | + except mariadb.Error: |
254 | 259 | return "zero collectors online"
|
255 | 260 |
|
| 261 | + feed_count = list(cursor) |
| 262 | + |
256 | 263 | no_feeds = len(list(feed_count))
|
257 | 264 |
|
258 | 265 | # FIXME: These can all be combined better, e.g. into a dataclass or
|
@@ -308,13 +315,13 @@ async def get_locations_map_hx():
|
308 | 315 | @app.get("/count_active_participants", tags=[TAG_HTMX], response_class=HTMLResponse)
|
309 | 316 | async def count_active_participants():
|
310 | 317 | """Count active participants."""
|
| 318 | + cursor = app.state.connection.cursor() |
311 | 319 | try:
|
312 |
| - participants = app.state.connection.execute( |
313 |
| - "select count(distinct address) as count from data_points;" |
314 |
| - ) |
315 |
| - except apsw.SQLError as err: |
| 320 | + cursor.execute("select count(distinct address) as count from data_points;") |
| 321 | + except mariadb.Error as err: |
316 | 322 | return {"error": f"{err}"}
|
317 |
| - data = list(participants) |
| 323 | + data = list(cursor) |
| 324 | + cursor.close() |
318 | 325 | return f"{data[0][0]}"
|
319 | 326 |
|
320 | 327 |
|
|
0 commit comments