Skip to content

Commit

Permalink
Re-implement pool clearing
Browse files Browse the repository at this point in the history
  • Loading branch information
sosthene-nitrokey committed Jan 31, 2025
1 parent bbc7404 commit a28dbd4
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 20 deletions.
4 changes: 2 additions & 2 deletions pkcs11/src/backend/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ impl std::fmt::Display for LoginError {

/// Perform a health check with a timeout of 1 second
fn health_check_get_timeout(instance: &InstanceData) -> bool {
todo!("{instance:?}");
// instance.config.client.clear_pool();
instance.clear_pool();
let config = &instance.config;
let uri_str = format!("{}/health/ready", config.base_path);
let mut req = config
Expand Down Expand Up @@ -534,6 +533,7 @@ fn get_user_api_config(
..api_config.config.clone()
},
state: api_config.state.clone(),
clear_flag: api_config.clear_flag.clone(),
})
}

Expand Down
23 changes: 18 additions & 5 deletions pkcs11/src/config/device.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
use std::{
collections::BTreeMap,
sync::{
atomic::{AtomicUsize, Ordering::Relaxed},
atomic::{
AtomicBool, AtomicUsize,
Ordering::{self, Relaxed},
},
mpsc::{self, RecvError, RecvTimeoutError},
Arc, Condvar, LazyLock, Mutex, RwLock, Weak,
},
thread,
time::{Duration, Instant},
};

use arc_swap::ArcSwap;
use nethsm_sdk_rs::apis::{configuration::Configuration, default_api::health_ready_get};

