refactor: rewrite

This commit is contained in:
Kevin Yue
2023-02-17 01:21:36 -05:00
parent 7bef2ccc68
commit 19b9b757f4
194 changed files with 7885 additions and 8034 deletions

255
common/src/client.rs Normal file
View File

@@ -0,0 +1,255 @@
use crate::cmd::{Connect, Disconnect, Status};
use crate::reader::Reader;
use crate::request::CommandPayload;
use crate::response::ResponseData;
use crate::writer::Writer;
use crate::RequestPool;
use crate::Response;
use crate::SOCKET_PATH;
use crate::{Request, VpnStatus};
use serde::{Deserialize, Serialize};
use std::fmt::Display;
use std::sync::Arc;
use tokio::io::{self, ReadHalf, WriteHalf};
use tokio::net::UnixStream;
use tokio::sync::{mpsc, Mutex, RwLock};
use tokio_util::sync::CancellationToken;
#[derive(Debug)]
enum ServerEvent {
Response(Response),
ServerDisconnected,
}
impl From<Response> for ServerEvent {
fn from(response: Response) -> Self {
Self::Response(response)
}
}
#[derive(Debug)]
pub struct Client {
// pool of requests that are waiting for responses
request_pool: Arc<RequestPool>,
// tx for sending requests to the channel
request_tx: mpsc::Sender<Request>,
// rx for receiving requests from the channel
request_rx: Arc<Mutex<mpsc::Receiver<Request>>>,
// tx for sending responses to the channel
server_event_tx: mpsc::Sender<ServerEvent>,
// rx for receiving responses from the channel
server_event_rx: Arc<Mutex<mpsc::Receiver<ServerEvent>>>,
is_healthy: Arc<RwLock<bool>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ServerApiError {
pub message: String,
}
impl Display for ServerApiError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{message}", message = self.message)
}
}
impl From<String> for ServerApiError {
fn from(message: String) -> Self {
Self { message }
}
}
impl From<&str> for ServerApiError {
fn from(message: &str) -> Self {
Self {
message: message.to_string(),
}
}
}
impl Default for Client {
fn default() -> Self {
let (request_tx, request_rx) = mpsc::channel::<Request>(32);
let (server_event_tx, server_event_rx) = mpsc::channel::<ServerEvent>(32);
Self {
request_pool: Default::default(),
request_tx,
request_rx: Arc::new(Mutex::new(request_rx)),
server_event_tx,
server_event_rx: Arc::new(Mutex::new(server_event_rx)),
is_healthy: Default::default(),
}
}
}
impl Client {
pub fn subscribe_status(&self, callback: impl Fn(VpnStatus) + Send + Sync + 'static) {
let server_event_rx = self.server_event_rx.clone();
tokio::spawn(async move {
loop {
let mut server_event_rx = server_event_rx.lock().await;
if let Some(server_event) = server_event_rx.recv().await {
match server_event {
ServerEvent::ServerDisconnected => {
callback(VpnStatus::Disconnected);
}
ServerEvent::Response(response) => {
if let ResponseData::Status(vpn_status) = response.data() {
callback(vpn_status);
}
}
}
}
}
});
}
pub async fn run(&self) {
loop {
match self.connect_to_server().await {
Ok(_) => {
println!("Disconnected from server, reconnecting...");
}
Err(err) => {
println!(
"Disconnected from server with error: {:?}, reconnecting...",
err
)
}
}
// wait for a second before trying to reconnect
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
}
async fn connect_to_server(&self) -> Result<(), Box<dyn std::error::Error>> {
let stream = UnixStream::connect(SOCKET_PATH).await?;
let (read_stream, write_stream) = io::split(stream);
let cancel_token = CancellationToken::new();
let read_handle = tokio::spawn(handle_read(
read_stream,
self.request_pool.clone(),
self.server_event_tx.clone(),
cancel_token.clone(),
));
let write_handle = tokio::spawn(handle_write(
write_stream,
self.request_rx.clone(),
cancel_token,
));
*self.is_healthy.write().await = true;
println!("Connected to server");
let _ = tokio::join!(read_handle, write_handle);
*self.is_healthy.write().await = false;
Ok(())
}
async fn send_command<T: TryFrom<ResponseData>>(
&self,
payload: CommandPayload,
) -> Result<T, ServerApiError> {
if !*self.is_healthy.read().await {
return Err("Background service is not running".into());
}
let (request, response_rx) = self.request_pool.create_request(payload).await;
if let Err(err) = self.request_tx.send(request).await {
return Err(format!("Error sending request to the channel: {}", err).into());
}
if let Ok(response) = response_rx.await {
if response.success() {
match response.data().try_into() {
Ok(it) => Ok(it),
Err(_) => Err("Error parsing response data".into()),
}
} else {
Err(response.message().into())
}
} else {
Err("Error receiving response from the channel".into())
}
}
pub async fn connect(&self, server: String, cookie: String) -> Result<(), ServerApiError> {
self.send_command(Connect::new(server, cookie).into()).await
}
pub async fn disconnect(&self) -> Result<(), ServerApiError> {
self.send_command(Disconnect.into()).await
}
pub async fn status(&self) -> Result<VpnStatus, ServerApiError> {
self.send_command(Status.into()).await
}
}
async fn handle_read(
read_stream: ReadHalf<UnixStream>,
request_pool: Arc<RequestPool>,
server_event_tx: mpsc::Sender<ServerEvent>,
cancel_token: CancellationToken,
) {
let mut reader: Reader = read_stream.into();
loop {
match reader.read_multiple::<Response>().await {
Ok(responses) => {
for response in responses {
match response.request_id() {
Some(id) => request_pool.complete_request(id, response).await,
None => {
if let Err(err) = server_event_tx.send(response.into()).await {
println!("Error sending response to output channel: {}", err);
}
}
}
}
}
Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => {
println!("Server disconnected");
if let Err(err) = server_event_tx.send(ServerEvent::ServerDisconnected).await {
println!("Error sending server disconnected event: {}", err);
}
cancel_token.cancel();
break;
}
Err(err) => {
println!("Error reading from server: {}", err);
}
}
}
}
async fn handle_write(
write_stream: WriteHalf<UnixStream>,
request_rx: Arc<Mutex<mpsc::Receiver<Request>>>,
cancel_token: CancellationToken,
) {
let mut writer: Writer = write_stream.into();
loop {
let mut request_rx = request_rx.lock().await;
tokio::select! {
Some(request) = request_rx.recv() => {
if let Err(err) = writer.write(&request).await {
println!("Error writing to server: {}", err);
}
}
_ = cancel_token.cancelled() => {
println!("The read loop has been cancelled, exiting the write loop");
break;
}
else => {
println!("Error reading command from channel");
}
}
}
}

34
common/src/cmd/connect.rs Normal file
View File

@@ -0,0 +1,34 @@
use super::{Command, CommandContext, CommandError};
use crate::{ResponseData, VpnStatus};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Connect {
server: String,
cookie: String,
}
impl Connect {
pub fn new(server: String, cookie: String) -> Self {
Self { server, cookie }
}
}
#[async_trait]
impl Command for Connect {
async fn handle(&self, context: CommandContext) -> Result<ResponseData, CommandError> {
let vpn = context.server_context.vpn();
let status = vpn.status().await;
if status != VpnStatus::Disconnected {
return Err(format!("VPN is already in state: {:?}", status).into());
}
if let Err(err) = vpn.connect(&self.server, &self.cookie).await {
return Err(err.to_string().into());
}
Ok(ResponseData::Empty)
}
}

View File

@@ -0,0 +1,15 @@
use super::{Command, CommandContext, CommandError};
use crate::ResponseData;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Disconnect;
#[async_trait]
impl Command for Disconnect {
async fn handle(&self, context: CommandContext) -> Result<ResponseData, CommandError> {
context.server_context.vpn().disconnect().await;
Ok(ResponseData::Empty)
}
}

54
common/src/cmd/mod.rs Normal file
View File

@@ -0,0 +1,54 @@
use crate::{response::ResponseData, server::ServerContext};
use async_trait::async_trait;
use core::fmt::Debug;
use std::{
fmt::{self, Display},
sync::Arc,
};
mod connect;
mod disconnect;
mod status;
pub use connect::Connect;
pub use disconnect::Disconnect;
pub use status::Status;
#[derive(Debug)]
pub(crate) struct CommandContext {
server_context: Arc<ServerContext>,
}
impl From<Arc<ServerContext>> for CommandContext {
fn from(server_context: Arc<ServerContext>) -> Self {
Self { server_context }
}
}
#[derive(Debug)]
pub(crate) struct CommandError {
message: String,
}
impl From<String> for CommandError {
fn from(message: String) -> Self {
Self { message }
}
}
impl Display for CommandError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "CommandError {:#?}", self.message)
}
}
#[async_trait]
pub(crate) trait Command: Send + Sync {
async fn handle(&self, context: CommandContext) -> Result<ResponseData, CommandError>;
}
impl Debug for dyn Command {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Command")
}
}

