more refactor

This commit is contained in:
Kevin Yue
2023-11-13 10:05:06 +08:00
parent 0b4829a610
commit bf2d327687
20 changed files with 965 additions and 64 deletions

View File

@@ -3,18 +3,19 @@ name = "gpcommon"
version.workspace = true
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow.workspace = true
async-trait.workspace = true
base64.workspace = true
bytes.workspace = true
configparser.workspace = true
data-encoding.workspace = true
is_executable.workspace = true
lexopt.workspace = true
log.workspace = true
reqwest.workspace = true
ring.workspace = true
roxmltree.workspace = true
serde_json.workspace = true
serde.workspace = true
shlex.workspace = true
@@ -23,5 +24,8 @@ thiserror.workspace = true
tokio-util.workspace = true
tokio.workspace = true
[dev-dependencies]
mockito.workspace = true
[build-dependencies]
cc = "1.0"

View File

@@ -17,6 +17,7 @@ mod response;
pub mod server;
mod vpn;
mod writer;
pub mod portal;
pub(crate) use request::Request;
pub(crate) use request::RequestPool;

201
gpcommon/src/portal.rs Normal file
View File

@@ -0,0 +1,201 @@
use anyhow::bail;
use base64::{engine::general_purpose, Engine};
use roxmltree::Document;
#[derive(Debug, Clone)]
pub struct Portal {
address: String,
}
pub enum PortalCredential {
Standard(String, String),
Prelogin(String),
Cached(String, String),
}
#[derive(Debug)]
pub struct SamlPrelogin {
pub region: String,
pub method: String,
pub request: String,
}
#[derive(Debug)]
pub struct StandardPrelogin {
pub region: String,
pub label_username: String,
pub label_password: String,
pub auth_message: String,
}
#[derive(Debug)]
pub enum Prelogin {
Saml(SamlPrelogin),
Standard(StandardPrelogin),
}
impl Portal {
pub fn new(address: &str) -> Self {
Self {
address: address.to_string(),
}
}
pub async fn prelogin(&self) -> anyhow::Result<Prelogin> {
let prelogin_url = format!("{}/global-protect/prelogin.esp", self.address);
let client = reqwest::Client::builder()
.user_agent("PAN GlobalProtect")
.build()?;
let res_xml = client.get(&prelogin_url).send().await?.text().await?;
let doc = Document::parse(&res_xml)?;
let status = get_child_text(&doc, "status")
.ok_or_else(|| anyhow::anyhow!("Prelogin response does not contain status element"))?;
// Check the status of the prelogin response
if status.to_uppercase() != "SUCCESS" {
let msg = get_child_text(&doc, "msg").unwrap_or(String::from("Unknown error"));
bail!("Prelogin failed: {}", msg)
}
let region = get_child_text(&doc, "region")
.ok_or_else(|| anyhow::anyhow!("Prelogin response does not contain region element"))?;
let saml_method = get_child_text(&doc, "saml-auth-method");
let saml_request = get_child_text(&doc, "saml-request");
// Check if the prelogin response is SAML
if saml_method.is_some() && saml_request.is_some() {
return Ok(Prelogin::Saml(SamlPrelogin {
region,
method: saml_method.unwrap(),
request: base64_decode(&saml_request.unwrap())?,
}));
}
let label_username = get_child_text(&doc, "username-label");
let label_password = get_child_text(&doc, "password-label");
// Check if the prelogin response is standard login
if label_username.is_some() && label_password.is_some() {
let auth_message = get_child_text(&doc, "authentication-message")
.unwrap_or(String::from("Please enter the login credentials"));
return Ok(Prelogin::Standard(StandardPrelogin {
region,
auth_message,
label_username: label_username.unwrap(),
label_password: label_password.unwrap(),
}));
}
bail!("Unknown prelogin response");
}
pub fn retrieve_config(&self, credential: &PortalCredential) {
todo!()
}
}
fn get_child_text(doc: &Document, name: &str) -> Option<String> {
let node = doc.descendants().find(|n| n.has_tag_name(name))?;
node.text().map(|s| s.to_string())
}
fn base64_decode(s: &str) -> anyhow::Result<String> {
let engine = general_purpose::STANDARD;
let decoded = engine.decode(s)?;
Ok(String::from_utf8(decoded)?)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
let portal = Portal::new("vpn.example.com");
assert_eq!(portal.address, "vpn.example.com");
}
#[tokio::test]
async fn test_prelogin_saml() {
let mut s = mockito::Server::new();
let mock = s
.mock("GET", "/global-protect/prelogin.esp")
.with_body(
r#"<?xml version="1.0" encoding="UTF-8" ?>
<prelogin-response>
<status>Success</status>
<ccusername></ccusername>
<autosubmit>false</autosubmit>
<msg></msg>
<newmsg></newmsg>
<authentication-message>Enter login credentials</authentication-message>
<username-label>Username</username-label>
<password-label>Password</password-label>
<panos-version>1</panos-version>
<saml-default-browser>yes</saml-default-browser>
<cas-auth></cas-auth>
<saml-auth-status>0</saml-auth-status>
<saml-auth-method>REDIRECT</saml-auth-method>
<saml-request-timeout>600</saml-request-timeout>
<saml-request-id>0</saml-request-id>
<saml-request>U0FNTFJlcXVlc3Q9eHh4</saml-request>
<auth-api>no</auth-api><region>CN</region>
</prelogin-response>"#,
)
.create();
let url = s.url();
let portal = Portal::new(&url);
let prelogin = portal.prelogin().await.unwrap();
let saml = match prelogin {
Prelogin::Saml(saml) => saml,
_ => panic!("Prelogin is not SAML"),
};
mock.assert();
assert!(saml.method == "REDIRECT");
assert!(saml.request.contains("SAMLRequest"));
assert!(saml.region == "CN")
}
#[tokio::test]
async fn test_prelogin_standard() {
let mut s = mockito::Server::new();
let mock = s
.mock("GET", "/global-protect/prelogin.esp")
.with_body(
r#"<?xml version="1.0" encoding="UTF-8" ?>
<prelogin-response>
<status>Success</status>
<ccusername></ccusername>
<autosubmit>false</autosubmit>
<msg></msg>
<newmsg></newmsg>
<authentication-message>Enter login credentials</authentication-message>
<username-label>Username</username-label>
<password-label>Password</password-label>
<panos-version>1</panos-version>
<saml-default-browser>yes</saml-default-browser><auth-api>no</auth-api><region>US</region>
</prelogin-response>"#,
)
.create();
let url = s.url();
let portal = Portal::new(&url);
let prelogin = portal.prelogin().await.unwrap();
let standard = match prelogin {
Prelogin::Standard(standard) => standard,
_ => panic!("Prelogin is not standard"),
};
mock.assert();
assert!(standard.label_username == "Username");
assert!(standard.label_password == "Password");
assert!(standard.auth_message == "Enter login credentials");
assert!(standard.region == "US");
}
#[tokio::test]
async fn test_retrieve_config_standard_credential() {}
}