Skip to content

use axum extractor to retrieve database connection from the pool #2312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/web/error.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::borrow::Cow;

use crate::{
db::PoolError,
storage::PathNotFoundError,
web::{releases::Search, AxumErrorPage},
};
use anyhow::anyhow;
use axum::{
http::StatusCode,
response::{IntoResponse, Response as AxumResponse},
};
use std::borrow::Cow;

#[derive(Debug, thiserror::Error)]
#[allow(dead_code)] // FIXME: remove after iron is gone
Expand Down Expand Up @@ -131,6 +132,12 @@ impl From<anyhow::Error> for AxumNope {
}
}

impl From<PoolError> for AxumNope {
fn from(err: PoolError) -> Self {
AxumNope::InternalError(anyhow!(err))
}
}

pub(crate) type AxumResult<T> = Result<T, AxumNope>;

#[cfg(test)]
Expand Down
54 changes: 54 additions & 0 deletions src/web/extractors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use crate::db::{AsyncPoolClient, Pool};
use anyhow::Context as _;
use axum::{
async_trait,
extract::{Extension, FromRequestParts},
http::request::Parts,
RequestPartsExt,
};
use std::ops::{Deref, DerefMut};

use super::error::AxumNope;

/// Extractor for a async sqlx database connection.
/// Can be used in normal axum handlers, middleware, or other extractors.
///
/// For now, we will retrieve a new connection each time the extractor is used.
///
/// This could be optimized in the future by caching the connection as a request
/// extension, so one request only uses on connection.
#[derive(Debug)]
pub(crate) struct DbConnection(AsyncPoolClient);

#[async_trait]
impl<S> FromRequestParts<S> for DbConnection
where
S: Send + Sync,
{
type Rejection = AxumNope;

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let Extension(pool) = parts
.extract::<Extension<Pool>>()
.await
.context("could not extract pool extension")?;

Ok(Self(pool.get_async().await?))
}
}