16
common/src/cmd/status.rs Normal file
View File

@@ -0,0 +1,16 @@
use super::{Command, CommandContext, CommandError};
use crate::ResponseData;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Status;
#[async_trait]
impl Command for Status {
async fn handle(&self, context: CommandContext) -> Result<ResponseData, CommandError> {
let status = context.server_context.vpn().status().await;
Ok(ResponseData::Status(status))
}
}

144
common/src/connection.rs Normal file
View File

@@ -0,0 +1,144 @@
use crate::request::Request;
use crate::server::ServerContext;
use crate::Reader;
use crate::Response;
use crate::ResponseData;
use crate::VpnStatus;
use crate::Writer;
use std::sync::Arc;
use tokio::io::{self, ReadHalf, WriteHalf};
use tokio::net::UnixStream;
use tokio::sync::{mpsc, watch};
use tokio_util::sync::CancellationToken;
async fn handle_read(
read_stream: ReadHalf<UnixStream>,
server_context: Arc<ServerContext>,
response_tx: mpsc::Sender<Response>,
cancel_token: CancellationToken,
) {
let mut reader: Reader = read_stream.into();
loop {
match reader.read::<Request>().await {
Ok(request) => {
println!("Received request: {:?}", request);
let command = request.command();
let context = server_context.clone().into();
let mut response = match command.handle(context).await {
Ok(data) => Response::from(data),
Err(err) => Response::from(err.to_string()),
};
response.set_request_id(request.id());
let _ = response_tx.send(response).await;
}
Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => {
println!("Client disconnected");
cancel_token.cancel();
break;
}
Err(err) => {
println!("Error receiving command: {:?}", err);
}
}
}
}
async fn handle_write(
write_stream: WriteHalf<UnixStream>,
mut response_rx: mpsc::Receiver<Response>,
cancel_token: CancellationToken,
) {
let mut writer: Writer = write_stream.into();
loop {
tokio::select! {
Some(response) = response_rx.recv() => {
println!("Sending response: {:?}", response);
if let Err(err) = writer.write(&response).await {
println!("Error sending response: {:?}", err);
} else {
println!("Response sent");
}
}
_ = cancel_token.cancelled() => {
println!("Exiting write loop");
break;
}
else => {
println!("Error receiving response");
}
}
}
}
async fn handle_status_change(
mut status_rx: watch::Receiver<VpnStatus>,
response_tx: mpsc::Sender<Response>,
cancel_token: CancellationToken,
) {
send_status(&status_rx, &response_tx).await;
println!("Waiting for status change");
let start_time = std::time::Instant::now();
loop {
tokio::select! {
_ = status_rx.changed() => {
println!("Status changed: {:?}", start_time.elapsed());
send_status(&status_rx, &response_tx).await;
}
_ = cancel_token.cancelled() => {
println!("Exiting status loop");
break;
}
else => {
println!("Error receiving status");
}
}
}
}
async fn send_status(status_rx: &watch::Receiver<VpnStatus>, response_tx: &mpsc::Sender<Response>) {
let status = *status_rx.borrow();
println!("received = {:?}", status);
if let Err(err) = response_tx
.send(Response::from(ResponseData::Status(status)))
.await
{
println!("Error sending status: {:?}", err);
}
}
pub(crate) async fn handle_connection(socket: UnixStream, context: Arc<ServerContext>) {
let (read_stream, write_stream) = io::split(socket);
let (response_tx, response_rx) = mpsc::channel::<Response>(32);
let cancel_token = CancellationToken::new();
let status_rx = context.vpn().status_rx().await;
let read_handle = tokio::spawn(handle_read(
read_stream,
context.clone(),
response_tx.clone(),
cancel_token.clone(),
));
let write_handle = tokio::spawn(handle_write(
write_stream,
response_rx,
cancel_token.clone(),
));
let status_handle = tokio::spawn(handle_status_change(
status_rx,
response_tx.clone(),
cancel_token,
));
let _ = tokio::join!(read_handle, write_handle, status_handle);
println!("Connection closed")
}

