mirror of
				https://github.com/yuezk/GlobalProtect-openconnect.git
				synced 2025-05-20 07:26:58 -04:00 
			
		
		
		
	refactor: improve workflow
This commit is contained in:
		| @@ -1,4 +1,4 @@ | ||||
| use crate::cmd::{Connect, Disconnect, Status}; | ||||
| use crate::cmd::{Connect, Disconnect, GetStatus}; | ||||
| use crate::reader::Reader; | ||||
| use crate::request::CommandPayload; | ||||
| use crate::response::ResponseData; | ||||
| @@ -7,7 +7,7 @@ use crate::RequestPool; | ||||
| use crate::Response; | ||||
| use crate::SOCKET_PATH; | ||||
| use crate::{Request, VpnStatus}; | ||||
| use log::{info, warn, debug}; | ||||
| use log::{debug, info, warn}; | ||||
| use serde::{Deserialize, Serialize}; | ||||
| use std::fmt::Display; | ||||
| use std::sync::Arc; | ||||
| @@ -17,17 +17,24 @@ use tokio::sync::{mpsc, Mutex, RwLock}; | ||||
| use tokio_util::sync::CancellationToken; | ||||
|  | ||||
| #[derive(Debug)] | ||||
| enum ServerEvent { | ||||
| enum ServiceEvent { | ||||
|     Online, | ||||
|     Response(Response), | ||||
|     ServerDisconnected, | ||||
|     Offline, | ||||
| } | ||||
|  | ||||
| impl From<Response> for ServerEvent { | ||||
| impl From<Response> for ServiceEvent { | ||||
|     fn from(response: Response) -> Self { | ||||
|         Self::Response(response) | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub enum ClientStatus { | ||||
|     Vpn(VpnStatus), | ||||
|     Service(bool), | ||||
| } | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub struct Client { | ||||
|     // pool of requests that are waiting for responses | ||||
| @@ -37,10 +44,10 @@ pub struct Client { | ||||
|     // rx for receiving requests from the channel | ||||
|     request_rx: Arc<Mutex<mpsc::Receiver<Request>>>, | ||||
|     // tx for sending responses to the channel | ||||
|     server_event_tx: mpsc::Sender<ServerEvent>, | ||||
|     service_event_tx: mpsc::Sender<ServiceEvent>, | ||||
|     // rx for receiving responses from the channel | ||||
|     server_event_rx: Arc<Mutex<mpsc::Receiver<ServerEvent>>>, | ||||
|     is_healthy: Arc<RwLock<bool>>, | ||||
|     service_event_rx: Arc<Mutex<mpsc::Receiver<ServiceEvent>>>, | ||||
|     is_online: Arc<RwLock<bool>>, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Serialize, Deserialize)] | ||||
| @@ -71,34 +78,41 @@ impl From<&str> for ServerApiError { | ||||
| impl Default for Client { | ||||
|     fn default() -> Self { | ||||
|         let (request_tx, request_rx) = mpsc::channel::<Request>(32); | ||||
|         let (server_event_tx, server_event_rx) = mpsc::channel::<ServerEvent>(32); | ||||
|         let (service_event_tx, server_event_rx) = mpsc::channel::<ServiceEvent>(32); | ||||
|  | ||||
|         Self { | ||||
|             request_pool: Default::default(), | ||||
|             request_tx, | ||||
|             request_rx: Arc::new(Mutex::new(request_rx)), | ||||
|             server_event_tx, | ||||
|             server_event_rx: Arc::new(Mutex::new(server_event_rx)), | ||||
|             is_healthy: Default::default(), | ||||
|             service_event_tx, | ||||
|             service_event_rx: Arc::new(Mutex::new(server_event_rx)), | ||||
|             is_online: Default::default(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Client { | ||||
|     pub fn subscribe_status(&self, callback: impl Fn(VpnStatus) + Send + Sync + 'static) { | ||||
|         let server_event_rx = self.server_event_rx.clone(); | ||||
|     pub async fn is_online(&self) -> bool { | ||||
|         *self.is_online.read().await | ||||
|     } | ||||
|  | ||||
|     pub fn subscribe_status(&self, callback: impl Fn(ClientStatus) + Send + Sync + 'static) { | ||||
|         let service_event_rx = self.service_event_rx.clone(); | ||||
|  | ||||
|         tokio::spawn(async move { | ||||
|             loop { | ||||
|                 let mut server_event_rx = server_event_rx.lock().await; | ||||
|                 let mut server_event_rx = service_event_rx.lock().await; | ||||
|                 if let Some(server_event) = server_event_rx.recv().await { | ||||
|                     match server_event { | ||||
|                         ServerEvent::ServerDisconnected => { | ||||
|                             callback(VpnStatus::Disconnected); | ||||
|                         ServiceEvent::Online => { | ||||
|                             callback(ClientStatus::Service(true)); | ||||
|                         } | ||||
|                         ServerEvent::Response(response) => { | ||||
|                         ServiceEvent::Offline => { | ||||
|                             callback(ClientStatus::Service(false)); | ||||
|                         } | ||||
|                         ServiceEvent::Response(response) => { | ||||
|                             if let ResponseData::Status(vpn_status) = response.data() { | ||||
|                                 callback(vpn_status); | ||||
|                                 callback(ClientStatus::Vpn(vpn_status)); | ||||
|                             } | ||||
|                         } | ||||
|                     } | ||||
| @@ -134,7 +148,7 @@ impl Client { | ||||
|         let read_handle = tokio::spawn(handle_read( | ||||
|             read_stream, | ||||
|             self.request_pool.clone(), | ||||
|             self.server_event_tx.clone(), | ||||
|             self.service_event_tx.clone(), | ||||
|             cancel_token.clone(), | ||||
|         )); | ||||
|  | ||||
| @@ -144,13 +158,16 @@ impl Client { | ||||
|             cancel_token, | ||||
|         )); | ||||
|  | ||||
|         *self.is_healthy.write().await = true; | ||||
|         *self.is_online.write().await = true; | ||||
|         info!("Connected to the background service"); | ||||
|         if let Err(err) = self.service_event_tx.send(ServiceEvent::Online).await { | ||||
|             warn!("Error sending online event to the channel: {}", err); | ||||
|         } | ||||
|  | ||||
|         let _ = tokio::join!(read_handle, write_handle); | ||||
|         *self.is_healthy.write().await = false; | ||||
|         *self.is_online.write().await = false; | ||||
|  | ||||
|         // TODO connection was lost, cleanup the request pool and notify the UI | ||||
|         // TODO connection was lost, cleanup the request pool | ||||
|  | ||||
|         Ok(()) | ||||
|     } | ||||
| @@ -159,7 +176,7 @@ impl Client { | ||||
|         &self, | ||||
|         payload: CommandPayload, | ||||
|     ) -> Result<T, ServerApiError> { | ||||
|         if !*self.is_healthy.read().await { | ||||
|         if !*self.is_online.read().await { | ||||
|             return Err("Background service is not running".into()); | ||||
|         } | ||||
|  | ||||
| @@ -169,18 +186,19 @@ impl Client { | ||||
|             return Err(format!("Error sending request to the channel: {}", err).into()); | ||||
|         } | ||||
|  | ||||
|         if let Ok(response) = response_rx.await { | ||||
|             if response.success() { | ||||
|                 match response.data().try_into() { | ||||
|                     Ok(it) => Ok(it), | ||||
|                     Err(_) => Err("Error parsing response data".into()), | ||||
|         response_rx | ||||
|             .await | ||||
|             .map_err(|_| "Error receiving response from the channel".into()) | ||||
|             .and_then(|response| { | ||||
|                 if response.success() { | ||||
|                     response | ||||
|                         .data() | ||||
|                         .try_into() | ||||
|                         .map_err(|_| "Error parsing response data".into()) | ||||
|                 } else { | ||||
|                     Err(response.message().into()) | ||||
|                 } | ||||
|             } else { | ||||
|                 Err(response.message().into()) | ||||
|             } | ||||
|         } else { | ||||
|             Err("Error receiving response from the channel".into()) | ||||
|         } | ||||
|             }) | ||||
|     } | ||||
|  | ||||
|     pub async fn connect(&self, server: String, cookie: String) -> Result<(), ServerApiError> { | ||||
| @@ -192,14 +210,14 @@ impl Client { | ||||
|     } | ||||
|  | ||||
|     pub async fn status(&self) -> Result<VpnStatus, ServerApiError> { | ||||
|         self.send_command(Status.into()).await | ||||
|         self.send_command(GetStatus.into()).await | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn handle_read( | ||||
|     read_stream: ReadHalf<UnixStream>, | ||||
|     request_pool: Arc<RequestPool>, | ||||
|     server_event_tx: mpsc::Sender<ServerEvent>, | ||||
|     service_event_tx: mpsc::Sender<ServiceEvent>, | ||||
|     cancel_token: CancellationToken, | ||||
| ) { | ||||
|     let mut reader: Reader = read_stream.into(); | ||||
| @@ -211,7 +229,7 @@ async fn handle_read( | ||||
|                     match response.request_id() { | ||||
|                         Some(id) => request_pool.complete_request(id, response).await, | ||||
|                         None => { | ||||
|                             if let Err(err) = server_event_tx.send(response.into()).await { | ||||
|                             if let Err(err) = service_event_tx.send(response.into()).await { | ||||
|                                 warn!("Error sending response to output channel: {}", err); | ||||
|                             } | ||||
|                         } | ||||
| @@ -220,7 +238,7 @@ async fn handle_read( | ||||
|             } | ||||
|             Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => { | ||||
|                 warn!("Disconnected from the background service"); | ||||
|                 if let Err(err) = server_event_tx.send(ServerEvent::ServerDisconnected).await { | ||||
|                 if let Err(err) = service_event_tx.send(ServiceEvent::Offline).await { | ||||
|                     warn!( | ||||
|                         "Error sending server disconnected event to channel: {}", | ||||
|                         err | ||||
|   | ||||
| @@ -12,7 +12,7 @@ mod status; | ||||
|  | ||||
| pub use connect::Connect; | ||||
| pub use disconnect::Disconnect; | ||||
| pub use status::Status; | ||||
| pub use status::GetStatus; | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub(crate) struct CommandContext { | ||||
|   | ||||
| @@ -4,10 +4,10 @@ use async_trait::async_trait; | ||||
| use serde::{Deserialize, Serialize}; | ||||
|  | ||||
| #[derive(Debug, Serialize, Deserialize, Clone)] | ||||
| pub struct Status; | ||||
| pub struct GetStatus; | ||||
|  | ||||
| #[async_trait] | ||||
| impl Command for Status { | ||||
| impl Command for GetStatus { | ||||
|     async fn handle(&self, context: CommandContext) -> Result<ResponseData, CommandError> { | ||||
|         let status = context.server_context.vpn().status().await; | ||||
|  | ||||
|   | ||||
| @@ -23,30 +23,31 @@ async fn handle_read( | ||||
|     let mut authenticated: Option<bool> = None; | ||||
|  | ||||
|     loop { | ||||
|         match reader.read::<Request>().await { | ||||
|             Ok(request) => { | ||||
|         match reader.read_multiple::<Request>().await { | ||||
|             Ok(requests) => { | ||||
|                 if authenticated.is_none() { | ||||
|                     authenticated = Some(authenticate(peer_pid)); | ||||
|                 } | ||||
|  | ||||
|                 if !authenticated.unwrap_or(false) { | ||||
|                     warn!("Client not authenticated, closing connection"); | ||||
|                     cancel_token.cancel(); | ||||
|                     break; | ||||
|                 } | ||||
|  | ||||
|                 debug!("Received client request: {:?}", request); | ||||
|                 for request in requests { | ||||
|                     debug!("Received client request: {:?}", request); | ||||
|  | ||||
|                 let command = request.command(); | ||||
|                 let context = server_context.clone().into(); | ||||
|                     let command = request.command(); | ||||
|                     let context = server_context.clone().into(); | ||||
|  | ||||
|                 let mut response = match command.handle(context).await { | ||||
|                     Ok(data) => Response::from(data), | ||||
|                     Err(err) => Response::from(err.to_string()), | ||||
|                 }; | ||||
|                 response.set_request_id(request.id()); | ||||
|                     let mut response = match command.handle(context).await { | ||||
|                         Ok(data) => Response::from(data), | ||||
|                         Err(err) => Response::from(err.to_string()), | ||||
|                     }; | ||||
|                     response.set_request_id(request.id()); | ||||
|  | ||||
|                 let _ = response_tx.send(response).await; | ||||
|                     let _ = response_tx.send(response).await; | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => { | ||||
|   | ||||
| @@ -30,6 +30,7 @@ pub(crate) use writer::Writer; | ||||
|  | ||||
| pub use client::Client; | ||||
| pub use client::ServerApiError; | ||||
| pub use client::ClientStatus; | ||||
| pub use vpn::VpnStatus; | ||||
|  | ||||
| pub fn sha256_digest<P: AsRef<Path>>(file_path: P) -> Result<String, std::io::Error> { | ||||
|   | ||||
| @@ -13,22 +13,6 @@ impl From<ReadHalf<UnixStream>> for Reader { | ||||
| } | ||||
|  | ||||
| impl Reader { | ||||
|     pub async fn read<T: for<'a> Deserialize<'a>>(&mut self) -> Result<T, io::Error> { | ||||
|         let mut buffer = [0; 2048]; | ||||
|  | ||||
|         match self.stream.read(&mut buffer).await { | ||||
|             Ok(0) => Err(io::Error::new( | ||||
|                 io::ErrorKind::ConnectionAborted, | ||||
|                 "Peer disconnected", | ||||
|             )), | ||||
|             Ok(bytes_read) => { | ||||
|                 let data = serde_json::from_slice::<T>(&buffer[..bytes_read])?; | ||||
|                 Ok(data) | ||||
|             } | ||||
|             Err(err) => Err(err), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub async fn read_multiple<T: for<'a> Deserialize<'a>>(&mut self) -> Result<Vec<T>, io::Error> { | ||||
|         let mut buffer = [0; 2048]; | ||||
|  | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| use crate::cmd::{Command, Connect, Disconnect, Status}; | ||||
| use crate::cmd::{Command, Connect, Disconnect, GetStatus}; | ||||
| use crate::Response; | ||||
| use serde::{Deserialize, Serialize}; | ||||
| use std::sync::Arc; | ||||
| @@ -21,7 +21,7 @@ impl Request { | ||||
|  | ||||
|     pub fn command(&self) -> Box<dyn Command> { | ||||
|         match &self.payload { | ||||
|             CommandPayload::Status(status) => Box::new(status.clone()), | ||||
|             CommandPayload::GetStatus(status) => Box::new(status.clone()), | ||||
|             CommandPayload::Connect(connect) => Box::new(connect.clone()), | ||||
|             CommandPayload::Disconnect(disconnect) => Box::new(disconnect.clone()), | ||||
|         } | ||||
| @@ -30,14 +30,14 @@ impl Request { | ||||
|  | ||||
| #[derive(Debug, Serialize, Deserialize)] | ||||
| pub(crate) enum CommandPayload { | ||||
|     Status(Status), | ||||
|     GetStatus(GetStatus), | ||||
|     Connect(Connect), | ||||
|     Disconnect(Disconnect), | ||||
| } | ||||
|  | ||||
| impl From<Status> for CommandPayload { | ||||
|     fn from(status: Status) -> Self { | ||||
|         Self::Status(status) | ||||
| impl From<GetStatus> for CommandPayload { | ||||
|     fn from(status: GetStatus) -> Self { | ||||
|         Self::GetStatus(status) | ||||
|     } | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| use log::{warn, info, debug}; | ||||
| use log::{debug, info, warn}; | ||||
| use serde::{Deserialize, Serialize}; | ||||
| use std::ffi::{c_void, CString}; | ||||
| use std::sync::Arc; | ||||
| @@ -91,7 +91,7 @@ impl Vpn { | ||||
|         *self.vpn_options.lock().await = Some(VpnOptions { | ||||
|             server: VpnOptions::to_cstr(server), | ||||
|             cookie: VpnOptions::to_cstr(cookie), | ||||
|             script: VpnOptions::to_cstr("/usr/share/vpnc-scripts/vpnc-script") | ||||
|             script: VpnOptions::to_cstr("/usr/share/vpnc-scripts/vpnc-script"), | ||||
|         }); | ||||
|  | ||||
|         let vpn_options = self.vpn_options.clone(); | ||||
| @@ -133,12 +133,15 @@ impl Vpn { | ||||
|         } | ||||
|  | ||||
|         info!("Disconnecting VPN..."); | ||||
|         self.status_holder | ||||
|             .lock() | ||||
|             .await | ||||
|             .set(VpnStatus::Disconnecting); | ||||
|         unsafe { ffi::disconnect() }; | ||||
|  | ||||
|         let mut status_rx = self.status_rx().await; | ||||
|         debug!("Waiting for the VPN to disconnect..."); | ||||
|  | ||||
|  | ||||
|         while status_rx.changed().await.is_ok() { | ||||
|             if *status_rx.borrow() == VpnStatus::Disconnected { | ||||
|                 info!("VPN disconnected"); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user