refactor: add the process check

This commit is contained in:
Kevin Yue 2023-05-14 20:11:37 +08:00
parent 16696e3840
commit d5af0e58c2
3 changed files with 100 additions and 63 deletions

View File

@ -15,13 +15,25 @@ async fn handle_read(
read_stream: ReadHalf<UnixStream>, read_stream: ReadHalf<UnixStream>,
server_context: Arc<ServerContext>, server_context: Arc<ServerContext>,
response_tx: mpsc::Sender<Response>, response_tx: mpsc::Sender<Response>,
peer_pid: Option<i32>,
cancel_token: CancellationToken, cancel_token: CancellationToken,
) { ) {
let mut reader: Reader = read_stream.into(); let mut reader: Reader = read_stream.into();
let mut authenticated: Option<bool> = None;
loop { loop {
match reader.read::<Request>().await { match reader.read::<Request>().await {
Ok(request) => { Ok(request) => {
if authenticated.is_none() {
authenticated = Some(authenticate(peer_pid));
}
if !authenticated.unwrap_or(false) {
println!("Client not authenticated");
cancel_token.cancel();
break;
}
println!("Received request: {:?}", request); println!("Received request: {:?}", request);
let command = request.command(); let command = request.command();
let context = server_context.clone().into(); let context = server_context.clone().into();
@ -114,6 +126,7 @@ async fn send_status(status_rx: &watch::Receiver<VpnStatus>, response_tx: &mpsc:
} }
pub(crate) async fn handle_connection(socket: UnixStream, context: Arc<ServerContext>) { pub(crate) async fn handle_connection(socket: UnixStream, context: Arc<ServerContext>) {
let peer_pid = peer_pid(&socket);
let (read_stream, write_stream) = io::split(socket); let (read_stream, write_stream) = io::split(socket);
let (response_tx, response_rx) = mpsc::channel::<Response>(32); let (response_tx, response_rx) = mpsc::channel::<Response>(32);
let cancel_token = CancellationToken::new(); let cancel_token = CancellationToken::new();
@ -123,6 +136,7 @@ pub(crate) async fn handle_connection(socket: UnixStream, context: Arc<ServerCon
read_stream, read_stream,
context.clone(), context.clone(),
response_tx.clone(), response_tx.clone(),
peer_pid,
cancel_token.clone(), cancel_token.clone(),
)); ));
@ -142,3 +156,19 @@ pub(crate) async fn handle_connection(socket: UnixStream, context: Arc<ServerCon
println!("Connection closed") println!("Connection closed")
} }
fn peer_pid(socket: &UnixStream) -> Option<i32> {
match socket.peer_cred() {
Ok(ucred) => ucred.pid(),
Err(_) => None,
}
}
fn authenticate(peer_pid: Option<i32>) -> bool {
if let Some(pid) = peer_pid {
println!("Peer PID: {}", pid);
true
} else {
false
}
}

View File

@ -1,6 +1,6 @@
import { Box, TextField } from "@mui/material"; import { Box, TextField } from "@mui/material";
import Button from "@mui/material/Button"; import Button from "@mui/material/Button";
import { ChangeEvent, useEffect, useState } from "react"; import { ChangeEvent, FormEvent, useEffect, useState } from "react";
import "./App.css"; import "./App.css";
import ConnectionStatus, { Status } from "./components/ConnectionStatus"; import ConnectionStatus, { Status } from "./components/ConnectionStatus";
@ -18,7 +18,7 @@ export default function App() {
const [status, setStatus] = useState<Status>("disconnected"); const [status, setStatus] = useState<Status>("disconnected");
const [processing, setProcessing] = useState(false); const [processing, setProcessing] = useState(false);
const [passwordAuthOpen, setPasswordAuthOpen] = useState(false); const [passwordAuthOpen, setPasswordAuthOpen] = useState(false);
const [passwordAuthenticating, setPasswordAuthenticating] = useState(false); `` const [passwordAuthenticating, setPasswordAuthenticating] = useState(false);
const [passwordAuth, setPasswordAuth] = useState<PasswordAuthData>(); const [passwordAuth, setPasswordAuth] = useState<PasswordAuthData>();
const [notification, setNotification] = useState<NotificationConfig>({ const [notification, setNotification] = useState<NotificationConfig>({
open: false, open: false,
@ -43,9 +43,9 @@ export default function App() {
} }
function clearOverlays() { function clearOverlays() {
closeNotification() closeNotification();
setPasswordAuthenticating(false) setPasswordAuthenticating(false);
setPasswordAuthOpen(false) setPasswordAuthOpen(false);
} }
function handlePortalChange(e: ChangeEvent<HTMLInputElement>) { function handlePortalChange(e: ChangeEvent<HTMLInputElement>) {
@ -53,9 +53,10 @@ export default function App() {
setPortalAddress(value.trim()); setPortalAddress(value.trim());
} }
async function handleConnect() { async function handleConnect(e: FormEvent<HTMLFormElement>) {
e.preventDefault();
setProcessing(true); setProcessing(true);
// setStatus("connecting");
try { try {
const response = await portalService.prelogin(portalAddress); const response = await portalService.prelogin(portalAddress);
@ -79,7 +80,7 @@ export default function App() {
function handleCancel() { function handleCancel() {
// TODO cancel the request first // TODO cancel the request first
setProcessing(false) setProcessing(false);
} }
async function handleDisconnect() { async function handleDisconnect() {
@ -145,50 +146,56 @@ export default function App() {
} }
return ( return (
<Box padding={2} paddingTop={3}> <Box padding={2} paddingTop={3}>
<ConnectionStatus sx={{ mb: 2 }} status={processing ? "processing" : status} /> <ConnectionStatus
sx={{ mb: 2 }}
<TextField status={processing ? "processing" : status}
autoFocus
label="Portal address"
placeholder="Hostname or IP address"
fullWidth
size="small"
value={portalAddress}
onChange={handlePortalChange}
InputProps={{ readOnly: status !== "disconnected" }}
/> />
<Box sx={{ mt: 1.5 }}>
{status === "disconnected" && ( <form onSubmit={handleConnect}>
<Button <TextField
variant="contained" autoFocus
fullWidth label="Portal address"
onClick={handleConnect} placeholder="Hostname or IP address"
sx={{ textTransform: "none" }} fullWidth
> size="small"
Connect value={portalAddress}
</Button> onChange={handlePortalChange}
)} InputProps={{ readOnly: status !== "disconnected" }}
{status === "connecting" && ( />
<Button <Box sx={{ mt: 1.5 }}>
variant="outlined" {status === "disconnected" && (
fullWidth <Button
onClick={handleCancel} type="submit"
sx={{ textTransform: "none" }} variant="contained"
> fullWidth
Cancel sx={{ textTransform: "none" }}
</Button> >
)} Connect
{status === "connected" && ( </Button>
<Button )}
variant="contained" {status === "connecting" && (
fullWidth <Button
onClick={handleDisconnect} variant="outlined"
sx={{ textTransform: "none" }} fullWidth
> onClick={handleCancel}
Disconnect sx={{ textTransform: "none" }}
</Button> >
)} Cancel
</Box> </Button>
)}
{status === "connected" && (
<Button
variant="contained"
fullWidth
onClick={handleDisconnect}
sx={{ textTransform: "none" }}
>
Disconnect
</Button>
)}
</Box>
</form>
<PasswordAuth <PasswordAuth
open={passwordAuthOpen} open={passwordAuthOpen}
authData={passwordAuth} authData={passwordAuth}

View File

@ -1,16 +1,16 @@
import { invoke } from "@tauri-apps/api"; import { invoke } from "@tauri-apps/api";
import { listen } from '@tauri-apps/api/event'; import { listen } from "@tauri-apps/api/event";
type Status = 'disconnected' | 'connecting' | 'connected' | 'disconnecting' type Status = "disconnected" | "connecting" | "connected" | "disconnecting";
type StatusCallback = (status: Status) => void type StatusCallback = (status: Status) => void;
type StatusEvent = { type StatusEvent = {
payload: { payload: {
status: Status status: Status;
} };
} };
class VpnService { class VpnService {
private _status: Status = 'disconnected'; private _status: Status = "disconnected";
private statusCallbacks: StatusCallback[] = []; private statusCallbacks: StatusCallback[] = [];
constructor() { constructor() {
@ -18,10 +18,10 @@ class VpnService {
} }
private async init() { private async init() {
const unlisten = await listen('vpn-status-received', (event: StatusEvent) => { await listen("vpn-status-received", (event: StatusEvent) => {
console.log('vpn-status-received', event.payload) console.log("vpn-status-received", event.payload);
this.setStatus(event.payload.status); this.setStatus(event.payload.status);
}) });
const status = await this.status(); const status = await this.status();
this.setStatus(status); this.setStatus(status);
@ -53,11 +53,11 @@ class VpnService {
} }
private fireStatusCallbacks() { private fireStatusCallbacks() {
this.statusCallbacks.forEach(cb => cb(this._status)); this.statusCallbacks.forEach((cb) => cb(this._status));
} }
private removeStatusCallback(callback: StatusCallback) { private removeStatusCallback(callback: StatusCallback) {
this.statusCallbacks = this.statusCallbacks.filter(cb => cb !== callback); this.statusCallbacks = this.statusCallbacks.filter((cb) => cb !== callback);
} }
private async invokeCommand<T>(command: string, args?: any) { private async invokeCommand<T>(command: string, args?: any) {