Compare commits

...

7 Commits

Author SHA1 Message Date
Kevin Yue
3bb115bd2d Merge branch 'main' into dev 2024-05-19 10:23:00 +08:00
Kevin Yue
e08f239176 fix: do not panic when failed to start service (fix #362) 2024-05-19 10:21:18 +08:00
Kevin Yue
a01c55e38d fix: do not panic when failed to start service (fix #362) 2024-05-19 10:19:21 +08:00
Kevin Yue
af51bc257b feat: add the --reconnect-timeout option 2024-05-19 09:59:25 +08:00
Kevin Yue
90a8c11acb feat: add disable_ipv6 option (related #364) 2024-05-19 09:04:45 +08:00
Kevin Yue
92b858884c fix: check executable for file 2024-05-10 10:26:45 -04:00
Kevin Yue
159673652c Refactor prelogin.rs to use default labels for username and password 2024-05-09 01:48:02 -04:00
10 changed files with 145 additions and 61 deletions

View File

@@ -48,8 +48,12 @@ pub(crate) struct ConnectArgs {
#[arg(long, help = "Same as the '--csd-wrapper' option in the openconnect command")]
csd_wrapper: Option<String>,
#[arg(long, default_value = "300", help = "Reconnection retry timeout in seconds")]
reconnect_timeout: u32,
#[arg(short, long, help = "Request MTU from server (legacy servers only)")]
mtu: Option<u32>,
#[arg(long, help = "Do not ask for IPv6 connectivity")]
disable_ipv6: bool,
#[arg(long, default_value = GP_USER_AGENT, help = "The user agent to use")]
user_agent: String,
@@ -215,7 +219,9 @@ impl<'a> ConnectHandler<'a> {
.user_agent(self.args.user_agent.clone())
.csd_uid(csd_uid)
.csd_wrapper(csd_wrapper)
.reconnect_timeout(self.args.reconnect_timeout)
.mtu(mtu)
.disable_ipv6(self.args.disable_ipv6)
.build()?;
let vpn = Arc::new(vpn);

View File

@@ -38,10 +38,12 @@ impl VpnTaskContext {
let vpn = match Vpn::builder(req.gateway().server(), args.cookie())
.script(args.vpnc_script())
.user_agent(args.user_agent())
.os(args.openconnect_os())
.csd_uid(args.csd_uid())
.csd_wrapper(args.csd_wrapper())
.reconnect_timeout(args.reconnect_timeout())
.mtu(args.mtu())
.os(args.openconnect_os())
.disable_ipv6(args.disable_ipv6())
.build()
{
Ok(vpn) => vpn,

View File

@@ -118,12 +118,14 @@ impl WsServer {
}
pub async fn start(&self, shutdown_tx: mpsc::Sender<()>) {
if let Ok(listener) = TcpListener::bind("127.0.0.1:0").await {
let local_addr = listener.local_addr().unwrap();
self.lock_file.lock(local_addr.port().to_string()).unwrap();
info!("WS server listening on port: {}", local_addr.port());
let listener = match self.start_tcp_server().await {
Ok(listener) => listener,
Err(err) => {
warn!("Failed to start WS server: {}", err);
let _ = shutdown_tx.send(()).await;
return;
},
};
tokio::select! {
_ = watch_vpn_state(self.ctx.vpn_state_rx(), Arc::clone(&self.ctx)) => {
@@ -136,10 +138,21 @@ impl WsServer {
info!("WS server cancelled");
}
}
}
let _ = shutdown_tx.send(()).await;
}
async fn start_tcp_server(&self) -> anyhow::Result<TcpListener> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let local_addr = listener.local_addr()?;
let port = local_addr.port();
info!("WS server listening on port: {}", port);
self.lock_file.lock(port.to_string())?;
Ok(listener)
}
}
async fn watch_vpn_state(mut vpn_state_rx: watch::Receiver<VpnState>, ctx: Arc<WsServerContext>) {

View File

@@ -1,7 +1,6 @@
use is_executable::IsExecutable;
use std::path::Path;
use std::{io, path::Path};
pub use is_executable::is_executable;
use is_executable::IsExecutable;
const VPNC_SCRIPT_LOCATIONS: [&str; 6] = [
"/usr/local/share/vpnc-scripts/vpnc-script",
@@ -39,3 +38,17 @@ pub fn find_vpnc_script() -> Option<String> {
pub fn find_csd_wrapper() -> Option<String> {
find_executable(&CSD_WRAPPER_LOCATIONS)
}
/// If file exists, check if it is executable
pub fn check_executable(file: &str) -> Result<(), io::Error> {
let path = Path::new(file);
if path.exists() && !path.is_executable() {
return Err(io::Error::new(
io::ErrorKind::PermissionDenied,
format!("{} is not executable", file),
));
}
Ok(())
}

View File

@@ -181,22 +181,24 @@ fn parse_res_xml(res_xml: &str, is_gateway: bool) -> anyhow::Result<Prelogin> {
return Ok(Prelogin::Saml(saml_prelogin));
}
let label_username = xml::get_child_text(&doc, "username-label");
let label_password = xml::get_child_text(&doc, "password-label");
// Check if the prelogin response is standard login
if label_username.is_some() && label_password.is_some() {
let label_username = xml::get_child_text(&doc, "username-label").unwrap_or_else(|| {
info!("Username label has no value, using default");
String::from("Username")
});
let label_password = xml::get_child_text(&doc, "password-label").unwrap_or_else(|| {
info!("Password label has no value, using default");
String::from("Password")
});
let auth_message =
xml::get_child_text(&doc, "authentication-message").unwrap_or(String::from("Please enter the login credentials"));
let standard_prelogin = StandardPrelogin {
region,
is_gateway,
auth_message,
label_username: label_username.unwrap(),
label_password: label_password.unwrap(),
label_username,
label_password,
};
Ok(Prelogin::Standard(standard_prelogin))
} else {
Err(anyhow!("Invalid prelogin response"))
}
}

View File

@@ -32,10 +32,12 @@ pub struct ConnectArgs {
cookie: String,
vpnc_script: Option<String>,
user_agent: Option<String>,
os: Option<ClientOs>,
csd_uid: u32,
csd_wrapper: Option<String>,
reconnect_timeout: u32,
mtu: u32,
os: Option<ClientOs>,
disable_ipv6: bool,
}
impl ConnectArgs {
@@ -47,7 +49,9 @@ impl ConnectArgs {
os: None,
csd_uid: 0,
csd_wrapper: None,
reconnect_timeout: 300,
mtu: 0,
disable_ipv6: false,
}
}
@@ -75,9 +79,17 @@ impl ConnectArgs {
self.csd_wrapper.clone()
}
pub fn reconnect_timeout(&self) -> u32 {
self.reconnect_timeout
}
pub fn mtu(&self) -> u32 {
self.mtu
}
pub fn disable_ipv6(&self) -> bool {
self.disable_ipv6
}
}
#[derive(Debug, Deserialize, Serialize, Type)]
@@ -109,11 +121,6 @@ impl ConnectRequest {
self
}
pub fn with_mtu(mut self, mtu: u32) -> Self {
self.args.mtu = mtu;
self
}
pub fn with_user_agent<T: Into<Option<String>>>(mut self, user_agent: T) -> Self {
self.args.user_agent = user_agent.into();
self
@@ -124,6 +131,21 @@ impl ConnectRequest {
self
}
pub fn with_reconnect_timeout(mut self, reconnect_timeout: u32) -> Self {
self.args.reconnect_timeout = reconnect_timeout;
self
}
pub fn with_mtu(mut self, mtu: u32) -> Self {
self.args.mtu = mtu;
self
}
pub fn with_disable_ipv6(mut self, disable_ipv6: bool) -> Self {
self.args.disable_ipv6 = disable_ipv6;
self
}
pub fn gateway(&self) -> &Gateway {
self.info.gateway()
}

View File

@@ -19,7 +19,9 @@ pub(crate) struct ConnectOptions {
pub csd_uid: u32,
pub csd_wrapper: *const c_char,
pub reconnect_timeout: u32,
pub mtu: u32,
pub disable_ipv6: u32,
}
#[link(name = "vpn")]

View File

@@ -63,7 +63,9 @@ int vpn_connect(const vpn_options *options, vpn_connected_callback callback)
INFO("OS: %s", options->os);
INFO("CSD_USER: %d", options->csd_uid);
INFO("CSD_WRAPPER: %s", options->csd_wrapper);
INFO("RECONNECT_TIMEOUT: %d", options->reconnect_timeout);
INFO("MTU: %d", options->mtu);
INFO("DISABLE_IPV6: %d", options->disable_ipv6);
vpninfo = openconnect_vpninfo_new(options->user_agent, validate_peer_cert, NULL, NULL, print_progress, NULL);
@@ -103,6 +105,10 @@ int vpn_connect(const vpn_options *options, vpn_connected_callback callback)
openconnect_set_reqmtu(vpninfo, mtu);
}
if (options->disable_ipv6) {
openconnect_disable_ipv6(vpninfo);
}
g_cmd_pipe_fd = openconnect_setup_cmd_pipe(vpninfo);
if (g_cmd_pipe_fd < 0)
{
@@ -132,7 +138,7 @@ int vpn_connect(const vpn_options *options, vpn_connected_callback callback)
while (1)
{
int ret = openconnect_mainloop(vpninfo, 300, 10);
int ret = openconnect_mainloop(vpninfo, options->reconnect_timeout, 10);
if (ret)
{

View File

@@ -20,7 +20,9 @@ typedef struct vpn_options
const uid_t csd_uid;
const char *csd_wrapper;
const int reconnect_timeout;
const int mtu;
const int disable_ipv6;
} vpn_options;
int vpn_connect(const vpn_options *options, vpn_connected_callback callback);
@@ -35,7 +37,7 @@ static char *format_message(const char *format, va_list args)
int len = vsnprintf(NULL, 0, format, args_copy);
va_end(args_copy);
char *buffer = malloc(len + 1);
char *buffer = (char*)malloc(len + 1);
if (buffer == NULL)
{
return NULL;

View File

@@ -4,7 +4,7 @@ use std::{
sync::{Arc, RwLock},
};
use common::vpn_utils::{find_vpnc_script, is_executable};
use common::vpn_utils::{check_executable, find_vpnc_script};
use log::info;
use crate::ffi;
@@ -23,7 +23,9 @@ pub struct Vpn {
csd_uid: u32,
csd_wrapper: Option<CString>,
reconnect_timeout: u32,
mtu: u32,
disable_ipv6: bool,
callback: OnConnectedCallback,
}
@@ -67,7 +69,9 @@ impl Vpn {
csd_uid: self.csd_uid,
csd_wrapper: Self::option_to_ptr(&self.csd_wrapper),
reconnect_timeout: self.reconnect_timeout,
mtu: self.mtu,
disable_ipv6: self.disable_ipv6 as u32,
}
}
@@ -80,23 +84,23 @@ impl Vpn {
}
#[derive(Debug)]
pub struct VpnError<'a> {
message: &'a str,
pub struct VpnError {
message: String,
}
impl<'a> VpnError<'a> {
fn new(message: &'a str) -> Self {
impl VpnError {
fn new(message: String) -> Self {
Self { message }
}
}
impl fmt::Display for VpnError<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
impl fmt::Display for VpnError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for VpnError<'_> {}
impl std::error::Error for VpnError {}
pub struct VpnBuilder {
server: String,
@@ -109,7 +113,9 @@ pub struct VpnBuilder {
csd_uid: u32,
csd_wrapper: Option<String>,
reconnect_timeout: u32,
mtu: u32,
disable_ipv6: bool,
}
impl VpnBuilder {
@@ -125,7 +131,9 @@ impl VpnBuilder {
csd_uid: 0,
csd_wrapper: None,
reconnect_timeout: 300,
mtu: 0,
disable_ipv6: false,
}
}
@@ -154,26 +162,32 @@ impl VpnBuilder {
self
}
pub fn reconnect_timeout(mut self, reconnect_timeout: u32) -> Self {
self.reconnect_timeout = reconnect_timeout;
self
}
pub fn mtu(mut self, mtu: u32) -> Self {
self.mtu = mtu;
self
}
pub fn build(self) -> Result<Vpn, VpnError<'static>> {
pub fn disable_ipv6(mut self, disable_ipv6: bool) -> Self {
self.disable_ipv6 = disable_ipv6;
self
}
pub fn build(self) -> Result<Vpn, VpnError> {
let script = match self.script {
Some(script) => {
if !is_executable(&script) {
return Err(VpnError::new("vpnc script is not executable"));
}
check_executable(&script).map_err(|e| VpnError::new(e.to_string()))?;
script
}
None => find_vpnc_script().ok_or_else(|| VpnError::new("Failed to find vpnc-script"))?,
None => find_vpnc_script().ok_or_else(|| VpnError::new(String::from("Failed to find vpnc-script")))?,
};
if let Some(csd_wrapper) = &self.csd_wrapper {
if !is_executable(csd_wrapper) {
return Err(VpnError::new("CSD wrapper is not executable"));
}
check_executable(csd_wrapper).map_err(|e| VpnError::new(e.to_string()))?;
}
let user_agent = self.user_agent.unwrap_or_default();
@@ -191,7 +205,9 @@ impl VpnBuilder {
csd_uid: self.csd_uid,
csd_wrapper: self.csd_wrapper.as_deref().map(Self::to_cstring),
reconnect_timeout: self.reconnect_timeout,
mtu: self.mtu,
disable_ipv6: self.disable_ipv6,
callback: Default::default(),
})