Support CAS authentication

This commit is contained in:
Kevin Yue 2024-04-01 06:28:20 -04:00
parent b2ca82e105
commit 338854b4aa
5 changed files with 135 additions and 29 deletions

View File

@ -82,7 +82,7 @@ async fn feed_auth_data(auth_data: &str) -> anyhow::Result<()> {
reqwest::Client::default() reqwest::Client::default()
.post(format!("{}/auth-data", service_endpoint)) .post(format!("{}/auth-data", service_endpoint))
.json(&auth_data) .body(auth_data.to_string())
.send() .send()
.await? .await?
.error_for_status()?; .error_for_status()?;

View File

@ -1,4 +1,4 @@
use anyhow::bail; use anyhow::anyhow;
use regex::Regex; use regex::Regex;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -35,7 +35,7 @@ impl SamlAuthData {
} }
} }
pub fn parse_html(html: &str) -> anyhow::Result<SamlAuthData> { pub fn from_html(html: &str) -> anyhow::Result<SamlAuthData> {
match parse_xml_tag(html, "saml-auth-status") { match parse_xml_tag(html, "saml-auth-status") {
Some(saml_status) if saml_status == "1" => { Some(saml_status) if saml_status == "1" => {
let username = parse_xml_tag(html, "saml-username"); let username = parse_xml_tag(html, "saml-username");
@ -43,21 +43,17 @@ impl SamlAuthData {
let portal_userauthcookie = parse_xml_tag(html, "portal-userauthcookie"); let portal_userauthcookie = parse_xml_tag(html, "portal-userauthcookie");
if SamlAuthData::check(&username, &prelogin_cookie, &portal_userauthcookie) { if SamlAuthData::check(&username, &prelogin_cookie, &portal_userauthcookie) {
return Ok(SamlAuthData::new( Ok(SamlAuthData::new(
username.unwrap(), username.unwrap(),
prelogin_cookie, prelogin_cookie,
portal_userauthcookie, portal_userauthcookie,
)); ))
} else {
Err(anyhow!("Found invalid auth data in HTML"))
} }
bail!("Found invalid auth data in HTML");
}
Some(status) => {
bail!("Found invalid SAML status {} in HTML", status);
}
None => {
bail!("No auth data found in HTML");
} }
Some(status) => Err(anyhow!("Found invalid SAML status {} in HTML", status)),
None => Err(anyhow!("No auth data found in HTML")),
} }
} }

View File

@ -1,5 +1,6 @@
use std::collections::HashMap; use std::collections::HashMap;
use log::info;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use specta::Type; use specta::Type;
@ -155,32 +156,58 @@ impl From<PasswordCredential> for CachedCredential {
} }
} }
#[derive(Debug, Serialize, Deserialize, Type, Clone)]
pub struct TokenCredential {
#[serde(alias = "un")]
username: String,
token: String,
}
impl TokenCredential {
pub fn username(&self) -> &str {
&self.username
}
pub fn token(&self) -> &str {
&self.token
}
}
#[derive(Debug, Serialize, Deserialize, Type, Clone)] #[derive(Debug, Serialize, Deserialize, Type, Clone)]
#[serde(tag = "type", rename_all = "camelCase")] #[serde(tag = "type", rename_all = "camelCase")]
pub enum Credential { pub enum Credential {
Password(PasswordCredential), Password(PasswordCredential),
PreloginCookie(PreloginCookieCredential), PreloginCookie(PreloginCookieCredential),
AuthCookie(AuthCookieCredential), AuthCookie(AuthCookieCredential),
TokenCredential(TokenCredential),
CachedCredential(CachedCredential), CachedCredential(CachedCredential),
} }
impl Credential { impl Credential {
/// Create a credential from a globalprotectcallback:<base64 encoded string> /// Create a credential from a globalprotectcallback:<base64 encoded string>,
pub fn parse_gpcallback(auth_data: &str) -> anyhow::Result<Self> { /// or globalprotectcallback:cas-as=1&un=user@xyz.com&token=very_long_string
// Remove the surrounding quotes pub fn from_gpcallback(auth_data: &str) -> anyhow::Result<Self> {
let auth_data = auth_data.trim_matches('"');
let auth_data = auth_data.trim_start_matches("globalprotectcallback:"); let auth_data = auth_data.trim_start_matches("globalprotectcallback:");
if auth_data.starts_with("cas-as") {
info!("Got token auth data: {}", auth_data);
let token_cred: TokenCredential = serde_urlencoded::from_str(auth_data)?;
Ok(Self::TokenCredential(token_cred))
} else {
info!("Parsing SAML auth data...");
let auth_data = decode_to_string(auth_data)?; let auth_data = decode_to_string(auth_data)?;
let auth_data = SamlAuthData::parse_html(&auth_data)?; let auth_data = SamlAuthData::from_html(&auth_data)?;
Self::try_from(auth_data) Self::try_from(auth_data)
} }
}
pub fn username(&self) -> &str { pub fn username(&self) -> &str {
match self { match self {
Credential::Password(cred) => cred.username(), Credential::Password(cred) => cred.username(),
Credential::PreloginCookie(cred) => cred.username(), Credential::PreloginCookie(cred) => cred.username(),
Credential::AuthCookie(cred) => cred.username(), Credential::AuthCookie(cred) => cred.username(),
Credential::TokenCredential(cred) => cred.username(),
Credential::CachedCredential(cred) => cred.username(), Credential::CachedCredential(cred) => cred.username(),
} }
} }
@ -189,20 +216,23 @@ impl Credential {
let mut params = HashMap::new(); let mut params = HashMap::new();
params.insert("user", self.username()); params.insert("user", self.username());
let (passwd, prelogin_cookie, portal_userauthcookie, portal_prelogonuserauthcookie) = match self { let (passwd, prelogin_cookie, portal_userauthcookie, portal_prelogonuserauthcookie, token) = match self {
Credential::Password(cred) => (Some(cred.password()), None, None, None), Credential::Password(cred) => (Some(cred.password()), None, None, None, None),
Credential::PreloginCookie(cred) => (None, Some(cred.prelogin_cookie()), None, None), Credential::PreloginCookie(cred) => (None, Some(cred.prelogin_cookie()), None, None, None),
Credential::AuthCookie(cred) => ( Credential::AuthCookie(cred) => (
None, None,
None, None,
Some(cred.user_auth_cookie()), Some(cred.user_auth_cookie()),
Some(cred.prelogon_user_auth_cookie()), Some(cred.prelogon_user_auth_cookie()),
None,
), ),
Credential::TokenCredential(cred) => (None, None, None, None, Some(cred.token())),
Credential::CachedCredential(cred) => ( Credential::CachedCredential(cred) => (
cred.password(), cred.password(),
None, None,
Some(cred.auth_cookie.user_auth_cookie()), Some(cred.auth_cookie.user_auth_cookie()),
Some(cred.auth_cookie.prelogon_user_auth_cookie()), Some(cred.auth_cookie.prelogon_user_auth_cookie()),
None,
), ),
}; };
@ -214,6 +244,10 @@ impl Credential {
portal_prelogonuserauthcookie.unwrap_or_default(), portal_prelogonuserauthcookie.unwrap_or_default(),
); );
if let Some(token) = token {
params.insert("token", token);
}
params params
} }
} }
@ -245,3 +279,38 @@ impl From<&CachedCredential> for Credential {
Self::CachedCredential(value.clone()) Self::CachedCredential(value.clone())
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cred_from_gpcallback_cas() {
let auth_data = "globalprotectcallback:cas-as=1&un=xyz@email.com&token=very_long_string";
let cred = Credential::from_gpcallback(auth_data).unwrap();
match cred {
Credential::TokenCredential(token_cred) => {
assert_eq!(token_cred.username(), "xyz@email.com");
assert_eq!(token_cred.token(), "very_long_string");
}
_ => panic!("Expected TokenCredential"),
}
}
#[test]
fn cred_from_gpcallback_non_cas() {
let auth_data = "PGh0bWw+PCEtLSA8c2FtbC1hdXRoLXN0YXR1cz4xPC9zYW1sLWF1dGgtc3RhdHVzPjxwcmVsb2dpbi1jb29raWU+cHJlbG9naW4tY29va2llPC9wcmVsb2dpbi1jb29raWU+PHNhbWwtdXNlcm5hbWU+eHl6QGVtYWlsLmNvbTwvc2FtbC11c2VybmFtZT48c2FtbC1zbG8+bm88L3NhbWwtc2xvPjxzYW1sLVNlc3Npb25Ob3RPbk9yQWZ0ZXI+PC9zYW1sLVNlc3Npb25Ob3RPbk9yQWZ0ZXI+IC0tPjwvaHRtbD4=";
let cred = Credential::from_gpcallback(auth_data).unwrap();
match cred {
Credential::PreloginCookie(cred) => {
assert_eq!(cred.username(), "xyz@email.com");
assert_eq!(cred.prelogin_cookie(), "prelogin-cookie");
}
_ => panic!("Expected PreloginCookieCredential")
}
}
}