50
common/src/lib.rs Normal file
View File

@@ -0,0 +1,50 @@
use data_encoding::HEXUPPER;
use ring::digest::{Context, SHA256};
use std::{
fs::File,
io::{BufReader, Read},
path::Path,
};
pub const SOCKET_PATH: &str = "/tmp/gpservice.sock";
mod client;
mod cmd;
mod connection;
mod reader;
mod request;
mod response;
pub mod server;
mod vpn;
mod writer;
pub(crate) use request::Request;
pub(crate) use request::RequestPool;
pub use response::Response;
pub use response::ResponseData;
pub use response::TryFromResponseDataError;
pub(crate) use reader::Reader;
pub(crate) use writer::Writer;
pub use client::Client;
pub use client::ServerApiError;
pub use vpn::VpnStatus;
pub fn sha256_digest<P: AsRef<Path>>(file_path: P) -> Result<String, std::io::Error> {
let input = File::open(file_path)?;
let mut reader = BufReader::new(input);
let mut context = Context::new(&SHA256);
let mut buffer = [0; 1024];
loop {
let count = reader.read(&mut buffer)?;
if count == 0 {
break;
}
context.update(&buffer[..count]);
}
Ok(HEXUPPER.encode(context.finish().as_ref()))
}

59
common/src/reader.rs Normal file
View File

