summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSerguey Parkhomovsky <xindigo@gmail.com>2026-03-21 13:34:32 -0700
committerSerguey Parkhomovsky <xindigo@gmail.com>2026-03-21 13:37:07 -0700
commitcc618ed79c2031ac30e17b023bceed7c0d83cf84 (patch)
tree730185fceb74d05a6769404f27d482a3903e476c
parenta23a714659b10473d74daaa41e03237c74c6861b (diff)
Make individual locks on each printer state
-rw-r--r--src/main.rs41
1 files changed, 24 insertions, 17 deletions
diff --git a/src/main.rs b/src/main.rs
index ddb1ed6..b6d8006 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -11,7 +11,8 @@ use std::sync::{Arc, RwLock};
use std::time::Duration;
use tracing::{debug, error, info, warn};
-type StateMap = Arc<RwLock<HashMap<String, PrinterState>>>;
+type PrinterStateRef = Arc<RwLock<PrinterState>>;
+type StateMap = HashMap<String, PrinterStateRef>;
#[derive(Parser, Debug)]
#[command(about = "Print statistics scraper")]
@@ -38,9 +39,8 @@ async fn main() {
error!(path = %args.config, error = %e, "Failed to parse config file");
process::exit(1);
});
- let state: StateMap = Arc::new(RwLock::new(HashMap::new()));
+ let mut map: StateMap = HashMap::new();
for printer in config.printers {
- let state = state.clone();
match printer {
Printer::Prusa {
name,
@@ -48,9 +48,11 @@ async fn main() {
api_key,
} => {
info!(name, host, "Found Prusa");
- state.write().unwrap().insert(name.clone(), PrinterState::default());
+ let printer_state = Arc::new(RwLock::new(PrinterState::default()));
+ map.insert(name.clone(), printer_state.clone());
tokio::spawn(async move {
- match tokio::spawn(poll_prusa(name.clone(), host, api_key, state)).await {
+ match tokio::spawn(poll_prusa(name.clone(), host, api_key, printer_state)).await
+ {
Ok(()) => warn!(name, "Prusa polling task exited unexpectedly"),
Err(e) => error!(name, error = ?e, "Prusa polling task panicked"),
}
@@ -63,14 +65,15 @@ async fn main() {
access_code,
} => {
info!(name, host, "Found Bambu");
- state.write().unwrap().insert(name.clone(), PrinterState::default());
+ let printer_state = Arc::new(RwLock::new(PrinterState::default()));
+ map.insert(name.clone(), printer_state.clone());
tokio::spawn(async move {
match tokio::spawn(poll_bambu(
name.clone(),
host,
serial_number,
access_code,
- state,
+ printer_state,
))
.await
{
@@ -81,6 +84,8 @@ async fn main() {
}
}
}
+ // Drop mutability on the StateMap
+ let state: StateMap = map;
let app = Router::new().route("/", get(root)).with_state(state);
@@ -99,10 +104,9 @@ async fn main() {
async fn fetch_prusa(
client: &Client,
- name: &str,
host: &str,
api_key: &str,
- state: &StateMap,
+ state: &PrinterStateRef,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let response = client
.get(format!("http://{}/api/v1/status", host))
@@ -111,12 +115,11 @@ async fn fetch_prusa(
.await?
.json::<PrusaStatus>()
.await?;
- let mut lock = state.write().unwrap();
- lock.get_mut(name).unwrap().update_from(&response);
+ state.write().unwrap().update_from(&response);
Ok(())
}
-async fn poll_prusa(name: String, host: String, api_key: String, state: StateMap) {
+async fn poll_prusa(name: String, host: String, api_key: String, state: PrinterStateRef) {
let client = Client::builder()
.timeout(Duration::from_secs(10))
.build()
@@ -125,7 +128,7 @@ async fn poll_prusa(name: String, host: String, api_key: String, state: StateMap
process::exit(1);
});
loop {
- if let Err(e) = fetch_prusa(&client, &name, &host, &api_key, &state).await {
+ if let Err(e) = fetch_prusa(&client, &host, &api_key, &state).await {
error!(name, error = %e, "Error polling Prusa printer");
}
tokio::time::sleep(Duration::from_secs(5)).await;
@@ -137,7 +140,7 @@ async fn poll_bambu(
host: String,
serial_number: String,
access_code: String,
- state: StateMap,
+ state: PrinterStateRef,
) {
let mut mqttoptions = MqttOptions::new(&name, &host, 8883);
mqttoptions.set_keep_alive(Duration::from_secs(5));
@@ -193,8 +196,7 @@ async fn poll_bambu(
debug!(payload = ?p.payload, "Received Bambu payload");
match serde_json::from_slice::<BambuStatus>(&p.payload) {
Ok(msg) => {
- let mut lock = state.write().unwrap();
- lock.get_mut(name.as_str()).unwrap().update_from(&msg);
+ state.write().unwrap().update_from(&msg);
debug!(name, "Updated state");
}
Err(e) => error!(error = %e, "Failed to deserialize BambuStatus"),
@@ -210,5 +212,10 @@ async fn poll_bambu(
}
async fn root(State(state): State<StateMap>) -> Json<HashMap<String, PrinterState>> {
- Json(state.read().unwrap().clone())
+ Json(
+ state
+ .iter()
+ .map(|(k, v)| (k.clone(), v.read().unwrap().clone()))
+ .collect(),
+ )
}