diff --git a/lib/src/lightclient.rs b/lib/src/lightclient.rs index ca754a9..ab73d98 100644 --- a/lib/src/lightclient.rs +++ b/lib/src/lightclient.rs @@ -3,7 +3,7 @@ use crate::lightwallet::LightWallet; use std::sync::{Arc, RwLock, Mutex, mpsc::channel}; -use std::sync::atomic::{AtomicI32, AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicI32, AtomicUsize, Ordering}; use std::path::{Path, PathBuf}; use std::fs::File; use std::collections::{HashSet, HashMap}; @@ -307,6 +307,7 @@ pub struct LightClient { sync_lock : Mutex<()>, sync_status : Arc>, // The current syncing status of the Wallet. + pub shutdown_flag : Arc, // Signal mempool threads to stop } impl LightClient { @@ -392,6 +393,7 @@ impl LightClient { sapling_spend : vec![], sync_lock : Mutex::new(()), sync_status : Arc::new(RwLock::new(WalletStatus::new())), + shutdown_flag : Arc::new(AtomicBool::new(false)), }; l.set_wallet_initial_state(0); @@ -420,6 +422,7 @@ impl LightClient { sapling_spend : vec![], sync_lock : Mutex::new(()), sync_status : Arc::new(RwLock::new(WalletStatus::new())), + shutdown_flag : Arc::new(AtomicBool::new(false)), }; l.set_wallet_initial_state(latest_block); @@ -447,6 +450,7 @@ impl LightClient { sapling_spend : vec![], sync_lock : Mutex::new(()), sync_status : Arc::new(RwLock::new(WalletStatus::new())), + shutdown_flag : Arc::new(AtomicBool::new(false)), }; // println!("Setting birthday to {}", birthday); @@ -464,13 +468,38 @@ impl LightClient { pub fn read_from_disk(config: &LightClientConfig) -> io::Result { if !config.wallet_exists() { - return Err(Error::new(ErrorKind::AlreadyExists, - format!("Cannot read wallet. No file at {}", config.get_wallet_path().display()))); + // Try to recover from backup + let bak_path = config.get_wallet_path().with_extension("dat.bak"); + if bak_path.exists() { + warn!("Wallet file missing but backup found, attempting recovery from {:?}", bak_path); + std::fs::copy(&bak_path, config.get_wallet_path())?; + info!("Wallet recovered from backup"); + } else { + return Err(Error::new(ErrorKind::AlreadyExists, + format!("Cannot read wallet. No file at {}", config.get_wallet_path().display()))); + } } - let mut file_buffer = BufReader::new(File::open(config.get_wallet_path())?); - - let wallet = LightWallet::read(&mut file_buffer, config)?; + // Try to open the wallet file; if it fails (corrupted/truncated), try the backup + let wallet = match File::open(config.get_wallet_path()) + .and_then(|f| { + let mut file_buffer = BufReader::new(f); + LightWallet::read(&mut file_buffer, config) + }) { + Ok(w) => w, + Err(e) => { + warn!("Failed to read wallet file: {}, trying backup", e); + let bak_path = config.get_wallet_path().with_extension("dat.bak"); + if bak_path.exists() { + std::fs::copy(&bak_path, config.get_wallet_path())?; + let mut file_buffer = BufReader::new(File::open(config.get_wallet_path())?); + LightWallet::read(&mut file_buffer, config)? + } else { + return Err(e); + } + } + }; + let mut lc = LightClient { wallet : Arc::new(RwLock::new(wallet)), config : config.clone(), @@ -478,6 +507,7 @@ impl LightClient { sapling_spend : vec![], sync_lock : Mutex::new(()), sync_status : Arc::new(RwLock::new(WalletStatus::new())), + shutdown_flag : Arc::new(AtomicBool::new(false)), }; #[cfg(feature = "embed_params")] @@ -682,22 +712,54 @@ impl LightClient { // Prevent any overlapping syncs during save, and don't save in the middle of a sync let _lock = self.sync_lock.lock().unwrap(); + let wallet_path = self.config.get_wallet_path(); + let tmp_path = wallet_path.with_extension("dat.tmp"); + let bak_path = wallet_path.with_extension("dat.bak"); + + // Write to a temporary file first (atomic save) let wallet = self.wallet.write().unwrap(); let mut file_buffer = BufWriter::with_capacity( 1_000_000, // 1 MB write buffer - File::create(self.config.get_wallet_path()).unwrap()); + File::create(&tmp_path).map_err(|e| format!("Failed to create temp file: {}", e))?); r = match wallet.write(&mut file_buffer) { Ok(_) => Ok(()), Err(e) => { let err = format!("ERR: {}", e); error!("{}", err); + // Clean up temp file on write failure + let _ = std::fs::remove_file(&tmp_path); Err(e.to_string()) } }; - file_buffer.flush().map_err(|e| format!("{}", e))?; + file_buffer.flush().map_err(|e| { + let _ = std::fs::remove_file(&tmp_path); + format!("{}", e) + })?; + + // Only proceed with rename if the write succeeded + if r.is_ok() { + // Create backup of existing wallet (if it exists) + if wallet_path.exists() { + if let Err(e) = std::fs::copy(&wallet_path, &bak_path) { + warn!("Failed to create wallet backup: {}", e); + // Non-fatal: continue with the save + } + } + + // Atomically replace the wallet file with the new one + if let Err(e) = std::fs::rename(&tmp_path, &wallet_path) { + error!("Failed to rename temp wallet file: {}", e); + // Try direct copy as fallback (rename can fail across filesystems) + if let Err(e2) = std::fs::copy(&tmp_path, &wallet_path) { + let _ = std::fs::remove_file(&tmp_path); + return Err(format!("Failed to save wallet: {} / {}", e, e2)); + } + let _ = std::fs::remove_file(&tmp_path); + } } + } r } @@ -1071,12 +1133,17 @@ impl LightClient { pub fn start_mempool_monitor(lc: Arc) -> Result<(), String> { let config = lc.config.clone(); let uri = config.server.clone(); + let shutdown = lc.shutdown_flag.clone(); let (incoming_mempool_tx, incoming_mempool_rx) = std::sync::mpsc::channel::(); - // Thread for reveive transactions + // Thread for receive transactions + let shutdown_rx = lc.shutdown_flag.clone(); std::thread::spawn(move || { - while let Ok(rtx) = incoming_mempool_rx.recv() { + while let Ok(rtx) = incoming_mempool_rx.recv() { + if shutdown_rx.load(Ordering::Relaxed) { + break; + } if let Ok(tx) = Transaction::read( &rtx.data[..]) { @@ -1091,6 +1158,11 @@ pub fn start_mempool_monitor(lc: Arc) -> Result<(), String> { let mut rt = Runtime::new().unwrap(); rt.block_on(async { loop { + if shutdown.load(Ordering::Relaxed) { + info!("Mempool monitor shutting down"); + break; + } + let incoming_mempool_tx_clone = incoming_mempool_tx.clone(); let send_closure = move |rtx: RawTransaction| { incoming_mempool_tx_clone.send(rtx).map_err(|e| Box::new(e) as Box) @@ -1101,13 +1173,26 @@ pub fn start_mempool_monitor(lc: Arc) -> Result<(), String> { Err(e) => warn!("Mempool monitor returned {:?}, will restart listening", e), } - std::thread::sleep(Duration::from_secs(10)); + // Sleep in 1-second increments so we can check shutdown flag + for _ in 0..10 { + if shutdown.load(Ordering::Relaxed) { + info!("Mempool monitor shutting down during sleep"); + return; + } + std::thread::sleep(Duration::from_secs(1)); + } } }); }); Ok(()) } + + /// Signal all background threads to stop + pub fn shutdown(&self) { + self.shutdown_flag.store(true, Ordering::Relaxed); + info!("Shutdown flag set"); + } /// Convinence function to determine what type of key this is and import it pub fn do_import_key(&self, key: String, birthday: u64) -> Result { if key.starts_with(self.config.hrp_sapling_private_key()) { @@ -1714,6 +1799,7 @@ pub mod tests { sapling_spend : vec![], sync_lock : Mutex::new(()), sync_status : Arc::new(RwLock::new(WalletStatus::new())), + shutdown_flag : Arc::new(AtomicBool::new(false)), }; { let addresses = lc.do_address();