refactor: improve workflow

This commit is contained in:
Kevin Yue 2023-06-11 15:55:47 +08:00
parent 1af21432d4
commit 15e798c1e7
39 changed files with 950 additions and 683 deletions

View File

@ -3,8 +3,10 @@
"authcookie", "authcookie",
"bindgen", "bindgen",
"clickaway", "clickaway",
"clientgpversion",
"clientos", "clientos",
"gpcommon", "gpcommon",
"gpservice",
"Immer", "Immer",
"jnlp", "jnlp",
"oneshot", "oneshot",

29
Cargo.lock generated
View File

@ -1106,16 +1106,6 @@ dependencies = [
"system-deps 6.0.3", "system-deps 6.0.3",
] ]
[[package]]
name = "gpauth"
version = "0.1.0"
dependencies = [
"regex",
"tokio",
"webkit2gtk",
"wry",
]
[[package]] [[package]]
name = "gpclient" name = "gpclient"
version = "0.1.0" version = "0.1.0"
@ -1828,6 +1818,16 @@ version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f61fba1741ea2b3d6a1e3178721804bb716a68a6aeba1149b5d52e3d464ea66" checksum = "6f61fba1741ea2b3d6a1e3178721804bb716a68a6aeba1149b5d52e3d464ea66"
[[package]]
name = "open"
version = "3.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2078c0039e6a54a0c42c28faa984e115fb4c2d5bf2208f77d1961002df8576f8"
dependencies = [
"pathdiff",
"windows-sys 0.42.0",
]
[[package]] [[package]]
name = "openssl" name = "openssl"
version = "0.10.45" version = "0.10.45"
@ -1927,6 +1927,12 @@ dependencies = [
"windows-sys 0.45.0", "windows-sys 0.45.0",
] ]
[[package]]
name = "pathdiff"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8835116a5c179084a830efb3adc117ab007512b535bc1a21c991d3b32a6b44dd"
[[package]] [[package]]
name = "percent-encoding" name = "percent-encoding"
version = "2.2.0" version = "2.2.0"
@ -2878,9 +2884,11 @@ dependencies = [
"ignore", "ignore",
"objc", "objc",
"once_cell", "once_cell",
"open",
"percent-encoding", "percent-encoding",
"rand 0.8.5", "rand 0.8.5",
"raw-window-handle", "raw-window-handle",
"regex",
"semver 1.0.16", "semver 1.0.16",
"serde", "serde",
"serde_json", "serde_json",
@ -2934,6 +2942,7 @@ dependencies = [
"png", "png",
"proc-macro2", "proc-macro2",
"quote", "quote",
"regex",
"semver 1.0.16", "semver 1.0.16",
"serde", "serde",
"serde_json", "serde_json",

View File

@ -4,7 +4,6 @@ members = [
"gpcommon", "gpcommon",
"gpclient", "gpclient",
"gpservice", "gpservice",
"gpauth",
"gpgui/src-tauri" "gpgui/src-tauri"
] ]

View File

@ -1,12 +0,0 @@
[package]
name = "gpauth"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
wry = "0.24.3"
webkit2gtk = "0.18.2"
tokio = { version = "1.14", features = ["full"] }
regex="1"

View File

@ -1,54 +0,0 @@
use crate::{
duplex::duplex,
saml::{SamlAuth, SamlBinding, SamlOptions},
DuplexStreamHandle,
};
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Debug)]
pub struct AuthService {
server: DuplexStreamHandle,
client: Arc<Mutex<DuplexStreamHandle>>,
saml_auth: Arc<SamlAuth>,
}
impl Default for AuthService {
fn default() -> Self {
let (client, server) = duplex(4096);
Self {
client: Arc::new(Mutex::new(client)),
server,
saml_auth: Default::default(),
}
}
}
impl AuthService {
pub async fn run(&mut self) {
loop {
println!("Server waiting for data");
match self.server.read().await {
Ok(data) => {
println!("Server received: {}", data);
let target = String::from("https://login.microsoftonline.com/901c038b-4638-4259-b115-c1753c7735aa/saml2?SAMLRequest=lVLBbsIwDP2VKveSNGlaiGilDg5DYlpFux12mdIQIFKbdEmKtr8fhaGxC9Lkk%2BXnZ79nzx3v2p4Vgz%2FojfwYpPPBZ9dqx86FDAxWM8OdckzzTjrmBauKpzXDE8R6a7wRpgVB4Zy0Xhm9MNoNnbSVtEcl5MtmnYGD971jEB57PemUsMZ5y73cf02E6VgcEzgyYgSrEhaLCgTL0xZK85Hvt7s1e3XtNztvdKu0HBngDEUCkWkTxgmZhjGms7CJIhqKKKVEpCmhnMNRDgbBapmBdzpL5BZFEu0oalKMpglttukpaBLHuEEnmHODXGnnufYZwAiTENEQ0xoljBJGyBsIyh%2F1D0pvld7ft6q5gBx7rOsyLJ%2BrGgSv0rqzxBMA5PNxQ3YebG9OcJ%2BWX30H%2BT9cnsObWfkl%2B%2FsD%2BTc%3D&RelayState=HEgCAOLrNmRmZTBkM2FlNDE2MDQyMDhjZTVmMTZlMTdiZTdiMTliNg%3D%3D");
let ua = String::from("PAN GlobalProtect");
let saml_options = SamlOptions::new(SamlBinding::Redirect, target, ua);
let saml_auth = self.saml_auth.clone();
tokio::spawn(async move {
saml_auth.process(saml_options).await;
});
// self.server.write(&data).await.expect("write failed");
}
Err(err) => {
println!("Server error: {:?}", err);
}
}
}
}
pub fn client(&self) -> Arc<Mutex<DuplexStreamHandle>> {
self.client.clone()
}
}

View File

@ -1,35 +0,0 @@
use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream};
#[derive(Debug)]
pub struct DuplexStreamHandle {
stream: DuplexStream,
buf_size: usize,
}
impl DuplexStreamHandle {
fn new(stream: DuplexStream, buf_size: usize) -> Self {
Self { stream, buf_size }
}
pub async fn write(&mut self, data: &str) -> Result<(), Box<dyn std::error::Error>> {
self.stream.write_all(data.as_bytes()).await?;
Ok(())
}
pub async fn read(&mut self) -> Result<String, Box<dyn std::error::Error>> {
let mut buffer = vec![0; self.buf_size];
match self.stream.read(&mut buffer).await {
Ok(0) => Err("EOF".into()),
Ok(n) => Ok(String::from_utf8_lossy(&buffer[..n]).to_string()),
Err(err) => Err(err.to_string().into()),
}
}
}
pub(crate) fn duplex(max_buf_size: usize) -> (DuplexStreamHandle, DuplexStreamHandle) {
let (a, b) = tokio::io::duplex(max_buf_size);
(
DuplexStreamHandle::new(a, max_buf_size),
DuplexStreamHandle::new(b, max_buf_size),
)
}

View File

@ -1,8 +0,0 @@
mod auth_service;
mod saml;
mod duplex;
pub use auth_service::AuthService;
pub use duplex::DuplexStreamHandle;
pub use saml::saml_login;
pub use saml::SamlBinding;

View File

@ -1,55 +0,0 @@
use gpauth::{AuthService, saml_login, SamlBinding};
#[tokio::main]
async fn main() {
let url = String::from("https://globalprotect.kochind.com/global-protect/prelogin.esp?tmp=tmp&kerberos-support=yes&ipv6-support=yes&clientVer=4100&clientos=Linux");
let _html = String::from(
r#"<html>
<body>
<form id="myform" method="POST" action="https://auth.kochid.com/idp/SSO.saml2">
<input type="hidden" name="SAMLRequest" value="PHNhbWxwOkF1dGhuUmVxdWVzdCB4bWxuczpzYW1scD0idXJuOm9hc2lzOm5hbWVzOnRjOlNBTUw6Mi4wOnByb3RvY29sIiBBc3NlcnRpb25Db25zdW1lclNlcnZpY2VVUkw9Imh0dHBzOi8vZ2xvYmFscHJvdGVjdC5rb2NoaW5kLmNvbTo0NDMvU0FNTDIwL1NQL0FDUyIgRGVzdGluYXRpb249Imh0dHBzOi8vYXV0aC5rb2NoaWQuY29tL2lkcC9TU08uc2FtbDIiIElEPSJfZmEzZTA4NDE5NjdkZTdlYzUyNzc4Nzc4YzBkOTViMDEiIElzc3VlSW5zdGFudD0iMjAyMy0wNS0yNFQwNToyNDo1OVoiIFByb3RvY29sQmluZGluZz0idXJuOm9hc2lzOm5hbWVzOnRjOlNBTUw6Mi4wOmJpbmRpbmdzOkhUVFAtUE9TVCIgVmVyc2lvbj0iMi4wIj48c2FtbDpJc3N1ZXIgeG1sbnM6c2FtbD0idXJuOm9hc2lzOm5hbWVzOnRjOlNBTUw6Mi4wOmFzc2VydGlvbiI+aHR0cHM6Ly9nbG9iYWxwcm90ZWN0LmtvY2hpbmQuY29tOjQ0My9TQU1MMjAvU1A8L3NhbWw6SXNzdWVyPjwvc2FtbHA6QXV0aG5SZXF1ZXN0Pg==" />
<input type="hidden" name="RelayState" value="rgbNAP1wSGI0NGE1ZDZjOGM4YTkzNjk5NWNhY2JlZjkwMWJmMzIwYg==" />
</form>
<script>
document.getElementById('myform').subm{
let (client, server) = duplex(1);
AuthService { client, server }
}>
</body>
</html>
"#,
);
let ua = String::from("PAN GlobalProtect");
match saml_login(SamlBinding::Redirect, url, ua) {
Ok(saml_result) => {
println!("SAML result: {:?}", saml_result);
}
Err(err) => {
println!("Error: {:?}", err);
}
}
// let mut auth_service = AuthService::default();
// let client = auth_service.client();
// tokio::spawn(async move {
// let mut client = client.lock().await;
// client.write("Hello").await.expect("write failed");
// loop {
// if let Ok(data) = client.read().await {
// println!("Received: {}", data);
// }
// }
// });
// tokio::select! {
// _ = auth_service.run() => {
// println!("AuthService exited");
// }
// _ = tokio::signal::ctrl_c() => {
// println!("Ctrl-C received, exiting");
// }
// }
}

View File

