Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE settings DROP COLUMN use_openid_for_mfa;
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE settings ADD COLUMN use_openid_for_mfa BOOLEAN NOT NULL DEFAULT FALSE;
3 changes: 2 additions & 1 deletion crates/defguard_core/src/db/models/audit_log/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::{
audit_stream::{AuditStream, AuditStreamType},
openid_provider::{DirectorySyncTarget, DirectorySyncUserBehavior, OpenIdProvider},
},
events::ClientMFAMethod,
};

#[derive(Serialize)]
Expand Down Expand Up @@ -159,7 +160,7 @@ pub struct VpnClientMetadata {
pub struct VpnClientMfaMetadata {
pub location: WireguardNetwork<Id>,
pub device: Device<Id>,
pub method: MFAMethod,
pub method: ClientMFAMethod,
}

#[derive(Serialize)]
Expand Down
3 changes: 2 additions & 1 deletion crates/defguard_core/src/db/models/audit_log/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::db::{Id, NoId};
use chrono::NaiveDateTime;
use ipnetwork::IpNetwork;
use model_derive::Model;
use sqlx::{FromRow, Type};

use crate::db::{Id, NoId};

pub mod metadata;

#[derive(Clone, Debug, Deserialize, Serialize, Type)]
Expand Down
7 changes: 5 additions & 2 deletions crates/defguard_core/src/db/models/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ pub struct Settings {
// Whether to create a new account when users try to log in with external OpenID
pub openid_create_account: bool,
pub openid_username_handling: OpenidUsernameHandling,
pub use_openid_for_mfa: bool,
pub license: Option<String>,
// Gateway disconnect notifications
pub gateway_disconnect_notifications_enabled: bool,
Expand Down Expand Up @@ -152,7 +153,7 @@ impl Settings {
ldap_enabled, ldap_sync_enabled, ldap_is_authoritative, \
ldap_sync_interval, ldap_user_auxiliary_obj_classes, ldap_uses_ad, \
ldap_user_rdn_attr, ldap_sync_groups, \
openid_username_handling \"openid_username_handling: OpenidUsernameHandling\" \
openid_username_handling \"openid_username_handling: OpenidUsernameHandling\", use_openid_for_mfa \
FROM \"settings\" WHERE id = 1",
)
.fetch_optional(executor)
Expand Down Expand Up @@ -224,7 +225,8 @@ impl Settings {
ldap_uses_ad = $45, \
ldap_user_rdn_attr = $46, \
ldap_sync_groups = $47, \
openid_username_handling = $48 \
openid_username_handling = $48, \
use_openid_for_mfa = $49 \
WHERE id = 1",
self.openid_enabled,
self.wireguard_enabled,
Expand Down Expand Up @@ -274,6 +276,7 @@ impl Settings {
self.ldap_user_rdn_attr,
&self.ldap_sync_groups as &Vec<String>,
&self.openid_username_handling as &OpenidUsernameHandling,
self.use_openid_for_mfa,
)
.execute(executor)
.await?;
Expand Down
26 changes: 17 additions & 9 deletions crates/defguard_core/src/db/models/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use rand::{
prelude::Distribution,
Rng,
};
use serde::Serialize;
use sqlx::{
query, query_as, query_scalar, Error as SqlxError, FromRow, PgConnection, PgExecutor, PgPool,
Type,
Expand Down Expand Up @@ -53,15 +54,7 @@ pub enum MFAMethod {
Email,
}

impl From<MfaMethod> for MFAMethod {
fn from(method: MfaMethod) -> Self {
match method {
MfaMethod::Totp => Self::OneTimePassword,
MfaMethod::Email => Self::Email,
}
}
}

// Web MFA methods
impl fmt::Display for MFAMethod {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
Expand All @@ -77,6 +70,21 @@ impl fmt::Display for MFAMethod {
}
}

// Client MFA methods
impl fmt::Display for MfaMethod {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}",
match self {
MfaMethod::Totp => "TOTP",
MfaMethod::Email => "Email",
MfaMethod::Oidc => "OIDC",
}
)
}
}

// User information ready to be sent as part of diagnostic data.
#[derive(Serialize)]
pub struct UserDiagnostic {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,15 @@ use bytes::Bytes;
use sqlx::PgPool;
use tokio::{sync::broadcast::Receiver, task::JoinSet, time::sleep};
use tokio_util::sync::CancellationToken;

use tracing::debug;

use super::AuditStreamReconfigurationNotification;
use crate::enterprise::{
audit_stream::http_stream::{run_http_stream_task, HttpAuditStreamConfig},
db::models::audit_stream::{AuditStream, AuditStreamConfig},
is_enterprise_enabled,
};

use super::AuditStreamReconfigurationNotification;

pub async fn run_audit_stream_manager(
pool: PgPool,
notification: AuditStreamReconfigurationNotification,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use bytes::Bytes;
use reqwest::tls;
use tokio::sync::broadcast::Receiver;
use tokio_util::sync::CancellationToken;

use tracing::{debug, error};

use crate::{
Expand Down
16 changes: 11 additions & 5 deletions crates/defguard_core/src/enterprise/db/models/openid_provider.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::fmt;

use model_derive::Model;
use sqlx::{query, query_as, Error as SqlxError, PgPool, Type};
use sqlx::{query, query_as, Error as SqlxError, PgExecutor, PgPool, Type};

use crate::db::{Id, NoId};

Expand Down Expand Up @@ -195,7 +195,10 @@ impl OpenIdProvider {
}

impl OpenIdProvider<Id> {
pub async fn find_by_name(pool: &PgPool, name: &str) -> Result<Option<Self>, SqlxError> {
pub async fn find_by_name<'e, E>(executor: E, name: &str) -> Result<Option<Self>, SqlxError>
where
E: PgExecutor<'e>,
{
query_as!(
OpenIdProvider,
"SELECT id, name, base_url, client_id, client_secret, display_name, \
Expand All @@ -207,11 +210,14 @@ impl OpenIdProvider<Id> {
FROM openidprovider WHERE name = $1",
name
)
.fetch_optional(pool)
.fetch_optional(executor)
.await
}

pub async fn get_current(pool: &PgPool) -> Result<Option<Self>, SqlxError> {
pub async fn get_current<'e, E>(executor: E) -> Result<Option<Self>, SqlxError>
where
E: PgExecutor<'e>,
{
query_as!(
OpenIdProvider,
"SELECT id, name, base_url, client_id, client_secret, display_name, \
Expand All @@ -222,7 +228,7 @@ impl OpenIdProvider<Id> {
okta_private_jwk, okta_dirsync_client_id, directory_sync_group_match \
FROM openidprovider LIMIT 1"
)
.fetch_optional(pool)
.fetch_optional(executor)
.await
}
}
146 changes: 146 additions & 0 deletions crates/defguard_core/src/enterprise/grpc/desktop_client_mfa.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
use openidconnect::{AuthorizationCode, Nonce};
use reqwest::Url;
use tonic::Status;

use crate::{
enterprise::{
handlers::openid_login::{extract_state_data, user_from_claims},
is_enterprise_enabled,
},
events::{BidiRequestContext, BidiStreamEvent, BidiStreamEventType, DesktopClientMfaEvent},
grpc::{
desktop_client_mfa::{ClientLoginSession, ClientMfaServer},
proto::proxy::{ClientMfaOidcAuthenticateRequest, DeviceInfo, MfaMethod},
utils::parse_client_info,
},
};

impl ClientMfaServer {
#[instrument(skip_all)]
pub async fn auth_mfa_session_with_oidc(
&mut self,
request: ClientMfaOidcAuthenticateRequest,
info: Option<DeviceInfo>,
) -> Result<(), Status> {
debug!("Received OIDC MFA authentication request: {request:?}");
if !is_enterprise_enabled() {
error!("OIDC MFA method requires enterprise feature to be enabled");
return Err(Status::invalid_argument("OIDC MFA method is not supported"));
}

let token = extract_state_data(&request.state).ok_or_else(|| {
error!(
"Failed to extract state data from state: {:?}",
request.state
);
Status::invalid_argument("invalid state data")
})?;
if token.is_empty() {
debug!("Empty token provided in request");
return Err(Status::invalid_argument("empty token provided"));
}
let pubkey = Self::parse_token(&token)?;

// fetch login session
let Some(session) = self.sessions.get(&pubkey).cloned() else {
debug!("Client login session not found");
return Err(Status::invalid_argument("login session not found"));
};
let ClientLoginSession {
method,
device,
location,
user,
openid_auth_completed,
} = session;

if openid_auth_completed {
debug!("Client login session already completed");
return Err(Status::invalid_argument("login session already completed"));
}

if method != MfaMethod::Oidc {
debug!("Invalid MFA method for OIDC authentication: {method:?}");
self.sessions.remove(&pubkey);
return Err(Status::invalid_argument("invalid MFA method"));
}

let (ip, user_agent) = parse_client_info(&info).map_err(Status::internal)?;
let context = BidiRequestContext::new(user.id, user.username.clone(), ip, user_agent);

let code = AuthorizationCode::new(request.code.clone());
let url = match Url::parse(&request.callback_url).map_err(|err| {
error!("Invalid redirect URL provided: {err:?}");
Status::invalid_argument("invalid redirect URL")
}) {
Ok(url) => url,
Err(status) => {
self.sessions.remove(&pubkey);
self.emit_event(BidiStreamEvent {
context,
event: BidiStreamEventType::DesktopClientMfa(Box::new(
DesktopClientMfaEvent::Failed {
location: location.clone(),
device: device.clone(),
method,
},
)),
})?;
return Err(status);
}
};

match user_from_claims(&self.pool, Nonce::new(request.nonce.clone()), code, url).await {
Ok(claims_user) => {
// if thats not our user, prevent login
if claims_user.id != user.id {
info!("User {claims_user} tried to use OIDC MFA for another user: {user}");
self.sessions.remove(&pubkey);
self.emit_event(BidiStreamEvent {
context,
event: BidiStreamEventType::DesktopClientMfa(Box::new(
DesktopClientMfaEvent::Failed {
location: location.clone(),
device: device.clone(),
method,
},
)),
})?;
return Err(Status::unauthenticated("unauthorized"));
}
info!(
"OIDC MFA authentication completed successfully for user: {}",
user.username
);
}
Err(err) => {
info!("Failed to verify OIDC code: {err:?}");
self.sessions.remove(&pubkey);
self.emit_event(BidiStreamEvent {
context,
event: BidiStreamEventType::DesktopClientMfa(Box::new(
DesktopClientMfaEvent::Failed {
location: location.clone(),
device: device.clone(),
method,
},
)),
})?;
return Err(Status::unauthenticated("unauthorized"));
}
};

self.sessions.insert(
pubkey.clone(),
ClientLoginSession {
method,
device: device.clone(),
location: location.clone(),
user: user.clone(),
openid_auth_completed: true,
},
);

Ok(())
}
}
1 change: 1 addition & 0 deletions crates/defguard_core/src/enterprise/grpc/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod desktop_client_mfa;
pub mod polling;
Loading