mirror of
https://github.com/yuezk/GlobalProtect-openconnect.git
synced 2025-05-20 07:26:58 -04:00
refactor: improve workflow
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
use crate::cmd::{Connect, Disconnect, Status};
|
||||
use crate::cmd::{Connect, Disconnect, GetStatus};
|
||||
use crate::reader::Reader;
|
||||
use crate::request::CommandPayload;
|
||||
use crate::response::ResponseData;
|
||||
@@ -7,7 +7,7 @@ use crate::RequestPool;
|
||||
use crate::Response;
|
||||
use crate::SOCKET_PATH;
|
||||
use crate::{Request, VpnStatus};
|
||||
use log::{info, warn, debug};
|
||||
use log::{debug, info, warn};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Display;
|
||||
use std::sync::Arc;
|
||||
@@ -17,17 +17,24 @@ use tokio::sync::{mpsc, Mutex, RwLock};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ServerEvent {
|
||||
enum ServiceEvent {
|
||||
Online,
|
||||
Response(Response),
|
||||
ServerDisconnected,
|
||||
Offline,
|
||||
}
|
||||
|
||||
impl From<Response> for ServerEvent {
|
||||
impl From<Response> for ServiceEvent {
|
||||
fn from(response: Response) -> Self {
|
||||
Self::Response(response)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ClientStatus {
|
||||
Vpn(VpnStatus),
|
||||
Service(bool),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Client {
|
||||
// pool of requests that are waiting for responses
|
||||
@@ -37,10 +44,10 @@ pub struct Client {
|
||||
// rx for receiving requests from the channel
|
||||
request_rx: Arc<Mutex<mpsc::Receiver<Request>>>,
|
||||
// tx for sending responses to the channel
|
||||
server_event_tx: mpsc::Sender<ServerEvent>,
|
||||
service_event_tx: mpsc::Sender<ServiceEvent>,
|
||||
// rx for receiving responses from the channel
|
||||
server_event_rx: Arc<Mutex<mpsc::Receiver<ServerEvent>>>,
|
||||
is_healthy: Arc<RwLock<bool>>,
|
||||
service_event_rx: Arc<Mutex<mpsc::Receiver<ServiceEvent>>>,
|
||||
is_online: Arc<RwLock<bool>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@@ -71,34 +78,41 @@ impl From<&str> for ServerApiError {
|
||||
impl Default for Client {
|
||||
fn default() -> Self {
|
||||
let (request_tx, request_rx) = mpsc::channel::<Request>(32);
|
||||
let (server_event_tx, server_event_rx) = mpsc::channel::<ServerEvent>(32);
|
||||
let (service_event_tx, server_event_rx) = mpsc::channel::<ServiceEvent>(32);
|
||||
|
||||
Self {
|
||||
request_pool: Default::default(),
|
||||
request_tx,
|
||||
request_rx: Arc::new(Mutex::new(request_rx)),
|
||||
server_event_tx,
|
||||
server_event_rx: Arc::new(Mutex::new(server_event_rx)),
|
||||
is_healthy: Default::default(),
|
||||
service_event_tx,
|
||||
service_event_rx: Arc::new(Mutex::new(server_event_rx)),
|
||||
is_online: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub fn subscribe_status(&self, callback: impl Fn(VpnStatus) + Send + Sync + 'static) {
|
||||
let server_event_rx = self.server_event_rx.clone();
|
||||
pub async fn is_online(&self) -> bool {
|
||||
*self.is_online.read().await
|
||||
}
|
||||
|
||||
pub fn subscribe_status(&self, callback: impl Fn(ClientStatus) + Send + Sync + 'static) {
|
||||
let service_event_rx = self.service_event_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let mut server_event_rx = server_event_rx.lock().await;
|
||||
let mut server_event_rx = service_event_rx.lock().await;
|
||||
if let Some(server_event) = server_event_rx.recv().await {
|
||||
match server_event {
|
||||
ServerEvent::ServerDisconnected => {
|
||||
callback(VpnStatus::Disconnected);
|
||||
ServiceEvent::Online => {
|
||||
callback(ClientStatus::Service(true));
|
||||
}
|
||||
ServerEvent::Response(response) => {
|
||||
ServiceEvent::Offline => {
|
||||
callback(ClientStatus::Service(false));
|
||||
}
|
||||
ServiceEvent::Response(response) => {
|
||||
if let ResponseData::Status(vpn_status) = response.data() {
|
||||
callback(vpn_status);
|
||||
callback(ClientStatus::Vpn(vpn_status));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -134,7 +148,7 @@ impl Client {
|
||||
let read_handle = tokio::spawn(handle_read(
|
||||
read_stream,
|
||||
self.request_pool.clone(),
|
||||
self.server_event_tx.clone(),
|
||||
self.service_event_tx.clone(),
|
||||
cancel_token.clone(),
|
||||
));
|
||||
|
||||
@@ -144,13 +158,16 @@ impl Client {
|
||||
cancel_token,
|
||||
));
|
||||
|
||||
*self.is_healthy.write().await = true;
|
||||
*self.is_online.write().await = true;
|
||||
info!("Connected to the background service");
|
||||
if let Err(err) = self.service_event_tx.send(ServiceEvent::Online).await {
|
||||
warn!("Error sending online event to the channel: {}", err);
|
||||
}
|
||||
|
||||
let _ = tokio::join!(read_handle, write_handle);
|
||||
*self.is_healthy.write().await = false;
|
||||
*self.is_online.write().await = false;
|
||||
|
||||
// TODO connection was lost, cleanup the request pool and notify the UI
|
||||
// TODO connection was lost, cleanup the request pool
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -159,7 +176,7 @@ impl Client {
|
||||
&self,
|
||||
payload: CommandPayload,
|
||||
) -> Result<T, ServerApiError> {
|
||||
if !*self.is_healthy.read().await {
|
||||
if !*self.is_online.read().await {
|
||||
return Err("Background service is not running".into());
|
||||
}
|
||||
|
||||
@@ -169,18 +186,19 @@ impl Client {
|
||||
return Err(format!("Error sending request to the channel: {}", err).into());
|
||||
}
|
||||
|
||||
if let Ok(response) = response_rx.await {
|
||||
if response.success() {
|
||||
match response.data().try_into() {
|
||||
Ok(it) => Ok(it),
|
||||
Err(_) => Err("Error parsing response data".into()),
|
||||
response_rx
|
||||
.await
|
||||
.map_err(|_| "Error receiving response from the channel".into())
|
||||
.and_then(|response| {
|
||||
if response.success() {
|
||||
response
|
||||
.data()
|
||||
.try_into()
|
||||
.map_err(|_| "Error parsing response data".into())
|
||||
} else {
|
||||
Err(response.message().into())
|
||||
}
|
||||
} else {
|
||||
Err(response.message().into())
|
||||
}
|
||||
} else {
|
||||
Err("Error receiving response from the channel".into())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn connect(&self, server: String, cookie: String) -> Result<(), ServerApiError> {
|
||||
@@ -192,14 +210,14 @@ impl Client {
|
||||
}
|
||||
|
||||
pub async fn status(&self) -> Result<VpnStatus, ServerApiError> {
|
||||
self.send_command(Status.into()).await
|
||||
self.send_command(GetStatus.into()).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_read(
|
||||
read_stream: ReadHalf<UnixStream>,
|
||||
request_pool: Arc<RequestPool>,
|
||||
server_event_tx: mpsc::Sender<ServerEvent>,
|
||||
service_event_tx: mpsc::Sender<ServiceEvent>,
|
||||
cancel_token: CancellationToken,
|
||||
) {
|
||||
let mut reader: Reader = read_stream.into();
|
||||
@@ -211,7 +229,7 @@ async fn handle_read(
|
||||
match response.request_id() {
|
||||
Some(id) => request_pool.complete_request(id, response).await,
|
||||
None => {
|
||||
if let Err(err) = server_event_tx.send(response.into()).await {
|
||||
if let Err(err) = service_event_tx.send(response.into()).await {
|
||||
warn!("Error sending response to output channel: {}", err);
|
||||
}
|
||||
}
|
||||
@@ -220,7 +238,7 @@ async fn handle_read(
|
||||
}
|
||||
Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => {
|
||||
warn!("Disconnected from the background service");
|
||||
if let Err(err) = server_event_tx.send(ServerEvent::ServerDisconnected).await {
|
||||
if let Err(err) = service_event_tx.send(ServiceEvent::Offline).await {
|
||||
warn!(
|
||||
"Error sending server disconnected event to channel: {}",
|
||||
err
|
||||
|
@@ -12,7 +12,7 @@ mod status;
|
||||
|
||||
pub use connect::Connect;
|
||||
pub use disconnect::Disconnect;
|
||||
pub use status::Status;
|
||||
pub use status::GetStatus;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct CommandContext {
|
||||
|
@@ -4,10 +4,10 @@ use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Status;
|
||||
pub struct GetStatus;
|
||||
|
||||
#[async_trait]
|
||||
impl Command for Status {
|
||||
impl Command for GetStatus {
|
||||
async fn handle(&self, context: CommandContext) -> Result<ResponseData, CommandError> {
|
||||
let status = context.server_context.vpn().status().await;
|
||||
|
||||
|
@@ -23,30 +23,31 @@ async fn handle_read(
|
||||
let mut authenticated: Option<bool> = None;
|
||||
|
||||
loop {
|
||||
match reader.read::<Request>().await {
|
||||
Ok(request) => {
|
||||
match reader.read_multiple::<Request>().await {
|
||||
Ok(requests) => {
|
||||
if authenticated.is_none() {
|
||||
authenticated = Some(authenticate(peer_pid));
|
||||
}
|
||||
|
||||
if !authenticated.unwrap_or(false) {
|
||||
warn!("Client not authenticated, closing connection");
|
||||
cancel_token.cancel();
|
||||
break;
|
||||
}
|
||||
|
||||
debug!("Received client request: {:?}", request);
|
||||
for request in requests {
|
||||
debug!("Received client request: {:?}", request);
|
||||
|
||||
let command = request.command();
|
||||
let context = server_context.clone().into();
|
||||
let command = request.command();
|
||||
let context = server_context.clone().into();
|
||||
|
||||
let mut response = match command.handle(context).await {
|
||||
Ok(data) => Response::from(data),
|
||||
Err(err) => Response::from(err.to_string()),
|
||||
};
|
||||
response.set_request_id(request.id());
|
||||
let mut response = match command.handle(context).await {
|
||||
Ok(data) => Response::from(data),
|
||||
Err(err) => Response::from(err.to_string()),
|
||||
};
|
||||
response.set_request_id(request.id());
|
||||
|
||||
let _ = response_tx.send(response).await;
|
||||
let _ = response_tx.send(response).await;
|
||||
}
|
||||
}
|
||||
|
||||
Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => {
|
||||
|
@@ -30,6 +30,7 @@ pub(crate) use writer::Writer;
|
||||
|
||||
pub use client::Client;
|
||||
pub use client::ServerApiError;
|
||||
pub use client::ClientStatus;
|
||||
pub use vpn::VpnStatus;
|
||||
|
||||
pub fn sha256_digest<P: AsRef<Path>>(file_path: P) -> Result<String, std::io::Error> {
|
||||
|
@@ -13,22 +13,6 @@ impl From<ReadHalf<UnixStream>> for Reader {
|
||||
}
|
||||
|
||||
impl Reader {
|
||||
pub async fn read<T: for<'a> Deserialize<'a>>(&mut self) -> Result<T, io::Error> {
|
||||
let mut buffer = [0; 2048];
|
||||
|
||||
match self.stream.read(&mut buffer).await {
|
||||
Ok(0) => Err(io::Error::new(
|
||||
io::ErrorKind::ConnectionAborted,
|
||||
"Peer disconnected",
|
||||
)),
|
||||
Ok(bytes_read) => {
|
||||
let data = serde_json::from_slice::<T>(&buffer[..bytes_read])?;
|
||||
Ok(data)
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_multiple<T: for<'a> Deserialize<'a>>(&mut self) -> Result<Vec<T>, io::Error> {
|
||||
let mut buffer = [0; 2048];
|
||||
|
||||
|
@@ -1,4 +1,4 @@
|
||||
use crate::cmd::{Command, Connect, Disconnect, Status};
|
||||
use crate::cmd::{Command, Connect, Disconnect, GetStatus};
|
||||
use crate::Response;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
@@ -21,7 +21,7 @@ impl Request {
|
||||
|
||||
pub fn command(&self) -> Box<dyn Command> {
|
||||
match &self.payload {
|
||||
CommandPayload::Status(status) => Box::new(status.clone()),
|
||||
CommandPayload::GetStatus(status) => Box::new(status.clone()),
|
||||
CommandPayload::Connect(connect) => Box::new(connect.clone()),
|
||||
CommandPayload::Disconnect(disconnect) => Box::new(disconnect.clone()),
|
||||
}
|
||||
@@ -30,14 +30,14 @@ impl Request {
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub(crate) enum CommandPayload {
|
||||
Status(Status),
|
||||
GetStatus(GetStatus),
|
||||
Connect(Connect),
|
||||
Disconnect(Disconnect),
|
||||
}
|
||||
|
||||
impl From<Status> for CommandPayload {
|
||||
fn from(status: Status) -> Self {
|
||||
Self::Status(status)
|
||||
impl From<GetStatus> for CommandPayload {
|
||||
fn from(status: GetStatus) -> Self {
|
||||
Self::GetStatus(status)
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,4 +1,4 @@
|
||||
use log::{warn, info, debug};
|
||||
use log::{debug, info, warn};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::ffi::{c_void, CString};
|
||||
use std::sync::Arc;
|
||||
@@ -91,7 +91,7 @@ impl Vpn {
|
||||
*self.vpn_options.lock().await = Some(VpnOptions {
|
||||
server: VpnOptions::to_cstr(server),
|
||||
cookie: VpnOptions::to_cstr(cookie),
|
||||
script: VpnOptions::to_cstr("/usr/share/vpnc-scripts/vpnc-script")
|
||||
script: VpnOptions::to_cstr("/usr/share/vpnc-scripts/vpnc-script"),
|
||||
});
|
||||
|
||||
let vpn_options = self.vpn_options.clone();
|
||||
@@ -133,12 +133,15 @@ impl Vpn {
|
||||
}
|
||||
|
||||
info!("Disconnecting VPN...");
|
||||
self.status_holder
|
||||
.lock()
|
||||
.await
|
||||
.set(VpnStatus::Disconnecting);
|
||||
unsafe { ffi::disconnect() };
|
||||
|
||||
let mut status_rx = self.status_rx().await;
|
||||
debug!("Waiting for the VPN to disconnect...");
|
||||
|
||||
|
||||
while status_rx.changed().await.is_ok() {
|
||||
if *status_rx.borrow() == VpnStatus::Disconnected {
|
||||
info!("VPN disconnected");
|
||||
|
Reference in New Issue
Block a user