impl Deref for DbConnection {
type Target = sqlx::PgConnection;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl DerefMut for DbConnection {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

// TODO: we will write tests for this when async db tests are working
1 change: 1 addition & 0 deletions src/web/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub(crate) mod cache;
pub(crate) mod crate_details;
mod csp;
pub(crate) mod error;
mod extractors;
mod features;
mod file;
mod headers;
Expand Down
62 changes: 19 additions & 43 deletions src/web/releases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{
web::{
axum_parse_uri_with_params, axum_redirect, encode_url_path,
error::{AxumNope, AxumResult},
extractors::DbConnection,
match_version_axum,
},
BuildQueue, Config, InstanceMetrics,
Expand Down Expand Up @@ -133,7 +134,7 @@ struct SearchResult {
///
/// This delegates to the crates.io search API.
async fn get_search_results(
pool: Pool,
conn: &mut sqlx::PgConnection,
config: &Config,
query_params: &str,
) -> Result<SearchResult, anyhow::Error> {
Expand Down Expand Up @@ -212,10 +213,6 @@ async fn get_search_results(
// So for now we are using the version with the youngest release_time.
// This is different from all other release-list views where we show
// our latest build.
let mut conn = pool
.get_async()
.await
.context("can't get pool connection")?;
let crates: HashMap<String, Release> = sqlx::query!(
r#"SELECT
crates.name,
Expand Down Expand Up @@ -276,12 +273,7 @@ impl_axum_webpage! {
HomePage = "core/home.html",
}

pub(crate) async fn home_page(Extension(pool): Extension<Pool>) -> AxumResult<impl IntoResponse> {
let mut conn = pool
.get_async()
.await
.context("can't get pool connection")?;

pub(crate) async fn home_page(mut conn: DbConnection) -> AxumResult<impl IntoResponse> {
let recent_releases =
get_releases(&mut conn, 1, RELEASES_IN_HOME, Order::ReleaseTime, true).await?;

Expand All @@ -298,14 +290,7 @@ impl_axum_webpage! {
content_type = "application/xml",
}

pub(crate) async fn releases_feed_handler(
Extension(pool): Extension<Pool>,
) -> AxumResult<impl IntoResponse> {
let mut conn = pool
.get_async()
.await
.context("can't get pool connection")?;

pub(crate) async fn releases_feed_handler(mut conn: DbConnection) -> AxumResult<impl IntoResponse> {
let recent_releases =
get_releases(&mut conn, 1, RELEASES_IN_FEED, Order::ReleaseTime, true).await?;
Ok(ReleaseFeed { recent_releases })
Expand Down Expand Up @@ -337,15 +322,10 @@ pub(crate) enum ReleaseType {
}

pub(crate) async fn releases_handler(
pool: Pool,
conn: &mut sqlx::PgConnection,
page: Option<i64>,
release_type: ReleaseType,
) -> AxumResult<impl IntoResponse> {
let mut conn = pool
.get_async()
.await
.context("can't get pool connection")?;

let page_number = page.unwrap_or(1);

let (description, release_order, latest_only) = match release_type {
Expand All @@ -368,7 +348,7 @@ pub(crate) async fn releases_handler(
};

let releases = get_releases(
&mut conn,
&mut *conn,
page_number,
RELEASES_IN_RELEASES,
release_order,
Expand All @@ -395,30 +375,30 @@ pub(crate) async fn releases_handler(

pub(crate) async fn recent_releases_handler(
page: Option<Path<i64>>,
Extension(pool): Extension<Pool>,
mut conn: DbConnection,
) -> AxumResult<impl IntoResponse> {
releases_handler(pool, page.map(|p| p.0), ReleaseType::Recent).await
releases_handler(&mut conn, page.map(|p| p.0), ReleaseType::Recent).await
}

pub(crate) async fn releases_by_stars_handler(
page: Option<Path<i64>>,
Extension(pool): Extension<Pool>,
mut conn: DbConnection,
) -> AxumResult<impl IntoResponse> {
releases_handler(pool, page.map(|p| p.0), ReleaseType::Stars).await
releases_handler(&mut conn, page.map(|p| p.0), ReleaseType::Stars).await
}

pub(crate) async fn releases_recent_failures_handler(
page: Option<Path<i64>>,
Extension(pool): Extension<Pool>,
mut conn: DbConnection,
) -> AxumResult<impl IntoResponse> {
releases_handler(pool, page.map(|p| p.0), ReleaseType::RecentFailures).await
releases_handler(&mut conn, page.map(|p| p.0), ReleaseType::RecentFailures).await
}

pub(crate) async fn releases_failures_by_stars_handler(
page: Option<Path<i64>>,
Extension(pool): Extension<Pool>,
mut conn: DbConnection,
) -> AxumResult<impl IntoResponse> {
releases_handler(pool, page.map(|p| p.0), ReleaseType::Failures).await
releases_handler(&mut conn, page.map(|p| p.0), ReleaseType::Failures).await
}

pub(crate) async fn owner_handler(Path(owner): Path<String>) -> AxumResult<impl IntoResponse> {
Expand Down Expand Up @@ -460,19 +440,14 @@ impl Default for Search {
async fn redirect_to_random_crate(
config: Arc<Config>,
metrics: Arc<InstanceMetrics>,
pool: Pool,
conn: &mut sqlx::PgConnection,
) -> AxumResult<impl IntoResponse> {
// We try to find a random crate and redirect to it.
//
// The query is efficient, but relies on a static factor which depends
// on the amount of crates with > 100 GH stars over the amount of all crates.
//
// If random-crate-searches end up being empty, increase that value.
let mut conn = pool
.get_async()
.await
.context("can't get pool connection")?;

let row = sqlx::query!(
"WITH params AS (
-- get maximum possible id-value in crates-table
Expand Down Expand Up @@ -519,6 +494,7 @@ impl_axum_webpage! {
}

pub(crate) async fn search_handler(
mut conn: DbConnection,
Extension(pool): Extension<Pool>,
Extension(config): Extension<Arc<Config>>,
Extension(metrics): Extension<Arc<InstanceMetrics>>,
Expand All @@ -534,7 +510,7 @@ pub(crate) async fn search_handler(
if params.remove("i-am-feeling-lucky").is_some() || query.contains("::") {
// redirect to a random crate if query is empty
if query.is_empty() {
return Ok(redirect_to_random_crate(config, metrics, pool)
return Ok(redirect_to_random_crate(config, metrics, &mut conn)
.await?
.into_response());
}
Expand Down Expand Up @@ -595,14 +571,14 @@ pub(crate) async fn search_handler(
return Err(AxumNope::NoResults);
}

get_search_results(pool, &config, &query_params).await?
get_search_results(&mut conn, &config, &query_params).await?
} else if !query.is_empty() {
let query_params: String = form_urlencoded::Serializer::new(String::new())
.append_pair("q", &query)
.append_pair("per_page", &RELEASES_IN_RELEASES.to_string())
.finish();

get_search_results(pool, &config, &format!("?{}", &query_params)).await?
get_search_results(&mut conn, &config, &format!("?{}", &query_params)).await?
} else {
return Err(AxumNope::NoResults);
};
Expand Down