@@ -0,0 +1,59 @@
use serde::Deserialize;
use tokio::io::{self, AsyncReadExt, ReadHalf};
use tokio::net::UnixStream;
pub(crate) struct Reader {
stream: ReadHalf<UnixStream>,
}
impl From<ReadHalf<UnixStream>> for Reader {
fn from(stream: ReadHalf<UnixStream>) -> Self {
Self { stream }
}
}
impl Reader {
pub async fn read<T: for<'a> Deserialize<'a>>(&mut self) -> Result<T, io::Error> {
let mut buffer = [0; 2048];
match self.stream.read(&mut buffer).await {
Ok(0) => Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"Peer disconnected",
)),
Ok(bytes_read) => {
let data = serde_json::from_slice::<T>(&buffer[..bytes_read])?;
Ok(data)
}
Err(err) => Err(err),
}
}
pub async fn read_multiple<T: for<'a> Deserialize<'a>>(&mut self) -> Result<Vec<T>, io::Error> {
let mut buffer = [0; 2048];
match self.stream.read(&mut buffer).await {
Ok(0) => Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"Peer disconnected",
)),
Ok(bytes_read) => {
let response_str = String::from_utf8_lossy(&buffer[..bytes_read]);
let responses: Vec<&str> = response_str.split("\n\n").collect();
let responses = responses
.iter()
.filter_map(|r| {
if !r.is_empty() {
serde_json::from_str(r).ok()
} else {
None
}
})
.collect::<Vec<T>>();
Ok(responses)
}
Err(err) => Err(err),
}
}
}

105
common/src/request.rs Normal file
View File

