refactor: refine the saml login

This commit is contained in:
Kevin Yue 2023-05-28 16:26:02 +08:00
parent 54d3bb8a92
commit 89eb42ceac
3 changed files with 79 additions and 82 deletions

View File

@ -2,6 +2,7 @@
"cSpell.words": [
"bindgen",
"clientos",
"gpcommon",
"jnlp",
"openconnect",
"prelogin",

View File

@ -1,4 +1,4 @@
use log::{debug, warn};
use log::{debug, info, warn};
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::{sync::Arc, time::Duration};
@ -14,6 +14,11 @@ use webkit2gtk::{
const AUTH_WINDOW_LABEL: &str = "auth_window";
const AUTH_ERROR_EVENT: &str = "auth-error";
const AUTH_REQUEST_EVENT: &str = "auth-request";
// Timeout to show the window if the token is not found in the response
// It will be cancelled if the token is found in the response
const SHOW_WINDOW_TIMEOUT: u64 = 3;
// A fallback timeout to show the window in case the authentication process takes longer than expected
const FALLBACK_SHOW_WINDOW_TIMEOUT: u64 = 15;
#[derive(Debug, Clone, Deserialize)]
pub(crate) enum SamlBinding {
@ -96,18 +101,11 @@ pub(crate) async fn saml_login(
let (event_tx, event_rx) = mpsc::channel::<AuthEvent>(8);
let window = build_window(app_handle, ua)?;
setup_webview(&window, event_tx.clone())?;
let handler_id = setup_window(&window, event_tx);
let handler = setup_window(&window, event_tx);
match process(&window, event_rx, auth_request).await {
Ok(auth_data) => {
window.unlisten(handler_id);
Ok(auth_data)
}
Err(err) => {
window.unlisten(handler_id);
Err(err)
}
}
let result = process(&window, auth_request, event_rx).await;
window.unlisten(handler);
result
}
fn build_window(app_handle: &AppHandle, ua: &str) -> tauri::Result<Window> {
@ -122,6 +120,7 @@ fn build_window(app_handle: &AppHandle, ua: &str) -> tauri::Result<Window> {
.build()
}
// Setup webview events
fn setup_webview(window: &Window, event_tx: mpsc::Sender<AuthEvent>) -> tauri::Result<()> {
window.with_webview(move |wv| {
let wv = wv.inner();
@ -129,25 +128,24 @@ fn setup_webview(window: &Window, event_tx: mpsc::Sender<AuthEvent>) -> tauri::R
wv.connect_load_changed(move |wv, event| {
if LoadEvent::Finished != event {
debug!("Skipping load event: {:?}", event);
return;
}
let uri = wv.uri().unwrap_or("".into());
// Empty URI indicates that an error occurred
if uri.is_empty() {
warn!("Empty URI");
warn!("Empty URI loaded");
if let Err(err) = event_tx.blocking_send(AuthEvent::Error(AuthError::TokenInvalid))
{
println!("Error sending event: {}", err);
warn!("Error sending event: {}", err);
}
return;
}
// TODO, redact URI
debug!("Loaded URI: {}", uri);
if let Some(main_res) = wv.main_resource() {
// AuthDataParser::new(&window_tx_clone).parse(&main_res);
parse_auth_data(&main_res, event_tx.clone());
} else {
warn!("No main_resource");
@ -155,7 +153,7 @@ fn setup_webview(window: &Window, event_tx: mpsc::Sender<AuthEvent>) -> tauri::R
});
wv.connect_load_failed(|_wv, event, err_msg, err| {
println!("Load failed: {:?}, {}, {:?}", event, err_msg, err);
warn!("Load failed: {:?}, {}, {:?}", event, err_msg, err);
false
});
})
@ -167,13 +165,11 @@ fn setup_window(window: &Window, event_tx: mpsc::Sender<AuthEvent>) -> EventHand
if let CloseRequested { api, .. } = event {
api.prevent_close();
if let Err(err) = event_tx_clone.blocking_send(AuthEvent::Cancel) {
println!("Error sending event: {}", err)
warn!("Error sending event: {}", err)
}
}
});
window.open_devtools();
window.listen_global(AUTH_REQUEST_EVENT, move |event| {
if let Ok(payload) = TryInto::<AuthRequest>::try_into(event.payload()) {
debug!("---------Received auth request");
@ -190,15 +186,15 @@ fn setup_window(window: &Window, event_tx: mpsc::Sender<AuthEvent>) -> EventHand
async fn process(
window: &Window,
event_rx: mpsc::Receiver<AuthEvent>,
auth_request: AuthRequest,
event_rx: mpsc::Receiver<AuthEvent>,
) -> tauri::Result<Option<AuthData>> {
process_request(window, auth_request)?;
let (close_tx, close_rx) = mpsc::channel::<()>(1);
tokio::spawn(show_window_after_timeout(window.clone(), close_rx));
process_auth_event(&window, event_rx, close_tx).await
let handle = tokio::spawn(show_window_after_timeout(window.clone()));
let auth_data = process_auth_event(&window, event_rx).await;
handle.abort();
Ok(auth_data)
}
fn process_request(window: &Window, auth_request: AuthRequest) -> tauri::Result<()> {
@ -211,63 +207,50 @@ fn process_request(window: &Window, auth_request: AuthRequest) -> tauri::Result<
// Load SAML request as HTML if POST binding is used
wv.load_html(&saml_request, None);
} else {
println!("Redirecting to SAML request URL: {}", saml_request);
// Redirect to SAML request URL if REDIRECT binding is used
wv.load_uri(&saml_request);
}
})
}
async fn show_window_after_timeout(window: Window, mut close_rx: mpsc::Receiver<()>) {
// Show the window after 10 seconds
let duration = Duration::from_secs(10);
if let Err(_) = timeout(duration, close_rx.recv()).await {
println!("Final show window");
async fn show_window_after_timeout(window: Window) {
tokio::time::sleep(Duration::from_secs(FALLBACK_SHOW_WINDOW_TIMEOUT)).await;
info!("Showing window after timeout expired: {} seconds", FALLBACK_SHOW_WINDOW_TIMEOUT);
show_window(&window);
} else {
println!("Window closed, cancel the final show window");
}
}
async fn process_auth_event(
window: &Window,
mut event_rx: mpsc::Receiver<AuthEvent>,
close_tx: mpsc::Sender<()>,
) -> tauri::Result<Option<AuthData>> {
) -> Option<AuthData> {
let (cancel_timeout_tx, cancel_timeout_rx) = mpsc::channel::<()>(1);
let cancel_timeout_rx = Arc::new(Mutex::new(cancel_timeout_rx));
async fn close_window(window: &Window, close_tx: mpsc::Sender<()>) {
if let Err(err) = window.close() {
println!("Error closing window: {}", err);
}
if let Err(err) = close_tx.send(()).await {
warn!("Error sending the close event: {:?}", err);
}
}
loop {
if let Some(auth_event) = event_rx.recv().await {
match auth_event {
AuthEvent::Request(auth_request) => {
println!("Got auth request: {:?}", auth_request);
process_request(&window, auth_request)?;
info!("Got auth request from auth-request event, processing");
if let Err(err) = process_request(&window, auth_request) {
warn!("Error processing auth request: {}", err);
}
}
AuthEvent::Success(auth_data) => {
close_window(window, close_tx).await;
return Ok(Some(auth_data));
close_window(window);
return Some(auth_data);
}
AuthEvent::Cancel => {
close_window(window, close_tx).await;
return Ok(None);
close_window(window);
return None;
}
AuthEvent::Error(AuthError::TokenInvalid) => {
// Found the invalid token, means that user is authenticated, keep retrying and no need to show the window
warn!("Found invalid auth data, retrying");
if let Err(err) = cancel_timeout_tx.send(()).await {
println!("Error sending event: {}", err);
warn!("Error sending cancel timeout: {}", err);
}
if let Err(err) =
window.emit_all(AUTH_ERROR_EVENT, "Invalid SAML result".to_string())
{
// Send the error event to the outside, so that we can retry it when receiving the auth-request event
if let Err(err) = window.emit_all(AUTH_ERROR_EVENT, ()) {
warn!("Error emitting auth-error event: {:?}", err);
}
}
@ -280,31 +263,34 @@ async fn process_auth_event(
}
}
/// Tokens not found means that the page might need the user interaction to login,
/// we should show the window after a short timeout, it will be cancelled if the
/// token is found in the response, no matter it's valid or not.
async fn handle_token_not_found(window: Window, cancel_timeout_rx: Arc<Mutex<mpsc::Receiver<()>>>) {
// Tokens not found, show the window in 5 seconds
match cancel_timeout_rx.try_lock() {
Ok(mut cancel_timeout_rx) => {
println!("Scheduling timeout");
let duration = Duration::from_secs(5);
debug!("Scheduling timeout to show window");
let duration = Duration::from_secs(SHOW_WINDOW_TIMEOUT);
if let Err(_) = timeout(duration, cancel_timeout_rx.recv()).await {
println!("Show window after timeout");
info!("Timeout expired, showing window");
show_window(&window);
} else {
println!("Cancel timeout");
debug!("Showing window timeout cancelled");
}
}
Err(_) => {
println!("Timeout already scheduled");
debug!("Timeout already scheduled, skipping");
}
}
}
/// Parse the authentication data from the response headers or HTML content
/// and send it to the event channel
fn parse_auth_data(main_res: &WebResource, event_tx: mpsc::Sender<AuthEvent>) {
if let Some(response) = main_res.response() {
if let Some(saml_result) = read_auth_data_from_response(&response) {
// Got SAML result from HTTP headers
println!("SAML result: {:?}", saml_result);
send_auth_data(&event_tx, saml_result);
if let Some(auth_data) = read_auth_data_from_response(&response) {
info!("Got auth data from HTTP headers: {:?}", auth_data);
send_auth_data(&event_tx, auth_data);
return;
}
}
@ -314,15 +300,14 @@ fn parse_auth_data(main_res: &WebResource, event_tx: mpsc::Sender<AuthEvent>) {
if let Ok(data) = data {
let html = String::from_utf8_lossy(&data);
match read_auth_data_from_html(&html) {
Ok(saml_result) => {
// Got SAML result from HTML
println!("SAML result: {:?}", saml_result);
send_auth_data(&event_tx, saml_result);
Ok(auth_data) => {
info!("Got auth data from HTML: {:?}", auth_data);
send_auth_data(&event_tx, auth_data);
}
Err(err) => {
println!("Auth error: {:?}", err);
debug!("Error reading auth data from HTML: {:?}", err);
if let Err(err) = event_tx.blocking_send(AuthEvent::Error(err)) {
println!("Error sending event: {}", err)
warn!("Error sending event: {}", err)
}
}
}
@ -330,22 +315,24 @@ fn parse_auth_data(main_res: &WebResource, event_tx: mpsc::Sender<AuthEvent>) {
});
}
/// Read the authentication data from the response headers
fn read_auth_data_from_response(response: &webkit2gtk::URIResponse) -> Option<AuthData> {
response.http_headers().and_then(|mut headers| {
let saml_result = AuthData::new(
let auth_data = AuthData::new(
headers.get("saml-username").map(GString::into),
headers.get("prelogin-cookie").map(GString::into),
headers.get("portal-userauthcookie").map(GString::into),
);
if saml_result.check() {
Some(saml_result)
if auth_data.check() {
Some(auth_data)
} else {
None
}
})
}
/// Read the authentication data from the HTML content
fn read_auth_data_from_html(html: &str) -> Result<AuthData, AuthError> {
let saml_auth_status = parse_xml_tag(html, "saml-auth-status");
@ -356,6 +343,7 @@ fn read_auth_data_from_html(html: &str) -> Result<AuthData, AuthError> {
}
}
/// Extract the authentication data from the HTML content
fn extract_auth_data(html: &str) -> Option<AuthData> {
let auth_data = AuthData::new(
parse_xml_tag(html, "saml-username"),
@ -377,21 +365,27 @@ fn parse_xml_tag(html: &str, tag: &str) -> Option<String> {
.map(|m| m.as_str().to_string())
}
fn send_auth_data(event_tx: &mpsc::Sender<AuthEvent>, saml_result: AuthData) {
if let Err(err) = event_tx.blocking_send(AuthEvent::Success(saml_result)) {
println!("Error sending event: {}", err)
fn send_auth_data(event_tx: &mpsc::Sender<AuthEvent>, auth_data: AuthData) {
if let Err(err) = event_tx.blocking_send(AuthEvent::Success(auth_data)) {
warn!("Error sending event: {}", err)
}
}
fn show_window(window: &Window) {
match window.is_visible() {
Ok(true) => {
println!("Window is already visible");
debug!("Window is already visible");
}
_ => {
if let Err(err) = window.show() {
println!("Error showing window: {}", err);
warn!("Error showing window: {}", err);
}
}
}
}
fn close_window(window: &Window) {
if let Err(err) = window.close() {
warn!("Error closing window: {}", err);
}
}

View File

@ -6,6 +6,7 @@
use auth::{AuthData, AuthRequest, SamlBinding};
use env_logger::Env;
use gpcommon::{Client, ServerApiError, VpnStatus};
use log::warn;
use serde::Serialize;
use std::sync::Arc;
use tauri::{AppHandle, Manager, State};
@ -56,11 +57,11 @@ fn setup(app: &mut tauri::App) -> Result<(), Box<dyn std::error::Error>> {
let _ = client_clone.subscribe_status(move |status| {
let payload = StatusPayload { status };
if let Err(err) = app_handle.emit_all("vpn-status-received", payload) {
println!("Error emitting event: {}", err);
warn!("Error emitting event: {}", err);
}
});
// let _ = client_clone.run().await;
let _ = client_clone.run().await;
});
app.manage(client);
@ -77,6 +78,7 @@ fn main() {
LogTarget::LogDir,
LogTarget::Stdout, /*LogTarget::Webview*/
])
.level(log::LevelFilter::Info)
.build(),
)
.setup(setup)