From cc618ed79c2031ac30e17b023bceed7c0d83cf84 Mon Sep 17 00:00:00 2001 From: Serguey Parkhomovsky Date: Sat, 21 Mar 2026 13:34:32 -0700 Subject: Make individual locks on each printer state --- src/main.rs | 41 ++++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 17 deletions(-) (limited to 'src') diff --git a/src/main.rs b/src/main.rs index ddb1ed6..b6d8006 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,7 +11,8 @@ use std::sync::{Arc, RwLock}; use std::time::Duration; use tracing::{debug, error, info, warn}; -type StateMap = Arc>>; +type PrinterStateRef = Arc>; +type StateMap = HashMap; #[derive(Parser, Debug)] #[command(about = "Print statistics scraper")] @@ -38,9 +39,8 @@ async fn main() { error!(path = %args.config, error = %e, "Failed to parse config file"); process::exit(1); }); - let state: StateMap = Arc::new(RwLock::new(HashMap::new())); + let mut map: StateMap = HashMap::new(); for printer in config.printers { - let state = state.clone(); match printer { Printer::Prusa { name, @@ -48,9 +48,11 @@ async fn main() { api_key, } => { info!(name, host, "Found Prusa"); - state.write().unwrap().insert(name.clone(), PrinterState::default()); + let printer_state = Arc::new(RwLock::new(PrinterState::default())); + map.insert(name.clone(), printer_state.clone()); tokio::spawn(async move { - match tokio::spawn(poll_prusa(name.clone(), host, api_key, state)).await { + match tokio::spawn(poll_prusa(name.clone(), host, api_key, printer_state)).await + { Ok(()) => warn!(name, "Prusa polling task exited unexpectedly"), Err(e) => error!(name, error = ?e, "Prusa polling task panicked"), } @@ -63,14 +65,15 @@ async fn main() { access_code, } => { info!(name, host, "Found Bambu"); - state.write().unwrap().insert(name.clone(), PrinterState::default()); + let printer_state = Arc::new(RwLock::new(PrinterState::default())); + map.insert(name.clone(), printer_state.clone()); tokio::spawn(async move { match tokio::spawn(poll_bambu( name.clone(), host, serial_number, access_code, - state, + printer_state, )) .await { @@ -81,6 +84,8 @@ async fn main() { } } } + // Drop mutability on the StateMap + let state: StateMap = map; let app = Router::new().route("/", get(root)).with_state(state); @@ -99,10 +104,9 @@ async fn main() { async fn fetch_prusa( client: &Client, - name: &str, host: &str, api_key: &str, - state: &StateMap, + state: &PrinterStateRef, ) -> Result<(), Box> { let response = client .get(format!("http://{}/api/v1/status", host)) @@ -111,12 +115,11 @@ async fn fetch_prusa( .await? .json::() .await?; - let mut lock = state.write().unwrap(); - lock.get_mut(name).unwrap().update_from(&response); + state.write().unwrap().update_from(&response); Ok(()) } -async fn poll_prusa(name: String, host: String, api_key: String, state: StateMap) { +async fn poll_prusa(name: String, host: String, api_key: String, state: PrinterStateRef) { let client = Client::builder() .timeout(Duration::from_secs(10)) .build() @@ -125,7 +128,7 @@ async fn poll_prusa(name: String, host: String, api_key: String, state: StateMap process::exit(1); }); loop { - if let Err(e) = fetch_prusa(&client, &name, &host, &api_key, &state).await { + if let Err(e) = fetch_prusa(&client, &host, &api_key, &state).await { error!(name, error = %e, "Error polling Prusa printer"); } tokio::time::sleep(Duration::from_secs(5)).await; @@ -137,7 +140,7 @@ async fn poll_bambu( host: String, serial_number: String, access_code: String, - state: StateMap, + state: PrinterStateRef, ) { let mut mqttoptions = MqttOptions::new(&name, &host, 8883); mqttoptions.set_keep_alive(Duration::from_secs(5)); @@ -193,8 +196,7 @@ async fn poll_bambu( debug!(payload = ?p.payload, "Received Bambu payload"); match serde_json::from_slice::(&p.payload) { Ok(msg) => { - let mut lock = state.write().unwrap(); - lock.get_mut(name.as_str()).unwrap().update_from(&msg); + state.write().unwrap().update_from(&msg); debug!(name, "Updated state"); } Err(e) => error!(error = %e, "Failed to deserialize BambuStatus"), @@ -210,5 +212,10 @@ async fn poll_bambu( } async fn root(State(state): State) -> Json> { - Json(state.read().unwrap().clone()) + Json( + state + .iter() + .map(|(k, v)| (k.clone(), v.read().unwrap().clone())) + .collect(), + ) } -- cgit v1.2.3