@@ -0,0 +1,105 @@
use crate::cmd::{Command, Connect, Disconnect, Status};
use crate::Response;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::{oneshot, RwLock};
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct Request {
id: u64,
payload: CommandPayload,
}
impl Request {
fn new(id: u64, payload: CommandPayload) -> Self {
Self { id, payload }
}
pub fn id(&self) -> u64 {
self.id
}
pub fn command(&self) -> Box<dyn Command> {
match &self.payload {
CommandPayload::Status(status) => Box::new(status.clone()),
CommandPayload::Connect(connect) => Box::new(connect.clone()),
CommandPayload::Disconnect(disconnect) => Box::new(disconnect.clone()),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) enum CommandPayload {
Status(Status),
Connect(Connect),
Disconnect(Disconnect),
}
impl From<Status> for CommandPayload {
fn from(status: Status) -> Self {
Self::Status(status)
}
}
impl From<Connect> for CommandPayload {
fn from(connect: Connect) -> Self {
Self::Connect(connect)
}
}
impl From<Disconnect> for CommandPayload {
fn from(disconnect: Disconnect) -> Self {
Self::Disconnect(disconnect)
}
}
#[derive(Debug)]
struct RequestHandle {
id: u64,
response_tx: oneshot::Sender<Response>,
}
#[derive(Debug, Default)]
struct IdGenerator {
current_id: u64,
}
impl IdGenerator {
fn next(&mut self) -> u64 {
let current_id = self.current_id;
self.current_id = self.current_id.wrapping_add(1);
current_id
}
}
#[derive(Debug, Default)]
pub(crate) struct RequestPool {
id_generator: Arc<RwLock<IdGenerator>>,
request_handles: Arc<RwLock<Vec<RequestHandle>>>,
}
impl RequestPool {
pub async fn create_request(
&self,
payload: CommandPayload,
) -> (Request, oneshot::Receiver<Response>) {
let id = self.id_generator.write().await.next();
let (response_tx, response_rx) = oneshot::channel();
let request_handle = RequestHandle { id, response_tx };
self.request_handles.write().await.push(request_handle);
(Request::new(id, payload), response_rx)
}
pub async fn complete_request(&self, id: u64, response: Response) {
let mut request_handles = self.request_handles.write().await;
let request_handle = request_handles
.iter()
.position(|handle| handle.id == id)
.map(|index| request_handles.remove(index));
if let Some(request_handle) = request_handle {
let _ = request_handle.response_tx.send(response);
}
}
}

113
common/src/response.rs Normal file
View File

@@ -0,0 +1,113 @@
use crate::vpn::VpnStatus;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Response {
request_id: Option<u64>,
success: bool,
message: String,
data: ResponseData,
}
impl From<ResponseData> for Response {
fn from(data: ResponseData) -> Self {
Self {
request_id: None,
success: true,
message: String::from("Success"),
data,
}
}
}
impl From<String> for Response {
fn from(message: String) -> Self {
Self {
request_id: None,
success: false,
message,
data: ResponseData::Empty,
}
}
}
impl Response {
pub fn success(&self) -> bool {
self.success
}
pub fn message(&self) -> &str {
&self.message
}
pub fn set_request_id(&mut self, command_id: u64) {
self.request_id = Some(command_id);
}
pub fn request_id(&self) -> Option<u64> {
self.request_id
}
pub fn data(&self) -> ResponseData {
self.data
}
}
#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
pub enum ResponseData {
Status(VpnStatus),
Empty,
}
impl From<VpnStatus> for ResponseData {
fn from(status: VpnStatus) -> Self {
Self::Status(status)
}
}
impl From<()> for ResponseData {
fn from(_: ()) -> Self {
Self::Empty
}
}
#[derive(Debug)]
pub struct TryFromResponseDataError {
message: String,
}
impl std::fmt::Display for TryFromResponseDataError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Invalid ResponseData: {}", self.message)
}
}
impl From<&str> for TryFromResponseDataError {
fn from(message: &str) -> Self {
Self {
message: message.into(),
}
}
}
impl TryFrom<ResponseData> for VpnStatus {
type Error = TryFromResponseDataError;
fn try_from(value: ResponseData) -> Result<Self, Self::Error> {
match value {
ResponseData::Status(status) => Ok(status),
_ => Err("ResponseData is not a VpnStatus".into()),
}
}
}
impl TryFrom<ResponseData> for () {
type Error = TryFromResponseDataError;
fn try_from(value: ResponseData) -> Result<Self, Self::Error> {
match value {
ResponseData::Empty => Ok(()),
_ => Err("ResponseData is not empty".into()),
}
}
}

