From 68227b64a26fc126cfa5e88230f2e1803036fb64 Mon Sep 17 00:00:00 2001 From: Kevin Yue Date: Sun, 7 Jul 2024 18:15:12 +0800 Subject: [PATCH] refactor: improve the XML parsing --- crates/gpapi/src/gateway/parse_gateways.rs | 39 ++++------- crates/gpapi/src/portal/config.rs | 79 ++++++++++++---------- crates/gpapi/src/portal/prelogin.rs | 55 +++++++++------ crates/gpapi/src/utils/xml.rs | 19 ++++-- 4 files changed, 106 insertions(+), 86 deletions(-) diff --git a/crates/gpapi/src/gateway/parse_gateways.rs b/crates/gpapi/src/gateway/parse_gateways.rs index 2ce2675..4810749 100644 --- a/crates/gpapi/src/gateway/parse_gateways.rs +++ b/crates/gpapi/src/gateway/parse_gateways.rs @@ -1,52 +1,41 @@ -use roxmltree::Document; +use roxmltree::Node; + +use crate::utils::xml::NodeExt; use super::{Gateway, PriorityRule}; -pub(crate) fn parse_gateways(doc: &Document, external: bool) -> Option> { - let node_gateways = doc.descendants().find(|n| n.has_tag_name("gateways"))?; +pub(crate) fn parse_gateways(node: &Node, use_internal: bool) -> Option> { + let node_gateways = node.find_child("gateways")?; - // if external flag is set, look for external gateways, otherwise look for internal gateways - let kind_gateways = if external { - node_gateways.descendants().find(|n| n.has_tag_name("external"))? + let list_gateway = if use_internal { + node_gateways.find_child("internal")?.find_child("list")? } else { - node_gateways.descendants().find(|n| n.has_tag_name("internal"))? + node_gateways.find_child("external")?.find_child("list")? }; - let list_gateway = kind_gateways.descendants().find(|n| n.has_tag_name("list"))?; - let gateways = list_gateway .children() .filter_map(|gateway_item| { if !gateway_item.has_tag_name("entry") { return None; } - let address = gateway_item.attribute("name").unwrap_or("").to_string(); - let name = gateway_item - .children() - .find(|n| n.has_tag_name("description")) - .and_then(|n| n.text()) - .unwrap_or("") - .to_string(); + let address = gateway_item.attribute("name").unwrap_or_default().to_string(); + let name = gateway_item.child_text("description").unwrap_or_default().to_string(); let priority = gateway_item - .children() - .find(|n| n.has_tag_name("priority")) - .and_then(|n| n.text()) + .child_text("priority") .and_then(|s| s.parse().ok()) .unwrap_or(u32::MAX); let priority_rules = gateway_item - .children() - .find(|n| n.has_tag_name("priority-rule")) + .find_child("priority-rule") .map(|n| { n.children() .filter_map(|n| { if !n.has_tag_name("entry") { return None; } - let name = n.attribute("name").unwrap_or("").to_string(); + let name = n.attribute("name").unwrap_or_default().to_string(); let priority: u32 = n - .children() - .find(|n| n.has_tag_name("priority")) - .and_then(|n| n.text()) + .child_text("priority") .and_then(|s| s.parse().ok()) .unwrap_or(u32::MAX); diff --git a/crates/gpapi/src/portal/config.rs b/crates/gpapi/src/portal/config.rs index 94d6b08..4855f51 100644 --- a/crates/gpapi/src/portal/config.rs +++ b/crates/gpapi/src/portal/config.rs @@ -2,7 +2,7 @@ use anyhow::bail; use dns_lookup::lookup_addr; use log::{info, warn}; use reqwest::{Client, StatusCode}; -use roxmltree::Document; +use roxmltree::{Document, Node}; use serde::Serialize; use specta::Type; @@ -11,7 +11,7 @@ use crate::{ error::PortalError, gateway::{parse_gateways, Gateway}, gp_params::GpParams, - utils::{normalize_server, parse_gp_response, remove_url_scheme, xml}, + utils::{normalize_server, parse_gp_response, remove_url_scheme, xml::NodeExt}, }; #[derive(Debug, Serialize, Type)] @@ -125,46 +125,22 @@ pub async fn retrieve_config(portal: &str, cred: &Credential, gp_params: &GpPara } let doc = Document::parse(&res_xml).map_err(|e| PortalError::ConfigError(e.to_string()))?; + let root = doc.root(); - let mut external_gateway = true; - + let mut use_internal_gateways = false; // Perform DNS lookup, set flag to internal or external, and pass it to parse_gateways - if let Some(_) = xml::get_child_text(&doc, "internal-host-detection") { - let ip_info = [ - (xml::get_child_text(&doc, "ip-address"), xml::get_child_text(&doc, "host")), - (xml::get_child_text(&doc, "ipv6-address"), xml::get_child_text(&doc, "ipv6-host")), - ]; - - info!("internal-host-detection returned, performing DNS lookup"); - - for (ip_address, host) in ip_info.iter() { - if let (Some(ip_address), Some(host)) = (ip_address.as_deref(), host.as_deref()) { - if !ip_address.is_empty() && !host.is_empty() { - match ip_address.parse::() { - Ok(ip) => match lookup_addr(&ip) { - Ok(host_lookup) if host_lookup == *host => { - external_gateway = false; - break; - } - Ok(_) => (), - Err(err) => warn!("DNS lookup failed for {}: {}", ip_address, err), - }, - Err(err) => warn!("Invalid IP address {}: {}", ip_address, err), - } - } - } - } + if let Some(ihd_node) = root.find_child("internal-host-detection") { + use_internal_gateways = internal_host_detect(&ihd_node) } - - let mut gateways = parse_gateways(&doc, external_gateway).unwrap_or_else(|| { + let mut gateways = parse_gateways(&root, use_internal_gateways).unwrap_or_else(|| { info!("No gateways found in portal config"); vec![] }); - let user_auth_cookie = xml::get_child_text(&doc, "portal-userauthcookie").unwrap_or_default(); - let prelogon_user_auth_cookie = xml::get_child_text(&doc, "portal-prelogonuserauthcookie").unwrap_or_default(); - let config_digest = xml::get_child_text(&doc, "config-digest"); + let user_auth_cookie = root.child_text("portal-userauthcookie").unwrap_or_default(); + let prelogon_user_auth_cookie = root.child_text("portal-prelogonuserauthcookie").unwrap_or_default(); + let config_digest = root.child_text("config-digest"); if gateways.is_empty() { gateways.push(Gateway::new(server.to_string(), server.to_string())); @@ -172,9 +148,40 @@ pub async fn retrieve_config(portal: &str, cred: &Credential, gp_params: &GpPara Ok(PortalConfig { portal: server.to_string(), - auth_cookie: AuthCookieCredential::new(cred.username(), &user_auth_cookie, &prelogon_user_auth_cookie), + auth_cookie: AuthCookieCredential::new(cred.username(), user_auth_cookie, prelogon_user_auth_cookie), config_cred: cred.clone(), gateways, - config_digest, + config_digest: config_digest.map(|s| s.to_string()), }) } + +fn internal_host_detect(node: &Node) -> bool { + let ip_info = [ + (node.child_text("ip-address"), node.child_text("host")), + (node.child_text("ipv6-address"), node.child_text("ipv6-host")), + ]; + + info!("Found internal-host-detection, performing DNS lookup"); + + for (ip_address, host) in ip_info.iter() { + if let (Some(ip_address), Some(host)) = (ip_address.as_deref(), host.as_deref()) { + if !ip_address.is_empty() && !host.is_empty() { + match ip_address.parse::() { + Ok(ip) => match lookup_addr(&ip) { + Ok(host_lookup) if host_lookup == *host => return true, + Ok(host_lookup) => { + info!( + "rDNS lookup for {} returned {}, expected {}", + ip_address, host_lookup, host + ); + } + Err(err) => warn!("DNS lookup failed for {}: {}", ip_address, err), + }, + Err(err) => warn!("Invalid IP address {}: {}", ip_address, err), + } + } + } + } + + false +} diff --git a/crates/gpapi/src/portal/prelogin.rs b/crates/gpapi/src/portal/prelogin.rs index d1c4573..9748e98 100644 --- a/crates/gpapi/src/portal/prelogin.rs +++ b/crates/gpapi/src/portal/prelogin.rs @@ -8,7 +8,7 @@ use specta::Type; use crate::{ error::PortalError, gp_params::GpParams, - utils::{base64, normalize_server, parse_gp_response, xml}, + utils::{base64, normalize_server, parse_gp_response, xml::NodeExt}, }; const REQUIRED_PARAMS: [&str; 8] = [ @@ -146,26 +146,31 @@ pub async fn prelogin(portal: &str, gp_params: &GpParams) -> anyhow::Result anyhow::Result { let doc = Document::parse(res_xml)?; + let root = doc.root(); - let status = xml::get_child_text(&doc, "status") + let status = root + .child_text("status") .ok_or_else(|| anyhow::anyhow!("Prelogin response does not contain status element"))?; // Check the status of the prelogin response if status.to_uppercase() != "SUCCESS" { - let msg = xml::get_child_text(&doc, "msg").unwrap_or(String::from("Unknown error")); + let msg = root.child_text("msg").unwrap_or("Unknown error"); bail!("{}", msg) } - let region = xml::get_child_text(&doc, "region").unwrap_or_else(|| { - info!("Prelogin response does not contain region element"); - String::from("Unknown") - }); + let region = root + .child_text("region") + .unwrap_or_else(|| { + info!("Prelogin response does not contain region element"); + "Unknown" + }) + .to_string(); - let saml_method = xml::get_child_text(&doc, "saml-auth-method"); - let saml_request = xml::get_child_text(&doc, "saml-request"); - let saml_default_browser = xml::get_child_text(&doc, "saml-default-browser"); + let saml_method = root.child_text("saml-auth-method"); + let saml_request = root.child_text("saml-request"); + let saml_default_browser = root.child_text("saml-default-browser"); // Check if the prelogin response is SAML if saml_method.is_some() && saml_request.is_some() { - let saml_request = base64::decode_to_string(&saml_request.unwrap())?; + let saml_request = base64::decode_to_string(saml_request.unwrap())?; let support_default_browser = saml_default_browser.map(|s| s.to_lowercase() == "yes").unwrap_or(false); let saml_prelogin = SamlPrelogin { @@ -178,17 +183,25 @@ fn parse_res_xml(res_xml: &str, is_gateway: bool) -> anyhow::Result { return Ok(Prelogin::Saml(saml_prelogin)); } - let label_username = xml::get_child_text(&doc, "username-label").unwrap_or_else(|| { - info!("Username label has no value, using default"); - String::from("Username") - }); - let label_password = xml::get_child_text(&doc, "password-label").unwrap_or_else(|| { - info!("Password label has no value, using default"); - String::from("Password") - }); + let label_username = root + .child_text("username-label") + .unwrap_or_else(|| { + info!("Username label has no value, using default"); + "Username" + }) + .to_string(); + let label_password = root + .child_text("password-label") + .unwrap_or_else(|| { + info!("Password label has no value, using default"); + "Password" + }) + .to_string(); - let auth_message = - xml::get_child_text(&doc, "authentication-message").unwrap_or(String::from("Please enter the login credentials")); + let auth_message = root + .child_text("authentication-message") + .unwrap_or("Please enter the login credentials") + .to_string(); let standard_prelogin = StandardPrelogin { region, is_gateway, diff --git a/crates/gpapi/src/utils/xml.rs b/crates/gpapi/src/utils/xml.rs index 674e866..c1f9699 100644 --- a/crates/gpapi/src/utils/xml.rs +++ b/crates/gpapi/src/utils/xml.rs @@ -1,6 +1,17 @@ -use roxmltree::Document; +use roxmltree::Node; -pub(crate) fn get_child_text(doc: &Document, name: &str) -> Option { - let node = doc.descendants().find(|n| n.has_tag_name(name))?; - node.text().map(|s| s.to_string()) +pub(crate) trait NodeExt<'a> { + fn find_child(&self, name: &str) -> Option>; + fn child_text(&self, name: &str) -> Option<&'a str>; +} + +impl<'a> NodeExt<'a> for Node<'a, 'a> { + fn find_child(&self, name: &str) -> Option> { + self.children().find(|n| n.has_tag_name(name)) + } + + fn child_text(&self, name: &str) -> Option<&'a str> { + let node = self.find_child(name)?; + node.text() + } }