@ -1,233 +0,0 @@
use regex::Regex;
use std::{cell::RefCell, rc::Rc};
use webkit2gtk::gio::Cancellable;
use webkit2gtk::glib::GString;
use webkit2gtk::{LoadEvent, URIResponseExt, WebResourceExt, WebViewExt};
use wry::application::event::{Event, StartCause, WindowEvent};
use wry::application::event_loop::{ControlFlow, EventLoop, EventLoopProxy};
use wry::application::platform::run_return::EventLoopExtRunReturn;
use wry::application::window::WindowBuilder;
use wry::webview::{WebViewBuilder, WebviewExtUnix};
#[derive(Debug, Default)]
pub(crate) struct SamlAuth {}
pub(crate) struct SamlOptions {
binding: SamlBinding,
target: String,
user_agent: String,
}
impl SamlOptions {
pub fn new(binding: SamlBinding, target: String, user_agent: String) -> Self {
Self {
binding,
target,
user_agent,
}
}
}
impl SamlAuth {
pub async fn process(&self, options: SamlOptions) -> Result<(), Box<dyn std::error::Error>> {
let saml_result = saml_login(options.binding, options.target, options.user_agent);
println!("SAML result: {:?}", saml_result);
Ok(())
}
}
pub enum SamlBinding {
Redirect,
Post,
}
#[derive(Debug, Clone)]
pub struct SamlResult {
username: Option<String>,
prelogin_cookie: Option<String>,
portal_userauthcookie: Option<String>,
}
impl SamlResult {
fn new(
username: Option<String>,
prelogin_cookie: Option<String>,
portal_userauthcookie: Option<String>,
) -> Self {
Self {
username,
prelogin_cookie,
portal_userauthcookie,
}
}
fn check(&self) -> bool {
self.username.is_some()
&& (self.prelogin_cookie.is_some() || self.portal_userauthcookie.is_some())
}
}
#[derive(Debug, PartialEq)]
enum SamlResultError {
NotFound,
Invalid,
}
#[derive(Debug)]
enum UserEvent {
SamlSuccess(SamlResult),
SamlError(String),
}
pub fn saml_login(
binding: SamlBinding,
target: String,
user_agent: String,
) -> Result<SamlResult, Box<dyn std::error::Error>> {
let mut event_loop: EventLoop<UserEvent> = EventLoop::with_user_event();
let event_proxy = event_loop.create_proxy();
let window = WindowBuilder::new()
.with_title("GlobalProtect Login")
.build(&event_loop)?;
let wv_builder = WebViewBuilder::new(window)?.with_user_agent(&user_agent);
let wv_builder = if let SamlBinding::Redirect = binding {
wv_builder.with_url(&target)?
} else {
wv_builder.with_html(&target)?
};
let wv = wv_builder.build()?;
let wv = wv.webview();
wv.connect_load_changed(move |webview, event| {
if let LoadEvent::Finished = event {
if let Some(main_resource) = webview.main_resource() {
// Read the SAML result from the HTTP headers
if let Some(response) = main_resource.response() {
if let Some(saml_result) = read_saml_result_from_response(&response) {
println!("Got SAML result from HTTP headers");
return emit_event(&event_proxy, UserEvent::SamlSuccess(saml_result));
}
}
// Read the SAML result from the HTTP body
let event_proxy = event_proxy.clone();
main_resource.data(Cancellable::NONE, move |data| {
if let Ok(data) = data {
match read_saml_result_from_html(&data) {
Ok(saml_result) => {
println!("Got SAML result from HTTP body");
emit_event(&event_proxy, UserEvent::SamlSuccess(saml_result));
}
Err(err) if err == SamlResultError::Invalid => {
println!("Error reading SAML result from HTTP body: {:?}", err);
emit_event(
&event_proxy,
UserEvent::SamlError("Invalid SAML result".into()),
);
}
Err(_) => {
println!("SAML result not found in HTTP body");
}
}
}
});
}
}
});
let saml_result: Rc<RefCell<Option<SamlResult>>> = Rc::new(RefCell::new(None));
let saml_result_clone = saml_result.clone();
let exit_code = event_loop.run_return(move |event, _, control_flow| {
*control_flow = ControlFlow::Wait;
match event {
Event::NewEvents(StartCause::Init) => println!("Wry has started!"),
Event::WindowEvent {
event: WindowEvent::CloseRequested,
..
} => {
println!("User closed the window");
*control_flow = ControlFlow::Exit
}
Event::UserEvent(UserEvent::SamlSuccess(result)) => {
*saml_result_clone.borrow_mut() = Some(result);
*control_flow = ControlFlow::Exit;
}
Event::UserEvent(UserEvent::SamlError(_)) => {
println!("Error reading SAML result");
wv.load_uri("https://baidu.com");
}
_ => (),
}
});
println!("Exit code: {:?}", exit_code);
let saml_result = if let Some(saml_result) = saml_result.borrow().clone() {
println!("SAML result: {:?}", saml_result);
Ok(saml_result)
} else {
println!("SAML result: None");
// TODO: Return a proper error
Err("SAML result not found".into())
};
saml_result
}
fn read_saml_result_from_response(response: &webkit2gtk::URIResponse) -> Option<SamlResult> {
response.http_headers().and_then(|mut headers| {
let saml_result = SamlResult::new(
headers.get("saml-username").map(GString::into),
headers.get("prelogin-cookie").map(GString::into),
headers.get("portal-userauthcookie").map(GString::into),
);
if saml_result.check() {
Some(saml_result)
} else {
None
}
})
}
fn read_saml_result_from_html(data: &[u8]) -> Result<SamlResult, SamlResultError> {
let body = String::from_utf8_lossy(data);
let saml_auth_status = parse_saml_tag(&body, "saml-auth-status");
match saml_auth_status {
Some(status) if status == "1" => extract_saml_result(&body).ok_or(SamlResultError::Invalid),
Some(status) if status == "-1" => Err(SamlResultError::Invalid),
_ => Err(SamlResultError::NotFound),
}
}
fn extract_saml_result(body: &str) -> Option<SamlResult> {
let saml_result = SamlResult::new(
parse_saml_tag(body, "saml-username"),
parse_saml_tag(body, "prelogin-cookie"),
parse_saml_tag(body, "portal-userauthcookie"),
);
if saml_result.check() {
Some(saml_result)
} else {
None
}
}
fn parse_saml_tag(body: &str, tag: &str) -> Option<String> {
let re = Regex::new(&format!("<{}>(.*)</{}>", tag, tag)).unwrap();
re.captures(body)
.and_then(|captures| captures.get(1))
.map(|m| m.as_str().to_string())
}
fn emit_event(event_proxy: &EventLoopProxy<UserEvent>, event: UserEvent) {
if let Err(err) = event_proxy.send_event(event) {
println!("Error sending event: {:?}", err);
}
}

View File