use crate::{backend::db::Db, data::THREADS_ALLOWED};
Expand Down Expand Up @@ -82,8 +86,7 @@ fn background_timer(
fn background_thread(rx: mpsc::Receiver<InstanceData>) -> impl FnOnce() {
move || loop {
while let Ok(instance) = rx.recv() {
todo!("{instance:?}");
// instance.config.client.clear_pool();
instance.clear_pool();
match health_ready_get(&instance.config) {
Ok(_) => instance.clear_failed(),
Err(_) => instance.bump_failed(),
Expand Down Expand Up @@ -122,19 +125,29 @@ impl InstanceState {
pub struct InstanceData {
pub config: Configuration,
pub state: Arc<RwLock<InstanceState>>,
pub clear_flag: Arc<ArcSwap<AtomicBool>>,
}

impl InstanceData {
pub fn clear_pool(&self) {
let old_flag = self.clear_flag.swap(Arc::new(AtomicBool::new(true)));
old_flag.store(false, Ordering::Relaxed);
}
}

#[derive(Debug, Clone)]
pub struct WeakInstanceData {
pub config: Configuration,
pub state: Weak<RwLock<InstanceState>>,
pub clear_flag: Arc<ArcSwap<AtomicBool>>,
}

impl From<InstanceData> for WeakInstanceData {
fn from(value: InstanceData) -> Self {
Self {
config: value.config,
state: Arc::downgrade(&value.state),
clear_flag: value.clear_flag,
}
}
}
Expand All @@ -145,6 +158,7 @@ impl WeakInstanceData {
Some(InstanceData {
config: self.config,
state,
clear_flag: self.clear_flag,
})
}
}
Expand Down Expand Up @@ -260,8 +274,7 @@ impl Slot {

pub fn clear_all_pools(&self) {
for instance in &self.instances {
todo!("{instance:?}");
// instance.config.client.clear_pool();
instance.clear_pool();
}
}
}
7 changes: 6 additions & 1 deletion pkcs11/src/config/initialization.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{
path::PathBuf,
sync::{Arc, Condvar, Mutex},
sync::{atomic::AtomicBool, Arc, Condvar, Mutex},
thread::available_parallelism,
time::Duration,
};
Expand All @@ -14,6 +14,7 @@ use super::{
config_file::{config_files, ConfigError, SlotConfig},
device::{Device, Slot},
};
use arc_swap::ArcSwap;
use log::{debug, error, info, trace};
use nethsm_sdk_rs::ureq;
use rustls::{
Expand Down Expand Up @@ -286,13 +287,16 @@ fn slot_from_config(slot: &SlotConfig) -> Result<Slot, InitializationError> {
builder = builder.max_idle_age(Duration::from_secs(max_idle_duration));
}

let clear_flag = Arc::new(ArcSwap::new(Arc::new(AtomicBool::new(true))));

let api_config = nethsm_sdk_rs::apis::configuration::Configuration {
client: ureq::Agent::with_parts(
builder.build(),
().chain(TcpConnector {
tcp_keepalive_time,
tcp_keepalive_retries,
tcp_keepalive_interval,
clear_flag: clear_flag.clone(),
})
.chain(RustlsConnector {
config: tls_conf.into(),
Expand All @@ -308,6 +312,7 @@ fn slot_from_config(slot: &SlotConfig) -> Result<Slot, InitializationError> {
instances.push(InstanceData {
config: api_config,
state: Default::default(),
clear_flag,
});
}
if instances.is_empty() {
Expand Down
2 changes: 0 additions & 2 deletions pkcs11/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#![allow(unreachable_code)]

mod api;

mod data;
Expand Down
28 changes: 18 additions & 10 deletions pkcs11/src/ureq/tcp_connector.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::io::{Read, Write};
use std::net::{SocketAddr, TcpStream};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration as StdDuration;
use std::{fmt, io};

use arc_swap::ArcSwap;
use ureq::config::Config;
use ureq::unversioned::transport::time::Duration;
use ureq::Error;
Expand All @@ -22,6 +25,9 @@ pub struct TcpConnector {
pub tcp_keepalive_time: Option<StdDuration>,
pub tcp_keepalive_interval: Option<StdDuration>,
pub tcp_keepalive_retries: Option<u32>,
/// True means connection can continue, false
/// means connection must be closed
pub clear_flag: Arc<ArcSwap<AtomicBool>>,
}

impl<In: Transport> Connector<In> for TcpConnector {
Expand Down Expand Up @@ -71,7 +77,7 @@ impl<In: Transport> Connector<In> for TcpConnector {
}

let buffers = LazyBuffers::new(config.input_buffer_size(), config.output_buffer_size());
let transport = TcpTransport::new(socket.into(), buffers);
let transport = TcpTransport::new(socket.into(), buffers, self.clear_flag.load_full());

Ok(Some(Either::B(transport)))
}
Expand Down Expand Up @@ -139,15 +145,22 @@ pub struct TcpTransport {
buffers: LazyBuffers,
timeout_write: Option<Duration>,
timeout_read: Option<Duration>,
/// Flag used to indicate that the connection must be closed
clear_flag: Arc<AtomicBool>,
}

impl TcpTransport {
pub fn new(stream: TcpStream, buffers: LazyBuffers) -> TcpTransport {
pub fn new(
stream: TcpStream,
buffers: LazyBuffers,
clear_flag: Arc<AtomicBool>,
) -> TcpTransport {
TcpTransport {
stream,
buffers,
timeout_read: None,
timeout_write: None,
clear_flag,
}
}
}
Expand Down Expand Up @@ -217,7 +230,8 @@ impl Transport for TcpTransport {
}

fn is_open(&mut self) -> bool {
probe_tcp_stream(&mut self.stream).unwrap_or(false)
self.clear_flag.load(Ordering::Relaxed)
&& probe_tcp_stream(&mut self.stream).unwrap_or(false)
}
}

Expand All @@ -227,15 +241,9 @@ fn probe_tcp_stream(stream: &mut TcpStream) -> Result<bool, Error> {

let mut buf = [0];
match stream.read(&mut buf) {
Err(e)
if matches!(
e.kind(),
io::ErrorKind::WouldBlock | io::ErrorKind::TimedOut
) =>
{
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
// This is the correct condition. There should be no waiting
// bytes, and therefore reading would block
// And a time out means the connection died and was detected dead by the tcp keepalive
}
// Any bytes read means the server sent some garbage we didn't ask for
Ok(_) => {
Expand Down

0 comments on commit a28dbd4

Please sign in to comment.