81
common/src/server.rs Normal file
View File

@@ -0,0 +1,81 @@
use crate::{connection::handle_connection, vpn::Vpn};
use std::{future::Future, os::unix::prelude::PermissionsExt, path::Path, sync::Arc};
use tokio::{fs, net::UnixListener};
#[derive(Debug, Default)]
pub(crate) struct ServerContext {
vpn: Arc<Vpn>,
}
struct Server {
socket_path: String,
context: Arc<ServerContext>,
}
impl ServerContext {
pub fn vpn(&self) -> Arc<Vpn> {
self.vpn.clone()
}
}
impl Server {
fn new(socket_path: String) -> Self {
Self {
socket_path,
context: Default::default(),
}
}
async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
if Path::new(&self.socket_path).exists() {
fs::remove_file(&self.socket_path).await?;
}
let listener = UnixListener::bind(&self.socket_path)?;
println!("Listening on socket: {:?}", listener.local_addr()?);
let metadata = fs::metadata(&self.socket_path).await?;
let mut permissions = metadata.permissions();
permissions.set_mode(0o666);
fs::set_permissions(&self.socket_path, permissions).await?;
loop {
match listener.accept().await {
Ok((socket, _)) => {
println!("Accepted connection: {:?}", socket.peer_addr()?);
tokio::spawn(handle_connection(socket, self.context.clone()));
}
Err(err) => {
println!("Error accepting connection: {:?}", err);
}
}
}
}
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
self.context.vpn().disconnect().await;
fs::remove_file(&self.socket_path).await?;
Ok(())
}
}
pub async fn run(
socket_path: &str,
shutdown: impl Future,
) -> Result<(), Box<dyn std::error::Error>> {
let server = Server::new(socket_path.to_string());
tokio::select! {
res = server.start() => {
if let Err(err) = res {
println!("Error starting server: {:?}", err);
}
},
_ = shutdown => {
println!("Shutting down");
server.stop().await?;
},
}
Ok(())
}

22
common/src/vpn/ffi.rs Normal file
View File

@@ -0,0 +1,22 @@
use std::ffi::c_void;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub(crate) struct Options {
pub server: *const ::std::os::raw::c_char,
pub cookie: *const ::std::os::raw::c_char,
pub script: *const ::std::os::raw::c_char,
pub user_data: *mut c_void,
}
#[link(name = "vpn")]
extern "C" {
#[link_name = "start"]
pub(crate) fn connect(
options: *const Options,
on_connected: extern "C" fn(i32, *mut c_void),
) -> ::std::os::raw::c_int;
#[link_name = "stop"]
pub(crate) fn disconnect();
}

161
common/src/vpn/mod.rs Normal file
View File