@ -1,4 +1,4 @@
use crate::cmd::{Connect, Disconnect, Status}; use crate::cmd::{Connect, Disconnect, GetStatus};
use crate::reader::Reader; use crate::reader::Reader;
use crate::request::CommandPayload; use crate::request::CommandPayload;
use crate::response::ResponseData; use crate::response::ResponseData;
@ -7,7 +7,7 @@ use crate::RequestPool;
use crate::Response; use crate::Response;
use crate::SOCKET_PATH; use crate::SOCKET_PATH;
use crate::{Request, VpnStatus}; use crate::{Request, VpnStatus};
use log::{info, warn, debug}; use log::{debug, info, warn};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt::Display; use std::fmt::Display;
use std::sync::Arc; use std::sync::Arc;
@ -17,17 +17,24 @@ use tokio::sync::{mpsc, Mutex, RwLock};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
#[derive(Debug)] #[derive(Debug)]
enum ServerEvent { enum ServiceEvent {
Online,
Response(Response), Response(Response),
ServerDisconnected, Offline,
} }
impl From<Response> for ServerEvent { impl From<Response> for ServiceEvent {
fn from(response: Response) -> Self { fn from(response: Response) -> Self {
Self::Response(response) Self::Response(response)
} }
} }
#[derive(Debug)]
pub enum ClientStatus {
Vpn(VpnStatus),
Service(bool),
}
#[derive(Debug)] #[derive(Debug)]
pub struct Client { pub struct Client {
// pool of requests that are waiting for responses // pool of requests that are waiting for responses
@ -37,10 +44,10 @@ pub struct Client {
// rx for receiving requests from the channel // rx for receiving requests from the channel
request_rx: Arc<Mutex<mpsc::Receiver<Request>>>, request_rx: Arc<Mutex<mpsc::Receiver<Request>>>,
// tx for sending responses to the channel // tx for sending responses to the channel
server_event_tx: mpsc::Sender<ServerEvent>, service_event_tx: mpsc::Sender<ServiceEvent>,
// rx for receiving responses from the channel // rx for receiving responses from the channel
server_event_rx: Arc<Mutex<mpsc::Receiver<ServerEvent>>>, service_event_rx: Arc<Mutex<mpsc::Receiver<ServiceEvent>>>,
is_healthy: Arc<RwLock<bool>>, is_online: Arc<RwLock<bool>>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@ -71,34 +78,41 @@ impl From<&str> for ServerApiError {
impl Default for Client { impl Default for Client {
fn default() -> Self { fn default() -> Self {
let (request_tx, request_rx) = mpsc::channel::<Request>(32); let (request_tx, request_rx) = mpsc::channel::<Request>(32);
let (server_event_tx, server_event_rx) = mpsc::channel::<ServerEvent>(32); let (service_event_tx, server_event_rx) = mpsc::channel::<ServiceEvent>(32);
Self { Self {
request_pool: Default::default(), request_pool: Default::default(),
request_tx, request_tx,
request_rx: Arc::new(Mutex::new(request_rx)), request_rx: Arc::new(Mutex::new(request_rx)),
server_event_tx, service_event_tx,
server_event_rx: Arc::new(Mutex::new(server_event_rx)), service_event_rx: Arc::new(Mutex::new(server_event_rx)),
is_healthy: Default::default(), is_online: Default::default(),
} }
} }
} }
impl Client { impl Client {
pub fn subscribe_status(&self, callback: impl Fn(VpnStatus) + Send + Sync + 'static) { pub async fn is_online(&self) -> bool {
let server_event_rx = self.server_event_rx.clone(); *self.is_online.read().await
}
pub fn subscribe_status(&self, callback: impl Fn(ClientStatus) + Send + Sync + 'static) {
let service_event_rx = self.service_event_rx.clone();
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
let mut server_event_rx = server_event_rx.lock().await; let mut server_event_rx = service_event_rx.lock().await;
if let Some(server_event) = server_event_rx.recv().await { if let Some(server_event) = server_event_rx.recv().await {
match server_event { match server_event {
ServerEvent::ServerDisconnected => { ServiceEvent::Online => {
callback(VpnStatus::Disconnected); callback(ClientStatus::Service(true));
} }
ServerEvent::Response(response) => { ServiceEvent::Offline => {
callback(ClientStatus::Service(false));
}
ServiceEvent::Response(response) => {
if let ResponseData::Status(vpn_status) = response.data() { if let ResponseData::Status(vpn_status) = response.data() {
callback(vpn_status); callback(ClientStatus::Vpn(vpn_status));
} }
} }
} }
@ -134,7 +148,7 @@ impl Client {
let read_handle = tokio::spawn(handle_read( let read_handle = tokio::spawn(handle_read(
read_stream, read_stream,
self.request_pool.clone(), self.request_pool.clone(),
self.server_event_tx.clone(), self.service_event_tx.clone(),
cancel_token.clone(), cancel_token.clone(),
)); ));
@ -144,13 +158,16 @@ impl Client {
cancel_token, cancel_token,
)); ));
*self.is_healthy.write().await = true; *self.is_online.write().await = true;
info!("Connected to the background service"); info!("Connected to the background service");
if let Err(err) = self.service_event_tx.send(ServiceEvent::Online).await {
warn!("Error sending online event to the channel: {}", err);
}
let _ = tokio::join!(read_handle, write_handle); let _ = tokio::join!(read_handle, write_handle);
*self.is_healthy.write().await = false; *self.is_online.write().await = false;
// TODO connection was lost, cleanup the request pool and notify the UI // TODO connection was lost, cleanup the request pool
Ok(()) Ok(())
} }
@ -159,7 +176,7 @@ impl Client {
&self, &self,
payload: CommandPayload, payload: CommandPayload,
) -> Result<T, ServerApiError> { ) -> Result<T, ServerApiError> {
if !*self.is_healthy.read().await { if !*self.is_online.read().await {
return Err("Background service is not running".into()); return Err("Background service is not running".into());
} }
@ -169,18 +186,19 @@ impl Client {
return Err(format!("Error sending request to the channel: {}", err).into()); return Err(format!("Error sending request to the channel: {}", err).into());
} }
if let Ok(response) = response_rx.await { response_rx
.await
.map_err(|_| "Error receiving response from the channel".into())
.and_then(|response| {
if response.success() { if response.success() {
match response.data().try_into() { response
Ok(it) => Ok(it), .data()
Err(_) => Err("Error parsing response data".into()), .try_into()
} .map_err(|_| "Error parsing response data".into())
} else { } else {
Err(response.message().into()) Err(response.message().into())
} }
} else { })
Err("Error receiving response from the channel".into())
}
} }
pub async fn connect(&self, server: String, cookie: String) -> Result<(), ServerApiError> { pub async fn connect(&self, server: String, cookie: String) -> Result<(), ServerApiError> {
@ -192,14 +210,14 @@ impl Client {
} }
pub async fn status(&self) -> Result<VpnStatus, ServerApiError> { pub async fn status(&self) -> Result<VpnStatus, ServerApiError> {
self.send_command(Status.into()).await self.send_command(GetStatus.into()).await
} }
} }
async fn handle_read( async fn handle_read(
read_stream: ReadHalf<UnixStream>, read_stream: ReadHalf<UnixStream>,
request_pool: Arc<RequestPool>, request_pool: Arc<RequestPool>,
server_event_tx: mpsc::Sender<ServerEvent>, service_event_tx: mpsc::Sender<ServiceEvent>,
cancel_token: CancellationToken, cancel_token: CancellationToken,
) { ) {
let mut reader: Reader = read_stream.into(); let mut reader: Reader = read_stream.into();
@ -211,7 +229,7 @@ async fn handle_read(
match response.request_id() { match response.request_id() {
Some(id) => request_pool.complete_request(id, response).await, Some(id) => request_pool.complete_request(id, response).await,
None => { None => {
if let Err(err) = server_event_tx.send(response.into()).await { if let Err(err) = service_event_tx.send(response.into()).await {
warn!("Error sending response to output channel: {}", err); warn!("Error sending response to output channel: {}", err);
} }
} }
@ -220,7 +238,7 @@ async fn handle_read(
} }
Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => { Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => {
warn!("Disconnected from the background service"); warn!("Disconnected from the background service");
if let Err(err) = server_event_tx.send(ServerEvent::ServerDisconnected).await { if let Err(err) = service_event_tx.send(ServiceEvent::Offline).await {
warn!( warn!(
"Error sending server disconnected event to channel: {}", "Error sending server disconnected event to channel: {}",
err err

View File

@ -12,7 +12,7 @@ mod status;
pub use connect::Connect; pub use connect::Connect;
pub use disconnect::Disconnect; pub use disconnect::Disconnect;
pub use status::Status; pub use status::GetStatus;
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct CommandContext { pub(crate) struct CommandContext {

View File

@ -4,10 +4,10 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Status; pub struct GetStatus;
#[async_trait] #[async_trait]
impl Command for Status { impl Command for GetStatus {
async fn handle(&self, context: CommandContext) -> Result<ResponseData, CommandError> { async fn handle(&self, context: CommandContext) -> Result<ResponseData, CommandError> {
let status = context.server_context.vpn().status().await; let status = context.server_context.vpn().status().await;

View File

@ -23,18 +23,18 @@ async fn handle_read(
let mut authenticated: Option<bool> = None; let mut authenticated: Option<bool> = None;
loop { loop {
match reader.read::<Request>().await { match reader.read_multiple::<Request>().await {
Ok(request) => { Ok(requests) => {
if authenticated.is_none() { if authenticated.is_none() {
authenticated = Some(authenticate(peer_pid)); authenticated = Some(authenticate(peer_pid));
} }
if !authenticated.unwrap_or(false) { if !authenticated.unwrap_or(false) {
warn!("Client not authenticated, closing connection"); warn!("Client not authenticated, closing connection");
cancel_token.cancel(); cancel_token.cancel();
break; break;
} }
for request in requests {
debug!("Received client request: {:?}", request); debug!("Received client request: {:?}", request);
let command = request.command(); let command = request.command();
@ -48,6 +48,7 @@ async fn handle_read(
let _ = response_tx.send(response).await; let _ = response_tx.send(response).await;
} }
}
Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => { Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => {
info!("Client disconnected"); info!("Client disconnected");

View File

@ -30,6 +30,7 @@ pub(crate) use writer::Writer;
pub use client::Client; pub use client::Client;
pub use client::ServerApiError; pub use client::ServerApiError;
pub use client::ClientStatus;
pub use vpn::VpnStatus; pub use vpn::VpnStatus;
pub fn sha256_digest<P: AsRef<Path>>(file_path: P) -> Result<String, std::io::Error> { pub fn sha256_digest<P: AsRef<Path>>(file_path: P) -> Result<String, std::io::Error> {

View File

@ -13,22 +13,6 @@ impl From<ReadHalf<UnixStream>> for Reader {
} }
impl Reader { impl Reader {
pub async fn read<T: for<'a> Deserialize<'a>>(&mut self) -> Result<T, io::Error> {
let mut buffer = [0; 2048];
match self.stream.read(&mut buffer).await {
Ok(0) => Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"Peer disconnected",
)),
Ok(bytes_read) => {
let data = serde_json::from_slice::<T>(&buffer[..bytes_read])?;
Ok(data)
}
Err(err) => Err(err),
}
}
pub async fn read_multiple<T: for<'a> Deserialize<'a>>(&mut self) -> Result<Vec<T>, io::Error> { pub async fn read_multiple<T: for<'a> Deserialize<'a>>(&mut self) -> Result<Vec<T>, io::Error> {
let mut buffer = [0; 2048]; let mut buffer = [0; 2048];

View File

@ -1,4 +1,4 @@
use crate::cmd::{Command, Connect, Disconnect, Status}; use crate::cmd::{Command, Connect, Disconnect, GetStatus};
use crate::Response; use crate::Response;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::sync::Arc; use std::sync::Arc;
@ -21,7 +21,7 @@ impl Request {
pub fn command(&self) -> Box<dyn Command> { pub fn command(&self) -> Box<dyn Command> {
match &self.payload { match &self.payload {
CommandPayload::Status(status) => Box::new(status.clone()), CommandPayload::GetStatus(status) => Box::new(status.clone()),
CommandPayload::Connect(connect) => Box::new(connect.clone()), CommandPayload::Connect(connect) => Box::new(connect.clone()),
CommandPayload::Disconnect(disconnect) => Box::new(disconnect.clone()), CommandPayload::Disconnect(disconnect) => Box::new(disconnect.clone()),
} }
@ -30,14 +30,14 @@ impl Request {
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub(crate) enum CommandPayload { pub(crate) enum CommandPayload {
Status(Status), GetStatus(GetStatus),
Connect(Connect), Connect(Connect),
Disconnect(Disconnect), Disconnect(Disconnect),
} }
impl From<Status> for CommandPayload { impl From<GetStatus> for CommandPayload {
fn from(status: Status) -> Self { fn from(status: GetStatus) -> Self {
Self::Status(status) Self::GetStatus(status)
} }
} }

View File

@ -1,4 +1,4 @@
use log::{warn, info, debug}; use log::{debug, info, warn};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::ffi::{c_void, CString}; use std::ffi::{c_void, CString};
use std::sync::Arc; use std::sync::Arc;
@ -91,7 +91,7 @@ impl Vpn {
*self.vpn_options.lock().await = Some(VpnOptions { *self.vpn_options.lock().await = Some(VpnOptions {
server: VpnOptions::to_cstr(server), server: VpnOptions::to_cstr(server),
cookie: VpnOptions::to_cstr(cookie), cookie: VpnOptions::to_cstr(cookie),
script: VpnOptions::to_cstr("/usr/share/vpnc-scripts/vpnc-script") script: VpnOptions::to_cstr("/usr/share/vpnc-scripts/vpnc-script"),
}); });
let vpn_options = self.vpn_options.clone(); let vpn_options = self.vpn_options.clone();
@ -133,12 +133,15 @@ impl Vpn {
} }
info!("Disconnecting VPN..."); info!("Disconnecting VPN...");
self.status_holder
.lock()
.await
.set(VpnStatus::Disconnecting);
unsafe { ffi::disconnect() }; unsafe { ffi::disconnect() };
let mut status_rx = self.status_rx().await; let mut status_rx = self.status_rx().await;
debug!("Waiting for the VPN to disconnect..."); debug!("Waiting for the VPN to disconnect...");
while status_rx.changed().await.is_ok() { while status_rx.changed().await.is_ok() {
if *status_rx.borrow() == VpnStatus::Disconnected { if *status_rx.borrow() == VpnStatus::Disconnected {
info!("VPN disconnected"); info!("VPN disconnected");

View File

@ -16,7 +16,7 @@ tauri-build = { version = "1.3", features = [] }
[dependencies] [dependencies]
gpcommon = { path = "../../gpcommon" } gpcommon = { path = "../../gpcommon" }
tauri = { version = "1.3", features = ["http-all", "window-all", "window-data-url"] } tauri = { version = "1.3", features = ["http-all", "process-exit", "shell-open", "window-all", "window-data-url"] }
tauri-plugin-log = { git = "https://github.com/tauri-apps/plugins-workspace", branch = "v1", features = [ tauri-plugin-log = { git = "https://github.com/tauri-apps/plugins-workspace", branch = "v1", features = [
"colored", "colored",
] } ] }

View File

@ -4,9 +4,9 @@ use regex::Regex;
use serde::de::Error; use serde::de::Error;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{sync::Arc, time::Duration}; use std::{sync::Arc, time::Duration};
use tauri::EventHandler; use tauri::{AppHandle, Manager, Window, WindowUrl};
use tauri::{AppHandle, Manager, Window, WindowEvent::CloseRequested, WindowUrl}; use tauri::{EventHandler, WindowEvent};
use tokio::sync::{mpsc, Mutex}; use tokio::sync::{mpsc, oneshot, Mutex};
use tokio::time::timeout; use tokio::time::timeout;
use veil::Redact; use veil::Redact;
use webkit2gtk::gio::Cancellable; use webkit2gtk::gio::Cancellable;
@ -100,7 +100,6 @@ enum AuthEvent {
Request(AuthRequest), Request(AuthRequest),
Success(AuthData), Success(AuthData),
Error(AuthError), Error(AuthError),
Cancel,
} }
pub(crate) struct SamlLoginParams { pub(crate) struct SamlLoginParams {
@ -113,10 +112,10 @@ pub(crate) struct SamlLoginParams {
pub(crate) async fn saml_login(params: SamlLoginParams) -> tauri::Result<Option<AuthData>> { pub(crate) async fn saml_login(params: SamlLoginParams) -> tauri::Result<Option<AuthData>> {
info!("Starting SAML login"); info!("Starting SAML login");
let (event_tx, event_rx) = mpsc::channel::<AuthEvent>(8); let (auth_event_tx, auth_event_rx) = mpsc::channel::<AuthEvent>(1);
let window = build_window(&params.app_handle, &params.user_agent)?; let window = build_window(&params.app_handle, &params.user_agent)?;
setup_webview(&window, event_tx.clone())?; setup_webview(&window, auth_event_tx.clone())?;
let handler = setup_window(&window, event_tx); let handler = setup_window(&window, auth_event_tx);
if params.clear_cookies { if params.clear_cookies {
if let Err(err) = clear_webview_cookies(&window).await { if let Err(err) = clear_webview_cookies(&window).await {
@ -124,7 +123,7 @@ pub(crate) async fn saml_login(params: SamlLoginParams) -> tauri::Result<Option<
} }
} }
let result = process(&window, params.auth_request, event_rx).await; let result = process(&window, params.auth_request, auth_event_rx).await;
window.unlisten(handler); window.unlisten(handler);
result result
} }
@ -134,6 +133,8 @@ fn build_window(app_handle: &AppHandle, ua: &str) -> tauri::Result<Window> {
Window::builder(app_handle, AUTH_WINDOW_LABEL, url) Window::builder(app_handle, AUTH_WINDOW_LABEL, url)
.visible(false) .visible(false)
.title("GlobalProtect Login") .title("GlobalProtect Login")
.inner_size(390.0, 694.0)
.min_inner_size(390.0, 600.0)
.user_agent(ua) .user_agent(ua)
.always_on_top(true) .always_on_top(true)
.focused(true) .focused(true)
@ -142,10 +143,10 @@ fn build_window(app_handle: &AppHandle, ua: &str) -> tauri::Result<Window> {
} }
// Setup webview events // Setup webview events
fn setup_webview(window: &Window, event_tx: mpsc::Sender<AuthEvent>) -> tauri::Result<()> { fn setup_webview(window: &Window, auth_event_tx: mpsc::Sender<AuthEvent>) -> tauri::Result<()> {
window.with_webview(move |wv| { window.with_webview(move |wv| {
let wv = wv.inner(); let wv = wv.inner();
let event_tx_clone = event_tx.clone(); let auth_event_tx_clone = auth_event_tx.clone();
wv.connect_load_changed(move |wv, event| { wv.connect_load_changed(move |wv, event| {
if LoadEvent::Finished != event { if LoadEvent::Finished != event {
@ -156,13 +157,13 @@ fn setup_webview(window: &Window, event_tx: mpsc::Sender<AuthEvent>) -> tauri::R
// Empty URI indicates that an error occurred // Empty URI indicates that an error occurred
if uri.is_empty() { if uri.is_empty() {
warn!("Empty URI loaded, retrying"); warn!("Empty URI loaded, retrying");
send_auth_error(event_tx_clone.clone(), AuthError::TokenInvalid); send_auth_error(auth_event_tx_clone.clone(), AuthError::TokenInvalid);
return; return;
} }
info!("Loaded URI: {}", redact_url(&uri)); info!("Loaded URI: {}", redact_url(&uri));
if let Some(main_res) = wv.main_resource() { if let Some(main_res) = wv.main_resource() {
parse_auth_data(&main_res, event_tx_clone.clone()); parse_auth_data(&main_res, auth_event_tx_clone.clone());
} else { } else {
warn!("No main_resource"); warn!("No main_resource");
} }
@ -170,20 +171,13 @@ fn setup_webview(window: &Window, event_tx: mpsc::Sender<AuthEvent>) -> tauri::R
wv.connect_load_failed(move |_wv, event, _uri, err| { wv.connect_load_failed(move |_wv, event, _uri, err| {
warn!("Load failed: {:?}, {:?}", event, err); warn!("Load failed: {:?}, {:?}", event, err);
send_auth_error(event_tx.clone(), AuthError::TokenInvalid); send_auth_error(auth_event_tx.clone(), AuthError::TokenInvalid);
false false
}); });
}) })
} }
fn setup_window(window: &Window, event_tx: mpsc::Sender<AuthEvent>) -> EventHandler { fn setup_window(window: &Window, event_tx: mpsc::Sender<AuthEvent>) -> EventHandler {
let event_tx_clone = event_tx.clone();
window.on_window_event(move |event| {
if let CloseRequested { .. } = event {
send_auth_event(event_tx_clone.clone(), AuthEvent::Cancel);
}
});
window.listen_global(AUTH_REQUEST_EVENT, move |event| { window.listen_global(AUTH_REQUEST_EVENT, move |event| {
if let Ok(payload) = TryInto::<AuthRequest>::try_into(event.payload()) { if let Ok(payload) = TryInto::<AuthRequest>::try_into(event.payload()) {
let event_tx = event_tx.clone(); let event_tx = event_tx.clone();
@ -204,7 +198,7 @@ async fn process(
process_request(window, auth_request)?; process_request(window, auth_request)?;
let handle = tokio::spawn(show_window_after_timeout(window.clone())); let handle = tokio::spawn(show_window_after_timeout(window.clone()));
let auth_data = process_auth_event(&window, event_rx).await; let auth_data = monitor_events(&window, event_rx).await;
if !handle.is_finished() { if !handle.is_finished() {
handle.abort(); handle.abort();
@ -239,20 +233,32 @@ async fn show_window_after_timeout(window: Window) {
show_window(&window); show_window(&window);
} }
async fn process_auth_event( async fn monitor_events(window: &Window, event_rx: mpsc::Receiver<AuthEvent>) -> Option<AuthData> {
window: &Window, tokio::select! {
mut event_rx: mpsc::Receiver<AuthEvent>, auth_data = monitor_auth_event(window, event_rx) => Some(auth_data),
) -> Option<AuthData> { _ = monitor_window_close_event(window) => {
info!("Processing auth event..."); warn!("Auth window closed without auth data");
None
}
}
}
async fn monitor_auth_event(window: &Window, mut event_rx: mpsc::Receiver<AuthEvent>) -> AuthData {
info!("Monitoring auth events");
let (cancel_timeout_tx, cancel_timeout_rx) = mpsc::channel::<()>(1); let (cancel_timeout_tx, cancel_timeout_rx) = mpsc::channel::<()>(1);
let cancel_timeout_rx = Arc::new(Mutex::new(cancel_timeout_rx)); let cancel_timeout_rx = Arc::new(Mutex::new(cancel_timeout_rx));
let mut attempt_times = 1;
loop { loop {
if let Some(auth_event) = event_rx.recv().await { if let Some(auth_event) = event_rx.recv().await {
match auth_event { match auth_event {
AuthEvent::Request(auth_request) => { AuthEvent::Request(auth_request) => {
info!("Got auth request from auth-request event, processing"); attempt_times = attempt_times + 1;
info!(
"Got auth request from auth-request event, attempt #{}",
attempt_times
);
if let Err(err) = process_request(&window, auth_request) { if let Err(err) = process_request(&window, auth_request) {
warn!("Error processing auth request: {}", err); warn!("Error processing auth request: {}", err);
} }
@ -260,20 +266,26 @@ async fn process_auth_event(
AuthEvent::Success(auth_data) => { AuthEvent::Success(auth_data) => {
info!("Got auth data successfully, closing window"); info!("Got auth data successfully, closing window");
close_window(window); close_window(window);
return Some(auth_data); return auth_data;
}
AuthEvent::Cancel => {
info!("User cancelled the authentication process, closing window");
return None;
} }
AuthEvent::Error(AuthError::TokenInvalid) => { AuthEvent::Error(AuthError::TokenInvalid) => {
// Found the invalid token, means that user is authenticated, keep retrying and no need to show the window // Found the invalid token, means that user is authenticated, keep retrying and no need to show the window
warn!("Found invalid auth data, retrying"); warn!(
if let Err(err) = cancel_timeout_tx.send(()).await { "Attempt #{} failed, found invalid token, retrying",
attempt_times
);
// If the cancel timeout is locked, it means that the window is about to show, so we need to cancel it
if cancel_timeout_rx.try_lock().is_err() {
if let Err(err) = cancel_timeout_tx.try_send(()) {
warn!("Error sending cancel timeout: {}", err); warn!("Error sending cancel timeout: {}", err);
} }
} else {
info!("Window is not about to show, skipping cancel timeout");
}
// Send the error event to the outside, so that we can retry it when receiving the auth-request event // Send the error event to the outside, so that we can retry it when receiving the auth-request event
if let Err(err) = window.emit_all(AUTH_ERROR_EVENT, ()) { if let Err(err) = window.emit_all(AUTH_ERROR_EVENT, attempt_times) {
warn!("Error emitting auth-error event: {:?}", err); warn!("Error emitting auth-error event: {:?}", err);
} }
} }
@ -296,6 +308,26 @@ async fn process_auth_event(
} }
} }
async fn monitor_window_close_event(window: &Window) {
let (close_tx, close_rx) = oneshot::channel();
let close_tx = Arc::new(Mutex::new(Some(close_tx)));
window.on_window_event(move |event| {
if matches!(event, WindowEvent::CloseRequested { .. }) {
if let Ok(mut close_tx_locked) = close_tx.try_lock() {
if let Some(close_tx) = close_tx_locked.take() {
if let Err(_) = close_tx.send(()) {
println!("Error sending close event");
}
}
}
}
});
if let Err(err) = close_rx.await {
warn!("Error receiving close event: {}", err);
}
}
/// Tokens not found means that the page might need the user interaction to login, /// Tokens not found means that the page might need the user interaction to login,
/// we should show the window after a short timeout, it will be cancelled if the /// we should show the window after a short timeout, it will be cancelled if the
/// token is found in the response, no matter it's valid or not. /// token is found in the response, no matter it's valid or not.
@ -309,36 +341,36 @@ async fn handle_token_not_found(window: Window, cancel_timeout_rx: Arc<Mutex<mps
); );
show_window(&window); show_window(&window);
} else { } else {
info!("Showing window timeout cancelled"); info!("The scheduled show window task is cancelled");
} }
} else { } else {
debug!("Window will be shown by another task, skipping"); info!("The show window task has been already been scheduled, skipping");
} }
} }
/// Parse the authentication data from the response headers or HTML content /// Parse the authentication data from the response headers or HTML content
/// and send it to the event channel /// and send it to the event channel
fn parse_auth_data(main_res: &WebResource, event_tx: mpsc::Sender<AuthEvent>) { fn parse_auth_data(main_res: &WebResource, auth_event_tx: mpsc::Sender<AuthEvent>) {
if let Some(response) = main_res.response() { if let Some(response) = main_res.response() {
if let Some(auth_data) = read_auth_data_from_response(&response) { if let Some(auth_data) = read_auth_data_from_response(&response) {
debug!("Got auth data from HTTP headers: {:?}", auth_data); debug!("Got auth data from HTTP headers: {:?}", auth_data);
send_auth_data(event_tx, auth_data); send_auth_data(auth_event_tx, auth_data);
return; return;
} }
} }
let event_tx = event_tx.clone(); let auth_event_tx = auth_event_tx.clone();
main_res.data(Cancellable::NONE, move |data| { main_res.data(Cancellable::NONE, move |data| {
if let Ok(data) = data { if let Ok(data) = data {
let html = String::from_utf8_lossy(&data); let html = String::from_utf8_lossy(&data);
match read_auth_data_from_html(&html) { match read_auth_data_from_html(&html) {
Ok(auth_data) => { Ok(auth_data) => {
debug!("Got auth data from HTML: {:?}", auth_data); debug!("Got auth data from HTML: {:?}", auth_data);
send_auth_data(event_tx, auth_data); send_auth_data(auth_event_tx, auth_data);
} }
Err(err) => { Err(err) => {
debug!("Error reading auth data from HTML: {:?}", err); debug!("Error reading auth data from HTML: {:?}", err);
send_auth_error(event_tx, err); send_auth_error(auth_event_tx, err);
} }
} }
} }
@ -400,17 +432,17 @@ fn parse_xml_tag(html: &str, tag: &str) -> Option<String> {
.map(|m| m.as_str().to_string()) .map(|m| m.as_str().to_string())
} }
fn send_auth_data(event_tx: mpsc::Sender<AuthEvent>, auth_data: AuthData) { fn send_auth_data(auth_event_tx: mpsc::Sender<AuthEvent>, auth_data: AuthData) {
send_auth_event(event_tx, AuthEvent::Success(auth_data)); send_auth_event(auth_event_tx, AuthEvent::Success(auth_data));
} }
fn send_auth_error(event_tx: mpsc::Sender<AuthEvent>, err: AuthError) { fn send_auth_error(auth_event_tx: mpsc::Sender<AuthEvent>, err: AuthError) {
send_auth_event(event_tx, AuthEvent::Error(err)); send_auth_event(auth_event_tx, AuthEvent::Error(err));
} }
fn send_auth_event(event_tx: mpsc::Sender<AuthEvent>, auth_event: AuthEvent) { fn send_auth_event(auth_event_tx: mpsc::Sender<AuthEvent>, auth_event: AuthEvent) {
let _ = tauri::async_runtime::spawn(async move { let _ = tauri::async_runtime::spawn(async move {
if let Err(err) = event_tx.send(auth_event).await { if let Err(err) = auth_event_tx.send(auth_event).await {
warn!("Error sending event: {}", err); warn!("Error sending event: {}", err);
} }
}); });

View File

@ -3,6 +3,11 @@ use gpcommon::{Client, ServerApiError, VpnStatus};
use std::sync::Arc; use std::sync::Arc;
use tauri::{AppHandle, State}; use tauri::{AppHandle, State};
#[tauri::command]
pub(crate) async fn service_online<'a>(client: State<'a, Arc<Client>>) -> Result<bool, ()> {
Ok(client.is_online().await)
}
#[tauri::command] #[tauri::command]
pub(crate) async fn vpn_status<'a>( pub(crate) async fn vpn_status<'a>(
client: State<'a, Arc<Client>>, client: State<'a, Arc<Client>>,
@ -30,10 +35,10 @@ pub(crate) async fn vpn_disconnect<'a>(
pub(crate) async fn saml_login( pub(crate) async fn saml_login(
binding: SamlBinding, binding: SamlBinding,
request: String, request: String,
clear_cookies: bool,
app_handle: AppHandle, app_handle: AppHandle,
) -> tauri::Result<Option<AuthData>> { ) -> tauri::Result<Option<AuthData>> {
let user_agent = String::from("PAN GlobalProtect"); let user_agent = String::from("PAN GlobalProtect");
let clear_cookies = false;
let params = SamlLoginParams { let params = SamlLoginParams {
auth_request: AuthRequest::new(binding, request), auth_request: AuthRequest::new(binding, request),
user_agent, user_agent,

View File

@ -4,7 +4,7 @@
)] )]
use env_logger::Env; use env_logger::Env;
use gpcommon::{Client, VpnStatus}; use gpcommon::{Client, ClientStatus, VpnStatus};
use log::warn; use log::warn;
use serde::Serialize; use serde::Serialize;
use std::sync::Arc; use std::sync::Arc;
@ -16,7 +16,7 @@ mod commands;
mod utils; mod utils;
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
struct StatusPayload { struct VpnStatusPayload {
status: VpnStatus, status: VpnStatus,
} }
@ -26,11 +26,18 @@ fn setup(app: &mut tauri::App) -> Result<(), Box<dyn std::error::Error>> {
let app_handle = app.handle(); let app_handle = app.handle();
tauri::async_runtime::spawn(async move { tauri::async_runtime::spawn(async move {
let _ = client_clone.subscribe_status(move |status| { let _ = client_clone.subscribe_status(move |client_status| match client_status {
let payload = StatusPayload { status }; ClientStatus::Vpn(vpn_status) => {
let payload = VpnStatusPayload { status: vpn_status };
if let Err(err) = app_handle.emit_all("vpn-status-received", payload) { if let Err(err) = app_handle.emit_all("vpn-status-received", payload) {
warn!("Error emitting event: {}", err); warn!("Error emitting event: {}", err);
} }
}
ClientStatus::Service(is_online) => {
if let Err(err) = app_handle.emit_all("service-status-changed", is_online) {
warn!("Error emitting event: {}", err);
}
}
}); });
let _ = client_clone.run().await; let _ = client_clone.run().await;
@ -56,6 +63,7 @@ fn main() {
) )
.setup(setup) .setup(setup)
.invoke_handler(tauri::generate_handler![ .invoke_handler(tauri::generate_handler![
commands::service_online,
commands::vpn_status, commands::vpn_status,
commands::vpn_connect, commands::vpn_connect,
commands::vpn_disconnect, commands::vpn_disconnect,

View File

@ -12,6 +12,9 @@
}, },
"tauri": { "tauri": {
"allowlist": { "allowlist": {
"shell": {
"open": true
},
"http": { "http": {
"all": true, "all": true,
"request": true, "request": true,
@ -19,6 +22,9 @@
}, },
"window": { "window": {
"all": true "all": true
},
"process": {
"exit": true
} }
}, },
"bundle": { "bundle": {

View File

@ -1,10 +1,8 @@
html { html, body {
height: 100%; height: 100%;
margin: 0;
padding: 0;
-webkit-user-select: none; -webkit-user-select: none;
user-select: none; user-select: none;
cursor: default; cursor: default;
} }
body {
height: 100%;
}

View File

@ -1,15 +1,48 @@
import { Box } from "@mui/material"; import { Box } from "@mui/material";
import { useAtomValue } from "jotai";
import "./App.css";
import { statusReadyAtom } from "./atoms/status";
import ConnectForm from "./components/ConnectForm"; import ConnectForm from "./components/ConnectForm";
import ConnectionStatus from "./components/ConnectionStatus"; import ConnectionStatus from "./components/ConnectionStatus";
import Feedback from "./components/Feedback"; import Feedback from "./components/Feedback";
import GatewaySwitcher from "./components/GatewaySwitcher";
import MainMenu from "./components/MainMenu";
import Notification from "./components/Notification"; import Notification from "./components/Notification";
export default function App() { function Loading() {
return ( return (
<Box padding={2} paddingTop={3}> <Box
sx={{
position: "absolute",
inset: 0,
display: "flex",
alignItems: "center",
justifyContent: "center",
}}
>
Loading...
</Box>
);
}
function MainContent() {
return (
<>
<MainMenu />
<ConnectionStatus /> <ConnectionStatus />
<ConnectForm /> <ConnectForm />
<GatewaySwitcher />
<Feedback /> <Feedback />
</>
);
}
export default function App() {
const ready = useAtomValue(statusReadyAtom);
return (
<Box padding={2} paddingBottom={0}>
{ready ? <MainContent /> : <Loading />}
<Notification /> <Notification />
</Box> </Box>
); );

View File

@ -22,7 +22,8 @@ export const gatewayLoginAtom = atom(
throw new Error("Failed to login to gateway"); throw new Error("Failed to login to gateway");
} }
if (!get(isProcessingAtom)) { const isProcessing = get(isProcessingAtom);
if (!isProcessing) {
console.info("Request cancelled"); console.info("Request cancelled");
return; return;
} }
@ -44,13 +45,21 @@ const connectVpnAtom = atom(
} }
); );
const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
export const disconnectVpnAtom = atom(null, async (get, set) => { export const disconnectVpnAtom = atom(null, async (get, set) => {
try { try {
set(statusAtom, "disconnecting"); set(statusAtom, "disconnecting");
await vpnService.disconnect(); await vpnService.disconnect();
set(statusAtom, "disconnected"); // Sleep a short time, so that the client can receive the service's disconnected event.
await sleep(100);
} catch (err) { } catch (err) {
set(statusAtom, "disconnected"); set(statusAtom, "disconnected");
set(notifyErrorAtom, "Failed to disconnect from VPN"); set(notifyErrorAtom, "Failed to disconnect from VPN");
} }
}); });
export const gatewaySwitcherVisibleAtom = atom(false);
export const openGatewaySwitcherAtom = atom(null, (get, set) => {
set(gatewaySwitcherVisibleAtom, true);
});

20
gpgui/src/atoms/menu.ts Normal file
View File

@ -0,0 +1,20 @@
import { exit } from "@tauri-apps/api/process";
import { atom } from "jotai";
import { RESET } from "jotai/utils";
import { disconnectVpnAtom } from "./gateway";
import { appDataStorageAtom, portalAddressAtom } from "./portal";
import { statusAtom } from "./status";
export const resetAtom = atom(null, (_get, set) => {
set(appDataStorageAtom, RESET);
set(portalAddressAtom, "");
});
export const quitAtom = atom(null, async (get, set) => {
const status = get(statusAtom);
if (status === "connected") {
await set(disconnectVpnAtom);
}
await exit();
});

View File

@ -3,11 +3,19 @@ import { atom } from "jotai";
export type Severity = AlertColor; export type Severity = AlertColor;
type NotificationConfig = {
title: string;
message: string;
severity: Severity;
duration?: number;
};
const notificationVisibleAtom = atom(false); const notificationVisibleAtom = atom(false);
export const notificationConfigAtom = atom({ export const notificationConfigAtom = atom<NotificationConfig>({
title: "", title: "",
message: "", message: "",
severity: "info" as Severity, severity: "info" as Severity,
duration: 5000,
}); });
export const closeNotificationAtom = atom( export const closeNotificationAtom = atom(
@ -17,7 +25,9 @@ export const closeNotificationAtom = atom(
} }
); );
export const notifyErrorAtom = atom(null, (_get, set, err: unknown) => { export const notifyErrorAtom = atom(
null,
(_get, set, err: unknown, duration: number = 5000) => {
let msg: string; let msg: string;
if (err instanceof Error) { if (err instanceof Error) {
msg = err.message; msg = err.message;
@ -32,5 +42,20 @@ export const notifyErrorAtom = atom(null, (_get, set, err: unknown) => {
title: "Error", title: "Error",
message: msg, message: msg,
severity: "error", severity: "error",
duration: duration <= 0 ? undefined : duration,
}); });
}); }
);
export const notifySuccessAtom = atom(
null,
(_get, set, msg: string, duration: number = 5000) => {
set(notificationVisibleAtom, true);
set(notificationConfigAtom, {
title: "Success",
message: msg,
severity: "success",
duration: duration <= 0 ? undefined : duration,
});
}
);

View File

@ -1,42 +1,119 @@
import { atom } from "jotai"; import { atom } from "jotai";
import { focusAtom } from "jotai-optics"; import { withImmer } from "jotai-immer";
import { atomWithDefault, atomWithStorage } from "jotai/utils";
import authService, { AuthData } from "../services/authService"; import authService, { AuthData } from "../services/authService";
import portalService, { import portalService, {
PasswordPrelogin, PasswordPrelogin,
PortalCredential,
Prelogin, Prelogin,
SamlPrelogin, SamlPrelogin,
} from "../services/portalService"; } from "../services/portalService";
import { gatewayLoginAtom } from "./gateway"; import { disconnectVpnAtom, gatewayLoginAtom } from "./gateway";
import { notifyErrorAtom } from "./notification"; import { notifyErrorAtom } from "./notification";
import { isProcessingAtom, statusAtom } from "./status"; import { isProcessingAtom, statusAtom } from "./status";
type GatewayData = { export type GatewayData = {
name: string; name: string;
address: string; address: string;
}; };
type Credential = { type CachedPortalCredential = Omit<PortalCredential, "prelogin-cookie">;
user: string;
passwd: string; type PortalData = {
userAuthCookie: string; address: string;
prelogonUserAuthCookie: string; gateways: GatewayData[];
cachedCredential?: CachedPortalCredential;
selectedGateway?: string;
}; };
type AppData = { type AppData = {
portal: string; portal: string;
gateways: GatewayData[]; portals: PortalData[];
selectedGateway: string; clearCookies: boolean;
credentials: Record<string, Credential>;
}; };
const appAtom = atom<AppData>({ type AppDataUpdate =
| {
type: "PORTAL";
payload: PortalData;
}
| {
type: "SELECTED_GATEWAY";
payload: string;
};
const defaultAppData: AppData = {
portal: "", portal: "",
gateways: [], portals: [],
selectedGateway: "", // Whether to clear the cookies of the SAML login webview, default is true
credentials: {}, clearCookies: true,
};
export const appDataStorageAtom = atomWithStorage<AppData>(
"APP_DATA",
defaultAppData
);
const appDataImmerAtom = withImmer(appDataStorageAtom);
const updateAppDataAtom = atom(null, (_get, set, update: AppDataUpdate) => {
const { type, payload } = update;
switch (type) {
case "PORTAL":
const { address } = payload;
set(appDataImmerAtom, (draft) => {
draft.portal = address;
const portalIndex = draft.portals.findIndex(
({ address: portalAddress }) => portalAddress === address
);
if (portalIndex === -1) {
draft.portals.push(payload);
} else {
draft.portals[portalIndex] = payload;
}
});
break;
case "SELECTED_GATEWAY":
set(appDataImmerAtom, (draft) => {
const { portal, portals } = draft;
const portalData = portals.find(({ address }) => address === portal);
if (portalData) {
portalData.selectedGateway = payload;
}
});
break;
}
}); });
export const portalAtom = focusAtom(appAtom, (optic) => optic.prop("portal")); export const portalAddressAtom = atomWithDefault(
(get) => get(appDataImmerAtom).portal
);
export const currentPortalDataAtom = atom<PortalData>((get) => {
const portalAddress = get(portalAddressAtom);
const { portals } = get(appDataImmerAtom);
const portalData = portals.find(({ address }) => address === portalAddress);
return portalData || { address: portalAddress, gateways: [] };
});
const clearCookiesAtom = atom(
(get) => get(appDataImmerAtom).clearCookies,
(_get, set, update: boolean) => {
set(appDataImmerAtom, (draft) => {
draft.clearCookies = update;
});
}
);
export const portalGatewaysAtom = atom<GatewayData[]>((get) => {
const { gateways } = get(currentPortalDataAtom);
return gateways;
});
export const selectedGatewayAtom = atom(
(get) => get(currentPortalDataAtom).selectedGateway
);
export const connectPortalAtom = atom( export const connectPortalAtom = atom(
(get) => get(isProcessingAtom), (get) => get(isProcessingAtom),
async (get, set, action?: "retry-auth") => { async (get, set, action?: "retry-auth") => {
@ -46,7 +123,7 @@ export const connectPortalAtom = atom(
return; return;
} }
const portal = get(portalAtom); const portal = get(portalAddressAtom);
if (!portal) { if (!portal) {
set(notifyErrorAtom, "Portal is empty"); set(notifyErrorAtom, "Portal is empty");
return; return;
@ -55,16 +132,21 @@ export const connectPortalAtom = atom(
try { try {
set(statusAtom, "prelogin"); set(statusAtom, "prelogin");
const prelogin = await portalService.prelogin(portal); const prelogin = await portalService.prelogin(portal);
if (!get(isProcessingAtom)) { const isProcessing = get(isProcessingAtom);
if (!isProcessing) {
console.info("Request cancelled"); console.info("Request cancelled");
return; return;
} }
try {
await set(loginWithCachedCredentialAtom, prelogin);
} catch {
if (prelogin.isSamlAuth) { if (prelogin.isSamlAuth) {
await set(launchSamlAuthAtom, prelogin); await set(launchSamlAuthAtom, prelogin);
} else { } else {
await set(launchPasswordAuthAtom, prelogin); await set(launchPasswordAuthAtom, prelogin);
} }
}
} catch (err) { } catch (err) {
set(cancelConnectPortalAtom); set(cancelConnectPortalAtom);
set(notifyErrorAtom, err); set(notifyErrorAtom, err);
@ -78,6 +160,17 @@ connectPortalAtom.onMount = (dispatch) => {
}); });
}; };
const loginWithCachedCredentialAtom = atom(
null,
async (get, set, prelogin: Prelogin) => {
const { cachedCredential } = get(currentPortalDataAtom);
if (!cachedCredential) {
throw new Error("No cached credential");
}
await set(portalLoginAtom, cachedCredential, prelogin);
}
);
export const passwordPreloginAtom = atom<PasswordPrelogin>({ export const passwordPreloginAtom = atom<PasswordPrelogin>({
isSamlAuth: false, isSamlAuth: false,
region: "", region: "",
@ -90,8 +183,14 @@ export const cancelConnectPortalAtom = atom(null, (_get, set) => {
set(statusAtom, "disconnected"); set(statusAtom, "disconnected");
}); });
export const usernameAtom = atom(""); export const usernameAtom = atomWithDefault(
export const passwordAtom = atom(""); (get) => get(currentPortalDataAtom).cachedCredential?.user ?? ""
);
export const passwordAtom = atomWithDefault(
(get) => get(currentPortalDataAtom).cachedCredential?.passwd ?? ""
);
const passwordAuthVisibleAtom = atom(false); const passwordAuthVisibleAtom = atom(false);
const launchPasswordAuthAtom = atom( const launchPasswordAuthAtom = atom(
@ -114,7 +213,7 @@ export const cancelPasswordAuthAtom = atom(
export const passwordLoginAtom = atom( export const passwordLoginAtom = atom(
(get) => get(portalConfigLoadingAtom), (get) => get(portalConfigLoadingAtom),
async (get, set, username: string, password: string) => { async (get, set, username: string, password: string) => {
const portal = get(portalAtom); const portal = get(portalAddressAtom);
if (!portal) { if (!portal) {
set(notifyErrorAtom, "Portal is empty"); set(notifyErrorAtom, "Portal is empty");
return; return;
@ -138,13 +237,18 @@ export const passwordLoginAtom = atom(
const launchSamlAuthAtom = atom( const launchSamlAuthAtom = atom(
null, null,
async (_get, set, prelogin: SamlPrelogin) => { async (get, set, prelogin: SamlPrelogin) => {
const { samlAuthMethod, samlRequest } = prelogin; const { samlAuthMethod, samlRequest } = prelogin;
let authData: AuthData; let authData: AuthData;
try { try {
set(statusAtom, "authenticating-saml"); set(statusAtom, "authenticating-saml");
authData = await authService.samlLogin(samlAuthMethod, samlRequest); const clearCookies = get(clearCookiesAtom);
authData = await authService.samlLogin(
samlAuthMethod,
samlRequest,
clearCookies
);
} catch (err) { } catch (err) {
throw new Error("SAML login failed"); throw new Error("SAML login failed");
} }
@ -155,17 +259,21 @@ const launchSamlAuthAtom = atom(
return; return;
} }
// SAML login success, update clearCookies to false to reuse the SAML session
set(clearCookiesAtom, false);
const credential = { const credential = {
user: authData.username, user: authData.username,
"prelogin-cookie": authData.prelogin_cookie, "prelogin-cookie": authData.prelogin_cookie,
"portal-userauthcookie": authData.portal_userauthcookie, "portal-userauthcookie": authData.portal_userauthcookie,
}; };
await set(portalLoginAtom, credential, prelogin); await set(portalLoginAtom, credential, prelogin);
} }
); );
const retrySamlAuthAtom = atom(null, async (get) => { const retrySamlAuthAtom = atom(null, async (get) => {
const portal = get(portalAtom); const portal = get(portalAddressAtom);
const prelogin = await portalService.prelogin(portal); const prelogin = await portalService.prelogin(portal);
if (prelogin.isSamlAuth) { if (prelogin.isSamlAuth) {
await authService.emitAuthRequest({ await authService.emitAuthRequest({
@ -175,17 +283,6 @@ const retrySamlAuthAtom = atom(null, async (get) => {
} }
}); });
type PortalCredential =
| {
user: string;
passwd: string;
}
| {
user: string;
"prelogin-cookie": string | null;
"portal-userauthcookie": string | null;
};
const portalConfigLoadingAtom = atom(false); const portalConfigLoadingAtom = atom(false);
const portalLoginAtom = atom( const portalLoginAtom = atom(
(get) => get(portalConfigLoadingAtom), (get) => get(portalConfigLoadingAtom),
@ -193,33 +290,88 @@ const portalLoginAtom = atom(
set(statusAtom, "portal-config"); set(statusAtom, "portal-config");
set(portalConfigLoadingAtom, true); set(portalConfigLoadingAtom, true);
const portal = get(portalAtom); const portalAddress = get(portalAddressAtom);
let portalConfig; let portalConfig;
try { try {
portalConfig = await portalService.fetchConfig(portal, credential); portalConfig = await portalService.fetchConfig(portalAddress, credential);
// Ensure the password auth window is closed // Ensure the password auth window is closed
set(passwordAuthVisibleAtom, false); set(passwordAuthVisibleAtom, false);
} finally { } finally {
set(portalConfigLoadingAtom, false); set(portalConfigLoadingAtom, false);
} }
if (!get(isProcessingAtom)) { const isProcessing = get(isProcessingAtom);
if (!isProcessing) {
console.info("Request cancelled"); console.info("Request cancelled");
return; return;
} }
const { gateways, userAuthCookie, prelogonUserAuthCookie } = portalConfig; const { gateways, userAuthCookie, prelogonUserAuthCookie } = portalConfig;
console.info("portalConfig", portalConfig);
if (!gateways.length) { if (!gateways.length) {
throw new Error("No gateway found"); throw new Error("No gateway found");
} }
if (userAuthCookie === "empty" || prelogonUserAuthCookie === "empty") {
throw new Error("Failed to login, please try again");
}
// Previous selected gateway
const previousGateway = get(selectedGatewayAtom);
// Update the app data to persist the portal data
set(updateAppDataAtom, {
type: "PORTAL",
payload: {
address: portalAddress,
gateways: gateways.map(({ name, address }) => ({
name,
address,
})),
cachedCredential: {
user: credential.user,
passwd: credential.passwd,
"portal-userauthcookie": userAuthCookie,
"portal-prelogonuserauthcookie": prelogonUserAuthCookie,
},
selectedGateway: previousGateway,
},
});
const { region } = prelogin; const { region } = prelogin;
const { address } = portalService.preferredGateway(gateways, region); const { name, address } = portalService.preferredGateway(gateways, {
region,
previousGateway,
});
await set(gatewayLoginAtom, address, { await set(gatewayLoginAtom, address, {
user: credential.user, user: credential.user,
userAuthCookie, userAuthCookie,
prelogonUserAuthCookie, prelogonUserAuthCookie,
}); });
// Update the app data to persist the gateway data
set(updateAppDataAtom, {
type: "SELECTED_GATEWAY",
payload: name,
});
}
);
export const switchingGatewayAtom = atom(false);
export const switchToGatewayAtom = atom(
(get) => get(switchingGatewayAtom),
async (get, set, gateway: GatewayData) => {
set(updateAppDataAtom, {
type: "SELECTED_GATEWAY",
payload: gateway.name,
});
if (get(statusAtom) === "connected") {
try {
set(switchingGatewayAtom, true);
await set(disconnectVpnAtom);
await set(connectPortalAtom);
} finally {
set(switchingGatewayAtom, false);
}
}
} }
); );

View File

@ -1,5 +1,8 @@
import { atom } from "jotai"; import { atom } from "jotai";
import { atomWithDefault } from "jotai/utils";
import vpnService from "../services/vpnService"; import vpnService from "../services/vpnService";
import { notifyErrorAtom, notifySuccessAtom } from "./notification";
import { selectedGatewayAtom, switchingGatewayAtom } from "./portal";
export type Status = export type Status =
| "disconnected" | "disconnected"
@ -13,13 +16,42 @@ export type Status =
| "disconnecting" | "disconnecting"
| "error"; | "error";
export const statusAtom = atom<Status>("disconnected"); const internalIsOnlineAtom = atomWithDefault(() => vpnService.isOnline());
statusAtom.onMount = (setAtom) => { export const isOnlineAtom = atom(
return vpnService.onStatusChanged((status) => { (get) => get(internalIsOnlineAtom),
status === "connected" && setAtom("connected"); async (get, set, update: boolean) => {
}); const isOnline = await get(internalIsOnlineAtom);
// Already online, do nothing
if (update && update === isOnline) {
return;
}
set(internalIsOnlineAtom, update);
if (update) {
set(notifySuccessAtom, "The background service is online");
} else {
set(notifyErrorAtom, "The background service is offline", 0);
}
}
);
isOnlineAtom.onMount = (setAtom) => vpnService.onServiceStatusChanged(setAtom);
const internalStatusReadyAtom = atom(false);
export const statusReadyAtom = atom(
(get) => get(internalStatusReadyAtom),
(get, set, status: Status) => {
set(internalStatusReadyAtom, true);
set(statusAtom, status);
}
);
statusReadyAtom.onMount = (setAtom) => {
vpnService.status().then(setAtom);
}; };
export const statusAtom = atom<Status>("disconnected");
statusAtom.onMount = (setAtom) => vpnService.onVpnStatusChanged(setAtom);
const statusTextMap: Record<Status, String> = { const statusTextMap: Record<Status, String> = {
disconnected: "Not Connected", disconnected: "Not Connected",
prelogin: "Portal pre-logging in...", prelogin: "Portal pre-logging in...",
@ -35,10 +67,28 @@ const statusTextMap: Record<Status, String> = {
export const statusTextAtom = atom((get) => { export const statusTextAtom = atom((get) => {
const status = get(statusAtom); const status = get(statusAtom);
const switchingGateway = get(switchingGatewayAtom);
if (status === "connected") {
const selectedGateway = get(selectedGatewayAtom);
return selectedGateway
? `Gateway: ${selectedGateway}`
: statusTextMap[status];
}
if (switchingGateway) {
const selectedGateway = get(selectedGatewayAtom);
return `Switching to ${selectedGateway}`;
}
return statusTextMap[status]; return statusTextMap[status];
}); });
export const isProcessingAtom = atom((get) => { export const isProcessingAtom = atom((get) => {
const status = get(statusAtom); const status = get(statusAtom);
return status !== "disconnected" && status !== "connected"; const switchingGateway = get(switchingGatewayAtom);
return (
(status !== "disconnected" && status !== "connected") || switchingGateway
);
}); });

View File

@ -5,16 +5,29 @@ import { disconnectVpnAtom } from "../../atoms/gateway";
import { import {
cancelConnectPortalAtom, cancelConnectPortalAtom,
connectPortalAtom, connectPortalAtom,
portalAtom, portalAddressAtom,
switchingGatewayAtom,
} from "../../atoms/portal"; } from "../../atoms/portal";
import { statusAtom } from "../../atoms/status"; import { isOnlineAtom, statusAtom } from "../../atoms/status";
export default function PortalForm() { export default function PortalForm() {
const [portal, setPortal] = useAtom(portalAtom); const isOnline = useAtomValue(isOnlineAtom);
const [portalAddress, setPortalAddress] = useAtom(portalAddressAtom);
const status = useAtomValue(statusAtom); const status = useAtomValue(statusAtom);
const [processing, connectPortal] = useAtom(connectPortalAtom); const [processing, connectPortal] = useAtom(connectPortalAtom);
const cancelConnectPortal = useSetAtom(cancelConnectPortalAtom); const cancelConnectPortal = useSetAtom(cancelConnectPortalAtom);
const disconnectVpn = useSetAtom(disconnectVpnAtom); const disconnectVpn = useSetAtom(disconnectVpnAtom);
const switchingGateway = useAtomValue(switchingGatewayAtom);
function handlePortalAddressChange(e: ChangeEvent<HTMLInputElement>) {
let host = e.target.value.trim();
if (/^https?:\/\//.test(host)) {
try {
host = new URL(host).hostname;
} catch (e) {}
}
setPortalAddress(host);
}
function handleSubmit(e: ChangeEvent<HTMLFormElement>) { function handleSubmit(e: ChangeEvent<HTMLFormElement>) {
e.preventDefault(); e.preventDefault();
@ -29,26 +42,32 @@ export default function PortalForm() {
placeholder="Hostname or IP address" placeholder="Hostname or IP address"
fullWidth fullWidth
size="small" size="small"
value={portal} value={portalAddress}
onChange={(e) => setPortal(e.target.value.trim())} onChange={handlePortalAddressChange}
InputProps={{ readOnly: status !== "disconnected" }} InputProps={{ readOnly: status !== "disconnected" || switchingGateway }}
sx={{ mb: 1 }} sx={{ mb: 1 }}
/> />
{status === "disconnected" && ( {status === "disconnected" && !switchingGateway && (
<Button <Button
fullWidth fullWidth
type="submit" type="submit"
variant="contained" variant="contained"
disabled={!isOnline}
sx={{ textTransform: "none" }} sx={{ textTransform: "none" }}
> >
Connect Connect
</Button> </Button>
)} )}
{processing && ( {(processing || switchingGateway) && (
<Button <Button
fullWidth fullWidth
variant="outlined" variant="outlined"
disabled={status === "authenticating-saml"} disabled={
status === "authenticating-saml" ||
status === "connecting" ||
status === "disconnecting" ||
switchingGateway
}
onClick={cancelConnectPortal} onClick={cancelConnectPortal}
sx={{ textTransform: "none" }} sx={{ textTransform: "none" }}
> >

View File

@ -2,7 +2,7 @@ import { GppBad, VerifiedUser as VerifiedIcon } from "@mui/icons-material";
import { Box, CircularProgress, styled, useTheme } from "@mui/material"; import { Box, CircularProgress, styled, useTheme } from "@mui/material";
import { useAtomValue } from "jotai"; import { useAtomValue } from "jotai";
import { BeatLoader } from "react-spinners"; import { BeatLoader } from "react-spinners";
import { statusAtom, isProcessingAtom } from "../../atoms/status"; import { isProcessingAtom, statusAtom } from "../../atoms/status";
function useStatusColor() { function useStatusColor() {
const status = useAtomValue(statusAtom); const status = useAtomValue(statusAtom);
@ -25,14 +25,14 @@ function useStatusColor() {
function BackgroundIcon() { function BackgroundIcon() {
const color = useStatusColor(); const color = useStatusColor();
const processing = useAtomValue(isProcessingAtom); const isProcessing = useAtomValue(isProcessingAtom);
return ( return (
<CircularProgress <CircularProgress
size={150} size={150}
thickness={1} thickness={1}
value={processing ? undefined : 100} value={isProcessing ? undefined : 100}
variant={processing ? "indeterminate" : "determinate"} variant={isProcessing ? "indeterminate" : "determinate"}
sx={{ sx={{
position: "absolute", position: "absolute",
top: 0, top: 0,
@ -40,7 +40,7 @@ function BackgroundIcon() {
color, color,
"& circle": { "& circle": {
fill: color, fill: color,
fillOpacity: processing ? 0.1 : 0.25, fillOpacity: isProcessing ? 0.1 : 0.25,
transition: "all 0.3s ease", transition: "all 0.3s ease",
}, },
}} }}
@ -78,16 +78,26 @@ const IconContainer = styled(Box)(({ theme }) =>
}) })
); );
export default function StatusIcon() { function InnerStatusIcon() {
const status = useAtomValue(statusAtom); const status = useAtomValue(statusAtom);
const processing = useAtomValue(isProcessingAtom); const isProcessing = useAtomValue(isProcessingAtom);
if (isProcessing) {
return <ProcessingIcon />;
}
if (status === "connected") {
return <ConnectedIcon />;
}
return <DisconnectedIcon />;
}
export default function StatusIcon() {
return ( return (
<IconContainer> <IconContainer>
<BackgroundIcon /> <BackgroundIcon />
{status === "disconnected" && <DisconnectedIcon />} <InnerStatusIcon />
{processing && <ProcessingIcon />}
{status === "connected" && <ConnectedIcon />}
</IconContainer> </IconContainer>
); );
} }

View File

@ -6,7 +6,17 @@ export default function StatusText() {
const statusText = useAtomValue(statusTextAtom); const statusText = useAtomValue(statusTextAtom);
return ( return (
<Typography textAlign="center" mt={1.5} variant="subtitle1" paragraph> <Typography
textAlign="center"
mt={1.5}
variant="subtitle1"
paragraph
sx={{
overflow: "hidden",
whiteSpace: "nowrap",
textOverflow: "ellipsis",
}}
>
{statusText} {statusText}
</Typography> </Typography>
); );

View File

@ -1,3 +1,43 @@
import { BugReport, Favorite } from "@mui/icons-material";
import { Chip, ChipProps, Stack } from "@mui/material";
import { red } from "@mui/material/colors";
const LinkChip = (props: ChipProps<"a">) => (
<Chip
component="a"
target="_blank"
clickable
variant="outlined"
size="small"
{...props}
/>
);
export default function Feedback() { export default function Feedback() {
return <div>Feedback</div> return (
<Stack direction="row" justifyContent="space-evenly" mt={1}>
<LinkChip
avatar={<BugReport />}
label="Feedback"
href="https://github.com/yuezk/GlobalProtect-openconnect/issues"
/>
<LinkChip
avatar={<Favorite />}
label="Donate"
href="https://www.buymeacoffee.com/yuezk"
sx={{
"& .MuiSvgIcon-root": {
color: red[300],
transition: "all 0.3s ease",
},
"&:hover": {
".MuiSvgIcon-root": {
color: red[500],
transform: "scale(1.1)",
},
},
}}
/>
</Stack>
);
} }

View File

@ -0,0 +1,58 @@
import { Check } from "@mui/icons-material";
import {
Drawer,
ListItemIcon,
ListItemText,
MenuItem,
MenuList,
} from "@mui/material";
import { useAtom, useAtomValue, useSetAtom } from "jotai";
import { gatewaySwitcherVisibleAtom } from "../../atoms/gateway";
import {
GatewayData,
portalGatewaysAtom,
selectedGatewayAtom,
switchToGatewayAtom,
} from "../../atoms/portal";
export default function GatewaySwitcher() {
const [visible, setGatewaySwitcherVisible] = useAtom(
gatewaySwitcherVisibleAtom
);
const gateways = useAtomValue(portalGatewaysAtom);
const selectedGateway = useAtomValue(selectedGatewayAtom);
const switchToGateway = useSetAtom(switchToGatewayAtom);
const handleClose = () => {
setGatewaySwitcherVisible(false);
};
const handleMenuClick = (gateway: GatewayData) => () => {
setGatewaySwitcherVisible(false);
if (gateway.name !== selectedGateway) {
switchToGateway(gateway);
}
};
return (
<Drawer anchor="bottom" open={visible} onClose={handleClose}>
<MenuList
sx={{
maxHeight: 320,
}}
>
{!gateways.length && <MenuItem disabled>No gateways found</MenuItem>}
{gateways.map(({ name, address }) => (
<MenuItem key={name} onClick={handleMenuClick({ name, address })}>
{selectedGateway === name && (
<ListItemIcon>
<Check />
</ListItemIcon>
)}
<ListItemText inset={selectedGateway !== name}>{name}</ListItemText>
</MenuItem>
))}
</MenuList>
</Drawer>
);
}

View File

@ -0,0 +1,111 @@
import {
ExitToApp,
GitHub,
LockReset,
Menu as MenuIcon,
Settings,
VpnLock,
} from "@mui/icons-material";
import { Box, Divider, IconButton, Menu, MenuItem } from "@mui/material";
import { alpha, styled } from "@mui/material/styles";
import { useAtomValue, useSetAtom } from "jotai";
import { useState } from "react";
import { openGatewaySwitcherAtom } from "../../atoms/gateway";
import { quitAtom, resetAtom } from "../../atoms/menu";
import { isProcessingAtom, statusAtom } from "../../atoms/status";
const MenuContainer = styled(Box)(({ theme }) => ({
position: "absolute",
left: theme.spacing(1),
top: theme.spacing(1),
}));
const StyledMenu = styled(Menu)(({ theme }) => ({
"& .MuiPaper-root": {
borderRadius: 6,
minWidth: 180,
"& .MuiMenu-list": {
padding: "4px 0",
},
"& .MuiMenuItem-root": {
minHeight: "auto",
"& .MuiSvgIcon-root": {
fontSize: 18,
color: theme.palette.text.secondary,
marginRight: theme.spacing(1.5),
},
"&:active": {
backgroundColor: alpha(
theme.palette.primary.main,
theme.palette.action.selectedOpacity
),
},
},
},
}));
export default function MainMenu() {
const isProcessing = useAtomValue(isProcessingAtom);
const [anchorEl, setAnchorEl] = useState<null | HTMLElement>(null);
const openGatewaySwitcher = useSetAtom(openGatewaySwitcherAtom);
const status = useAtomValue(statusAtom);
const reset = useSetAtom(resetAtom);
const quit = useSetAtom(quitAtom);
const open = Boolean(anchorEl);
const handleClick = (event: React.MouseEvent<HTMLElement>) => {
setAnchorEl(event.currentTarget);
};
const handleClose = () => {
setAnchorEl(null);
};
return (
<>
<MenuContainer>
<IconButton onClick={handleClick} disabled={isProcessing}>
<MenuIcon />
</IconButton>
<StyledMenu
anchorEl={anchorEl}
open={open}
onClose={handleClose}
onClick={handleClose}
>
<MenuItem onClick={openGatewaySwitcher} disableRipple>
<VpnLock />
Switch Gateway
</MenuItem>
<MenuItem onClick={handleClose} disableRipple>
<Settings />
Settings
</MenuItem>
<MenuItem
onClick={reset}
disableRipple
disabled={status !== "disconnected"}
>
<LockReset />
Reset
</MenuItem>
<Divider />
<MenuItem onClick={quit} disableRipple>
<ExitToApp />
Quit
</MenuItem>
</StyledMenu>
</MenuContainer>
<IconButton
href="https://github.com/yuezk/GlobalProtect-openconnect"
target="_blank"
sx={{
position: "absolute",
right: (theme) => theme.spacing(1),
top: (theme) => theme.spacing(1),
}}
>
<GitHub />
</IconButton>
</>
);
}

View File

@ -11,16 +11,23 @@ function TransitionDown(props: TransitionProps) {
} }
export default function Notification() { export default function Notification() {
const { title, message, severity } = useAtomValue(notificationConfigAtom); const { title, message, severity, duration } = useAtomValue(
notificationConfigAtom
);
const [visible, closeNotification] = useAtom(closeNotificationAtom); const [visible, closeNotification] = useAtom(closeNotificationAtom);
const handleClose = () => {
if (duration) {
closeNotification();
}
};
return ( return (
<Snackbar <Snackbar
open={visible} open={visible}
anchorOrigin={{ vertical: "top", horizontal: "center" }} anchorOrigin={{ vertical: "top", horizontal: "center" }}
autoHideDuration={5000} autoHideDuration={duration}
TransitionComponent={TransitionDown} TransitionComponent={TransitionDown}
onClose={closeNotification} onClose={handleClose}
sx={{ sx={{
top: 0, top: 0,
left: 0, left: 0,

View File

@ -3,8 +3,8 @@ import invokeCommand from "../utils/invokeCommand";
export type AuthData = { export type AuthData = {
username: string; username: string;
prelogin_cookie: string | null; prelogin_cookie?: string;
portal_userauthcookie: string | null; portal_userauthcookie?: string;
}; };
class AuthService { class AuthService {
@ -15,7 +15,8 @@ class AuthService {
} }
private async init() { private async init() {
await listen("auth-error", () => { await listen("auth-error", (evt) => {
console.error("auth-error", evt);
this.authErrorCallback?.(); this.authErrorCallback?.();
}); });
} }
@ -28,8 +29,12 @@ class AuthService {
} }
// binding: "POST" | "REDIRECT" // binding: "POST" | "REDIRECT"
async samlLogin(binding: string, request: string) { async samlLogin(binding: string, request: string, clearCookies: boolean) {
return invokeCommand<AuthData>("saml_login", { binding, request }); return invokeCommand<AuthData>("saml_login", {
binding,
request,
clearCookies,
});
} }
async emitAuthRequest({ async emitAuthRequest({

View File

@ -25,12 +25,12 @@ export type PortalConfig = {
gateways: Gateway[]; gateways: Gateway[];
}; };
export type PortalConfigParams = { export type PortalCredential = {
user: string; user: string;
passwd?: string | null; passwd?: string; // for password auth
"prelogin-cookie"?: string | null; "prelogin-cookie"?: string; // for saml auth
"portal-userauthcookie"?: string | null; "portal-userauthcookie"?: string; // cached cookie from previous portal config
"portal-prelogonuserauthcookie"?: string | null; "portal-prelogonuserauthcookie"?: string; // cached cookie from previous portal config
}; };
class PortalService { class PortalService {
@ -105,7 +105,7 @@ class PortalService {
throw new Error("Unknown prelogin response"); throw new Error("Unknown prelogin response");
} }
async fetchConfig(portal: string, params: PortalConfigParams) { async fetchConfig(portal: string, params: PortalCredential) {
const { const {
user, user,
passwd, passwd,
@ -125,8 +125,10 @@ class PortalService {
direct: "yes", direct: "yes",
clientVer: "4100", clientVer: "4100",
"os-version": "Linux", "os-version": "Linux",
clientgpversion: "6.0.1-19",
"ipv6-support": "yes", "ipv6-support": "yes",
server: portal, server: portal,
host: portal,
user, user,
passwd: passwd || "", passwd: passwd || "",
"prelogin-cookie": preloginCookie || "", "prelogin-cookie": preloginCookie || "",
@ -152,7 +154,7 @@ class PortalService {
} }
private parsePortalConfigResponse(response: string): PortalConfig { private parsePortalConfigResponse(response: string): PortalConfig {
console.log(response); // console.log(response);
const result = parseXml(response); const result = parseXml(response);
const gateways = result.all("gateways list > entry").map((entry) => { const gateways = result.all("gateways list > entry").map((entry) => {
@ -182,8 +184,16 @@ class PortalService {
}; };
} }
preferredGateway(gateways: Gateway[], region: string) { preferredGateway(
console.log(gateways); gateways: Gateway[],
{ region, previousGateway }: { region: string; previousGateway?: string }
) {
for (const gateway of gateways) {
if (gateway.name === previousGateway) {
return gateway;
}
}
let defaultGateway = gateways[0]; let defaultGateway = gateways[0];
for (const gateway of gateways) { for (const gateway of gateways) {
if (gateway.priority < defaultGateway.priority) { if (gateway.priority < defaultGateway.priority) {

View File

@ -1,39 +1,66 @@
import { Event, listen } from "@tauri-apps/api/event"; import { Event, listen } from "@tauri-apps/api/event";
import invokeCommand from "../utils/invokeCommand"; import invokeCommand from "../utils/invokeCommand";
type Status = "disconnected" | "connecting" | "connected" | "disconnecting"; type VpnStatus = "disconnected" | "connecting" | "connected" | "disconnecting";
type StatusCallback = (status: Status) => void; type VpnStatusCallback = (status: VpnStatus) => void;
type StatusPayload = { type VpnStatusPayload = {
status: Status; status: VpnStatus;
}; };
type ServiceStatusCallback = (status: boolean) => void;
class VpnService { class VpnService {
private _status: Status = "disconnected"; private _isOnline?: boolean;
private statusCallbacks: StatusCallback[] = []; private _status?: VpnStatus;
private statusCallbacks: VpnStatusCallback[] = [];
private serviceStatusCallbacks: ServiceStatusCallback[] = [];
constructor() { constructor() {
this.init(); this.init();
} }
private async init() { private async init() {
await listen("vpn-status-received", (event: Event<StatusPayload>) => { await listen("service-status-changed", (event: Event<boolean>) => {
console.log("vpn-status-received", event.payload); this.setIsOnline(event.payload);
this.setStatus(event.payload.status);
}); });
const status = await this.status(); await listen("vpn-status-received", (event: Event<VpnStatusPayload>) => {
this.setStatus(status); this.setStatus(event.payload.status);
});
} }
private setStatus(status: Status) { async isOnline() {
if (this._status != status) { try {
const isOnline = await invokeCommand<boolean>("service_online");
this.setIsOnline(isOnline);
return isOnline;
} catch (err) {
return false;
}
}
private setIsOnline(isOnline: boolean) {
if (this._isOnline !== isOnline) {
this._isOnline = isOnline;
this.serviceStatusCallbacks.forEach((cb) => cb(isOnline));
}
}
private setStatus(status: VpnStatus) {
if (this._status !== status) {
this._status = status; this._status = status;
this.fireStatusCallbacks(); this.statusCallbacks.forEach((cb) => cb(status));
} }
} }
private async status(): Promise<Status> { async status(): Promise<VpnStatus> {
return invokeCommand<Status>("vpn_status"); try {
const status = await invokeCommand<VpnStatus>("vpn_status");
this._status = status;
return status;
} catch (err) {
return "disconnected";
}
} }
async connect(server: string, cookie: string) { async connect(server: string, cookie: string) {
@ -44,19 +71,31 @@ class VpnService {
return invokeCommand("vpn_disconnect"); return invokeCommand("vpn_disconnect");
} }
onStatusChanged(callback: StatusCallback) { onVpnStatusChanged(callback: VpnStatusCallback) {
this.statusCallbacks.push(callback); this.statusCallbacks.push(callback);
if (typeof this._status === "string") {
callback(this._status); callback(this._status);
return () => this.removeStatusCallback(callback); }
return () => this.removeVpnStatusCallback(callback);
} }
private fireStatusCallbacks() { onServiceStatusChanged(callback: ServiceStatusCallback) {
this.statusCallbacks.forEach((cb) => cb(this._status)); this.serviceStatusCallbacks.push(callback);
if (typeof this._isOnline === "boolean") {
callback(this._isOnline);
}
return () => this.removeServiceStatusCallback(callback);
} }
private removeStatusCallback(callback: StatusCallback) { private removeVpnStatusCallback(callback: VpnStatusCallback) {
this.statusCallbacks = this.statusCallbacks.filter((cb) => cb !== callback); this.statusCallbacks = this.statusCallbacks.filter((cb) => cb !== callback);
} }
private removeServiceStatusCallback(callback: ServiceStatusCallback) {
this.serviceStatusCallbacks = this.serviceStatusCallbacks.filter(
(cb) => cb !== callback
);
}
} }
export default new VpnService(); export default new VpnService();