View File

@ -42,7 +42,7 @@ impl ClientOs {
} }
} }
#[derive(Debug, Serialize, Deserialize, Type, Default)] #[derive(Debug, Serialize, Deserialize, Type, Default, Clone)]
pub struct GpParams { pub struct GpParams {
is_gateway: bool, is_gateway: bool,
user_agent: String, user_agent: String,
@ -83,6 +83,10 @@ impl GpParams {
self.prefer_default_browser self.prefer_default_browser
} }
pub fn set_prefer_default_browser(&mut self, prefer_default_browser: bool) {
self.prefer_default_browser = prefer_default_browser;
}
pub fn client_os(&self) -> &str { pub fn client_os(&self) -> &str {
self.client_os.as_str() self.client_os.as_str()
} }

View File

@ -1,4 +1,4 @@
use anyhow::bail; use anyhow::{anyhow, bail};
use log::{info, warn}; use log::{info, warn};
use reqwest::{Client, StatusCode}; use reqwest::{Client, StatusCode};
use roxmltree::Document; use roxmltree::Document;
@ -29,6 +29,7 @@ pub struct SamlPrelogin {
is_gateway: bool, is_gateway: bool,
saml_request: String, saml_request: String,
support_default_browser: bool, support_default_browser: bool,
is_cas: bool,
} }
impl SamlPrelogin { impl SamlPrelogin {
@ -43,6 +44,14 @@ impl SamlPrelogin {
pub fn support_default_browser(&self) -> bool { pub fn support_default_browser(&self) -> bool {
self.support_default_browser self.support_default_browser
} }
pub fn is_cas(&self) -> bool {
self.is_cas
}
fn set_is_cas(&mut self, is_cas: bool) {
self.is_cas = is_cas;
}
} }
#[derive(Debug, Serialize, Type, Clone)] #[derive(Debug, Serialize, Type, Clone)]
@ -97,6 +106,29 @@ impl Prelogin {
} }
pub async fn prelogin(portal: &str, gp_params: &GpParams) -> anyhow::Result<Prelogin> { pub async fn prelogin(portal: &str, gp_params: &GpParams) -> anyhow::Result<Prelogin> {
match prelogin_impl(portal, gp_params).await {
Ok(prelogin) => Ok(prelogin),
Err(e) => {
if e.to_string().contains("CAS is not supported by the client") {
info!("CAS authentication detected, retrying with default browser");
let mut gp_params = gp_params.clone();
gp_params.set_prefer_default_browser(true);
let mut prelogin = prelogin_impl(portal, &gp_params).await?;
// Mark the prelogin as CAS
if let Prelogin::Saml(saml) = &mut prelogin {
saml.set_is_cas(true);
}
Ok(prelogin)
} else {
Err(e)
}
}
}
}
pub async fn prelogin_impl(portal: &str, gp_params: &GpParams) -> anyhow::Result<Prelogin> {
let user_agent = gp_params.user_agent(); let user_agent = gp_params.user_agent();
info!("Prelogin with user_agent: {}", user_agent); info!("Prelogin with user_agent: {}", user_agent);
@ -107,12 +139,16 @@ pub async fn prelogin(portal: &str, gp_params: &GpParams) -> anyhow::Result<Prel
let mut params = gp_params.to_params(); let mut params = gp_params.to_params();
params.insert("tmp", "tmp"); params.insert("tmp", "tmp");
// CAS support requires external browser
if gp_params.prefer_default_browser() { if gp_params.prefer_default_browser() {
params.insert("default-browser", "1"); params.insert("default-browser", "1");
params.insert("cas-support", "yes");
} }
params.retain(|k, _| REQUIRED_PARAMS.iter().any(|required_param| required_param == k)); params.retain(|k, _| REQUIRED_PARAMS.iter().any(|required_param| required_param == k));
info!("Prelogin with params: {:?}", params);
let client = Client::builder() let client = Client::builder()
.danger_accept_invalid_certs(gp_params.ignore_tls_errors()) .danger_accept_invalid_certs(gp_params.ignore_tls_errors())
.user_agent(user_agent) .user_agent(user_agent)
@ -124,8 +160,8 @@ pub async fn prelogin(portal: &str, gp_params: &GpParams) -> anyhow::Result<Prel
.send() .send()
.await .await
.map_err(|e| anyhow::anyhow!(PortalError::NetworkError(e.to_string())))?; .map_err(|e| anyhow::anyhow!(PortalError::NetworkError(e.to_string())))?;
let status = res.status();
let status = res.status();
if status == StatusCode::NOT_FOUND { if status == StatusCode::NOT_FOUND {
bail!(PortalError::PreloginError("Prelogin endpoint not found".to_string())) bail!(PortalError::PreloginError("Prelogin endpoint not found".to_string()))
} }
@ -177,6 +213,7 @@ fn parse_res_xml(res_xml: String, is_gateway: bool) -> anyhow::Result<Prelogin>
is_gateway, is_gateway,
saml_request, saml_request,
support_default_browser, support_default_browser,
is_cas: false,
}; };
return Ok(Prelogin::Saml(saml_prelogin)); return Ok(Prelogin::Saml(saml_prelogin));
@ -196,8 +233,8 @@ fn parse_res_xml(res_xml: String, is_gateway: bool) -> anyhow::Result<Prelogin>
label_password: label_password.unwrap(), label_password: label_password.unwrap(),
}; };
return Ok(Prelogin::Standard(standard_prelogin)); Ok(Prelogin::Standard(standard_prelogin))
} else {
Err(anyhow!("Invalid prelogin response"))
} }
bail!("Invalid prelogin response");
} }