@@ -0,0 +1,161 @@
mod ffi;
use serde::{Deserialize, Serialize};
use std::ffi::{c_void, CString};
use std::sync::Arc;
use std::thread;
use tokio::sync::watch;
use tokio::sync::{mpsc, Mutex};
#[no_mangle]
extern "C" fn on_connected(value: i32, sender: *mut c_void) {
let sender = unsafe { &*(sender as *const mpsc::Sender<i32>) };
sender
.blocking_send(value)
.expect("Failed to send VPN connection code");
}
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum VpnStatus {
Disconnected,
Connecting,
Connected,
Disconnecting,
}
#[derive(Debug)]
struct StatusHolder {
status: VpnStatus,
status_tx: watch::Sender<VpnStatus>,
status_rx: watch::Receiver<VpnStatus>,
}
impl Default for StatusHolder {
fn default() -> Self {
Self::new()
}
}
impl StatusHolder {
fn new() -> Self {
let (status_tx, status_rx) = watch::channel(VpnStatus::Disconnected);
Self {
status: VpnStatus::Disconnected,
status_tx,
status_rx,
}
}
fn set(&mut self, status: VpnStatus) {
self.status = status;
if let Err(err) = self.status_tx.send(status) {
println!("Failed to send VPN status: {}", err);
}
}
fn status_rx(&self) -> watch::Receiver<VpnStatus> {
self.status_rx.clone()
}
}
#[derive(Debug)]
pub(crate) struct VpnOptions {
server: CString,
cookie: CString,
script: CString,
}
impl VpnOptions {
fn as_oc_options(&self, user_data: *mut c_void) -> ffi::Options {
ffi::Options {
server: self.server.as_ptr(),
cookie: self.cookie.as_ptr(),
script: self.script.as_ptr(),
user_data,
}
}
fn to_cstr(value: &str) -> CString {
CString::new(value.to_string()).expect("Failed to convert to CString")
}
}
#[derive(Debug, Default)]
pub(crate) struct Vpn {
status_holder: Arc<Mutex<StatusHolder>>,
vpn_options: Arc<Mutex<Option<VpnOptions>>>,
}
impl Vpn {
pub async fn status_rx(&self) -> watch::Receiver<VpnStatus> {
self.status_holder.lock().await.status_rx()
}
pub async fn connect(
&self,
server: &str,
cookie: &str,
) -> Result<(), Box<dyn std::error::Error>> {
// Save the VPN options so we can use them later, e.g. reconnect
*self.vpn_options.lock().await = Some(VpnOptions {
server: VpnOptions::to_cstr(server),
cookie: VpnOptions::to_cstr(cookie),
script: VpnOptions::to_cstr("/usr/share/vpnc-scripts/vpnc-script")
});
let vpn_options = self.vpn_options.clone();
let status_holder = self.status_holder.clone();
let (vpn_tx, mut vpn_rx) = mpsc::channel::<i32>(1);
thread::spawn(move || {
let vpn_tx = &vpn_tx as *const _ as *mut c_void;
let oc_options = vpn_options
.blocking_lock()
.as_ref()
.expect("Failed to unwrap vpn_options")
.as_oc_options(vpn_tx);
// Start the VPN connection, this will block until the connection is closed
status_holder.blocking_lock().set(VpnStatus::Connecting);
let ret = unsafe { ffi::connect(&oc_options, on_connected) };
println!("VPN connection closed with code: {}", ret);
status_holder.blocking_lock().set(VpnStatus::Disconnected);
});
println!("Waiting for the VPN connection...");
if let Some(cmd_pipe_fd) = vpn_rx.recv().await {
println!("VPN connection started, code: {}", cmd_pipe_fd);
self.status_holder.lock().await.set(VpnStatus::Connected);
} else {
println!("VPN connection failed to start");
}
Ok(())
}
pub async fn disconnect(&self) {
if self.status().await == VpnStatus::Disconnected {
println!("VPN already disconnected");
return;
}
unsafe { ffi::disconnect() };
// Wait for the VPN to disconnect
println!("VPN disconnect waiting for disconnect...");
let mut status_rx = self.status_rx().await;
while status_rx.changed().await.is_ok() {
if *status_rx.borrow() == VpnStatus::Disconnected {
break;
}
}
}
pub async fn status(&self) -> VpnStatus {
self.status_holder.lock().await.status
}
}

127
common/src/vpn/vpn.c Normal file
View File

@@ -0,0 +1,127 @@
#include <stdio.h>
#include <openconnect.h>
#include <stdlib.h>
#include <stdarg.h>
#include <time.h>
#include <sys/utsname.h>
#include <unistd.h>
#include "vpn.h"
void *g_user_data;
on_connected_cb g_on_connected_cb;
static int g_cmd_pipe_fd;
const char *g_vpnc_script;
/* Validate the peer certificate */
static int validate_peer_cert(__attribute__((unused)) void *_vpninfo, const char *reason)
{
printf("Validating peer cert: %s\n", reason);
return 0;
}
/* Print progress messages */
static void print_progress(__attribute__((unused)) void *_vpninfo, int level, const char *fmt, ...)
{
FILE *outf = level ? stdout : stderr;
va_list args;
char ts[64];
time_t t = time(NULL);
struct tm *tm = localtime(&t);
strftime(ts, 64, "[%Y-%m-%d %H:%M:%S] ", tm);
fprintf(outf, "%s", ts);
va_start(args, fmt);
vfprintf(outf, fmt, args);
va_end(args);
fflush(outf);
}
static void setup_tun_handler(void *_vpninfo)
{
openconnect_setup_tun_device(_vpninfo, g_vpnc_script, NULL);
if (g_on_connected_cb)
{
g_on_connected_cb(g_cmd_pipe_fd, g_user_data);
}
}
/* Initialize VPN connection */
int start(const Options *options, on_connected_cb cb)
{
struct openconnect_info *vpninfo;
struct utsname utsbuf;
vpninfo = openconnect_vpninfo_new("PAN GlobalProtect", validate_peer_cert, NULL, NULL, print_progress, NULL);
if (!vpninfo)
{
printf("openconnect_vpninfo_new failed\n");
return 1;
}
openconnect_set_loglevel(vpninfo, 1);
openconnect_init_ssl();
openconnect_set_protocol(vpninfo, "gp");
openconnect_set_hostname(vpninfo, options->server);
openconnect_set_cookie(vpninfo, options->cookie);
g_cmd_pipe_fd = openconnect_setup_cmd_pipe(vpninfo);
if (g_cmd_pipe_fd < 0)
{
printf("openconnect_setup_cmd_pipe failed\n");
return 1;
}
if (!uname(&utsbuf))
{
openconnect_set_localname(vpninfo, utsbuf.nodename);
}
// Essential step
if (openconnect_make_cstp_connection(vpninfo) != 0)
{
printf("openconnect_make_cstp_connection failed\n");
return 1;
}
if (openconnect_setup_dtls(vpninfo, 60) != 0)
{
openconnect_disable_dtls(vpninfo);
}
// Essential step
// openconnect_setup_tun_device(vpninfo, options->script, NULL);
g_user_data = options->user_data;
g_on_connected_cb = cb;
g_vpnc_script = options->script;
openconnect_set_setup_tun_handler(vpninfo, setup_tun_handler);
while (1)
{
int ret = openconnect_mainloop(vpninfo, 300, 10);
printf("openconnect_mainloop returned %d\n", ret);
if (ret)
{
openconnect_vpninfo_free(vpninfo);
return ret;
}
printf("openconnect_mainloop returned\n");
}
}
/* Stop the VPN connection */
void stop()
{
char cmd = OC_CMD_CANCEL;
if (write(g_cmd_pipe_fd, &cmd, 1) < 0)
{
printf("Stopping VPN failed\n");
}
}

11
common/src/vpn/vpn.h Normal file
View File

@@ -0,0 +1,11 @@
typedef void (*on_connected_cb)(int32_t, void *);
typedef struct Options {
const char *server;
const char *cookie;
const char *script;
void *user_data;
} Options;
int start(const Options *options, on_connected_cb cb);
void stop();

24
common/src/writer.rs Normal file
View File

@@ -0,0 +1,24 @@
use serde::Serialize;
use tokio::io::{self, AsyncWriteExt, WriteHalf};
use tokio::net::UnixStream;
pub(crate) struct Writer {
stream: WriteHalf<UnixStream>,
}
impl From<WriteHalf<UnixStream>> for Writer {
fn from(stream: WriteHalf<UnixStream>) -> Self {
Self { stream }
}
}
impl Writer {
pub async fn write<T: Serialize>(&mut self, data: &T) -> Result<(), io::Error> {
let data = serde_json::to_string(data)?;
let data = format!("{}\n\n", data);
self.stream.write_all(data.as_bytes()).await?;
self.stream.flush().await?;
Ok(())
}
}