diff --git a/.gitignore b/.gitignore index b07033a86..641f4f03b 100644 --- a/.gitignore +++ b/.gitignore @@ -45,5 +45,3 @@ cobertura.xml # Build scripts artifacts *.log -/dash-spv-ffi/peer_reputation.json -/dash-spv/peer_reputation.json diff --git a/dash-spv-ffi/src/types.rs b/dash-spv-ffi/src/types.rs index c644c52da..e05504e59 100644 --- a/dash-spv-ffi/src/types.rs +++ b/dash-spv-ffi/src/types.rs @@ -181,8 +181,6 @@ impl From for FFIDetailedSyncProgress { #[repr(C)] pub struct FFIChainState { - pub header_height: u32, - pub filter_header_height: u32, pub masternode_height: u32, pub last_chainlock_height: u32, pub last_chainlock_hash: FFIString, @@ -192,8 +190,6 @@ pub struct FFIChainState { impl From for FFIChainState { fn from(state: ChainState) -> Self { FFIChainState { - header_height: state.headers.len() as u32, - filter_header_height: state.filter_headers.len() as u32, masternode_height: state.last_masternode_diff_height.unwrap_or(0), last_chainlock_height: state.last_chainlock_height.unwrap_or(0), last_chainlock_hash: FFIString::new( diff --git a/dash-spv-ffi/tests/unit/test_type_conversions.rs b/dash-spv-ffi/tests/unit/test_type_conversions.rs index 58e29ce5f..9713fe183 100644 --- a/dash-spv-ffi/tests/unit/test_type_conversions.rs +++ b/dash-spv-ffi/tests/unit/test_type_conversions.rs @@ -163,8 +163,6 @@ mod tests { #[test] fn test_chain_state_none_values() { let state = dash_spv::ChainState { - headers: vec![], - filter_headers: vec![], last_chainlock_height: None, last_chainlock_hash: None, current_filter_tip: None, @@ -174,8 +172,7 @@ mod tests { }; let ffi_state = FFIChainState::from(state); - assert_eq!(ffi_state.header_height, 0); - assert_eq!(ffi_state.filter_header_height, 0); + assert_eq!(ffi_state.masternode_height, 0); assert_eq!(ffi_state.last_chainlock_height, 0); assert_eq!(ffi_state.current_filter_tip, 0); diff --git a/dash-spv/examples/filter_sync.rs b/dash-spv/examples/filter_sync.rs index 25e86a5bf..98c9cc103 100644 --- a/dash-spv/examples/filter_sync.rs +++ b/dash-spv/examples/filter_sync.rs @@ -28,8 +28,7 @@ async fn main() -> Result<(), Box> { let network_manager = PeerNetworkManager::new(&config).await?; // Create storage manager - let storage_manager = - DiskStorageManager::new("./.tmp/filter-sync-example-storage".into()).await?; + let storage_manager = DiskStorageManager::new("./.tmp/filter-sync-example-storage").await?; // Create wallet manager let wallet = Arc::new(RwLock::new(WalletManager::::new())); diff --git a/dash-spv/examples/simple_sync.rs b/dash-spv/examples/simple_sync.rs index 89a70066a..08238c8ea 100644 --- a/dash-spv/examples/simple_sync.rs +++ b/dash-spv/examples/simple_sync.rs @@ -24,8 +24,7 @@ async fn main() -> Result<(), Box> { let network_manager = PeerNetworkManager::new(&config).await?; // Create storage manager - let storage_manager = - DiskStorageManager::new("./.tmp/simple-sync-example-storage".into()).await?; + let storage_manager = DiskStorageManager::new("./.tmp/simple-sync-example-storage").await?; // Create wallet manager let wallet = Arc::new(RwLock::new(WalletManager::::new())); diff --git a/dash-spv/examples/spv_with_wallet.rs b/dash-spv/examples/spv_with_wallet.rs index 8e4b4e866..d1fce9e6f 100644 --- a/dash-spv/examples/spv_with_wallet.rs +++ b/dash-spv/examples/spv_with_wallet.rs @@ -26,8 +26,7 @@ async fn main() -> Result<(), Box> { let network_manager = PeerNetworkManager::new(&config).await?; // Create storage manager - use disk storage for persistence - let storage_manager = - DiskStorageManager::new("./.tmp/spv-with-wallet-example-storage".into()).await?; + let storage_manager = DiskStorageManager::new("./.tmp/spv-with-wallet-example-storage").await?; // Create wallet manager let wallet = Arc::new(RwLock::new(WalletManager::::new())); diff --git a/dash-spv/src/chain/chainlock_manager.rs b/dash-spv/src/chain/chainlock_manager.rs index b4780bbd7..0bdcfb5c3 100644 --- a/dash-spv/src/chain/chainlock_manager.rs +++ b/dash-spv/src/chain/chainlock_manager.rs @@ -175,7 +175,11 @@ impl ChainLockManager { } // Verify the block exists in our chain - if let Some(header) = chain_state.header_at_height(chain_lock.block_height) { + if let Some(header) = storage + .get_header(chain_lock.block_height) + .await + .map_err(ValidationError::StorageError)? + { let header_hash = header.block_hash(); if header_hash != chain_lock.block_hash { return Err(ValidationError::InvalidChainLock(format!( diff --git a/dash-spv/src/client/block_processor_test.rs b/dash-spv/src/client/block_processor_test.rs index 418a449ed..7106a7a13 100644 --- a/dash-spv/src/client/block_processor_test.rs +++ b/dash-spv/src/client/block_processor_test.rs @@ -4,7 +4,7 @@ mod tests { use crate::client::block_processor::{BlockProcessingTask, BlockProcessor}; - use crate::storage::DiskStorageManager; + use crate::storage::{BlockHeaderStorage, DiskStorageManager}; use crate::types::{SpvEvent, SpvStats}; use dashcore::{blockdata::constants::genesis_block, Block, Network, Transaction}; diff --git a/dash-spv/src/client/chainlock.rs b/dash-spv/src/client/chainlock.rs index 553f0b58d..59632d8dd 100644 --- a/dash-spv/src/client/chainlock.rs +++ b/dash-spv/src/client/chainlock.rs @@ -43,8 +43,7 @@ impl< .await { // Penalize the peer that relayed the invalid ChainLock - let reason = format!("Invalid ChainLock: {}", e); - let _ = self.network.penalize_last_message_peer_invalid_chainlock(&reason).await; + let _ = self.network.penalize_last_message_peer_invalid_chainlock().await; return Err(SpvError::Validation(e)); } } @@ -111,7 +110,7 @@ impl< tracing::warn!("{}", reason); // Ban the peer using the reputation system - let _ = self.network.penalize_last_message_peer_invalid_instantlock(&reason).await; + let _ = self.network.penalize_last_message_peer_invalid_instantlock().await; return Err(SpvError::Validation(e)); } diff --git a/dash-spv/src/client/core.rs b/dash-spv/src/client/core.rs index c4ab3199c..e3c1011d2 100644 --- a/dash-spv/src/client/core.rs +++ b/dash-spv/src/client/core.rs @@ -189,14 +189,17 @@ impl< /// Returns the current chain tip hash if available. pub async fn tip_hash(&self) -> Option { - let state = self.state.read().await; - state.tip_hash() + let storage = self.storage.lock().await; + + let tip_height = storage.get_tip_height().await?; + let header = storage.get_header(tip_height).await.ok()??; + + Some(header.block_hash()) } /// Returns the current chain tip height (absolute), accounting for checkpoint base. pub async fn tip_height(&self) -> u32 { - let state = self.state.read().await; - state.tip_height() + self.storage.lock().await.get_tip_height().await.unwrap_or(0) } /// Get current chain state (read-only). @@ -271,42 +274,6 @@ impl< Ok(()) } - /// Clear all stored filter headers and compact filters while keeping other data intact. - pub async fn clear_filters(&mut self) -> Result<()> { - { - let mut storage = self.storage.lock().await; - storage.clear_filters().await.map_err(SpvError::Storage)?; - } - - // Reset in-memory chain state for filters - { - let mut state = self.state.write().await; - state.filter_headers.clear(); - state.current_filter_tip = None; - } - - // Reset filter sync manager tracking - self.sync_manager.filter_sync_mut().clear_filter_state().await; - - // Reset filter-related statistics - let received_heights = { - let stats = self.stats.read().await; - stats.received_filter_heights.clone() - }; - - { - let mut stats = self.stats.write().await; - stats.filter_headers_downloaded = 0; - stats.filter_height = 0; - stats.filters_downloaded = 0; - stats.filters_received = 0; - } - - received_heights.lock().await.clear(); - - Ok(()) - } - // ============ Configuration ============ /// Update the client configuration. diff --git a/dash-spv/src/client/lifecycle.rs b/dash-spv/src/client/lifecycle.rs index 2711db224..b0de35d19 100644 --- a/dash-spv/src/client/lifecycle.rs +++ b/dash-spv/src/client/lifecycle.rs @@ -169,30 +169,12 @@ impl< // This ensures the ChainState has headers loaded for both checkpoint and normal sync let tip_height = { let storage = self.storage.lock().await; - storage.get_tip_height().await.map_err(SpvError::Storage)?.unwrap_or(0) + storage.get_tip_height().await.unwrap_or(0) }; if tip_height > 0 { tracing::info!("Found {} headers in storage, loading into sync manager...", tip_height); - let loaded_count = { - let storage = self.storage.lock().await; - self.sync_manager.load_headers_from_storage(&storage).await - }; - - match loaded_count { - Ok(loaded_count) => { - tracing::info!("✅ Sync manager loaded {} headers from storage", loaded_count); - } - Err(e) => { - tracing::error!("Failed to load headers into sync manager: {}", e); - // For checkpoint sync, this is critical - let state = self.state.read().await; - if state.synced_from_checkpoint() { - return Err(SpvError::Sync(e)); - } - // For normal sync, we can continue as headers will be re-synced - tracing::warn!("Continuing without pre-loaded headers for normal sync"); - } - } + let storage = self.storage.lock().await; + self.sync_manager.load_headers_from_storage(&storage).await } // Connect to network @@ -209,8 +191,7 @@ impl< // Get initial header count from storage let (header_height, filter_height) = { let storage = self.storage.lock().await; - let h_height = - storage.get_tip_height().await.map_err(SpvError::Storage)?.unwrap_or(0); + let h_height = storage.get_tip_height().await.unwrap_or(0); let f_height = storage.get_filter_tip_height().await.map_err(SpvError::Storage)?.unwrap_or(0); (h_height, f_height) @@ -244,7 +225,7 @@ impl< // Shutdown storage to ensure all data is persisted { let mut storage = self.storage.lock().await; - storage.shutdown().await.map_err(SpvError::Storage)?; + storage.shutdown().await; tracing::info!("Storage shutdown completed - all data persisted"); } @@ -271,7 +252,7 @@ impl< // Check if we already have any headers in storage let current_tip = { let storage = self.storage.lock().await; - storage.get_tip_height().await.map_err(SpvError::Storage)? + storage.get_tip_height().await }; if current_tip.is_some() { @@ -344,12 +325,12 @@ impl< // Clone the chain state for storage let chain_state_for_storage = (*chain_state).clone(); - let headers_len = chain_state_for_storage.headers.len() as u32; drop(chain_state); // Update storage with chain state including sync_base_height { let mut storage = self.storage.lock().await; + storage.store_headers(&[checkpoint_header]).await?; storage .store_chain_state(&chain_state_for_storage) .await @@ -366,7 +347,7 @@ impl< ); // Update the sync manager's cached flags from the checkpoint-initialized state - self.sync_manager.update_chain_state_cache(checkpoint.height, headers_len); + self.sync_manager.update_chain_state_cache(checkpoint.height); tracing::info!( "Updated sync manager with checkpoint-initialized chain state" ); @@ -414,7 +395,7 @@ impl< // Verify it was stored correctly let stored_height = { let storage = self.storage.lock().await; - storage.get_tip_height().await.map_err(SpvError::Storage)? + storage.get_tip_height().await }; tracing::info!( "✅ Genesis block initialized at height 0, storage reports tip height: {:?}", diff --git a/dash-spv/src/client/progress.rs b/dash-spv/src/client/progress.rs index 7998560a6..5bc2b8d4c 100644 --- a/dash-spv/src/client/progress.rs +++ b/dash-spv/src/client/progress.rs @@ -38,7 +38,7 @@ impl< // Get current heights from storage { let storage = self.storage.lock().await; - if let Ok(Some(header_height)) = storage.get_tip_height().await { + if let Some(header_height) = storage.get_tip_height().await { stats.header_height = header_height; } diff --git a/dash-spv/src/client/queries.rs b/dash-spv/src/client/queries.rs index bb0be8c3b..6adb2e271 100644 --- a/dash-spv/src/client/queries.rs +++ b/dash-spv/src/client/queries.rs @@ -42,20 +42,6 @@ impl< self.network.peer_count() } - /// Disconnect a specific peer. - pub async fn disconnect_peer(&self, addr: &std::net::SocketAddr, reason: &str) -> Result<()> { - // Cast network manager to PeerNetworkManager to access disconnect_peer - let network = self - .network - .as_any() - .downcast_ref::() - .ok_or_else(|| { - SpvError::Config("Network manager does not support peer disconnection".to_string()) - })?; - - network.disconnect_peer(addr, reason).await - } - // ============ Masternode Queries ============ /// Get a reference to the masternode list engine. diff --git a/dash-spv/src/client/status_display.rs b/dash-spv/src/client/status_display.rs index 0324fe964..3b07fca9d 100644 --- a/dash-spv/src/client/status_display.rs +++ b/dash-spv/src/client/status_display.rs @@ -76,7 +76,7 @@ impl<'a, S: StorageManager + Send + Sync + 'static, W: WalletInterface + Send + // For genesis sync: sync_base_height = 0, so height = 0 + storage_count // For checkpoint sync: height = checkpoint_height + storage_count let storage = self.storage.lock().await; - if let Ok(Some(storage_tip)) = storage.get_tip_height().await { + if let Some(storage_tip) = storage.get_tip_height().await { let blockchain_height = storage_tip; if with_logging { tracing::debug!( diff --git a/dash-spv/src/client/sync_coordinator.rs b/dash-spv/src/client/sync_coordinator.rs index de06633ec..2af4716dc 100644 --- a/dash-spv/src/client/sync_coordinator.rs +++ b/dash-spv/src/client/sync_coordinator.rs @@ -42,7 +42,7 @@ impl< let result = SyncProgress { header_height: { let storage = self.storage.lock().await; - storage.get_tip_height().await.map_err(SpvError::Storage)?.unwrap_or(0) + storage.get_tip_height().await.unwrap_or(0) }, filter_header_height: { let storage = self.storage.lock().await; @@ -241,7 +241,7 @@ impl< // Storage tip now represents the absolute blockchain height. let current_tip_height = { let storage = self.storage.lock().await; - storage.get_tip_height().await.ok().flatten().unwrap_or(0) + storage.get_tip_height().await.unwrap_or(0) }; let current_height = current_tip_height; let peer_best = self @@ -315,7 +315,7 @@ impl< // Emit filter headers progress only when heights change let (abs_header_height, filter_header_height) = { let storage = self.storage.lock().await; - let storage_tip = storage.get_tip_height().await.ok().flatten().unwrap_or(0); + let storage_tip = storage.get_tip_height().await.unwrap_or(0); let filter_tip = storage.get_filter_tip_height().await.ok().flatten().unwrap_or(0); (storage_tip, filter_tip) diff --git a/dash-spv/src/lib.rs b/dash-spv/src/lib.rs index 2e93b57b6..291807819 100644 --- a/dash-spv/src/lib.rs +++ b/dash-spv/src/lib.rs @@ -30,7 +30,7 @@ //! //! // Create the required components //! let network = PeerNetworkManager::new(&config).await?; -//! let storage = DiskStorageManager::new("./.tmp/example-storage".into()).await?; +//! let storage = DiskStorageManager::new("./.tmp/example-storage").await?; //! let wallet = Arc::new(RwLock::new(WalletManager::::new())); //! //! // Create and start the client diff --git a/dash-spv/src/network/manager.rs b/dash-spv/src/network/manager.rs index c0dc87ff2..ce0e2166d 100644 --- a/dash-spv/src/network/manager.rs +++ b/dash-spv/src/network/manager.rs @@ -1,6 +1,6 @@ //! Peer network manager for SPV client -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::net::SocketAddr; use std::path::PathBuf; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -22,12 +22,10 @@ use crate::error::{NetworkError, NetworkResult, SpvError as Error}; use crate::network::addrv2::AddrV2Handler; use crate::network::constants::*; use crate::network::discovery::DnsDiscovery; -use crate::network::persist::PeerStore; use crate::network::pool::PeerPool; -use crate::network::reputation::{ - misbehavior_scores, positive_scores, PeerReputationManager, ReputationAware, -}; +use crate::network::reputation::{PeerReputationManager, ReputationChangeReason}; use crate::network::{HandshakeManager, NetworkManager, Peer}; +use crate::storage::{PeerStorage, PersistentPeerStorage, PersistentStorage}; use crate::types::PeerInfo; /// Peer network manager @@ -39,9 +37,9 @@ pub struct PeerNetworkManager { /// AddrV2 handler addrv2_handler: Arc, /// Peer persistence - peer_store: Arc, + peer_store: Arc, /// Peer reputation manager - reputation_manager: Arc, + reputation_manager: Arc>, /// Network type network: Network, /// Shutdown token @@ -80,25 +78,10 @@ impl PeerNetworkManager { let discovery = DnsDiscovery::new().await?; let data_dir = config.storage_path.clone().unwrap_or_else(|| PathBuf::from(".")); - let peer_store = PeerStore::new(config.network, data_dir.clone()); - let reputation_manager = Arc::new(PeerReputationManager::new()); + let peer_store = PersistentPeerStorage::open(data_dir.clone()).await?; - // Load reputation data if available - let reputation_path = data_dir.join("peer_reputation.json"); - - // Ensure the directory exists before attempting to load - if let Some(parent_dir) = reputation_path.parent() { - if !parent_dir.exists() { - if let Err(e) = std::fs::create_dir_all(parent_dir) { - log::warn!("Failed to create directory for reputation data: {}", e); - } - } - } - - if let Err(e) = reputation_manager.load_from_storage(&reputation_path).await { - log::warn!("Failed to load peer reputation data: {}", e); - } + let reputation_manager = PeerReputationManager::load_or_new(&peer_store).await; // Determine exclusive mode: either explicitly requested or peers were provided let exclusive_mode = config.restrict_to_configured_peers || !config.peers.is_empty(); @@ -108,7 +91,7 @@ impl PeerNetworkManager { discovery: Arc::new(discovery), addrv2_handler: Arc::new(AddrV2Handler::new()), peer_store: Arc::new(peer_store), - reputation_manager, + reputation_manager: Arc::new(Mutex::new(reputation_manager)), network: config.network, shutdown_token: CancellationToken::new(), message_tx, @@ -183,7 +166,7 @@ impl PeerNetworkManager { /// Connect to a specific peer async fn connect_to_peer(&self, addr: SocketAddr) { // Check reputation first - if !self.reputation_manager.should_connect_to_peer(&addr).await { + if !self.reputation_manager.lock().await.should_connect_to_peer(&addr).await { log::warn!("Not connecting to {} due to bad reputation", addr); return; } @@ -199,7 +182,7 @@ impl PeerNetworkManager { } // Record connection attempt - self.reputation_manager.record_connection_attempt(addr).await; + self.reputation_manager.lock().await.record_connection_attempt(addr).await; let pool = self.pool.clone(); let network = self.network; @@ -225,9 +208,6 @@ impl PeerNetworkManager { Ok(_) => { log::info!("Successfully connected to {}", addr); - // Record successful connection - reputation_manager.record_successful_connection(addr).await; - // Add to pool if let Err(e) = pool.add_peer(addr, peer).await { log::error!("Failed to add peer to pool: {}", e); @@ -256,11 +236,9 @@ impl PeerNetworkManager { log::warn!("Handshake failed with {}: {}", addr, e); // Update reputation for handshake failure reputation_manager - .update_reputation( - addr, - misbehavior_scores::INVALID_MESSAGE, - "Handshake failed", - ) + .lock() + .await + .update_reputation(addr, ReputationChangeReason::InvalidMessage) .await; // For handshake failures, try again later tokio::time::sleep(RECONNECT_DELAY).await; @@ -271,11 +249,9 @@ impl PeerNetworkManager { log::debug!("Failed to connect to {}: {}", addr, e); // Minor reputation penalty for connection failure reputation_manager - .update_reputation( - addr, - misbehavior_scores::TIMEOUT / 2, - "Connection failed", - ) + .lock() + .await + .update_reputation(addr, ReputationChangeReason::Timeout) .await; } } @@ -289,7 +265,7 @@ impl PeerNetworkManager { message_tx: mpsc::Sender<(SocketAddr, NetworkMessage)>, addrv2_handler: Arc, shutdown_token: CancellationToken, - reputation_manager: Arc, + reputation_manager: Arc>, connected_peer_count: Arc, ) { tokio::spawn(async move { @@ -496,11 +472,9 @@ impl PeerNetworkManager { log::debug!("Timeout reading from {}, continuing...", addr); // Minor reputation penalty for timeout reputation_manager - .update_reputation( - addr, - misbehavior_scores::TIMEOUT, - "Read timeout", - ) + .lock() + .await + .update_reputation(addr, ReputationChangeReason::Timeout) .await; continue; } @@ -518,10 +492,11 @@ impl PeerNetworkManager { ); // Reputation penalty for invalid data reputation_manager + .lock() + .await .update_reputation( addr, - misbehavior_scores::INVALID_TRANSACTION, - "Invalid transaction type in block", + ReputationChangeReason::InvalidTransaction, ) .await; } else if error_msg @@ -578,7 +553,9 @@ impl PeerNetworkManager { if conn_duration > Duration::from_secs(3600) { // 1 hour reputation_manager - .update_reputation(addr, positive_scores::LONG_UPTIME, "Long connection uptime") + .lock() + .await + .update_reputation(addr, ReputationChangeReason::LongUptime) .await; } }); @@ -595,7 +572,6 @@ impl PeerNetworkManager { let reputation_manager = self.reputation_manager.clone(); let peer_search_started = self.peer_search_started.clone(); let initial_peers = self.initial_peers.clone(); - let data_dir = self.data_dir.clone(); let connected_peer_count = self.connected_peer_count.clone(); // Check if we're in exclusive mode (explicit flag or peers configured) @@ -656,7 +632,7 @@ impl PeerNetworkManager { let known = addrv2_handler.get_known_addresses().await; let needed = TARGET_PEERS.saturating_sub(count); // Select best peers based on reputation - let best_peers = reputation_manager.select_best_peers(known, needed * 2).await; + let best_peers = reputation_manager.lock().await.select_best_peers(known, needed * 2).await; let mut attempted = 0; for addr in best_peers { @@ -730,10 +706,9 @@ impl PeerNetworkManager { if let Err(e) = peer_guard.send_ping().await { log::error!("Failed to ping {}: {}", addr, e); // Update reputation for ping failure - reputation_manager.update_reputation( + reputation_manager.lock().await.update_reputation( addr, - misbehavior_scores::TIMEOUT, - "Ping failed", + ReputationChangeReason::Timeout, ).await; } } @@ -750,8 +725,7 @@ impl PeerNetworkManager { } // Save reputation data periodically - let storage_path = data_dir.join("peer_reputation.json"); - if let Err(e) = reputation_manager.save_to_storage(&storage_path).await { + if let Err(e) = reputation_manager.lock().await.save_to_storage(&peer_store).await { log::warn!("Failed to save reputation data: {}", e); } } @@ -946,21 +920,12 @@ impl PeerNetworkManager { } /// Disconnect a specific peer - pub async fn disconnect_peer(&self, addr: &SocketAddr, reason: &str) -> Result<(), Error> { - log::info!("Disconnecting peer {} - reason: {}", addr, reason); - - // Remove the peer + pub async fn disconnect_peer(&self, addr: &SocketAddr) -> Result<(), Error> { self.pool.remove_peer(addr).await; Ok(()) } - /// Get reputation information for all peers - pub async fn get_peer_reputations(&self) -> HashMap { - let reputations = self.reputation_manager.get_all_reputations().await; - reputations.into_iter().map(|(addr, rep)| (addr, (rep.score, rep.is_banned()))).collect() - } - /// Get the last peer that sent us a message pub async fn get_last_message_peer(&self) -> Option { let last_peer = self.last_message_peer.lock().await; @@ -987,28 +952,9 @@ impl PeerNetworkManager { *last_peer } - /// Ban a specific peer manually - pub async fn ban_peer(&self, addr: &SocketAddr, reason: &str) -> Result<(), Error> { - log::info!("Manually banning peer {} - reason: {}", addr, reason); - - // Disconnect the peer first - self.disconnect_peer(addr, reason).await?; - - // Update reputation to trigger ban - self.reputation_manager - .update_reputation( - *addr, - misbehavior_scores::INVALID_HEADER * 2, // Severe penalty - reason, - ) - .await; - - Ok(()) - } - /// Unban a specific peer pub async fn unban_peer(&self, addr: &SocketAddr) { - self.reputation_manager.unban_peer(addr).await; + self.reputation_manager.lock().await.unban_peer(addr).await; } /// Shutdown the network manager @@ -1025,8 +971,8 @@ impl PeerNetworkManager { } // Save reputation data before shutdown - let reputation_path = self.data_dir.join("peer_reputation.json"); - if let Err(e) = self.reputation_manager.save_to_storage(&reputation_path).await { + if let Err(e) = self.reputation_manager.lock().await.save_to_storage(&self.peer_store).await + { log::warn!("Failed to save reputation data on shutdown: {}", e); } @@ -1121,81 +1067,68 @@ impl NetworkManager for PeerNetworkManager { async fn penalize_last_message_peer( &self, - score_change: i32, - reason: &str, + reason: ReputationChangeReason, ) -> NetworkResult<()> { // Get the last peer that sent us a message if let Some(addr) = self.get_last_message_peer().await { - self.reputation_manager.update_reputation(addr, score_change, reason).await; + self.reputation_manager.lock().await.update_reputation(addr, reason).await; } Ok(()) } - async fn penalize_last_message_peer_invalid_chainlock( - &self, - reason: &str, - ) -> NetworkResult<()> { + async fn penalize_last_message_peer_invalid_chainlock(&self) -> NetworkResult<()> { if let Some(addr) = self.get_last_message_peer().await { - match self.disconnect_peer(&addr, reason).await { + match self.disconnect_peer(&addr).await { Ok(()) => { - log::warn!( - "Peer {} disconnected for invalid ChainLock enforcement: {}", - addr, - reason - ); + log::warn!("Peer {addr} disconnected for invalid ChainLock enforcement",); } Err(err) => { log::error!( - "Failed to disconnect peer {} after invalid ChainLock enforcement ({}): {}", - addr, - reason, - err + "Failed to disconnect peer {addr} after invalid ChainLock enforcement: {err}", ); } } // Apply misbehavior score and a short temporary ban self.reputation_manager - .update_reputation(addr, misbehavior_scores::INVALID_CHAINLOCK, reason) + .lock() + .await + .update_reputation(addr, ReputationChangeReason::InvalidChainLock) .await; // Short ban: 10 minutes for relaying invalid ChainLock self.reputation_manager - .temporary_ban_peer(addr, Duration::from_secs(10 * 60), reason) + .lock() + .await + .temporary_ban_peer(addr, Duration::from_secs(10 * 60)) .await; } Ok(()) } - async fn penalize_last_message_peer_invalid_instantlock( - &self, - reason: &str, - ) -> NetworkResult<()> { + async fn penalize_last_message_peer_invalid_instantlock(&self) -> NetworkResult<()> { if let Some(addr) = self.get_last_message_peer().await { // Apply misbehavior score and a short temporary ban self.reputation_manager - .update_reputation(addr, misbehavior_scores::INVALID_INSTANTLOCK, reason) + .lock() + .await + .update_reputation(addr, ReputationChangeReason::InvalidInstantLock) .await; // Short ban: 10 minutes for relaying invalid InstantLock self.reputation_manager - .temporary_ban_peer(addr, Duration::from_secs(10 * 60), reason) + .lock() + .await + .temporary_ban_peer(addr, Duration::from_secs(10 * 60)) .await; - match self.disconnect_peer(&addr, reason).await { + match self.disconnect_peer(&addr).await { Ok(()) => { - log::warn!( - "Peer {} disconnected for invalid InstantLock enforcement: {}", - addr, - reason - ); + log::warn!("Peer {addr} disconnected for invalid InstantLock enforcement",); } Err(err) => { log::error!( - "Failed to disconnect peer {} after invalid InstantLock enforcement ({}): {}", - addr, - reason, - err + "Failed to disconnect peer {addr} after invalid InstantLock enforcement: {err}" ); } } diff --git a/dash-spv/src/network/mod.rs b/dash-spv/src/network/mod.rs index 89e8bde78..f2bcaa562 100644 --- a/dash-spv/src/network/mod.rs +++ b/dash-spv/src/network/mod.rs @@ -6,9 +6,8 @@ pub mod discovery; pub mod handshake; pub mod manager; pub mod peer; -pub mod persist; pub mod pool; -pub mod reputation; +mod reputation; #[cfg(test)] mod tests; @@ -18,13 +17,14 @@ pub mod mock; use async_trait::async_trait; -use crate::error::NetworkResult; +use crate::{error::NetworkResult, network::reputation::ReputationChangeReason}; use dashcore::network::message::NetworkMessage; use dashcore::BlockHash; pub use handshake::{HandshakeManager, HandshakeState}; pub use manager::PeerNetworkManager; pub use peer::Peer; +pub use reputation::PeerReputation; /// Network manager trait for abstracting network operations. #[async_trait] @@ -130,33 +130,18 @@ pub trait NetworkManager: Send + Sync { /// Default implementation is a no-op for managers without reputation. async fn penalize_last_message_peer( &self, - _score_change: i32, - _reason: &str, + _reason: ReputationChangeReason, ) -> NetworkResult<()> { Ok(()) } /// Convenience: penalize last peer for an invalid ChainLock. - async fn penalize_last_message_peer_invalid_chainlock( - &self, - reason: &str, - ) -> NetworkResult<()> { - self.penalize_last_message_peer( - crate::network::reputation::misbehavior_scores::INVALID_CHAINLOCK, - reason, - ) - .await + async fn penalize_last_message_peer_invalid_chainlock(&self) -> NetworkResult<()> { + self.penalize_last_message_peer(ReputationChangeReason::InvalidChainLock).await } /// Convenience: penalize last peer for an invalid InstantLock. - async fn penalize_last_message_peer_invalid_instantlock( - &self, - reason: &str, - ) -> NetworkResult<()> { - self.penalize_last_message_peer( - crate::network::reputation::misbehavior_scores::INVALID_INSTANTLOCK, - reason, - ) - .await + async fn penalize_last_message_peer_invalid_instantlock(&self) -> NetworkResult<()> { + self.penalize_last_message_peer(ReputationChangeReason::InvalidInstantLock).await } } diff --git a/dash-spv/src/network/persist.rs b/dash-spv/src/network/persist.rs deleted file mode 100644 index 814eedeff..000000000 --- a/dash-spv/src/network/persist.rs +++ /dev/null @@ -1,159 +0,0 @@ -//! Peer persistence for saving and loading known peers - -use dashcore::Network; -use serde::{Deserialize, Serialize}; -use std::path::PathBuf; - -use crate::error::{SpvError as Error, StorageError}; -use crate::storage::io::atomic_write; - -/// Peer persistence for saving and loading known peer addresses -pub struct PeerStore { - network: Network, - path: PathBuf, -} - -#[derive(Serialize, Deserialize)] -struct SavedPeers { - version: u32, - network: String, - peers: Vec, -} - -#[derive(Serialize, Deserialize)] -struct SavedPeer { - address: String, - services: u64, - last_seen: u64, -} - -impl PeerStore { - /// Create a new peer store for the given network - pub fn new(network: Network, data_dir: PathBuf) -> Self { - let filename = format!("peers_{}.json", network); - let path = data_dir.join(filename); - - Self { - network, - path, - } - } - - /// Save peers to disk - pub async fn save_peers( - &self, - peers: &[dashcore::network::address::AddrV2Message], - ) -> Result<(), Error> { - let saved = SavedPeers { - version: 1, - network: format!("{:?}", self.network), - peers: peers - .iter() - .filter_map(|p| { - p.socket_addr().ok().map(|addr| SavedPeer { - address: addr.to_string(), - services: p.services.as_u64(), - last_seen: p.time as u64, - }) - }) - .collect(), - }; - - let json = serde_json::to_string_pretty(&saved) - .map_err(|e| Error::Storage(StorageError::Serialization(e.to_string())))?; - - atomic_write(&self.path, json.as_bytes()).await.map_err(Error::Storage)?; - - log::debug!("Saved {} peers to {:?}", saved.peers.len(), self.path); - Ok(()) - } - - /// Load peers from disk - pub async fn load_peers(&self) -> Result, Error> { - match tokio::fs::read_to_string(&self.path).await { - Ok(json) => { - let saved: SavedPeers = serde_json::from_str(&json).map_err(|e| { - Error::Storage(StorageError::Corruption(format!( - "Failed to parse peers file: {}", - e - ))) - })?; - - // Verify network matches - if saved.network != format!("{:?}", self.network) { - return Err(Error::Storage(StorageError::Corruption(format!( - "Peers file is for network {} but we are on {:?}", - saved.network, self.network - )))); - } - - let addresses: Vec<_> = - saved.peers.iter().filter_map(|p| p.address.parse().ok()).collect(); - - log::info!("Loaded {} peers from {:?}", addresses.len(), self.path); - Ok(addresses) - } - Err(e) if e.kind() == std::io::ErrorKind::NotFound => { - log::debug!("No saved peers file found at {:?}", self.path); - Ok(vec![]) - } - Err(e) => Err(Error::Storage(StorageError::ReadFailed(e.to_string()))), - } - } - - /// Delete the peers file - pub async fn clear(&self) -> Result<(), Error> { - match tokio::fs::remove_file(&self.path).await { - Ok(_) => { - log::info!("Cleared peer store at {:?}", self.path); - Ok(()) - } - Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()), - Err(e) => Err(Error::Storage(StorageError::WriteFailed(e.to_string()))), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use dashcore::network::address::{AddrV2, AddrV2Message}; - use dashcore::network::constants::ServiceFlags; - use tempfile::TempDir; - - #[tokio::test] - async fn test_peer_store_save_load() { - let temp_dir = TempDir::new().expect("Failed to create temporary directory for test"); - let store = PeerStore::new(Network::Dash, temp_dir.path().to_path_buf()); - - // Create test peer messages - let addr: std::net::SocketAddr = - "192.168.1.1:9999".parse().expect("Failed to parse test address"); - let msg = AddrV2Message { - time: 1234567890, - services: ServiceFlags::from(1), - addr: AddrV2::Ipv4( - addr.ip().to_string().parse().expect("Failed to parse IPv4 address"), - ), - port: addr.port(), - }; - - // Save peers - store.save_peers(&[msg]).await.expect("Failed to save peers in test"); - - // Load peers - let loaded = store.load_peers().await.expect("Failed to load peers in test"); - assert_eq!(loaded.len(), 1); - assert_eq!(loaded[0], addr); - } - - #[tokio::test] - async fn test_peer_store_empty() { - let temp_dir = TempDir::new().expect("Failed to create temporary directory for test"); - let store = PeerStore::new(Network::Testnet, temp_dir.path().to_path_buf()); - - // Load from non-existent file - let loaded = store.load_peers().await.expect("Failed to load peers from empty store"); - assert!(loaded.is_empty()); - } -} diff --git a/dash-spv/src/network/reputation.rs b/dash-spv/src/network/reputation.rs index 87e6666f3..4d78b3087 100644 --- a/dash-spv/src/network/reputation.rs +++ b/dash-spv/src/network/reputation.rs @@ -5,124 +5,137 @@ //! implements automatic banning for excessive misbehavior, and provides reputation //! decay over time for recovery. -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; use std::collections::HashMap; use std::net::SocketAddr; -use std::sync::Arc; use std::time::{Duration, Instant}; -use tokio::sync::RwLock; -use crate::storage::io::atomic_write; +use crate::storage::{PeerStorage, PersistentPeerStorage}; -/// Maximum misbehavior score before a peer is banned -const MAX_MISBEHAVIOR_SCORE: i32 = 100; +pub enum ReputationChangeReason { + // Negative Changes + InvalidMessage, + InvalidHeader, + Timeout, + InvalidTransaction, + InvalidChainLock, + InvalidInstantLock, -/// Misbehavior score thresholds for different violations -pub mod misbehavior_scores { - /// Invalid message format or protocol violation - pub const INVALID_MESSAGE: i32 = 10; + // Positive changes + LongUptime, - /// Invalid block header - pub const INVALID_HEADER: i32 = 50; + // Other + Other(i32, String), +} - /// Invalid compact filter - pub const INVALID_FILTER: i32 = 25; +impl ReputationChangeReason { + pub fn score(&self) -> i32 { + // This score represents the missbehaviour score change, that means + // the higher the score, the more severe the violation. + match self { + ReputationChangeReason::InvalidMessage => 10, + ReputationChangeReason::InvalidHeader => 50, + ReputationChangeReason::Timeout => 5, + ReputationChangeReason::InvalidTransaction => 20, + ReputationChangeReason::InvalidChainLock => 40, + ReputationChangeReason::InvalidInstantLock => 35, + ReputationChangeReason::LongUptime => -5, + ReputationChangeReason::Other(score, _) => *score, + } + } +} - /// Timeout or slow response - pub const TIMEOUT: i32 = 5; +impl std::fmt::Display for ReputationChangeReason { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + ReputationChangeReason::InvalidMessage => { + write!(f, "Invalid message format or protocol violation") + } + ReputationChangeReason::InvalidHeader => write!(f, "Invalid block header"), + ReputationChangeReason::Timeout => write!(f, "Timeout or slow response"), + ReputationChangeReason::InvalidTransaction => write!(f, "Invalid transaction"), + ReputationChangeReason::InvalidChainLock => write!(f, "Invalid ChainLock"), + ReputationChangeReason::InvalidInstantLock => write!(f, "Invalid InstantLock"), + ReputationChangeReason::LongUptime => write!(f, "Long uptime"), + ReputationChangeReason::Other(_, reason) => write!(f, "{}", reason), + } + } +} - /// Sending unsolicited data - pub const UNSOLICITED_DATA: i32 = 15; +/// Ban duration for misbehaving peers +const BAN_DURATION: Duration = Duration::from_secs(24 * 60 * 60); // 24 hours - /// Invalid transaction - pub const INVALID_TRANSACTION: i32 = 20; +/// Reputation decay interval +const DECAY_INTERVAL: Duration = Duration::from_secs(60 * 60); // 1 hour - /// Invalid masternode list diff - pub const INVALID_MASTERNODE_DIFF: i32 = 30; +/// Amount to decay reputation score per interval +const DECAY_AMOUNT: i32 = 5; - /// Invalid ChainLock - pub const INVALID_CHAINLOCK: i32 = 40; +/// Maximum misbehavior score before a peer is banned +const MAX_MISBEHAVIOR_SCORE: i32 = 100; - /// Invalid InstantLock - pub const INVALID_INSTANTLOCK: i32 = 35; +/// Minimum score (most positive reputation) +const MIN_MISBEHAVIOR_SCORE: i32 = -50; - /// Duplicate message - pub const DUPLICATE_MESSAGE: i32 = 5; +const MAX_BAN_COUNT: u32 = 1000; - /// Connection flood attempt - pub const CONNECTION_FLOOD: i32 = 20; +fn default_instant() -> Instant { + Instant::now() } -/// Positive behavior scores -pub mod positive_scores { - /// Successfully provided valid headers - pub const VALID_HEADERS: i32 = -5; - - /// Successfully provided valid filters - pub const VALID_FILTERS: i32 = -3; - - /// Successfully provided valid block - pub const VALID_BLOCK: i32 = -10; - - /// Fast response time - pub const FAST_RESPONSE: i32 = -2; +fn clamp_peer_score<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let mut v = i32::deserialize(deserializer)?; + + if v < MIN_MISBEHAVIOR_SCORE { + log::warn!("Peer has invalid score {v}, clamping to min {MIN_MISBEHAVIOR_SCORE}"); + v = MIN_MISBEHAVIOR_SCORE + } else if v > MAX_MISBEHAVIOR_SCORE { + log::warn!("Peer has invalid score {v}, clamping to max {MAX_MISBEHAVIOR_SCORE}"); + v = MAX_MISBEHAVIOR_SCORE + } - /// Long uptime connection - pub const LONG_UPTIME: i32 = -5; + Ok(v) } -/// Ban duration for misbehaving peers -const BAN_DURATION: Duration = Duration::from_secs(24 * 60 * 60); // 24 hours +fn clamp_peer_ban_count<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let mut v = u32::deserialize(deserializer)?; -/// Reputation decay interval -const DECAY_INTERVAL: Duration = Duration::from_secs(60 * 60); // 1 hour - -/// Amount to decay reputation score per interval -const DECAY_AMOUNT: i32 = 5; + if v > MAX_BAN_COUNT { + log::warn!("Peer has excessive ban count {v}, clamping to {MAX_BAN_COUNT}"); + v = MAX_BAN_COUNT + } -/// Minimum score (most positive reputation) -const MIN_SCORE: i32 = -50; + Ok(v) +} /// Peer reputation entry -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct PeerReputation { /// Current misbehavior score - pub score: i32, + #[serde(deserialize_with = "clamp_peer_score")] + score: i32, /// Number of times this peer has been banned - pub ban_count: u32, + #[serde(deserialize_with = "clamp_peer_ban_count")] + ban_count: u32, /// Time when the peer was banned (if currently banned) - pub banned_until: Option, + #[serde(skip)] + banned_until: Option, /// Last time the reputation was updated - pub last_update: Instant, - - /// Total number of positive actions - pub positive_actions: u64, - - /// Total number of negative actions - pub negative_actions: u64, - - /// Connection count - pub connection_attempts: u64, - - /// Successful connection count - pub successful_connections: u64, + #[serde(skip, default = "default_instant")] + last_update: Instant, /// Last connection time - pub last_connection: Option, -} - -// Custom serialization for PeerReputation -#[derive(Serialize, Deserialize)] -struct SerializedPeerReputation { - score: i32, - ban_count: u32, - positive_actions: u64, - negative_actions: u64, - connection_attempts: u64, - successful_connections: u64, + #[serde(skip)] + last_connection: Option, } impl Default for PeerReputation { @@ -131,24 +144,18 @@ impl Default for PeerReputation { score: 0, ban_count: 0, banned_until: None, - last_update: Instant::now(), - positive_actions: 0, - negative_actions: 0, - connection_attempts: 0, - successful_connections: 0, + last_update: default_instant(), last_connection: None, } } } impl PeerReputation { - /// Check if the peer is currently banned - pub fn is_banned(&self) -> bool { + fn is_banned(&self) -> bool { self.banned_until.is_some_and(|until| Instant::now() < until) } - /// Get remaining ban time - pub fn ban_time_remaining(&self) -> Option { + fn ban_time_remaining(&self) -> Option { self.banned_until.and_then(|until| { let now = Instant::now(); if now < until { @@ -160,7 +167,7 @@ impl PeerReputation { } /// Apply reputation decay - pub fn apply_decay(&mut self) { + fn apply_decay(&mut self) { let now = Instant::now(); let elapsed = now - self.last_update; @@ -171,7 +178,7 @@ impl PeerReputation { // Cap at a reasonable maximum to avoid excessive decay let intervals_i32 = intervals.min(i32::MAX as u64) as i32; let decay = intervals_i32.saturating_mul(DECAY_AMOUNT); - self.score = (self.score - decay).max(MIN_SCORE); + self.score = (self.score - decay).max(MIN_MISBEHAVIOR_SCORE); self.last_update = now; } @@ -182,52 +189,35 @@ impl PeerReputation { } } -/// Reputation change event -#[derive(Debug, Clone)] -pub struct ReputationEvent { - pub peer: SocketAddr, - pub change: i32, - pub reason: String, - pub timestamp: Instant, -} - -/// Peer reputation manager pub struct PeerReputationManager { - /// Reputation data for each peer - reputations: Arc>>, + reputations: HashMap, +} - /// Recent reputation events for monitoring - recent_events: Arc>>, +impl PeerReputationManager { + pub async fn load_or_new(storage: &PersistentPeerStorage) -> Self { + let mut reputations = + storage.load_peers_reputation().await.unwrap_or_else(|_| HashMap::new()); - /// Maximum number of events to keep - max_events: usize, -} + log::info!("Loaded reputation data for {} peers", reputations.len()); -impl Default for PeerReputationManager { - fn default() -> Self { - Self::new() - } -} + for (_, reputation) in reputations.iter_mut() { + if reputation.ban_count > 0 { + reputation.score = reputation.score.max(50); // Start with higher score for previously banned peers + } + } -impl PeerReputationManager { - /// Create a new reputation manager - pub fn new() -> Self { Self { - reputations: Arc::new(RwLock::new(HashMap::new())), - recent_events: Arc::new(RwLock::new(Vec::new())), - max_events: 1000, + reputations, } } /// Update peer reputation pub async fn update_reputation( - &self, + &mut self, peer: SocketAddr, - score_change: i32, - reason: &str, + reason: ReputationChangeReason, ) -> bool { - let mut reputations = self.reputations.write().await; - let reputation = reputations.entry(peer).or_default(); + let reputation = self.reputations.entry(peer).or_default(); // Apply decay first reputation.apply_decay(); @@ -235,14 +225,7 @@ impl PeerReputationManager { // Update score let old_score = reputation.score; reputation.score = - (reputation.score + score_change).clamp(MIN_SCORE, MAX_MISBEHAVIOR_SCORE); - - // Track positive/negative actions - if score_change > 0 { - reputation.negative_actions += 1; - } else if score_change < 0 { - reputation.positive_actions += 1; - } + (reputation.score + reason.score()).clamp(MIN_MISBEHAVIOR_SCORE, MAX_MISBEHAVIOR_SCORE); // Check if peer should be banned let should_ban = reputation.score >= MAX_MISBEHAVIOR_SCORE && !reputation.is_banned(); @@ -259,46 +242,23 @@ impl PeerReputationManager { } // Log significant changes - if score_change.abs() >= 10 || should_ban { + if reason.score().abs() >= 10 || should_ban { log::info!( "Peer {} reputation changed: {} -> {} (change: {}, reason: {})", peer, old_score, reputation.score, - score_change, + reason.score(), reason ); } - // Record event - let event = ReputationEvent { - peer, - change: score_change, - reason: reason.to_string(), - timestamp: Instant::now(), - }; - - drop(reputations); // Release lock before recording event - self.record_event(event).await; - should_ban } - /// Record a reputation event - async fn record_event(&self, event: ReputationEvent) { - let mut events = self.recent_events.write().await; - events.push(event); - - // Keep only recent events - if events.len() > self.max_events { - let drain_count = events.len() - self.max_events; - events.drain(0..drain_count); - } - } - /// Check if a peer is banned - pub async fn is_banned(&self, peer: &SocketAddr) -> bool { - let mut reputations = self.reputations.write().await; + pub async fn is_banned(&mut self, peer: &SocketAddr) -> bool { + let reputations = &mut self.reputations; if let Some(reputation) = reputations.get_mut(peer) { reputation.apply_decay(); reputation.is_banned() @@ -307,70 +267,33 @@ impl PeerReputationManager { } } - /// Get peer reputation score - pub async fn get_score(&self, peer: &SocketAddr) -> i32 { - let mut reputations = self.reputations.write().await; - if let Some(reputation) = reputations.get_mut(peer) { - reputation.apply_decay(); - reputation.score - } else { - 0 - } - } - /// Temporarily ban a peer for a specified duration, regardless of score. /// This can be used for critical protocol violations (e.g., invalid ChainLocks). - pub async fn temporary_ban_peer(&self, peer: SocketAddr, duration: Duration, reason: &str) { - let mut reputations = self.reputations.write().await; + pub async fn temporary_ban_peer(&mut self, peer: SocketAddr, duration: Duration) { + let reputations = &mut self.reputations; let reputation = reputations.entry(peer).or_default(); reputation.banned_until = Some(Instant::now() + duration); reputation.ban_count += 1; log::warn!( - "Peer {} temporarily banned for {:?} (ban #{}, reason: {})", + "Peer {} temporarily banned for {:?} (ban #{})", peer, duration, reputation.ban_count, - reason ); } /// Record a connection attempt - pub async fn record_connection_attempt(&self, peer: SocketAddr) { - let mut reputations = self.reputations.write().await; + pub async fn record_connection_attempt(&mut self, peer: SocketAddr) { + let reputations = &mut self.reputations; let reputation = reputations.entry(peer).or_default(); - reputation.connection_attempts += 1; reputation.last_connection = Some(Instant::now()); } - /// Record a successful connection - pub async fn record_successful_connection(&self, peer: SocketAddr) { - let mut reputations = self.reputations.write().await; - let reputation = reputations.entry(peer).or_default(); - reputation.successful_connections += 1; - } - - /// Get all peer reputations - pub async fn get_all_reputations(&self) -> HashMap { - let mut reputations = self.reputations.write().await; - - // Apply decay to all peers - for reputation in reputations.values_mut() { - reputation.apply_decay(); - } - - reputations.clone() - } - - /// Get recent reputation events - pub async fn get_recent_events(&self) -> Vec { - self.recent_events.read().await.clone() - } - /// Clear banned status for a peer (admin function) - pub async fn unban_peer(&self, peer: &SocketAddr) { - let mut reputations = self.reputations.write().await; + pub async fn unban_peer(&mut self, peer: &SocketAddr) { + let reputations = &mut self.reputations; if let Some(reputation) = reputations.get_mut(peer) { reputation.banned_until = None; reputation.score = reputation.score.min(MAX_MISBEHAVIOR_SCORE - 10); @@ -378,178 +301,13 @@ impl PeerReputationManager { } } - /// Reset reputation for a peer - pub async fn reset_reputation(&self, peer: &SocketAddr) { - let mut reputations = self.reputations.write().await; - reputations.remove(peer); - log::info!("Reset reputation for peer {}", peer); - } - - /// Get peers sorted by reputation (best first) - pub async fn get_peers_by_reputation(&self) -> Vec<(SocketAddr, i32)> { - let mut reputations = self.reputations.write().await; - - // Apply decay and collect scores - let mut peer_scores: Vec<(SocketAddr, i32)> = reputations - .iter_mut() - .map(|(addr, rep)| { - rep.apply_decay(); - (*addr, rep.score) - }) - .filter(|(_, score)| *score < MAX_MISBEHAVIOR_SCORE) // Exclude banned peers - .collect(); - - // Sort by score (lower is better) - peer_scores.sort_by_key(|(_, score)| *score); - - peer_scores - } - - /// Save reputation data to persistent storage - pub async fn save_to_storage(&self, path: &std::path::Path) -> std::io::Result<()> { - let reputations = self.reputations.read().await; - - // Convert to serializable format - let data: Vec<(SocketAddr, SerializedPeerReputation)> = reputations - .iter() - .map(|(addr, rep)| { - let serialized = SerializedPeerReputation { - score: rep.score, - ban_count: rep.ban_count, - positive_actions: rep.positive_actions, - negative_actions: rep.negative_actions, - connection_attempts: rep.connection_attempts, - successful_connections: rep.successful_connections, - }; - (*addr, serialized) - }) - .collect(); - - let json = serde_json::to_string_pretty(&data)?; - atomic_write(path, json.as_bytes()).await.map_err(std::io::Error::other) - } - - /// Load reputation data from persistent storage - pub async fn load_from_storage(&self, path: &std::path::Path) -> std::io::Result<()> { - if !path.exists() { - return Ok(()); - } - - let json = tokio::fs::read_to_string(path).await?; - let data: Vec<(SocketAddr, SerializedPeerReputation)> = serde_json::from_str(&json)?; - - let mut reputations = self.reputations.write().await; - let mut loaded_count = 0; - let mut skipped_count = 0; - - for (addr, serialized) in data { - // Validate score is within expected range - let score = if serialized.score < MIN_SCORE { - log::warn!( - "Peer {} has invalid score {} (below minimum), clamping to {}", - addr, - serialized.score, - MIN_SCORE - ); - MIN_SCORE - } else if serialized.score > MAX_MISBEHAVIOR_SCORE { - log::warn!( - "Peer {} has invalid score {} (above maximum), clamping to {}", - addr, - serialized.score, - MAX_MISBEHAVIOR_SCORE - ); - MAX_MISBEHAVIOR_SCORE - } else { - serialized.score - }; - - // Validate ban count is reasonable (max 1000 bans) - const MAX_BAN_COUNT: u32 = 1000; - let ban_count = if serialized.ban_count > MAX_BAN_COUNT { - log::warn!( - "Peer {} has excessive ban count {}, clamping to {}", - addr, - serialized.ban_count, - MAX_BAN_COUNT - ); - MAX_BAN_COUNT - } else { - serialized.ban_count - }; - - // Validate action counts are reasonable (max 1 million actions) - const MAX_ACTION_COUNT: u64 = 1_000_000; - let positive_actions = serialized.positive_actions.min(MAX_ACTION_COUNT); - let negative_actions = serialized.negative_actions.min(MAX_ACTION_COUNT); - let connection_attempts = serialized.connection_attempts.min(MAX_ACTION_COUNT); - let successful_connections = serialized.successful_connections.min(MAX_ACTION_COUNT); - - // Validate successful connections don't exceed attempts - let successful_connections = successful_connections.min(connection_attempts); - - // Skip entry if data appears corrupted - if positive_actions == MAX_ACTION_COUNT || negative_actions == MAX_ACTION_COUNT { - log::warn!("Skipping peer {} with potentially corrupted action counts", addr); - skipped_count += 1; - continue; - } - - let rep = PeerReputation { - score, - ban_count, - banned_until: None, - last_update: Instant::now(), - positive_actions, - negative_actions, - connection_attempts, - successful_connections, - last_connection: None, - }; - - // Apply initial decay based on ban count - let mut rep = rep; - if rep.ban_count > 0 { - rep.score = rep.score.max(50); // Start with higher score for previously banned peers - } - - reputations.insert(addr, rep); - loaded_count += 1; - } - - log::info!( - "Loaded reputation data for {} peers (skipped {} corrupted entries)", - loaded_count, - skipped_count - ); - Ok(()) - } -} - -/// Helper trait for reputation-aware peer selection -pub trait ReputationAware { - /// Select best peers based on reputation - fn select_best_peers( - &self, - available_peers: Vec, - count: usize, - ) -> impl std::future::Future> + Send; - - /// Check if we should connect to a peer based on reputation - fn should_connect_to_peer( - &self, - peer: &SocketAddr, - ) -> impl std::future::Future + Send; -} - -impl ReputationAware for PeerReputationManager { - async fn select_best_peers( - &self, + pub async fn select_best_peers( + &mut self, available_peers: Vec, count: usize, ) -> Vec { let mut peer_scores = Vec::new(); - let mut reputations = self.reputations.write().await; + let reputations = &mut self.reputations; for peer in available_peers { let reputation = reputations.entry(peer).or_default(); @@ -567,12 +325,123 @@ impl ReputationAware for PeerReputationManager { peer_scores.into_iter().take(count).map(|(peer, _)| peer).collect() } - async fn should_connect_to_peer(&self, peer: &SocketAddr) -> bool { + pub async fn should_connect_to_peer(&mut self, peer: &SocketAddr) -> bool { !self.is_banned(peer).await } + + /// Save reputation data to persistent storage + pub async fn save_to_storage( + &mut self, + storage: &PersistentPeerStorage, + ) -> std::io::Result<()> { + storage.save_peers_reputation(&self.reputations).await.map_err(std::io::Error::other) + } } -// Include tests module #[cfg(test)] -#[path = "reputation_tests.rs"] -mod reputation_tests; +mod tests { + use crate::storage::PersistentStorage; + + use super::*; + use std::net::SocketAddr; + + async fn build_peer_reputation_manager() -> PeerReputationManager { + let temp_dir = tempfile::TempDir::new().unwrap(); + let peer_storage = PersistentPeerStorage::open(temp_dir.path()) + .await + .expect("Failed to open PersistentPeerStorage"); + PeerReputationManager::load_or_new(&peer_storage).await + } + + #[tokio::test] + async fn test_basic_reputation_operations() { + let mut manager = build_peer_reputation_manager().await; + let peer: SocketAddr = "127.0.0.1:8333".parse().unwrap(); + + // Initial score should be 0 + assert_eq!(manager.select_best_peers(vec![peer], 1).await[0], peer); + assert_eq!(manager.reputations.get(&peer).expect("Peer not found").score, 0); + + // Test misbehavior + manager.update_reputation(peer, ReputationChangeReason::InvalidMessage).await; + assert_eq!(manager.reputations.get(&peer).expect("Peer not found").score, 10); + } + + #[tokio::test] + async fn test_banning_mechanism() { + let mut manager = build_peer_reputation_manager().await; + let peer: SocketAddr = "192.168.1.1:8333".parse().unwrap(); + + // Accumulate misbehavior + for i in 0..10 { + let banned = + manager.update_reputation(peer, ReputationChangeReason::InvalidMessage).await; + + // Should be banned on the 10th violation (total score = 100) + if i == 9 { + assert!(banned); + } else { + assert!(!banned); + } + } + + assert!(manager.is_banned(&peer).await); + } + + #[tokio::test] + async fn test_reputation_persistence() { + let mut manager = build_peer_reputation_manager().await; + let peer1: SocketAddr = "10.0.0.1:8333".parse().unwrap(); + let peer2: SocketAddr = "10.0.0.2:8333".parse().unwrap(); + + // Set reputations + manager + .update_reputation(peer1, ReputationChangeReason::Other(-10, "Good peer".to_string())) + .await; + manager + .update_reputation(peer2, ReputationChangeReason::Other(50, "Bad peer".to_string())) + .await; + + // Save and load + let temp_dir = tempfile::TempDir::new().unwrap(); + let peer_storage = PersistentPeerStorage::open(temp_dir.path()) + .await + .expect("Failed to open PersistentPeerStorage"); + manager.save_to_storage(&peer_storage).await.unwrap(); + + let new_manager = PeerReputationManager::load_or_new(&peer_storage).await; + + // Verify scores were preserved + assert_eq!(new_manager.reputations.get(&peer1).expect("Peer not found").score, -10); + assert_eq!(new_manager.reputations.get(&peer2).expect("Peer not found").score, 50); + } + + #[tokio::test] + async fn test_peer_selection() { + let mut manager = build_peer_reputation_manager().await; + + let good_peer: SocketAddr = "1.1.1.1:8333".parse().unwrap(); + let neutral_peer: SocketAddr = "2.2.2.2:8333".parse().unwrap(); + let bad_peer: SocketAddr = "3.3.3.3:8333".parse().unwrap(); + + // Set different reputations + manager + .update_reputation( + good_peer, + ReputationChangeReason::Other(-20, "Very good".to_string()), + ) + .await; + manager + .update_reputation(bad_peer, ReputationChangeReason::Other(80, "Very bad".to_string())) + .await; + // neutral_peer has default score of 0 + + let all_peers = vec![good_peer, neutral_peer, bad_peer]; + let selected = manager.select_best_peers(all_peers, 2).await; + + // Should select good_peer first, then neutral_peer + assert_eq!(selected.len(), 2); + assert_eq!(selected[0], good_peer); + assert_eq!(selected[1], neutral_peer); + } +} diff --git a/dash-spv/src/network/reputation_tests.rs b/dash-spv/src/network/reputation_tests.rs deleted file mode 100644 index 82c8453af..000000000 --- a/dash-spv/src/network/reputation_tests.rs +++ /dev/null @@ -1,113 +0,0 @@ -//! Unit tests for reputation system (in-module tests) - -#[cfg(test)] -mod tests { - use super::super::*; - use std::net::SocketAddr; - - #[tokio::test] - async fn test_basic_reputation_operations() { - let manager = PeerReputationManager::new(); - let peer: SocketAddr = "127.0.0.1:8333".parse().unwrap(); - - // Initial score should be 0 - assert_eq!(manager.get_score(&peer).await, 0); - - // Test misbehavior - manager - .update_reputation(peer, misbehavior_scores::INVALID_MESSAGE, "Test invalid message") - .await; - assert_eq!(manager.get_score(&peer).await, 10); - - // Test positive behavior - manager.update_reputation(peer, positive_scores::VALID_HEADERS, "Test valid headers").await; - assert_eq!(manager.get_score(&peer).await, 5); - } - - #[tokio::test] - async fn test_banning_mechanism() { - let manager = PeerReputationManager::new(); - let peer: SocketAddr = "192.168.1.1:8333".parse().unwrap(); - - // Accumulate misbehavior - for i in 0..10 { - let banned = manager - .update_reputation( - peer, - misbehavior_scores::INVALID_MESSAGE, - &format!("Violation {}", i), - ) - .await; - - // Should be banned on the 10th violation (total score = 100) - if i == 9 { - assert!(banned); - } else { - assert!(!banned); - } - } - - assert!(manager.is_banned(&peer).await); - } - - #[tokio::test] - async fn test_reputation_persistence() { - let manager = PeerReputationManager::new(); - let peer1: SocketAddr = "10.0.0.1:8333".parse().unwrap(); - let peer2: SocketAddr = "10.0.0.2:8333".parse().unwrap(); - - // Set reputations - manager.update_reputation(peer1, -10, "Good peer").await; - manager.update_reputation(peer2, 50, "Bad peer").await; - - // Save and load - let temp_file = tempfile::NamedTempFile::new().unwrap(); - manager.save_to_storage(temp_file.path()).await.unwrap(); - - let new_manager = PeerReputationManager::new(); - new_manager.load_from_storage(temp_file.path()).await.unwrap(); - - // Verify scores were preserved - assert_eq!(new_manager.get_score(&peer1).await, -10); - assert_eq!(new_manager.get_score(&peer2).await, 50); - } - - #[tokio::test] - async fn test_peer_selection() { - let manager = PeerReputationManager::new(); - - let good_peer: SocketAddr = "1.1.1.1:8333".parse().unwrap(); - let neutral_peer: SocketAddr = "2.2.2.2:8333".parse().unwrap(); - let bad_peer: SocketAddr = "3.3.3.3:8333".parse().unwrap(); - - // Set different reputations - manager.update_reputation(good_peer, -20, "Very good").await; - manager.update_reputation(bad_peer, 80, "Very bad").await; - // neutral_peer has default score of 0 - - let all_peers = vec![good_peer, neutral_peer, bad_peer]; - let selected = manager.select_best_peers(all_peers, 2).await; - - // Should select good_peer first, then neutral_peer - assert_eq!(selected.len(), 2); - assert_eq!(selected[0], good_peer); - assert_eq!(selected[1], neutral_peer); - } - - #[tokio::test] - async fn test_connection_tracking() { - let manager = PeerReputationManager::new(); - let peer: SocketAddr = "127.0.0.1:9999".parse().unwrap(); - - // Track connection attempts - manager.record_connection_attempt(peer).await; - manager.record_connection_attempt(peer).await; - manager.record_successful_connection(peer).await; - - let reputations = manager.get_all_reputations().await; - let rep = &reputations[&peer]; - - assert_eq!(rep.connection_attempts, 2); - assert_eq!(rep.successful_connections, 1); - } -} diff --git a/dash-spv/src/storage/blocks.rs b/dash-spv/src/storage/blocks.rs new file mode 100644 index 000000000..430cb17e4 --- /dev/null +++ b/dash-spv/src/storage/blocks.rs @@ -0,0 +1,180 @@ +//! Header storage operations for DiskStorageManager. + +use std::collections::HashMap; +use std::ops::Range; +use std::path::PathBuf; + +use async_trait::async_trait; +use dashcore::block::Header as BlockHeader; +use dashcore::BlockHash; +use tokio::sync::RwLock; + +use crate::error::StorageResult; +use crate::storage::io::atomic_write; +use crate::storage::segments::SegmentCache; +use crate::storage::PersistentStorage; +use crate::StorageError; + +#[async_trait] +pub trait BlockHeaderStorage { + async fn store_headers(&mut self, headers: &[BlockHeader]) -> StorageResult<()>; + + async fn store_headers_at_height( + &mut self, + headers: &[BlockHeader], + height: u32, + ) -> StorageResult<()>; + + async fn load_headers(&self, range: Range) -> StorageResult>; + + async fn get_header(&self, height: u32) -> StorageResult> { + if let Some(tip_height) = self.get_tip_height().await { + if height > tip_height { + return Ok(None); + } + } else { + return Ok(None); + } + + if let Some(start_height) = self.get_start_height().await { + if height < start_height { + return Ok(None); + } + } else { + return Ok(None); + } + + Ok(self.load_headers(height..height + 1).await?.first().copied()) + } + + async fn get_tip_height(&self) -> Option; + + async fn get_start_height(&self) -> Option; + + async fn get_stored_headers_len(&self) -> u32; + + async fn get_header_height_by_hash( + &self, + hash: &dashcore::BlockHash, + ) -> StorageResult>; +} + +pub struct PersistentBlockHeaderStorage { + block_headers: RwLock>, + header_hash_index: HashMap, +} + +impl PersistentBlockHeaderStorage { + const FOLDER_NAME: &str = "block_headers"; + const INDEX_FILE_NAME: &str = "index.dat"; +} + +#[async_trait] +impl PersistentStorage for PersistentBlockHeaderStorage { + async fn open(storage_path: impl Into + Send) -> StorageResult { + let storage_path = storage_path.into(); + let segments_folder = storage_path.join(Self::FOLDER_NAME); + + let index_path = segments_folder.join(Self::INDEX_FILE_NAME); + + let mut block_headers = SegmentCache::load_or_new(&segments_folder).await?; + + let header_hash_index = match tokio::fs::read(&index_path) + .await + .ok() + .and_then(|content| bincode::deserialize(&content).ok()) + { + Some(index) => index, + _ => { + if segments_folder.exists() { + block_headers.build_block_index_from_segments().await? + } else { + HashMap::new() + } + } + }; + + Ok(Self { + block_headers: RwLock::new(block_headers), + header_hash_index, + }) + } + + async fn persist(&mut self, storage_path: impl Into + Send) -> StorageResult<()> { + let block_headers_folder = storage_path.into().join(Self::FOLDER_NAME); + let index_path = block_headers_folder.join(Self::INDEX_FILE_NAME); + + tokio::fs::create_dir_all(&block_headers_folder).await?; + + self.block_headers.write().await.persist(&block_headers_folder).await; + + let data = bincode::serialize(&self.header_hash_index) + .map_err(|e| StorageError::WriteFailed(format!("Failed to serialize index: {}", e)))?; + + atomic_write(&index_path, &data).await + } +} + +#[async_trait] +impl BlockHeaderStorage for PersistentBlockHeaderStorage { + async fn store_headers(&mut self, headers: &[BlockHeader]) -> StorageResult<()> { + let height = self.block_headers.read().await.next_height(); + self.store_headers_at_height(headers, height).await + } + + async fn store_headers_at_height( + &mut self, + headers: &[BlockHeader], + height: u32, + ) -> StorageResult<()> { + let mut height = height; + + let hashes = headers.iter().map(|header| header.block_hash()).collect::>(); + + self.block_headers.write().await.store_items_at_height(headers, height).await?; + + for hash in hashes { + self.header_hash_index.insert(hash, height); + height += 1; + } + + Ok(()) + } + + async fn load_headers(&self, range: Range) -> StorageResult> { + self.block_headers.write().await.get_items(range).await + } + + async fn get_tip_height(&self) -> Option { + self.block_headers.read().await.tip_height() + } + + async fn get_start_height(&self) -> Option { + self.block_headers.read().await.start_height() + } + + async fn get_stored_headers_len(&self) -> u32 { + let block_headers = self.block_headers.read().await; + + let start_height = if let Some(start_height) = block_headers.start_height() { + start_height + } else { + return 0; + }; + + let end_height = if let Some(end_height) = block_headers.tip_height() { + end_height + } else { + return 0; + }; + + end_height - start_height + 1 + } + + async fn get_header_height_by_hash( + &self, + hash: &dashcore::BlockHash, + ) -> StorageResult> { + Ok(self.header_hash_index.get(hash).copied()) + } +} diff --git a/dash-spv/src/storage/chainstate.rs b/dash-spv/src/storage/chainstate.rs new file mode 100644 index 000000000..c6c3b69af --- /dev/null +++ b/dash-spv/src/storage/chainstate.rs @@ -0,0 +1,101 @@ +use std::path::PathBuf; + +use async_trait::async_trait; + +use crate::{ + error::StorageResult, + storage::{io::atomic_write, PersistentStorage}, + ChainState, +}; + +#[async_trait] +pub trait ChainStateStorage { + async fn store_chain_state(&mut self, state: &ChainState) -> StorageResult<()>; + + async fn load_chain_state(&self) -> StorageResult>; +} + +pub struct PersistentChainStateStorage { + storage_path: PathBuf, +} + +impl PersistentChainStateStorage { + const FOLDER_NAME: &str = "chainstate"; + const FILE_NAME: &str = "chainstate.json"; +} + +#[async_trait] +impl PersistentStorage for PersistentChainStateStorage { + async fn open(storage_path: impl Into + Send) -> StorageResult { + Ok(PersistentChainStateStorage { + storage_path: storage_path.into(), + }) + } + + async fn persist(&mut self, _storage_path: impl Into + Send) -> StorageResult<()> { + // Current implementation persists data everytime data is stored + Ok(()) + } +} + +#[async_trait] +impl ChainStateStorage for PersistentChainStateStorage { + async fn store_chain_state(&mut self, state: &ChainState) -> StorageResult<()> { + let state_data = serde_json::json!({ + "last_chainlock_height": state.last_chainlock_height, + "last_chainlock_hash": state.last_chainlock_hash, + "current_filter_tip": state.current_filter_tip, + "last_masternode_diff_height": state.last_masternode_diff_height, + "sync_base_height": state.sync_base_height, + }); + + let chainstate_folder = self.storage_path.join(Self::FOLDER_NAME); + let path = chainstate_folder.join(Self::FILE_NAME); + + tokio::fs::create_dir_all(chainstate_folder).await?; + + let json = state_data.to_string(); + atomic_write(&path, json.as_bytes()).await?; + + Ok(()) + } + + async fn load_chain_state(&self) -> StorageResult> { + let path = self.storage_path.join(Self::FOLDER_NAME).join(Self::FILE_NAME); + if !path.exists() { + return Ok(None); + } + + let content = tokio::fs::read_to_string(path).await?; + let value: serde_json::Value = serde_json::from_str(&content).map_err(|e| { + crate::error::StorageError::Serialization(format!("Failed to parse chain state: {}", e)) + })?; + + let state = ChainState { + last_chainlock_height: value + .get("last_chainlock_height") + .and_then(|v| v.as_u64()) + .map(|h| h as u32), + last_chainlock_hash: value + .get("last_chainlock_hash") + .and_then(|v| v.as_str()) + .and_then(|s| s.parse().ok()), + current_filter_tip: value + .get("current_filter_tip") + .and_then(|v| v.as_str()) + .and_then(|s| s.parse().ok()), + masternode_engine: None, + last_masternode_diff_height: value + .get("last_masternode_diff_height") + .and_then(|v| v.as_u64()) + .map(|h| h as u32), + sync_base_height: value + .get("sync_base_height") + .and_then(|v| v.as_u64()) + .map(|h| h as u32) + .unwrap_or(0), + }; + + Ok(Some(state)) + } +} diff --git a/dash-spv/src/storage/filters.rs b/dash-spv/src/storage/filters.rs new file mode 100644 index 000000000..0e4916805 --- /dev/null +++ b/dash-spv/src/storage/filters.rs @@ -0,0 +1,141 @@ +use std::{ops::Range, path::PathBuf}; + +use async_trait::async_trait; +use dashcore::hash_types::FilterHeader; +use tokio::sync::RwLock; + +use crate::{ + error::StorageResult, + storage::{segments::SegmentCache, PersistentStorage}, +}; + +#[async_trait] +pub trait FilterHeaderStorage { + async fn store_filter_headers(&mut self, headers: &[FilterHeader]) -> StorageResult<()>; + + async fn load_filter_headers(&self, range: Range) -> StorageResult>; + + async fn get_filter_header(&self, height: u32) -> StorageResult> { + if let Some(tip_height) = self.get_filter_tip_height().await? { + if height > tip_height { + return Ok(None); + } + } else { + return Ok(None); + } + + if let Some(start_height) = self.get_filter_start_height().await { + if height < start_height { + return Ok(None); + } + } else { + return Ok(None); + } + + Ok(self.load_filter_headers(height..height + 1).await?.first().copied()) + } + + async fn get_filter_tip_height(&self) -> StorageResult>; + + async fn get_filter_start_height(&self) -> Option; +} + +#[async_trait] +pub trait FilterStorage { + async fn store_filter(&mut self, height: u32, filter: &[u8]) -> StorageResult<()>; + + async fn load_filters(&self, range: Range) -> StorageResult>>; +} + +pub struct PersistentFilterHeaderStorage { + filter_headers: RwLock>, +} + +impl PersistentFilterHeaderStorage { + const FOLDER_NAME: &str = "filter_headers"; +} + +#[async_trait] +impl PersistentStorage for PersistentFilterHeaderStorage { + async fn open(storage_path: impl Into + Send) -> StorageResult { + let storage_path = storage_path.into(); + let segments_folder = storage_path.join(Self::FOLDER_NAME); + + let filter_headers = SegmentCache::load_or_new(segments_folder).await?; + + Ok(Self { + filter_headers: RwLock::new(filter_headers), + }) + } + + async fn persist(&mut self, base_path: impl Into + Send) -> StorageResult<()> { + let filter_headers_folder = base_path.into().join(Self::FOLDER_NAME); + + tokio::fs::create_dir_all(&filter_headers_folder).await?; + + self.filter_headers.write().await.persist(&filter_headers_folder).await; + Ok(()) + } +} + +#[async_trait] +impl FilterHeaderStorage for PersistentFilterHeaderStorage { + async fn store_filter_headers(&mut self, headers: &[FilterHeader]) -> StorageResult<()> { + self.filter_headers.write().await.store_items(headers).await + } + + async fn load_filter_headers(&self, range: Range) -> StorageResult> { + self.filter_headers.write().await.get_items(range).await + } + + async fn get_filter_tip_height(&self) -> StorageResult> { + Ok(self.filter_headers.read().await.tip_height()) + } + + async fn get_filter_start_height(&self) -> Option { + self.filter_headers.read().await.start_height() + } +} + +pub struct PersistentFilterStorage { + filters: RwLock>>, +} + +impl PersistentFilterStorage { + const FOLDER_NAME: &str = "filters"; +} + +#[async_trait] +impl PersistentStorage for PersistentFilterStorage { + async fn open(storage_path: impl Into + Send) -> StorageResult { + let storage_path = storage_path.into(); + let filters_folder = storage_path.join(Self::FOLDER_NAME); + + let filters = SegmentCache::load_or_new(filters_folder).await?; + + Ok(Self { + filters: RwLock::new(filters), + }) + } + + async fn persist(&mut self, storage_path: impl Into + Send) -> StorageResult<()> { + let storage_path = storage_path.into(); + let filters_folder = storage_path.join(Self::FOLDER_NAME); + + tokio::fs::create_dir_all(&filters_folder).await?; + + self.filters.write().await.persist(&filters_folder).await; + Ok(()) + } +} + +#[async_trait] +impl FilterStorage for PersistentFilterStorage { + async fn store_filter(&mut self, height: u32, filter: &[u8]) -> StorageResult<()> { + self.filters.write().await.store_items_at_height(&[filter.to_vec()], height).await + } + + async fn load_filters(&self, range: Range) -> StorageResult>> { + self.filters.write().await.get_items(range).await + } +} diff --git a/dash-spv/src/storage/headers.rs b/dash-spv/src/storage/headers.rs deleted file mode 100644 index 0b0e87b09..000000000 --- a/dash-spv/src/storage/headers.rs +++ /dev/null @@ -1,77 +0,0 @@ -//! Header storage operations for DiskStorageManager. - -use std::collections::HashMap; -use std::path::Path; - -use dashcore::block::Header as BlockHeader; -use dashcore::BlockHash; - -use crate::error::StorageResult; -use crate::storage::io::atomic_write; -use crate::StorageError; - -use super::manager::DiskStorageManager; - -impl DiskStorageManager { - pub async fn store_headers_at_height( - &mut self, - headers: &[BlockHeader], - mut height: u32, - ) -> StorageResult<()> { - let hashes = headers.iter().map(|header| header.block_hash()).collect::>(); - - self.block_headers.write().await.store_items(headers, height, self).await?; - - // Update reverse index - let mut reverse_index = self.header_hash_index.write().await; - - for hash in hashes { - reverse_index.insert(hash, height); - height += 1; - } - - // Release locks before saving (to avoid deadlocks during background saves) - drop(reverse_index); - - Ok(()) - } - - pub async fn store_headers(&mut self, headers: &[BlockHeader]) -> StorageResult<()> { - let height = self.block_headers.read().await.next_height(); - self.store_headers_at_height(headers, height).await - } - - /// Get header height by hash. - pub async fn get_header_height_by_hash(&self, hash: &BlockHash) -> StorageResult> { - Ok(self.header_hash_index.read().await.get(hash).copied()) - } -} - -/// Load index from file, if it fails it tries to build it from block -/// header segments and, if that also fails, it return an empty index. -/// -/// IO and deserialize errors are returned, the empty index is only built -/// if there is no persisted data to recreate it. -pub(super) async fn load_block_index( - manager: &DiskStorageManager, -) -> StorageResult> { - let index_path = manager.base_path.join("headers/index.dat"); - - if let Ok(content) = tokio::fs::read(&index_path).await { - bincode::deserialize(&content) - .map_err(|e| StorageError::ReadFailed(format!("Failed to deserialize index: {}", e))) - } else { - manager.block_headers.write().await.build_block_index_from_segments().await - } -} - -/// Save index to disk. -pub(super) async fn save_index_to_disk( - path: &Path, - index: &HashMap, -) -> StorageResult<()> { - let data = bincode::serialize(index) - .map_err(|e| StorageError::WriteFailed(format!("Failed to serialize index: {}", e)))?; - - atomic_write(path, &data).await -} diff --git a/dash-spv/src/storage/manager.rs b/dash-spv/src/storage/manager.rs deleted file mode 100644 index 9f13cda28..000000000 --- a/dash-spv/src/storage/manager.rs +++ /dev/null @@ -1,253 +0,0 @@ -//! Core DiskStorageManager struct and background worker implementation. - -use std::collections::HashMap; -use std::path::PathBuf; -use std::sync::Arc; -use tokio::sync::{mpsc, RwLock}; - -use dashcore::{block::Header as BlockHeader, hash_types::FilterHeader, BlockHash, Txid}; - -use crate::error::{StorageError, StorageResult}; -use crate::storage::headers::load_block_index; -use crate::storage::segments::SegmentCache; -use crate::types::{MempoolState, UnconfirmedTransaction}; - -use super::lockfile::LockFile; - -/// Commands for the background worker -#[derive(Debug, Clone)] -pub(super) enum WorkerCommand { - SaveBlockHeaderSegmentCache { - segment_id: u32, - }, - SaveFilterHeaderSegmentCache { - segment_id: u32, - }, - SaveFilterSegmentCache { - segment_id: u32, - }, - SaveIndex { - index: HashMap, - }, - Shutdown, -} - -/// Disk-based storage manager with segmented files and async background saving. -pub struct DiskStorageManager { - pub(super) base_path: PathBuf, - - // Segmented header storage - pub(super) block_headers: Arc>>, - pub(super) filter_headers: Arc>>, - pub(super) filters: Arc>>>, - - // Reverse index for O(1) lookups - pub(super) header_hash_index: Arc>>, - - // Background worker - pub(super) worker_tx: Option>, - pub(super) worker_handle: Option>, - - // Index save tracking to avoid redundant saves - pub(super) last_index_save_count: Arc>, - - // Mempool storage - pub(super) mempool_transactions: Arc>>, - pub(super) mempool_state: Arc>>, - - // Lock file to prevent concurrent access from multiple processes. - _lock_file: LockFile, -} - -impl DiskStorageManager { - pub async fn new(base_path: PathBuf) -> StorageResult { - use std::fs; - - // Create directories if they don't exist - fs::create_dir_all(&base_path) - .map_err(|e| StorageError::WriteFailed(format!("Failed to create directory: {}", e)))?; - - // Acquire exclusive lock on the data directory - let lock_file = LockFile::new(base_path.join(".lock"))?; - - let headers_dir = base_path.join("headers"); - let filters_dir = base_path.join("filters"); - let state_dir = base_path.join("state"); - - fs::create_dir_all(&headers_dir).map_err(|e| { - StorageError::WriteFailed(format!("Failed to create headers directory: {}", e)) - })?; - fs::create_dir_all(&filters_dir).map_err(|e| { - StorageError::WriteFailed(format!("Failed to create filters directory: {}", e)) - })?; - fs::create_dir_all(&state_dir).map_err(|e| { - StorageError::WriteFailed(format!("Failed to create state directory: {}", e)) - })?; - - let mut storage = Self { - base_path: base_path.clone(), - block_headers: Arc::new(RwLock::new( - SegmentCache::load_or_new(base_path.clone()).await?, - )), - filter_headers: Arc::new(RwLock::new( - SegmentCache::load_or_new(base_path.clone()).await?, - )), - filters: Arc::new(RwLock::new(SegmentCache::load_or_new(base_path.clone()).await?)), - header_hash_index: Arc::new(RwLock::new(HashMap::new())), - worker_tx: None, - worker_handle: None, - last_index_save_count: Arc::new(RwLock::new(0)), - mempool_transactions: Arc::new(RwLock::new(HashMap::new())), - mempool_state: Arc::new(RwLock::new(None)), - _lock_file: lock_file, - }; - - // Load chain state to get sync_base_height - if let Ok(Some(state)) = storage.load_chain_state().await { - tracing::debug!("Loaded sync_base_height: {}", state.sync_base_height); - } - - // Start background worker - storage.start_worker().await; - - // Rebuild index - let block_index = match load_block_index(&storage).await { - Ok(index) => index, - Err(e) => { - tracing::error!( - "An unexpected IO or deserialization error didn't allow the block index to be built: {}", - e - ); - HashMap::new() - } - }; - storage.header_hash_index = Arc::new(RwLock::new(block_index)); - - Ok(storage) - } - - #[cfg(test)] - pub async fn with_temp_dir() -> StorageResult { - use tempfile::TempDir; - - let temp_dir = TempDir::new()?; - Self::new(temp_dir.path().into()).await - } - - /// Start the background worker - pub(super) async fn start_worker(&mut self) { - let (worker_tx, mut worker_rx) = mpsc::channel::(100); - - let worker_base_path = self.base_path.clone(); - let base_path = self.base_path.clone(); - - let block_headers = Arc::clone(&self.block_headers); - let filter_headers = Arc::clone(&self.filter_headers); - let cfilters = Arc::clone(&self.filters); - - let worker_handle = tokio::spawn(async move { - while let Some(cmd) = worker_rx.recv().await { - match cmd { - WorkerCommand::SaveBlockHeaderSegmentCache { - segment_id, - } => { - let mut cache = block_headers.write().await; - let segment = match cache.get_segment_mut(&segment_id).await { - Ok(segment) => segment, - Err(e) => { - eprintln!("Failed to get segment {}: {}", segment_id, e); - continue; - } - }; - - match segment.persist(&base_path).await { - Ok(()) => { - tracing::trace!( - "Background worker completed saving header segment {}", - segment_id - ); - } - Err(e) => { - eprintln!("Failed to save segment {}: {}", segment_id, e); - } - } - } - WorkerCommand::SaveFilterHeaderSegmentCache { - segment_id, - } => { - let mut cache = filter_headers.write().await; - let segment = match cache.get_segment_mut(&segment_id).await { - Ok(segment) => segment, - Err(e) => { - eprintln!("Failed to get segment {}: {}", segment_id, e); - continue; - } - }; - - match segment.persist(&base_path).await { - Ok(()) => { - tracing::trace!( - "Background worker completed saving header segment {}", - segment_id - ); - } - Err(e) => { - eprintln!("Failed to save segment {}: {}", segment_id, e); - } - } - } - WorkerCommand::SaveFilterSegmentCache { - segment_id, - } => { - let mut cache = cfilters.write().await; - let segment = match cache.get_segment_mut(&segment_id).await { - Ok(segment) => segment, - Err(e) => { - eprintln!("Failed to get segment {}: {}", segment_id, e); - continue; - } - }; - - match segment.persist(&base_path).await { - Ok(()) => { - tracing::trace!( - "Background worker completed saving filter segment {}", - segment_id - ); - } - Err(e) => { - eprintln!("Failed to save segment {}: {}", segment_id, e); - } - } - } - WorkerCommand::SaveIndex { - index, - } => { - let path = worker_base_path.join("headers/index.dat"); - if let Err(e) = super::headers::save_index_to_disk(&path, &index).await { - eprintln!("Failed to save index: {}", e); - } else { - tracing::trace!("Background worker completed saving index"); - } - } - WorkerCommand::Shutdown => { - break; - } - } - } - }); - - self.worker_tx = Some(worker_tx); - self.worker_handle = Some(worker_handle); - } - - /// Stop the background worker without forcing a save. - pub(super) async fn stop_worker(&mut self) { - if let Some(tx) = self.worker_tx.take() { - let _ = tx.send(WorkerCommand::Shutdown).await; - } - if let Some(handle) = self.worker_handle.take() { - let _ = handle.await; - } - } -} diff --git a/dash-spv/src/storage/masternode.rs b/dash-spv/src/storage/masternode.rs new file mode 100644 index 000000000..d7ec1dd9f --- /dev/null +++ b/dash-spv/src/storage/masternode.rs @@ -0,0 +1,76 @@ +use std::path::PathBuf; + +use async_trait::async_trait; + +use crate::{ + error::StorageResult, + storage::{io::atomic_write, MasternodeState, PersistentStorage}, +}; + +#[async_trait] +pub trait MasternodeStateStorage { + async fn store_masternode_state(&mut self, state: &MasternodeState) -> StorageResult<()>; + + async fn load_masternode_state(&self) -> StorageResult>; +} + +pub struct PersistentMasternodeStateStorage { + storage_path: PathBuf, +} + +impl PersistentMasternodeStateStorage { + const FOLDER_NAME: &str = "masternodestate"; + const MASTERNODE_FILE_NAME: &str = "masternodestate.json"; +} + +#[async_trait] +impl PersistentStorage for PersistentMasternodeStateStorage { + async fn open(storage_path: impl Into + Send) -> StorageResult { + Ok(PersistentMasternodeStateStorage { + storage_path: storage_path.into(), + }) + } + + async fn persist(&mut self, _storage_path: impl Into + Send) -> StorageResult<()> { + // Current implementation persists data everytime data is stored + Ok(()) + } +} + +#[async_trait] +impl MasternodeStateStorage for PersistentMasternodeStateStorage { + async fn store_masternode_state(&mut self, state: &MasternodeState) -> StorageResult<()> { + let masternodestate_folder = self.storage_path.join(Self::FOLDER_NAME); + let path = masternodestate_folder.join(Self::MASTERNODE_FILE_NAME); + + tokio::fs::create_dir_all(masternodestate_folder).await?; + + let json = serde_json::to_string_pretty(state).map_err(|e| { + crate::error::StorageError::Serialization(format!( + "Failed to serialize masternode state: {}", + e + )) + })?; + + atomic_write(&path, json.as_bytes()).await?; + Ok(()) + } + + async fn load_masternode_state(&self) -> StorageResult> { + let path = self.storage_path.join(Self::FOLDER_NAME).join(Self::MASTERNODE_FILE_NAME); + + if !path.exists() { + return Ok(None); + } + + let content = tokio::fs::read_to_string(path).await?; + let state = serde_json::from_str(&content).map_err(|e| { + crate::error::StorageError::Serialization(format!( + "Failed to deserialize masternode state: {}", + e + )) + })?; + + Ok(Some(state)) + } +} diff --git a/dash-spv/src/storage/metadata.rs b/dash-spv/src/storage/metadata.rs new file mode 100644 index 000000000..7707e41ab --- /dev/null +++ b/dash-spv/src/storage/metadata.rs @@ -0,0 +1,62 @@ +use std::path::PathBuf; + +use async_trait::async_trait; + +use crate::{ + error::StorageResult, + storage::{io::atomic_write, PersistentStorage}, +}; + +#[async_trait] +pub trait MetadataStorage { + async fn store_metadata(&mut self, key: &str, value: &[u8]) -> StorageResult<()>; + + async fn load_metadata(&self, key: &str) -> StorageResult>>; +} + +pub struct PersistentMetadataStorage { + storage_path: PathBuf, +} + +impl PersistentMetadataStorage { + const FOLDER_NAME: &str = "metadata"; +} + +#[async_trait] +impl PersistentStorage for PersistentMetadataStorage { + async fn open(storage_path: impl Into + Send) -> StorageResult { + Ok(PersistentMetadataStorage { + storage_path: storage_path.into(), + }) + } + + async fn persist(&mut self, _storage_path: impl Into + Send) -> StorageResult<()> { + // Current implementation persists data everytime data is stored + Ok(()) + } +} + +#[async_trait] +impl MetadataStorage for PersistentMetadataStorage { + async fn store_metadata(&mut self, key: &str, value: &[u8]) -> StorageResult<()> { + let metadata_folder = self.storage_path.join(Self::FOLDER_NAME); + let path = metadata_folder.join(format!("{key}.dat")); + + tokio::fs::create_dir_all(metadata_folder).await?; + + atomic_write(&path, value).await?; + + Ok(()) + } + + async fn load_metadata(&self, key: &str) -> StorageResult>> { + let path = self.storage_path.join(Self::FOLDER_NAME).join(format!("{key}.dat")); + + if !path.exists() { + return Ok(None); + } + + let data = tokio::fs::read(path).await?; + Ok(Some(data)) + } +} diff --git a/dash-spv/src/storage/mod.rs b/dash-spv/src/storage/mod.rs index aa8e0387a..8a052bbe3 100644 --- a/dash-spv/src/storage/mod.rs +++ b/dash-spv/src/storage/mod.rs @@ -1,186 +1,565 @@ //! Storage abstraction for the Dash SPV client. -pub(crate) mod io; - pub mod types; -mod headers; +mod blocks; +mod chainstate; +mod filters; +mod io; mod lockfile; -mod manager; +mod masternode; +mod metadata; +mod peers; mod segments; -mod state; +mod transactions; use async_trait::async_trait; +use dashcore::hash_types::FilterHeader; +use dashcore::{Header as BlockHeader, Txid}; use std::collections::HashMap; use std::ops::Range; - -use dashcore::{block::Header as BlockHeader, hash_types::FilterHeader, Txid}; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::RwLock; use crate::error::StorageResult; -use crate::types::{ChainState, MempoolState, UnconfirmedTransaction}; +use crate::storage::blocks::PersistentBlockHeaderStorage; +use crate::storage::chainstate::PersistentChainStateStorage; +use crate::storage::filters::{PersistentFilterHeaderStorage, PersistentFilterStorage}; +use crate::storage::lockfile::LockFile; +use crate::storage::masternode::PersistentMasternodeStateStorage; +use crate::storage::metadata::PersistentMetadataStorage; +use crate::storage::transactions::PersistentTransactionStorage; +use crate::types::{MempoolState, UnconfirmedTransaction}; +use crate::ChainState; + +pub use crate::storage::blocks::BlockHeaderStorage; +pub use crate::storage::chainstate::ChainStateStorage; +pub use crate::storage::filters::FilterHeaderStorage; +pub use crate::storage::filters::FilterStorage; +pub use crate::storage::masternode::MasternodeStateStorage; +pub use crate::storage::metadata::MetadataStorage; +pub use crate::storage::peers::{PeerStorage, PersistentPeerStorage}; +pub use crate::storage::transactions::TransactionStorage; -pub use manager::DiskStorageManager; pub use types::*; -/// Storage manager trait for abstracting data persistence. -/// -/// # Thread Safety -/// -/// This trait requires `Send + Sync` bounds to ensure thread safety, but uses `&mut self` -/// for mutation methods. This design choice provides several benefits: -/// -/// 1. **Simplified Implementation**: Storage backends don't need to implement interior -/// mutability patterns (like `Arc>` or `RwLock`) internally. -/// -/// 2. **Performance**: Avoids unnecessary locking overhead when the storage manager -/// is already protected by external synchronization. -/// -/// 3. **Flexibility**: Callers can choose the appropriate synchronization strategy -/// based on their specific use case (e.g., single-threaded, mutex-protected, etc.). -/// -/// ## Usage Pattern -/// -/// The typical usage pattern wraps the storage manager in an `Arc>` or similar: -/// -/// ```rust,no_run -/// # use std::sync::Arc; -/// # use tokio::sync::Mutex; -/// # use dash_spv::storage::DiskStorageManager; -/// # use dashcore::blockdata::block::Header as BlockHeader; -/// # -/// # async fn example() -> Result<(), Box> { -/// let storage: Arc> = Arc::new(Mutex::new(DiskStorageManager::new("./.tmp/example-storage".into()).await?)); -/// let headers: Vec = vec![]; // Your headers here -/// -/// // In async context: -/// let mut guard = storage.lock().await; -/// guard.store_headers(&headers).await?; -/// # Ok(()) -/// # } -/// ``` -/// -/// ## Implementation Requirements -/// -/// Implementations must ensure that: -/// - All operations are atomic at the logical level (e.g., all headers in a batch succeed or fail together) -/// - Read operations are consistent (no partial reads of in-progress writes) -/// - The implementation is safe to move between threads (`Send`) -/// - The implementation can be referenced from multiple threads (`Sync`) -/// -/// Note that the `&mut self` requirement means only one thread can be mutating the storage -/// at a time when using external synchronization, which naturally provides consistency. #[async_trait] -pub trait StorageManager: Send + Sync { - /// Store block headers. - async fn store_headers(&mut self, headers: &[BlockHeader]) -> StorageResult<()>; - - /// Load block headers in the given range. - async fn load_headers(&self, range: Range) -> StorageResult>; - - /// Get a specific header by blockchain height. - async fn get_header(&self, height: u32) -> StorageResult>; - - /// Get the current tip blockchain height. - async fn get_tip_height(&self) -> StorageResult>; +pub trait PersistentStorage: Sized { + /// If the storage_path contains persisted data the storage will use it, if not, + /// a empty storage will be created. + async fn open(storage_path: impl Into + Send) -> StorageResult; - /// Store filter headers. - async fn store_filter_headers(&mut self, headers: &[FilterHeader]) -> StorageResult<()>; - - /// Load filter headers in the given blockchain height range. - async fn load_filter_headers(&self, range: Range) -> StorageResult>; - - /// Get a specific filter header by blockchain height. - async fn get_filter_header(&self, height: u32) -> StorageResult>; + async fn persist(&mut self, storage_path: impl Into + Send) -> StorageResult<()>; +} - /// Get the current filter tip blockchain height. - async fn get_filter_tip_height(&self) -> StorageResult>; +#[async_trait] +pub trait StorageManager: + BlockHeaderStorage + + FilterHeaderStorage + + FilterStorage + + TransactionStorage + + MetadataStorage + + ChainStateStorage + + MasternodeStateStorage + + Send + + Sync +{ + /// Deletes in-disk and in-memory data + async fn clear(&mut self) -> StorageResult<()>; - /// Store masternode state. - async fn store_masternode_state(&mut self, state: &MasternodeState) -> StorageResult<()>; + /// Stops all background tasks and persists the data. + async fn shutdown(&mut self); +} - /// Load masternode state. - async fn load_masternode_state(&self) -> StorageResult>; +/// Disk-based storage manager with segmented files and async background saving. +/// Only one instance of DiskStorageManager working on the same storage path +/// can exist at a time. +pub struct DiskStorageManager { + storage_path: PathBuf, + + block_headers: Arc>, + filter_headers: Arc>, + filters: Arc>, + transactions: Arc>, + metadata: Arc>, + chainstate: Arc>, + masternodestate: Arc>, + peers: Arc>, + + // Background worker + worker_handle: Option>, + + _lock_file: LockFile, +} - /// Store chain state. - async fn store_chain_state(&mut self, state: &ChainState) -> StorageResult<()>; +impl DiskStorageManager { + pub async fn new(storage_path: impl Into + Send) -> StorageResult { + use std::fs; + + let storage_path = storage_path.into(); + let lock_file = { + let mut lock_file = storage_path.clone(); + lock_file.set_extension("lock"); + lock_file + }; + + fs::create_dir_all(&storage_path)?; + + let lock_file = LockFile::new(lock_file)?; + + let mut storage = Self { + storage_path: storage_path.clone(), + + block_headers: Arc::new(RwLock::new( + PersistentBlockHeaderStorage::open(&storage_path).await?, + )), + filter_headers: Arc::new(RwLock::new( + PersistentFilterHeaderStorage::open(&storage_path).await?, + )), + filters: Arc::new(RwLock::new(PersistentFilterStorage::open(&storage_path).await?)), + transactions: Arc::new(RwLock::new( + PersistentTransactionStorage::open(&storage_path).await?, + )), + metadata: Arc::new(RwLock::new(PersistentMetadataStorage::open(&storage_path).await?)), + chainstate: Arc::new(RwLock::new( + PersistentChainStateStorage::open(&storage_path).await?, + )), + masternodestate: Arc::new(RwLock::new( + PersistentMasternodeStateStorage::open(&storage_path).await?, + )), + peers: Arc::new(RwLock::new(PersistentPeerStorage::open(&storage_path).await?)), + + worker_handle: None, + + _lock_file: lock_file, + }; + + storage.start_worker().await; + + Ok(storage) + } + + #[cfg(test)] + pub async fn with_temp_dir() -> StorageResult { + use tempfile::TempDir; + + let temp_dir = TempDir::new()?; + Self::new(temp_dir.path()).await + } + + /// Start the background worker saving data every 5 seconds + async fn start_worker(&mut self) { + let block_headers = Arc::clone(&self.block_headers); + let filter_headers = Arc::clone(&self.filter_headers); + let filters = Arc::clone(&self.filters); + let transactions = Arc::clone(&self.transactions); + let metadata = Arc::clone(&self.metadata); + let chainstate = Arc::clone(&self.chainstate); + let peers = Arc::clone(&self.peers); + + let storage_path = self.storage_path.clone(); + + let worker_handle = tokio::spawn(async move { + let mut ticker = tokio::time::interval(Duration::from_secs(5)); + + loop { + ticker.tick().await; + + let _ = block_headers.write().await.persist(&storage_path).await; + let _ = filter_headers.write().await.persist(&storage_path).await; + let _ = filters.write().await.persist(&storage_path).await; + let _ = transactions.write().await.persist(&storage_path).await; + let _ = metadata.write().await.persist(&storage_path).await; + let _ = chainstate.write().await.persist(&storage_path).await; + let _ = peers.write().await.persist(&storage_path).await; + } + }); + + self.worker_handle = Some(worker_handle); + } + + /// Stop the background worker without forcing a save. + fn stop_worker(&self) { + if let Some(handle) = &self.worker_handle { + handle.abort(); + } + } + + async fn persist(&self) { + let storage_path = &self.storage_path; + + let _ = self.block_headers.write().await.persist(storage_path).await; + let _ = self.filter_headers.write().await.persist(storage_path).await; + let _ = self.filters.write().await.persist(storage_path).await; + let _ = self.transactions.write().await.persist(storage_path).await; + let _ = self.metadata.write().await.persist(storage_path).await; + let _ = self.chainstate.write().await.persist(storage_path).await; + let _ = self.peers.write().await.persist(storage_path).await; + } +} - /// Load chain state. - async fn load_chain_state(&self) -> StorageResult>; +#[async_trait] +impl StorageManager for DiskStorageManager { + async fn clear(&mut self) -> StorageResult<()> { + // First, stop the background worker to avoid races with file deletion + self.stop_worker(); + + // Remove all files and directories under storage_path + if self.storage_path.exists() { + // Best-effort removal; if concurrent files appear, retry once + match tokio::fs::remove_dir_all(&self.storage_path).await { + Ok(_) => {} + Err(e) + if e.kind() == std::io::ErrorKind::Other + || e.kind() == std::io::ErrorKind::DirectoryNotEmpty => + { + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + tokio::fs::remove_dir_all(&self.storage_path).await?; + } + Err(e) => return Err(crate::error::StorageError::Io(e)), + } + tokio::fs::create_dir_all(&self.storage_path).await?; + } + + // Instantiate storages again once persisted data has been cleared + let storage_path = &self.storage_path; + + self.block_headers = + Arc::new(RwLock::new(PersistentBlockHeaderStorage::open(storage_path).await?)); + self.filter_headers = + Arc::new(RwLock::new(PersistentFilterHeaderStorage::open(storage_path).await?)); + self.filters = Arc::new(RwLock::new(PersistentFilterStorage::open(storage_path).await?)); + self.transactions = + Arc::new(RwLock::new(PersistentTransactionStorage::open(storage_path).await?)); + self.metadata = Arc::new(RwLock::new(PersistentMetadataStorage::open(storage_path).await?)); + self.chainstate = + Arc::new(RwLock::new(PersistentChainStateStorage::open(storage_path).await?)); + + // Restart the background worker for future operations + self.start_worker().await; + + Ok(()) + } + + async fn shutdown(&mut self) { + self.stop_worker(); + + self.persist().await; + } +} - /// Store a compact filter at a blockchain height. - async fn store_filter(&mut self, height: u32, filter: &[u8]) -> StorageResult<()>; +#[async_trait] +impl blocks::BlockHeaderStorage for DiskStorageManager { + async fn store_headers(&mut self, headers: &[BlockHeader]) -> StorageResult<()> { + self.block_headers.write().await.store_headers(headers).await + } - /// Load compact filters in the given blockchain height range. - async fn load_filters(&self, range: Range) -> StorageResult>>; + async fn store_headers_at_height( + &mut self, + headers: &[BlockHeader], + height: u32, + ) -> StorageResult<()> { + self.block_headers.write().await.store_headers_at_height(headers, height).await + } - /// Store metadata. - async fn store_metadata(&mut self, key: &str, value: &[u8]) -> StorageResult<()>; + async fn load_headers(&self, range: Range) -> StorageResult> { + self.block_headers.write().await.load_headers(range).await + } - /// Load metadata. - async fn load_metadata(&self, key: &str) -> StorageResult>>; + async fn get_tip_height(&self) -> Option { + self.block_headers.read().await.get_tip_height().await + } - /// Clear all data. - async fn clear(&mut self) -> StorageResult<()>; + async fn get_start_height(&self) -> Option { + self.block_headers.read().await.get_start_height().await + } - /// Clear all filter headers and compact filters. - async fn clear_filters(&mut self) -> StorageResult<()>; + async fn get_stored_headers_len(&self) -> u32 { + self.block_headers.read().await.get_stored_headers_len().await + } - /// Get header height by block hash (reverse lookup). async fn get_header_height_by_hash( &self, hash: &dashcore::BlockHash, - ) -> StorageResult>; - - // UTXO methods removed - handled by external wallet - - /// Store a chain lock. - async fn store_chain_lock( - &mut self, - height: u32, - chain_lock: &dashcore::ChainLock, - ) -> StorageResult<()>; + ) -> StorageResult> { + self.block_headers.read().await.get_header_height_by_hash(hash).await + } +} - /// Load a chain lock by height. - async fn load_chain_lock(&self, height: u32) -> StorageResult>; +#[async_trait] +impl filters::FilterHeaderStorage for DiskStorageManager { + async fn store_filter_headers(&mut self, headers: &[FilterHeader]) -> StorageResult<()> { + self.filter_headers.write().await.store_filter_headers(headers).await + } + + async fn load_filter_headers(&self, range: Range) -> StorageResult> { + self.filter_headers.write().await.load_filter_headers(range).await + } + + async fn get_filter_tip_height(&self) -> StorageResult> { + self.filter_headers.read().await.get_filter_tip_height().await + } + + async fn get_filter_start_height(&self) -> Option { + self.filter_headers.read().await.get_filter_start_height().await + } +} - /// Get chain locks in a height range. - async fn get_chain_locks( - &self, - start_height: u32, - end_height: u32, - ) -> StorageResult>; +#[async_trait] +impl filters::FilterStorage for DiskStorageManager { + async fn store_filter(&mut self, height: u32, filter: &[u8]) -> StorageResult<()> { + self.filters.write().await.store_filter(height, filter).await + } + + async fn load_filters(&self, range: Range) -> StorageResult>> { + self.filters.write().await.load_filters(range).await + } +} - // Mempool storage methods - /// Store an unconfirmed transaction. +#[async_trait] +impl transactions::TransactionStorage for DiskStorageManager { async fn store_mempool_transaction( &mut self, txid: &Txid, tx: &UnconfirmedTransaction, - ) -> StorageResult<()>; + ) -> StorageResult<()> { + self.transactions.write().await.store_mempool_transaction(txid, tx).await + } - /// Remove a mempool transaction. - async fn remove_mempool_transaction(&mut self, txid: &Txid) -> StorageResult<()>; + async fn remove_mempool_transaction(&mut self, txid: &Txid) -> StorageResult<()> { + self.transactions.write().await.remove_mempool_transaction(txid).await + } - /// Get a mempool transaction. async fn get_mempool_transaction( &self, txid: &Txid, - ) -> StorageResult>; + ) -> StorageResult> { + self.transactions.read().await.get_mempool_transaction(txid).await + } - /// Get all mempool transactions. async fn get_all_mempool_transactions( &self, - ) -> StorageResult>; + ) -> StorageResult> { + self.transactions.read().await.get_all_mempool_transactions().await + } - /// Store the complete mempool state. - async fn store_mempool_state(&mut self, state: &MempoolState) -> StorageResult<()>; + async fn store_mempool_state(&mut self, state: &MempoolState) -> StorageResult<()> { + self.transactions.write().await.store_mempool_state(state).await + } - /// Load the mempool state. - async fn load_mempool_state(&self) -> StorageResult>; + async fn load_mempool_state(&self) -> StorageResult> { + self.transactions.read().await.load_mempool_state().await + } +} - /// Clear all mempool data. - async fn clear_mempool(&mut self) -> StorageResult<()>; +#[async_trait] +impl metadata::MetadataStorage for DiskStorageManager { + async fn store_metadata(&mut self, key: &str, value: &[u8]) -> StorageResult<()> { + self.metadata.write().await.store_metadata(key, value).await + } + + async fn load_metadata(&self, key: &str) -> StorageResult>> { + self.metadata.read().await.load_metadata(key).await + } +} + +#[async_trait] +impl chainstate::ChainStateStorage for DiskStorageManager { + async fn store_chain_state(&mut self, state: &ChainState) -> StorageResult<()> { + self.chainstate.write().await.store_chain_state(state).await + } + + async fn load_chain_state(&self) -> StorageResult> { + self.chainstate.read().await.load_chain_state().await + } +} + +#[async_trait] +impl masternode::MasternodeStateStorage for DiskStorageManager { + async fn store_masternode_state(&mut self, state: &MasternodeState) -> StorageResult<()> { + self.masternodestate.write().await.store_masternode_state(state).await + } + + async fn load_masternode_state(&self) -> StorageResult> { + self.masternodestate.read().await.load_masternode_state().await + } +} - /// Shutdown the storage manager - async fn shutdown(&mut self) -> StorageResult<()>; +#[cfg(test)] +mod tests { + use crate::ChainState; + + use super::*; + use dashcore::{block::Version, pow::CompactTarget, BlockHash, Header as BlockHeader}; + use dashcore_hashes::Hash; + use tempfile::TempDir; + + fn build_headers(count: usize) -> Vec { + let mut headers = Vec::with_capacity(count); + let mut prev_hash = BlockHash::from_byte_array([0u8; 32]); + + for i in 0..count { + let header = BlockHeader { + version: Version::from_consensus(1), + prev_blockhash: prev_hash, + merkle_root: dashcore::hashes::sha256d::Hash::from_byte_array( + [(i % 255) as u8; 32], + ) + .into(), + time: 1 + i as u32, + bits: CompactTarget::from_consensus(0x1d00ffff), + nonce: i as u32, + }; + prev_hash = header.block_hash(); + headers.push(header); + } + + headers + } + + #[tokio::test] + async fn test_load_headers() -> Result<(), Box> { + // Create a temporary directory for the test + let temp_dir = TempDir::new()?; + let mut storage = DiskStorageManager::new(temp_dir.path().to_path_buf()) + .await + .expect("Unable to create storage"); + + // Create a test header + let test_header = BlockHeader { + version: Version::from_consensus(1), + prev_blockhash: BlockHash::from_byte_array([1; 32]), + merkle_root: dashcore::hashes::sha256d::Hash::from_byte_array([2; 32]).into(), + time: 12345, + bits: CompactTarget::from_consensus(0x1d00ffff), + nonce: 67890, + }; + + // Store just one header + storage.store_headers(&[test_header]).await?; + + let loaded_headers = storage.load_headers(0..1).await?; + + // Should only get back the one header we stored + assert_eq!(loaded_headers.len(), 1); + assert_eq!(loaded_headers[0], test_header); + + Ok(()) + } + + #[tokio::test] + async fn test_checkpoint_storage_indexing() -> StorageResult<()> { + use dashcore::TxMerkleNode; + use tempfile::tempdir; + + let temp_dir = tempdir().expect("Failed to create temp dir"); + let mut storage = DiskStorageManager::new(temp_dir.path().to_path_buf()).await?; + + // Create test headers starting from checkpoint height + let checkpoint_height = 1_100_000; + let headers: Vec = (0..100) + .map(|i| BlockHeader { + version: Version::from_consensus(1), + prev_blockhash: BlockHash::from_byte_array([i as u8; 32]), + merkle_root: TxMerkleNode::from_byte_array([(i + 1) as u8; 32]), + time: 1234567890 + i, + bits: CompactTarget::from_consensus(0x1a2b3c4d), + nonce: 67890 + i, + }) + .collect(); + + let mut base_state = ChainState::new(); + base_state.sync_base_height = checkpoint_height; + storage.store_chain_state(&base_state).await?; + + storage.store_headers_at_height(&headers, checkpoint_height).await?; + assert_eq!(storage.get_stored_headers_len().await, headers.len() as u32); + + // Verify headers are stored at correct blockchain heights + let header_at_base = storage.get_header(checkpoint_height).await?; + assert_eq!( + header_at_base.expect("Header at base blockchain height should exist"), + headers[0] + ); + + let header_at_ending = storage.get_header(checkpoint_height + 99).await?; + assert_eq!( + header_at_ending.expect("Header at ending blockchain height should exist"), + headers[99] + ); + + // Test the reverse index (hash -> blockchain height) + let hash_0 = headers[0].block_hash(); + let height_0 = storage.get_header_height_by_hash(&hash_0).await?; + assert_eq!( + height_0, + Some(checkpoint_height), + "Hash should map to blockchain height 1,100,000" + ); + + let hash_99 = headers[99].block_hash(); + let height_99 = storage.get_header_height_by_hash(&hash_99).await?; + assert_eq!( + height_99, + Some(checkpoint_height + 99), + "Hash should map to blockchain height 1,100,099" + ); + + // Store chain state to persist sync_base_height + let mut chain_state = ChainState::new(); + chain_state.sync_base_height = checkpoint_height; + storage.store_chain_state(&chain_state).await?; + + // Force save to disk + storage.persist().await; + + drop(storage); + + // Create a new storage instance to test index rebuilding + let storage2 = DiskStorageManager::new(temp_dir.path().to_path_buf()).await?; + + // Verify the index was rebuilt correctly + let height_after_rebuild = storage2.get_header_height_by_hash(&hash_0).await?; + assert_eq!( + height_after_rebuild, + Some(checkpoint_height), + "After index rebuild, hash should still map to blockchain height 1,100,000" + ); + + // Verify header can still be retrieved by blockchain height after reload + let header_after_reload = storage2.get_header(checkpoint_height).await?; + assert!( + header_after_reload.is_some(), + "Header at base blockchain height should exist after reload" + ); + assert_eq!(header_after_reload.unwrap(), headers[0]); + + Ok(()) + } + + #[tokio::test] + async fn test_shutdown_flushes_index() -> Result<(), Box> { + let temp_dir = TempDir::new()?; + let base_path = temp_dir.path().to_path_buf(); + let headers = build_headers(11_000); + let last_hash = headers.last().unwrap().block_hash(); + + { + let mut storage = DiskStorageManager::new(base_path.clone()).await?; + + storage.store_headers(&headers[..10_000]).await?; + storage.persist().await; + + storage.store_headers(&headers[10_000..]).await?; + storage.shutdown().await; + } + + let storage = DiskStorageManager::new(base_path).await?; + let height = storage.get_header_height_by_hash(&last_hash).await?; + assert_eq!(height, Some(10_999)); + + Ok(()) + } } diff --git a/dash-spv/src/storage/peers.rs b/dash-spv/src/storage/peers.rs new file mode 100644 index 000000000..9d39baff0 --- /dev/null +++ b/dash-spv/src/storage/peers.rs @@ -0,0 +1,206 @@ +use std::{ + collections::HashMap, + fs::{self, File}, + io::BufReader, + net::SocketAddr, + path::PathBuf, +}; + +use async_trait::async_trait; +use dashcore::{ + consensus::{encode, Decodable, Encodable}, + network::address::AddrV2Message, +}; + +use crate::{ + error::StorageResult, + network::PeerReputation, + storage::{io::atomic_write, PersistentStorage}, + StorageError, +}; + +#[async_trait] +pub trait PeerStorage { + async fn save_peers( + &self, + peers: &[dashcore::network::address::AddrV2Message], + ) -> StorageResult<()>; + + async fn load_peers(&self) -> StorageResult>; + + async fn save_peers_reputation( + &self, + reputations: &HashMap, + ) -> StorageResult<()>; + + async fn load_peers_reputation(&self) -> StorageResult>; +} + +pub struct PersistentPeerStorage { + storage_path: PathBuf, +} + +impl PersistentPeerStorage { + const FOLDER_NAME: &str = "peers"; + + fn peers_data_file(&self) -> PathBuf { + self.storage_path.join("peers.dat") + } + + fn peers_reputation_file(&self) -> PathBuf { + self.storage_path.join("reputations.json") + } +} + +#[async_trait] +impl PersistentStorage for PersistentPeerStorage { + async fn open(storage_path: impl Into + Send) -> StorageResult { + let storage_path = storage_path.into(); + + Ok(PersistentPeerStorage { + storage_path: storage_path.join(Self::FOLDER_NAME), + }) + } + + async fn persist(&mut self, _storage_path: impl Into + Send) -> StorageResult<()> { + // Current implementation persists data everytime data is stored + Ok(()) + } +} + +#[async_trait] +impl PeerStorage for PersistentPeerStorage { + async fn save_peers( + &self, + peers: &[dashcore::network::address::AddrV2Message], + ) -> StorageResult<()> { + let peers_file = self.peers_data_file(); + + if let Err(e) = fs::create_dir_all(peers_file.parent().unwrap()) { + return Err(StorageError::WriteFailed(format!("Failed to persist peers: {}", e))); + } + + let mut buffer = Vec::new(); + + for item in peers.iter() { + item.consensus_encode(&mut buffer) + .map_err(|e| StorageError::WriteFailed(format!("Failed to encode peer: {}", e)))?; + } + + let peers_file_parent = peers_file + .parent() + .ok_or(StorageError::NotFound("peers_file doesn't have a parent".to_string()))?; + + tokio::fs::create_dir_all(peers_file_parent).await?; + + atomic_write(&peers_file, &buffer).await?; + + Ok(()) + } + + async fn load_peers(&self) -> StorageResult> { + let peers_file = self.peers_data_file(); + + if !peers_file.exists() { + return Ok(Vec::new()); + }; + + let file = File::open(&peers_file)?; + let mut reader = BufReader::new(file); + let mut peers = Vec::new(); + + loop { + match AddrV2Message::consensus_decode(&mut reader) { + Ok(peer) => peers.push(peer), + Err(encode::Error::Io(ref e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + break + } + Err(e) => { + return Err(StorageError::ReadFailed(format!("Failed to decode peer: {e}"))) + } + } + } + + let peers = peers.into_iter().filter_map(|p| p.socket_addr().ok()).collect(); + + Ok(peers) + } + + async fn save_peers_reputation( + &self, + reputations: &HashMap, + ) -> StorageResult<()> { + let reputation_file = self.peers_reputation_file(); + + let json = serde_json::to_string_pretty(reputations).map_err(|e| { + StorageError::Serialization(format!("Failed to serialize peers reputations: {e}")) + })?; + + let reputation_file_parent = reputation_file + .parent() + .ok_or(StorageError::NotFound("reputation_file doesn't have a parent".to_string()))?; + + tokio::fs::create_dir_all(reputation_file_parent).await?; + + atomic_write(&reputation_file, json.as_bytes()).await + } + + async fn load_peers_reputation(&self) -> StorageResult> { + let reputation_file = self.peers_reputation_file(); + + if !reputation_file.exists() { + return Ok(HashMap::new()); + } + + let json = tokio::fs::read_to_string(reputation_file).await?; + serde_json::from_str(&json).map_err(|e| { + StorageError::ReadFailed(format!("Failed to deserialize peers reputations: {e}")) + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use dashcore::network::address::{AddrV2, AddrV2Message}; + use dashcore::network::constants::ServiceFlags; + use tempfile::TempDir; + + #[tokio::test] + async fn test_persistent_peer_storage_save_load() { + let temp_dir = TempDir::new().expect("Failed to create temporary directory for test"); + let store = PersistentPeerStorage::open(temp_dir.path()) + .await + .expect("Failed to open persistent peer storage"); + + // Create test peer messages + let addr: std::net::SocketAddr = + "192.168.1.1:9999".parse().expect("Failed to parse test address"); + let msg = AddrV2Message { + time: 1234567890, + services: ServiceFlags::from(1), + addr: AddrV2::Ipv4( + addr.ip().to_string().parse().expect("Failed to parse IPv4 address"), + ), + port: addr.port(), + }; + + store.save_peers(&[msg]).await.expect("Failed to save peers in test"); + + let loaded = store.load_peers().await.expect("Failed to load peers in test"); + assert_eq!(loaded.len(), 1); + assert_eq!(loaded[0], addr); + } + + #[tokio::test] + async fn test_persistent_peer_storage_empty() { + let temp_dir = TempDir::new().expect("Failed to create temporary directory for test"); + let store = PersistentPeerStorage::open(temp_dir.path()) + .await + .expect("Failed to open persistent peer storage"); + + // Load from non-existent file + let loaded = store.load_peers().await.expect("Failed to load peers from empty store"); + assert!(loaded.is_empty()); + } +} diff --git a/dash-spv/src/storage/segments.rs b/dash-spv/src/storage/segments.rs index 64ff4adad..fb72a3f42 100644 --- a/dash-spv/src/storage/segments.rs +++ b/dash-spv/src/storage/segments.rs @@ -17,65 +17,26 @@ use dashcore::{ }; use dashcore_hashes::Hash; -use crate::{ - error::StorageResult, - storage::{io::atomic_write, manager::WorkerCommand}, - StorageError, -}; - -use super::manager::DiskStorageManager; - -/// State of a segment in memory -#[derive(Debug, Clone, PartialEq)] -enum SegmentState { - Clean, // No changes, up to date on disk - Dirty, // Has changes, needs saving - Saving, // Currently being saved in background -} +use crate::{error::StorageResult, storage::io::atomic_write, StorageError}; pub trait Persistable: Sized + Encodable + Decodable + PartialEq + Clone { - const FOLDER_NAME: &'static str; const SEGMENT_PREFIX: &'static str = "segment"; const DATA_FILE_EXTENSION: &'static str = "dat"; - fn relative_disk_path(segment_id: u32) -> PathBuf { - format!( - "{}/{}_{:04}.{}", - Self::FOLDER_NAME, - Self::SEGMENT_PREFIX, - segment_id, - Self::DATA_FILE_EXTENSION - ) - .into() + fn segment_file_name(segment_id: u32) -> String { + format!("{}_{:04}.{}", Self::SEGMENT_PREFIX, segment_id, Self::DATA_FILE_EXTENSION) } fn sentinel() -> Self; - fn make_save_command(segment: &Segment) -> WorkerCommand; } impl Persistable for Vec { - const FOLDER_NAME: &'static str = "filters"; - - fn make_save_command(segment: &Segment) -> WorkerCommand { - WorkerCommand::SaveFilterSegmentCache { - segment_id: segment.segment_id, - } - } - fn sentinel() -> Self { vec![] } } impl Persistable for BlockHeader { - const FOLDER_NAME: &'static str = "block_headers"; - - fn make_save_command(segment: &Segment) -> WorkerCommand { - WorkerCommand::SaveBlockHeaderSegmentCache { - segment_id: segment.segment_id, - } - } - fn sentinel() -> Self { Self { version: Version::from_consensus(i32::MAX), // Invalid version @@ -89,14 +50,6 @@ impl Persistable for BlockHeader { } impl Persistable for FilterHeader { - const FOLDER_NAME: &'static str = "filter_headers"; - - fn make_save_command(segment: &Segment) -> WorkerCommand { - WorkerCommand::SaveFilterHeaderSegmentCache { - segment_id: segment.segment_id, - } - } - fn sentinel() -> Self { FilterHeader::from_byte_array([0u8; 32]) } @@ -106,20 +59,20 @@ impl Persistable for FilterHeader { #[derive(Debug)] pub struct SegmentCache { segments: HashMap>, + evicted: HashMap>, tip_height: Option, - base_path: PathBuf, + start_height: Option, + segments_dir: PathBuf, } impl SegmentCache { pub async fn build_block_index_from_segments( &mut self, ) -> StorageResult> { - let segments_dir = self.base_path.join(BlockHeader::FOLDER_NAME); + let entries = fs::read_dir(&self.segments_dir)?; let mut block_index = HashMap::new(); - let entries = fs::read_dir(&segments_dir)?; - for entry in entries.flatten() { let name = match entry.file_name().into_string() { Ok(s) => s, @@ -157,19 +110,21 @@ impl SegmentCache { impl SegmentCache { const MAX_ACTIVE_SEGMENTS: usize = 10; - pub async fn load_or_new(base_path: impl Into) -> StorageResult { - let base_path = base_path.into(); - let items_dir = base_path.join(I::FOLDER_NAME); + pub async fn load_or_new(segments_dir: impl Into) -> StorageResult { + let segments_dir = segments_dir.into(); let mut cache = Self { segments: HashMap::with_capacity(Self::MAX_ACTIVE_SEGMENTS), + evicted: HashMap::new(), tip_height: None, - base_path, + start_height: None, + segments_dir: segments_dir.clone(), }; // Building the metadata - if let Ok(entries) = fs::read_dir(&items_dir) { - let mut max_segment_id = None; + if let Ok(entries) = fs::read_dir(&segments_dir) { + let mut max_seg_id = None; + let mut min_seg_id = None; for entry in entries.flatten() { if let Some(name) = entry.file_name().to_str() { @@ -180,26 +135,33 @@ impl SegmentCache { let segment_id_end = segment_id_start + 4; if let Ok(id) = name[segment_id_start..segment_id_end].parse::() { - max_segment_id = - Some(max_segment_id.map_or(id, |max: u32| max.max(id))); + max_seg_id = Some(max_seg_id.map_or(id, |max: u32| max.max(id))); + min_seg_id = Some(min_seg_id.map_or(id, |min: u32| min.min(id))); } } } } - if let Some(segment_id) = max_segment_id { + if let Some(segment_id) = max_seg_id { let segment = cache.get_segment(&segment_id).await?; cache.tip_height = segment .last_valid_offset() - .map(|offset| segment_id * Segment::::ITEMS_PER_SEGMENT + offset); + .map(|offset| Self::segment_id_to_start_height(segment_id) + offset); + } + + if let Some(segment_id) = min_seg_id { + let segment = cache.get_segment(&segment_id).await?; + + cache.start_height = segment + .first_valid_offset() + .map(|offset| Self::segment_id_to_start_height(segment_id) + offset); } } Ok(cache) } - /// Get the segment ID for a given storage index. #[inline] fn height_to_segment_id(height: u32) -> u32 { height / Segment::::ITEMS_PER_SEGMENT @@ -216,54 +178,45 @@ impl SegmentCache { height % Segment::::ITEMS_PER_SEGMENT } - pub fn clear_in_memory(&mut self) { - self.segments.clear(); - self.tip_height = None; - } - - pub async fn clear_all(&mut self) -> StorageResult<()> { - self.clear_in_memory(); - - let persistence_dir = self.base_path.join(I::FOLDER_NAME); - if persistence_dir.exists() { - tokio::fs::remove_dir_all(&persistence_dir).await?; - } - tokio::fs::create_dir_all(&persistence_dir).await?; - - Ok(()) - } - - pub async fn get_segment(&mut self, segment_id: &u32) -> StorageResult<&Segment> { + async fn get_segment(&mut self, segment_id: &u32) -> StorageResult<&Segment> { let segment = self.get_segment_mut(segment_id).await?; Ok(&*segment) } - pub async fn get_segment_mut<'a>( + async fn get_segment_mut<'a>( &'a mut self, segment_id: &u32, ) -> StorageResult<&'a mut Segment> { let segments_len = self.segments.len(); - let segments = &mut self.segments; - if segments.contains_key(segment_id) { - let segment = segments.get_mut(segment_id).expect("We already checked that it exists"); - segment.last_accessed = Instant::now(); + if self.segments.contains_key(segment_id) { + let segment = + self.segments.get_mut(segment_id).expect("We already checked that it exists"); return Ok(segment); } if segments_len >= Self::MAX_ACTIVE_SEGMENTS { let key_to_evict = - segments.iter_mut().min_by_key(|(_, s)| s.last_accessed).map(|(k, v)| (*k, v)); + self.segments.iter_mut().min_by_key(|(_, s)| s.last_accessed).map(|(k, v)| (*k, v)); - if let Some((key, segment)) = key_to_evict { - segment.persist(&self.base_path).await?; - segments.remove(&key); + if let Some((key, _)) = key_to_evict { + if let Some(segment) = self.segments.remove(&key) { + if segment.state == SegmentState::Dirty { + self.evicted.insert(key, segment); + } + } } } - // Load and insert - let segment = Segment::load(&self.base_path, *segment_id).await?; - let segment = segments.entry(*segment_id).or_insert(segment); + // If the segment is already in the to_persist map, load it from there. + // If the segment is not in the to_persist map, load it from disk. + let segment = if let Some(segment) = self.evicted.remove(segment_id) { + segment + } else { + Segment::load(&self.segments_dir, *segment_id).await? + }; + + let segment = self.segments.entry(*segment_id).or_insert(segment); Ok(segment) } @@ -332,11 +285,14 @@ impl SegmentCache { Ok(items) } - pub async fn store_items( + pub async fn store_items(&mut self, items: &[I]) -> StorageResult<()> { + self.store_items_at_height(items, self.next_height()).await + } + + pub async fn store_items_at_height( &mut self, items: &[I], start_height: u32, - manager: &DiskStorageManager, ) -> StorageResult<()> { if items.is_empty() { tracing::trace!("DiskStorage: no items to store"); @@ -356,35 +312,41 @@ impl SegmentCache { let offset = Self::height_to_offset(height); // Update segment - let segments = self.get_segment_mut(&segment_id).await?; - segments.insert(item.clone(), offset); + let segment = self.get_segment_mut(&segment_id).await?; + segment.insert(item.clone(), offset); height += 1; } - // Update cached tip height with blockchain height + // Update cached tip height and start height + // if needed self.tip_height = match self.tip_height { Some(current) => Some(current.max(height - 1)), None => Some(height - 1), }; - // Persist dirty segments periodically (every 1000 filter items) - if items.len() >= 1000 || start_height.is_multiple_of(1000) { - self.persist_dirty(manager).await; - } + self.start_height = match self.start_height { + Some(current) => Some(current.min(start_height)), + None => Some(start_height), + }; Ok(()) } - pub async fn persist_dirty(&mut self, manager: &DiskStorageManager) { - // Collect segments to persist (only dirty ones) - let segments: Vec<_> = - self.segments.values().filter(|s| s.state == SegmentState::Dirty).collect(); + pub async fn persist(&mut self, segments_dir: impl Into) { + let segments_dir = segments_dir.into(); + + for (id, segments) in self.evicted.iter_mut() { + if let Err(e) = segments.persist(&segments_dir).await { + tracing::error!("Failed to persist segment with id {id}: {e}"); + } + } + + self.evicted.clear(); - // Send header segments to worker if exists - if let Some(tx) = &manager.worker_tx { - for segment in segments { - let _ = tx.send(I::make_save_command(segment)).await; + for (id, segments) in self.segments.iter_mut() { + if let Err(e) = segments.persist(&segments_dir).await { + tracing::error!("Failed to persist segment with id {id}: {e}"); } } } @@ -394,6 +356,11 @@ impl SegmentCache { self.tip_height } + #[inline] + pub fn start_height(&self) -> Option { + self.start_height + } + #[inline] pub fn next_height(&self) -> u32 { match self.tip_height() { @@ -403,6 +370,13 @@ impl SegmentCache { } } +/// State of a segment in memory +#[derive(Debug, Clone, PartialEq)] +enum SegmentState { + Clean, // No changes, up to date on disk + Dirty, // Has changes, needs saving +} + /// In-memory cache for a segment of items #[derive(Debug, Clone)] pub struct Segment { @@ -453,7 +427,7 @@ impl Segment { pub async fn load(base_path: &Path, segment_id: u32) -> StorageResult { // Load segment from disk - let segment_path = base_path.join(I::relative_disk_path(segment_id)); + let segment_path = base_path.join(I::segment_file_name(segment_id)); let (items, state) = if segment_path.exists() { let file = File::open(&segment_path)?; @@ -487,19 +461,18 @@ impl Segment { Ok(Self::new(segment_id, items, state)) } - pub async fn persist(&mut self, base_path: &Path) -> StorageResult<()> { + pub async fn persist(&mut self, segments_dir: impl Into) -> StorageResult<()> { if self.state == SegmentState::Clean { return Ok(()); } - let path = base_path.join(I::relative_disk_path(self.segment_id)); + let segments_dir = segments_dir.into(); + let path = segments_dir.join(I::segment_file_name(self.segment_id)); if let Err(e) = fs::create_dir_all(path.parent().unwrap()) { return Err(StorageError::WriteFailed(format!("Failed to persist segment: {}", e))); } - self.state = SegmentState::Saving; - let mut buffer = Vec::new(); for item in self.items.iter() { @@ -527,7 +500,6 @@ impl Segment { self.items[offset] = item; - // Transition to Dirty state (from Clean, Dirty, or Saving) self.state = SegmentState::Dirty; self.last_accessed = std::time::Instant::now(); } @@ -592,8 +564,8 @@ mod tests { // This logic is a little tricky. Each cache can contain up to MAX_SEGMENTS segments in memory. // By storing MAX_SEGMENTS + 1 items, we ensure that the cache will evict the first introduced. // Then, by asking again in order starting in 0, we force the cache to load the evicted segment - // from disk, evicting at the same time the next, 1 in this case. Then we ask for the 1 that we - // know is evicted and so on. + // evicting at the same time the next, 1 in this case. Then we ask for the 1 that we know is + // evicted and so on. for i in 0..=MAX_SEGMENTS { let segment = cache.get_segment_mut(&i).await.expect("Failed to create a new segment"); @@ -608,7 +580,6 @@ mod tests { let segment = cache.get_segment_mut(&i).await.expect("Failed to load segment"); assert_eq!(segment.get(0..1), [FilterHeader::new_test(i)]); - assert!(segment.state == SegmentState::Clean); } } @@ -622,47 +593,59 @@ mod tests { .await .expect("Failed to create new segment_cache"); - let segment = cache.get_segment_mut(&0).await.expect("Failed to create a new segment"); - - assert!(segment.first_valid_offset().is_none()); - assert!(segment.last_valid_offset().is_none()); - assert_eq!(segment.state, SegmentState::Dirty); + cache.store_items_at_height(&items, 10).await.expect("Failed to store items"); - for (index, item) in items.iter().enumerate() { - segment.insert(*item, index as u32 + 10); - } + cache.persist(tmp_dir.path()).await; - assert_eq!(segment.first_valid_offset(), Some(10)); - assert_eq!(segment.last_valid_offset(), Some(19)); + let mut cache = SegmentCache::::load_or_new(tmp_dir.path()) + .await + .expect("Failed to load new segment_cache"); - assert!(segment.persist(tmp_dir.path()).await.is_ok()); + assert_eq!( + cache.get_items(10..20).await.expect("Failed to get items from segment cache"), + items + ); + } - cache.clear_in_memory(); - assert!(cache.segments.is_empty()); + #[tokio::test] + async fn test_segment_cache_get_insert() { + let tmp_dir = TempDir::new().unwrap(); - let segment = cache.get_segment_mut(&0).await.expect("Failed to load segment"); + const ITEMS_PER_SEGMENT: u32 = Segment::::ITEMS_PER_SEGMENT; - assert_eq!(segment.state, SegmentState::Clean); + let mut cache = SegmentCache::::load_or_new(tmp_dir.path()) + .await + .expect("Failed to create new segment_cache"); - assert_eq!(segment.get(10..20), items); + let items: Vec<_> = (0..ITEMS_PER_SEGMENT * 2 + ITEMS_PER_SEGMENT / 2) + .map(FilterHeader::new_test) + .collect(); - assert_eq!(segment.first_valid_offset(), Some(10)); - assert_eq!(segment.last_valid_offset(), Some(19)); + cache.store_items(&items).await.expect("Failed to store items"); - cache.clear_all().await.expect("Failed to clean on-memory and on-disk data"); - assert!(cache.segments.is_empty()); + assert_eq!( + items[0..ITEMS_PER_SEGMENT as usize], + cache.get_items(0..ITEMS_PER_SEGMENT).await.expect("Failed to get items") + ); - let segment = cache.get_segment(&0).await.expect("Failed to create a new segment"); + assert_eq!( + items[0..(ITEMS_PER_SEGMENT - 1) as usize], + cache.get_items(0..ITEMS_PER_SEGMENT - 1).await.expect("Failed to get items") + ); - assert!(segment.first_valid_offset().is_none()); - assert!(segment.last_valid_offset().is_none()); - assert_eq!(segment.state, SegmentState::Dirty); - } + assert_eq!( + items[0..(ITEMS_PER_SEGMENT + 1) as usize], + cache.get_items(0..ITEMS_PER_SEGMENT + 1).await.expect("Failed to get items") + ); - #[tokio::test] - async fn test_segment_cache_get_insert() { - // Cannot test the get/insert logic bcs it depends on the DiskStorageManager, test that struct properly or - // remove the necessity of it + assert_eq!( + items[(ITEMS_PER_SEGMENT - 1) as usize + ..(ITEMS_PER_SEGMENT * 2 + ITEMS_PER_SEGMENT / 2) as usize], + cache + .get_items(ITEMS_PER_SEGMENT - 1..ITEMS_PER_SEGMENT * 2 + ITEMS_PER_SEGMENT / 2) + .await + .expect("Failed to get items") + ); } #[tokio::test] diff --git a/dash-spv/src/storage/state.rs b/dash-spv/src/storage/state.rs deleted file mode 100644 index 937ac3d2a..000000000 --- a/dash-spv/src/storage/state.rs +++ /dev/null @@ -1,711 +0,0 @@ -//! State persistence and StorageManager trait implementation. - -use async_trait::async_trait; -use std::collections::HashMap; - -use dashcore::{block::Header as BlockHeader, BlockHash, Txid}; -#[cfg(test)] -use dashcore_hashes::Hash; - -use crate::error::StorageResult; -use crate::storage::manager::WorkerCommand; -use crate::storage::{MasternodeState, StorageManager}; -use crate::types::{ChainState, MempoolState, UnconfirmedTransaction}; - -use super::io::atomic_write; -use super::manager::DiskStorageManager; - -impl DiskStorageManager { - /// Store chain state to disk. - pub async fn store_chain_state(&mut self, state: &ChainState) -> StorageResult<()> { - // First store all headers - // For checkpoint sync, we need to store headers starting from the checkpoint height - self.store_headers_at_height(&state.headers, state.sync_base_height).await?; - - // Store filter headers - self.filter_headers - .write() - .await - .store_items(&state.filter_headers, state.sync_base_height, self) - .await?; - - // Store other state as JSON - let state_data = serde_json::json!({ - "last_chainlock_height": state.last_chainlock_height, - "last_chainlock_hash": state.last_chainlock_hash, - "current_filter_tip": state.current_filter_tip, - "last_masternode_diff_height": state.last_masternode_diff_height, - "sync_base_height": state.sync_base_height, - }); - - let path = self.base_path.join("state/chain.json"); - let json = state_data.to_string(); - atomic_write(&path, json.as_bytes()).await?; - - Ok(()) - } - - /// Load chain state from disk. - pub async fn load_chain_state(&self) -> StorageResult> { - let path = self.base_path.join("state/chain.json"); - if !path.exists() { - return Ok(None); - } - - let content = tokio::fs::read_to_string(path).await?; - let value: serde_json::Value = serde_json::from_str(&content).map_err(|e| { - crate::error::StorageError::Serialization(format!("Failed to parse chain state: {}", e)) - })?; - - let mut state = ChainState { - last_chainlock_height: value - .get("last_chainlock_height") - .and_then(|v| v.as_u64()) - .map(|h| h as u32), - last_chainlock_hash: value - .get("last_chainlock_hash") - .and_then(|v| v.as_str()) - .and_then(|s| s.parse().ok()), - current_filter_tip: value - .get("current_filter_tip") - .and_then(|v| v.as_str()) - .and_then(|s| s.parse().ok()), - masternode_engine: None, - last_masternode_diff_height: value - .get("last_masternode_diff_height") - .and_then(|v| v.as_u64()) - .map(|h| h as u32), - sync_base_height: value - .get("sync_base_height") - .and_then(|v| v.as_u64()) - .map(|h| h as u32) - .unwrap_or(0), - ..Default::default() - }; - - let range_start = state.sync_base_height; - if let Some(tip_height) = self.get_tip_height().await? { - state.headers = self.load_headers(range_start..tip_height + 1).await?; - } - if let Some(filter_tip_height) = self.get_filter_tip_height().await? { - state.filter_headers = - self.load_filter_headers(range_start..filter_tip_height + 1).await?; - } - - Ok(Some(state)) - } - - /// Store masternode state. - pub async fn store_masternode_state(&mut self, state: &MasternodeState) -> StorageResult<()> { - let path = self.base_path.join("state/masternode.json"); - let json = serde_json::to_string_pretty(state).map_err(|e| { - crate::error::StorageError::Serialization(format!( - "Failed to serialize masternode state: {}", - e - )) - })?; - - atomic_write(&path, json.as_bytes()).await?; - Ok(()) - } - - /// Load masternode state. - pub async fn load_masternode_state(&self) -> StorageResult> { - let path = self.base_path.join("state/masternode.json"); - if !path.exists() { - return Ok(None); - } - - let content = tokio::fs::read_to_string(path).await?; - let state = serde_json::from_str(&content).map_err(|e| { - crate::error::StorageError::Serialization(format!( - "Failed to deserialize masternode state: {}", - e - )) - })?; - - Ok(Some(state)) - } - - /// Store a ChainLock. - pub async fn store_chain_lock( - &mut self, - height: u32, - chain_lock: &dashcore::ChainLock, - ) -> StorageResult<()> { - let path = self.base_path.join("chainlocks").join(format!("chainlock_{:08}.bin", height)); - let data = bincode::serialize(chain_lock).map_err(|e| { - crate::error::StorageError::WriteFailed(format!( - "Failed to serialize chain lock: {}", - e - )) - })?; - - atomic_write(&path, &data).await?; - tracing::debug!("Stored chain lock at height {}", height); - Ok(()) - } - - /// Load a ChainLock. - pub async fn load_chain_lock(&self, height: u32) -> StorageResult> { - let path = self.base_path.join("chainlocks").join(format!("chainlock_{:08}.bin", height)); - - if !path.exists() { - return Ok(None); - } - - let data = tokio::fs::read(&path).await?; - let chain_lock = bincode::deserialize(&data).map_err(|e| { - crate::error::StorageError::ReadFailed(format!( - "Failed to deserialize chain lock: {}", - e - )) - })?; - - Ok(Some(chain_lock)) - } - - /// Get ChainLocks in a height range. - pub async fn get_chain_locks( - &self, - start_height: u32, - end_height: u32, - ) -> StorageResult> { - let chainlocks_dir = self.base_path.join("chainlocks"); - - if !chainlocks_dir.exists() { - return Ok(Vec::new()); - } - - let mut chain_locks = Vec::new(); - let mut entries = tokio::fs::read_dir(&chainlocks_dir).await?; - - while let Some(entry) = entries.next_entry().await? { - let file_name = entry.file_name(); - let file_name_str = file_name.to_string_lossy(); - - // Parse height from filename - if let Some(height_str) = - file_name_str.strip_prefix("chainlock_").and_then(|s| s.strip_suffix(".bin")) - { - if let Ok(height) = height_str.parse::() { - if height >= start_height && height <= end_height { - let path = entry.path(); - let data = tokio::fs::read(&path).await?; - if let Ok(chain_lock) = bincode::deserialize(&data) { - chain_locks.push((height, chain_lock)); - } - } - } - } - } - - // Sort by height - chain_locks.sort_by_key(|(h, _)| *h); - Ok(chain_locks) - } - - /// Store metadata. - pub async fn store_metadata(&mut self, key: &str, value: &[u8]) -> StorageResult<()> { - let path = self.base_path.join(format!("state/{}.dat", key)); - atomic_write(&path, value).await?; - Ok(()) - } - - /// Load metadata. - pub async fn load_metadata(&self, key: &str) -> StorageResult>> { - let path = self.base_path.join(format!("state/{}.dat", key)); - if !path.exists() { - return Ok(None); - } - - let data = tokio::fs::read(path).await?; - Ok(Some(data)) - } - - /// Clear all storage. - pub async fn clear(&mut self) -> StorageResult<()> { - // First, stop the background worker to avoid races with file deletion - self.stop_worker().await; - - // Clear in-memory state - self.block_headers.write().await.clear_in_memory(); - self.filter_headers.write().await.clear_in_memory(); - self.filters.write().await.clear_in_memory(); - - self.header_hash_index.write().await.clear(); - self.mempool_transactions.write().await.clear(); - *self.mempool_state.write().await = None; - - // Remove all files and directories under base_path - if self.base_path.exists() { - // Best-effort removal; if concurrent files appear, retry once - match tokio::fs::remove_dir_all(&self.base_path).await { - Ok(_) => {} - Err(e) => { - // Retry once after a short delay to handle transient races - if e.kind() == std::io::ErrorKind::Other - || e.kind() == std::io::ErrorKind::DirectoryNotEmpty - { - tokio::time::sleep(std::time::Duration::from_millis(50)).await; - tokio::fs::remove_dir_all(&self.base_path).await?; - } else { - return Err(crate::error::StorageError::Io(e)); - } - } - } - tokio::fs::create_dir_all(&self.base_path).await?; - } - - // Recreate expected subdirectories - tokio::fs::create_dir_all(self.base_path.join("headers")).await?; - tokio::fs::create_dir_all(self.base_path.join("filters")).await?; - tokio::fs::create_dir_all(self.base_path.join("state")).await?; - - // Restart the background worker for future operations - self.start_worker().await; - - Ok(()) - } - - /// Shutdown the storage manager. - pub async fn shutdown(&mut self) { - // Persist all dirty data - self.save_dirty().await; - - // Shutdown background worker - if let Some(tx) = self.worker_tx.take() { - // Save the header index before shutdown - let index = self.header_hash_index.read().await.clone(); - let _ = tx - .send(super::manager::WorkerCommand::SaveIndex { - index, - }) - .await; - let _ = tx.send(super::manager::WorkerCommand::Shutdown).await; - } - - if let Some(handle) = self.worker_handle.take() { - let _ = handle.await; - } - } - - /// Save all dirty segments to disk via background worker. - pub(super) async fn save_dirty(&self) { - self.filter_headers.write().await.persist_dirty(self).await; - self.block_headers.write().await.persist_dirty(self).await; - self.filters.write().await.persist_dirty(self).await; - - if let Some(tx) = &self.worker_tx { - // Save the index only if it has grown significantly (every 10k new entries) - let current_index_size = self.header_hash_index.read().await.len(); - let last_save_count = *self.last_index_save_count.read().await; - - // Save if index has grown by 10k entries, or if we've never saved before - if current_index_size >= last_save_count + 10_000 || last_save_count == 0 { - let index = self.header_hash_index.read().await.clone(); - let _ = tx - .send(WorkerCommand::SaveIndex { - index, - }) - .await; - - // Update the last save count - *self.last_index_save_count.write().await = current_index_size; - tracing::debug!( - "Scheduled index save (size: {}, last_save: {})", - current_index_size, - last_save_count - ); - } - } - } -} - -/// Mempool storage methods -impl DiskStorageManager { - /// Store a mempool transaction. - pub async fn store_mempool_transaction( - &mut self, - txid: &Txid, - tx: &UnconfirmedTransaction, - ) -> StorageResult<()> { - self.mempool_transactions.write().await.insert(*txid, tx.clone()); - Ok(()) - } - - /// Remove a mempool transaction. - pub async fn remove_mempool_transaction(&mut self, txid: &Txid) -> StorageResult<()> { - self.mempool_transactions.write().await.remove(txid); - Ok(()) - } - - /// Get a mempool transaction. - pub async fn get_mempool_transaction( - &self, - txid: &Txid, - ) -> StorageResult> { - Ok(self.mempool_transactions.read().await.get(txid).cloned()) - } - - /// Get all mempool transactions. - pub async fn get_all_mempool_transactions( - &self, - ) -> StorageResult> { - Ok(self.mempool_transactions.read().await.clone()) - } - - /// Store mempool state. - pub async fn store_mempool_state(&mut self, state: &MempoolState) -> StorageResult<()> { - *self.mempool_state.write().await = Some(state.clone()); - Ok(()) - } - - /// Load mempool state. - pub async fn load_mempool_state(&self) -> StorageResult> { - Ok(self.mempool_state.read().await.clone()) - } - - /// Clear mempool. - pub async fn clear_mempool(&mut self) -> StorageResult<()> { - self.mempool_transactions.write().await.clear(); - *self.mempool_state.write().await = None; - Ok(()) - } -} - -#[async_trait] -impl StorageManager for DiskStorageManager { - async fn store_headers(&mut self, headers: &[BlockHeader]) -> StorageResult<()> { - self.store_headers(headers).await - } - - async fn load_headers(&self, range: std::ops::Range) -> StorageResult> { - self.block_headers.write().await.get_items(range).await - } - - async fn get_header(&self, height: u32) -> StorageResult> { - Ok(self.block_headers.write().await.get_items(height..height + 1).await?.first().copied()) - } - - async fn get_tip_height(&self) -> StorageResult> { - Ok(self.block_headers.read().await.tip_height()) - } - - async fn store_filter_headers( - &mut self, - headers: &[dashcore::hash_types::FilterHeader], - ) -> StorageResult<()> { - let mut filter_headers = self.filter_headers.write().await; - let next_height = filter_headers.next_height(); - filter_headers.store_items(headers, next_height, self).await - } - - async fn load_filter_headers( - &self, - range: std::ops::Range, - ) -> StorageResult> { - self.filter_headers.write().await.get_items(range).await - } - - async fn get_filter_header( - &self, - height: u32, - ) -> StorageResult> { - Ok(self.filter_headers.write().await.get_items(height..height + 1).await?.first().copied()) - } - - async fn get_filter_tip_height(&self) -> StorageResult> { - Ok(self.filter_headers.read().await.tip_height()) - } - - async fn store_masternode_state(&mut self, state: &MasternodeState) -> StorageResult<()> { - Self::store_masternode_state(self, state).await - } - - async fn load_masternode_state(&self) -> StorageResult> { - Self::load_masternode_state(self).await - } - - async fn store_chain_state(&mut self, state: &ChainState) -> StorageResult<()> { - Self::store_chain_state(self, state).await - } - - async fn load_chain_state(&self) -> StorageResult> { - Self::load_chain_state(self).await - } - - async fn store_filter(&mut self, height: u32, filter: &[u8]) -> StorageResult<()> { - self.filters.write().await.store_items(&[filter.to_vec()], height, self).await - } - - async fn load_filters(&self, range: std::ops::Range) -> StorageResult>> { - self.filters.write().await.get_items(range).await - } - - async fn store_metadata(&mut self, key: &str, value: &[u8]) -> StorageResult<()> { - Self::store_metadata(self, key, value).await - } - - async fn load_metadata(&self, key: &str) -> StorageResult>> { - Self::load_metadata(self, key).await - } - - async fn clear(&mut self) -> StorageResult<()> { - Self::clear(self).await - } - - async fn clear_filters(&mut self) -> StorageResult<()> { - // Stop worker to prevent concurrent writes to filter directories - self.stop_worker().await; - - // Clear in-memory and on-disk filter headers segments - self.filter_headers.write().await.clear_all().await?; - self.filters.write().await.clear_all().await?; - - // Restart background worker for future operations - self.start_worker().await; - - Ok(()) - } - - async fn get_header_height_by_hash(&self, hash: &BlockHash) -> StorageResult> { - Self::get_header_height_by_hash(self, hash).await - } - - async fn store_chain_lock( - &mut self, - height: u32, - chain_lock: &dashcore::ChainLock, - ) -> StorageResult<()> { - Self::store_chain_lock(self, height, chain_lock).await - } - - async fn load_chain_lock(&self, height: u32) -> StorageResult> { - Self::load_chain_lock(self, height).await - } - - async fn get_chain_locks( - &self, - start_height: u32, - end_height: u32, - ) -> StorageResult> { - Self::get_chain_locks(self, start_height, end_height).await - } - - async fn store_mempool_transaction( - &mut self, - txid: &Txid, - tx: &UnconfirmedTransaction, - ) -> StorageResult<()> { - Self::store_mempool_transaction(self, txid, tx).await - } - - async fn remove_mempool_transaction(&mut self, txid: &Txid) -> StorageResult<()> { - Self::remove_mempool_transaction(self, txid).await - } - - async fn get_mempool_transaction( - &self, - txid: &Txid, - ) -> StorageResult> { - Self::get_mempool_transaction(self, txid).await - } - - async fn get_all_mempool_transactions( - &self, - ) -> StorageResult> { - Self::get_all_mempool_transactions(self).await - } - - async fn store_mempool_state(&mut self, state: &MempoolState) -> StorageResult<()> { - Self::store_mempool_state(self, state).await - } - - async fn load_mempool_state(&self) -> StorageResult> { - Self::load_mempool_state(self).await - } - - async fn clear_mempool(&mut self) -> StorageResult<()> { - Self::clear_mempool(self).await - } - - async fn shutdown(&mut self) -> StorageResult<()> { - Self::shutdown(self).await; - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use dashcore::{block::Version, pow::CompactTarget}; - use tempfile::TempDir; - - fn build_headers(count: usize) -> Vec { - let mut headers = Vec::with_capacity(count); - let mut prev_hash = BlockHash::from_byte_array([0u8; 32]); - - for i in 0..count { - let header = BlockHeader { - version: Version::from_consensus(1), - prev_blockhash: prev_hash, - merkle_root: dashcore::hashes::sha256d::Hash::from_byte_array( - [(i % 255) as u8; 32], - ) - .into(), - time: 1 + i as u32, - bits: CompactTarget::from_consensus(0x1d00ffff), - nonce: i as u32, - }; - prev_hash = header.block_hash(); - headers.push(header); - } - - headers - } - - #[tokio::test] - async fn test_load_headers() -> Result<(), Box> { - // Create a temporary directory for the test - let temp_dir = TempDir::new()?; - let mut storage = DiskStorageManager::new(temp_dir.path().to_path_buf()) - .await - .expect("Unable to create storage"); - - // Create a test header - let test_header = BlockHeader { - version: Version::from_consensus(1), - prev_blockhash: BlockHash::from_byte_array([1; 32]), - merkle_root: dashcore::hashes::sha256d::Hash::from_byte_array([2; 32]).into(), - time: 12345, - bits: CompactTarget::from_consensus(0x1d00ffff), - nonce: 67890, - }; - - // Store just one header - storage.store_headers(&[test_header]).await?; - - let loaded_headers = storage.load_headers(0..1).await?; - - // Should only get back the one header we stored - assert_eq!(loaded_headers.len(), 1); - assert_eq!(loaded_headers[0], test_header); - - Ok(()) - } - - #[tokio::test] - async fn test_checkpoint_storage_indexing() -> StorageResult<()> { - use dashcore::TxMerkleNode; - use tempfile::tempdir; - - let temp_dir = tempdir().expect("Failed to create temp dir"); - let mut storage = DiskStorageManager::new(temp_dir.path().to_path_buf()).await?; - - // Create test headers starting from checkpoint height - let checkpoint_height = 1_100_000; - let headers: Vec = (0..100) - .map(|i| BlockHeader { - version: Version::from_consensus(1), - prev_blockhash: BlockHash::from_byte_array([i as u8; 32]), - merkle_root: TxMerkleNode::from_byte_array([(i + 1) as u8; 32]), - time: 1234567890 + i, - bits: CompactTarget::from_consensus(0x1a2b3c4d), - nonce: 67890 + i, - }) - .collect(); - - let mut base_state = ChainState::new(); - base_state.sync_base_height = checkpoint_height; - storage.store_chain_state(&base_state).await?; - - storage.store_headers_at_height(&headers, checkpoint_height).await?; - - // Verify headers are stored at correct blockchain heights - let header_at_base = storage.get_header(checkpoint_height).await?; - assert_eq!( - header_at_base.expect("Header at base blockchain height should exist"), - headers[0] - ); - - let header_at_ending = storage.get_header(checkpoint_height + 99).await?; - assert_eq!( - header_at_ending.expect("Header at ending blockchain height should exist"), - headers[99] - ); - - // Test the reverse index (hash -> blockchain height) - let hash_0 = headers[0].block_hash(); - let height_0 = storage.get_header_height_by_hash(&hash_0).await?; - assert_eq!( - height_0, - Some(checkpoint_height), - "Hash should map to blockchain height 1,100,000" - ); - - let hash_99 = headers[99].block_hash(); - let height_99 = storage.get_header_height_by_hash(&hash_99).await?; - assert_eq!( - height_99, - Some(checkpoint_height + 99), - "Hash should map to blockchain height 1,100,099" - ); - - // Store chain state to persist sync_base_height - let mut chain_state = ChainState::new(); - chain_state.sync_base_height = checkpoint_height; - storage.store_chain_state(&chain_state).await?; - - // Force save to disk - storage.save_dirty().await; - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - - drop(storage); - - // Create a new storage instance to test index rebuilding - let storage2 = DiskStorageManager::new(temp_dir.path().to_path_buf()).await?; - - // Verify the index was rebuilt correctly - let height_after_rebuild = storage2.get_header_height_by_hash(&hash_0).await?; - assert_eq!( - height_after_rebuild, - Some(checkpoint_height), - "After index rebuild, hash should still map to blockchain height 1,100,000" - ); - - // Verify header can still be retrieved by blockchain height after reload - let header_after_reload = storage2.get_header(checkpoint_height).await?; - assert!( - header_after_reload.is_some(), - "Header at base blockchain height should exist after reload" - ); - assert_eq!(header_after_reload.unwrap(), headers[0]); - - Ok(()) - } - - #[tokio::test] - async fn test_shutdown_flushes_index() -> Result<(), Box> { - let temp_dir = TempDir::new()?; - let base_path = temp_dir.path().to_path_buf(); - let headers = build_headers(11_000); - let last_hash = headers.last().unwrap().block_hash(); - - { - let mut storage = DiskStorageManager::new(base_path.clone()).await?; - - storage.store_headers(&headers[..10_000]).await?; - storage.save_dirty().await; - - storage.store_headers(&headers[10_000..]).await?; - storage.shutdown().await; - } - - let storage = DiskStorageManager::new(base_path).await?; - let height = storage.get_header_height_by_hash(&last_hash).await?; - assert_eq!(height, Some(10_999)); - - Ok(()) - } -} diff --git a/dash-spv/src/storage/transactions.rs b/dash-spv/src/storage/transactions.rs new file mode 100644 index 000000000..480273c4c --- /dev/null +++ b/dash-spv/src/storage/transactions.rs @@ -0,0 +1,96 @@ +use std::{collections::HashMap, path::PathBuf}; + +use async_trait::async_trait; +use dashcore::Txid; + +use crate::{ + error::StorageResult, + storage::PersistentStorage, + types::{MempoolState, UnconfirmedTransaction}, +}; + +#[async_trait] +pub trait TransactionStorage { + async fn store_mempool_transaction( + &mut self, + txid: &Txid, + tx: &UnconfirmedTransaction, + ) -> StorageResult<()>; + + async fn remove_mempool_transaction(&mut self, txid: &Txid) -> StorageResult<()>; + + async fn get_mempool_transaction( + &self, + txid: &Txid, + ) -> StorageResult>; + + async fn get_all_mempool_transactions( + &self, + ) -> StorageResult>; + + async fn store_mempool_state(&mut self, state: &MempoolState) -> StorageResult<()>; + + async fn load_mempool_state(&self) -> StorageResult>; +} + +pub struct PersistentTransactionStorage { + mempool_transactions: HashMap, + mempool_state: Option, +} + +#[async_trait] +impl PersistentStorage for PersistentTransactionStorage { + async fn open(_storage_path: impl Into + Send) -> StorageResult { + let mempool_transactions = HashMap::new(); + let mempool_state = None; + + Ok(PersistentTransactionStorage { + mempool_transactions, + mempool_state, + }) + } + + async fn persist(&mut self, _storage_path: impl Into + Send) -> StorageResult<()> { + // This data is not currently being persisted + Ok(()) + } +} + +#[async_trait] +impl TransactionStorage for PersistentTransactionStorage { + async fn store_mempool_transaction( + &mut self, + txid: &Txid, + tx: &UnconfirmedTransaction, + ) -> StorageResult<()> { + self.mempool_transactions.insert(*txid, tx.clone()); + Ok(()) + } + + async fn remove_mempool_transaction(&mut self, txid: &Txid) -> StorageResult<()> { + self.mempool_transactions.remove(txid); + Ok(()) + } + + async fn get_mempool_transaction( + &self, + txid: &Txid, + ) -> StorageResult> { + Ok(self.mempool_transactions.get(txid).cloned()) + } + + async fn get_all_mempool_transactions( + &self, + ) -> StorageResult> { + Ok(self.mempool_transactions.clone()) + } + + async fn store_mempool_state(&mut self, state: &MempoolState) -> StorageResult<()> { + self.mempool_state = Some(state.clone()); + Ok(()) + } + + async fn load_mempool_state(&self) -> StorageResult> { + Ok(self.mempool_state.clone()) + } +} diff --git a/dash-spv/src/sync/filters/headers.rs b/dash-spv/src/sync/filters/headers.rs index 40ce1622f..f1f165949 100644 --- a/dash-spv/src/sync/filters/headers.rs +++ b/dash-spv/src/sync/filters/headers.rs @@ -82,13 +82,9 @@ impl SyncResult<(u32, u32, u32)> { - let header_tip_height = storage - .get_tip_height() - .await - .map_err(|e| SyncError::Storage(format!("Failed to get header tip height: {}", e)))? - .ok_or_else(|| { - SyncError::Storage("No headers available for filter sync".to_string()) - })?; + let header_tip_height = storage.get_tip_height().await.ok_or_else(|| { + SyncError::Storage("No headers available for filter sync".to_string()) + })?; let stop_height = self .find_height_for_block_hash(&cf_headers.stop_hash, storage, 0, header_tip_height) @@ -188,13 +184,9 @@ impl= header_tip_height { tracing::info!("Filter headers already synced to header tip"); @@ -773,11 +761,7 @@ impl header_tip) } diff --git a/dash-spv/src/sync/filters/retry.rs b/dash-spv/src/sync/filters/retry.rs index f998066d0..fe7103792 100644 --- a/dash-spv/src/sync/filters/retry.rs +++ b/dash-spv/src/sync/filters/retry.rs @@ -35,13 +35,9 @@ impl { config: ClientConfig, tip_manager: ChainTipManager, checkpoint_manager: CheckpointManager, - reorg_config: ReorgConfig, chain_state: Arc>, // WalletState removed - wallet functionality is now handled externally headers2_state: Headers2StateManager, - total_headers_synced: u32, syncing_headers: bool, last_sync_progress: std::time::Instant, headers2_failed: bool, @@ -84,11 +82,9 @@ impl SyncResult { - let start_time = std::time::Instant::now(); - let mut loaded_count = 0; - let mut tip_height = 0; + pub async fn load_headers_from_storage(&mut self, storage: &S) { // First, try to load the persisted chain state which may contain sync_base_height if let Ok(Some(stored_chain_state)) = storage.load_chain_state().await { tracing::info!( @@ -111,26 +104,11 @@ impl {}, chain_state.headers.len()={}", - batch_size, - previous_total, - self.total_headers_synced, - self.chain_state.read().await.headers.len() + "Header sync progress: processed {} headers in batch, total_headers_synced: {}", + headers.len() as u32, + storage.get_stored_headers_len().await, ); // Update chain tip manager with the last header in the batch if let Some(last_header) = headers.last() { - let final_height = self.chain_state.read().await.get_height(); + let final_height = storage.get_tip_height().await.unwrap_or(0); let chain_work = ChainWork::from_height_and_header(final_height, last_header); let tip = ChainTip::new(*last_header, final_height, chain_work); self.tip_manager @@ -290,7 +244,7 @@ impl, + storage: &S, ) -> SyncResult<()> { let block_locator = match base_hash { Some(hash) => vec![hash], None => { // Check if we're syncing from a checkpoint - if self.is_synced_from_checkpoint() - && !self.chain_state.read().await.headers.is_empty() - { + if self.is_synced_from_checkpoint() && storage.get_stored_headers_len().await > 0 { + let first_height = storage + .get_start_height() + .await + .ok_or(SyncError::Storage("Failed to get start height".to_string()))?; + let checkpoint_header = storage + .get_header(first_height) + .await + .map_err(|e| { + SyncError::Storage(format!("Failed to get first header: {}", e)) + })? + .ok_or(SyncError::Storage( + "Storage didn't return first header".to_string(), + ))?; + // Use the checkpoint hash from chain state - let checkpoint_hash = self.chain_state.read().await.headers[0].block_hash(); + let checkpoint_hash = checkpoint_header.block_hash(); tracing::info!( "📍 No base_hash provided but syncing from checkpoint at height {}. Using checkpoint hash: {}", self.get_sync_base_height(), @@ -348,7 +315,7 @@ impl { // No headers in storage - check if we're syncing from a checkpoint - if self.is_synced_from_checkpoint() - && !self.chain_state.read().await.headers.is_empty() - { - // We're syncing from a checkpoint and have the checkpoint header - let checkpoint_header = &self.chain_state.read().await.headers[0]; + if self.is_synced_from_checkpoint() && storage.get_stored_headers_len().await > 0 { let checkpoint_hash = checkpoint_header.block_hash(); tracing::info!( "No headers in storage but syncing from checkpoint at height {}. Using checkpoint hash: {}", @@ -545,8 +520,12 @@ impl 0 { let hash = checkpoint_header.block_hash(); tracing::info!("Using checkpoint hash for height {}: {}", height, hash); Some(hash) @@ -639,7 +617,7 @@ impl { // No headers in storage - check if we're syncing from a checkpoint if self.is_synced_from_checkpoint() { // Use the checkpoint hash from chain state - if !self.chain_state.read().await.headers.is_empty() { - let checkpoint_hash = - self.chain_state.read().await.headers[0].block_hash(); + if storage.get_stored_headers_len().await > 0 { + let checkpoint_hash = checkpoint_header.block_hash(); tracing::info!( "Using checkpoint hash for recovery: {} (chain state has {} headers, first header time: {})", checkpoint_hash, - self.chain_state.read().await.headers.len(), - self.chain_state.read().await.headers[0].time + storage.get_stored_headers_len().await, + checkpoint_header.time ); Some(checkpoint_hash) } else { @@ -720,7 +704,7 @@ impl u32 { - // Always use total_headers_synced which tracks the absolute blockchain height - self.total_headers_synced - } - - /// Get the tip hash - pub async fn get_tip_hash(&self) -> Option { - self.chain_state.read().await.tip_hash() + pub async fn get_chain_height(&self, storage: &S) -> u32 { + storage.get_tip_height().await.unwrap_or(0) } /// Get the sync base height (used when syncing from checkpoint) @@ -865,9 +839,7 @@ impl SyncResult { + pub async fn load_headers_from_storage(&mut self, storage: &S) { // Load headers into the header sync manager - let loaded_count = self.header_sync.load_headers_from_storage(storage).await?; - - if loaded_count > 0 { - tracing::info!("Sequential sync manager loaded {} headers from storage", loaded_count); - - // Update the current phase if we have headers - // This helps the sync manager understand where to resume from - if matches!(self.current_phase, SyncPhase::Idle) { - // We have headers but haven't started sync yet - // The phase will be properly set when start_sync is called - tracing::debug!("Headers loaded but sync not started yet"); - } - } - - Ok(loaded_count) + self.header_sync.load_headers_from_storage(storage).await; } /// Get the earliest wallet birth height hint for the configured network, if available. @@ -234,7 +220,7 @@ impl< let base_hash = self.get_base_hash_from_storage(storage).await?; // Request headers starting from our current tip - self.header_sync.request_headers(network, base_hash).await?; + self.header_sync.request_headers(network, base_hash, storage).await?; } else { // Otherwise start sync normally self.header_sync.start_sync(network, storage).await?; @@ -265,10 +251,7 @@ impl< &self, storage: &S, ) -> SyncResult> { - let current_tip_height = storage - .get_tip_height() - .await - .map_err(|e| SyncError::Storage(format!("Failed to get tip height: {}", e)))?; + let current_tip_height = storage.get_tip_height().await; let base_hash = match current_tip_height { None => None, @@ -284,11 +267,6 @@ impl< Ok(base_hash) } - /// Get the current chain height from the header sync manager - pub fn get_chain_height(&self) -> u32 { - self.header_sync.get_chain_height() - } - /// Get current sync progress template. /// /// **IMPORTANT**: This method returns a TEMPLATE ONLY. It does NOT query storage or network @@ -378,8 +356,8 @@ impl< } /// Update the chain state (used for checkpoint sync initialization) - pub fn update_chain_state_cache(&mut self, sync_base_height: u32, headers_len: u32) { - self.header_sync.update_cached_from_state_snapshot(sync_base_height, headers_len); + pub fn update_chain_state_cache(&mut self, sync_base_height: u32) { + self.header_sync.update_cached_from_state_snapshot(sync_base_height); } /// Get reference to the masternode engine if available. @@ -401,22 +379,7 @@ impl< } /// Get the actual blockchain height from storage height, accounting for checkpoints - pub(super) async fn get_blockchain_height_from_storage(&self, storage: &S) -> SyncResult { - let storage_height = storage - .get_tip_height() - .await - .map_err(|e| { - crate::error::SyncError::Storage(format!("Failed to get tip height: {}", e)) - })? - .unwrap_or(0); - - // Check if we're syncing from a checkpoint - if self.header_sync.is_synced_from_checkpoint() { - // For checkpoint sync, blockchain height = sync_base_height + storage_height - Ok(self.header_sync.get_sync_base_height() + storage_height) - } else { - // Normal sync: storage height IS the blockchain height - Ok(storage_height) - } + pub(super) async fn get_blockchain_height_from_storage(&self, storage: &S) -> u32 { + storage.get_tip_height().await.unwrap_or(0) } } diff --git a/dash-spv/src/sync/masternodes/manager.rs b/dash-spv/src/sync/masternodes/manager.rs index 065f26dbc..c5eebcbf0 100644 --- a/dash-spv/src/sync/masternodes/manager.rs +++ b/dash-spv/src/sync/masternodes/manager.rs @@ -391,11 +391,7 @@ impl { + Some(tip_height) => { let state = crate::storage::MasternodeState { last_height: tip_height, engine_state: Vec::new(), @@ -477,17 +473,11 @@ impl { + None => { tracing::warn!( "⚠️ Storage returned no tip height when persisting masternode state" ); } - Err(e) => { - tracing::warn!( - "⚠️ Failed to read tip height to persist masternode state: {}", - e - ); - } } } } @@ -518,13 +508,7 @@ impl { + Some(tip_height) => { let state = crate::storage::MasternodeState { last_height: tip_height, engine_state: Vec::new(), @@ -688,17 +672,11 @@ impl { + None => { tracing::warn!( "⚠️ Storage returned no tip height when persisting masternode state" ); } - Err(e) => { - tracing::warn!( - "⚠️ Failed to read tip height to persist masternode state: {}", - e - ); - } } } else { tracing::info!( diff --git a/dash-spv/src/sync/message_handlers.rs b/dash-spv/src/sync/message_handlers.rs index 027317c5c..e4479ad24 100644 --- a/dash-spv/src/sync/message_handlers.rs +++ b/dash-spv/src/sync/message_handlers.rs @@ -345,7 +345,7 @@ impl< storage: &mut S, transition_reason: &str, ) -> SyncResult<()> { - let blockchain_height = self.get_blockchain_height_from_storage(storage).await.unwrap_or(0); + let blockchain_height = self.get_blockchain_height_from_storage(storage).await; let should_transition = if let SyncPhase::DownloadingHeaders { current_height, diff --git a/dash-spv/src/sync/phase_execution.rs b/dash-spv/src/sync/phase_execution.rs index 77758d833..252f58ba6 100644 --- a/dash-spv/src/sync/phase_execution.rs +++ b/dash-spv/src/sync/phase_execution.rs @@ -32,7 +32,7 @@ impl< // Already prepared, just send the initial request let base_hash = self.get_base_hash_from_storage(storage).await?; - self.header_sync.request_headers(network, base_hash).await?; + self.header_sync.request_headers(network, base_hash, storage).await?; } else { // Not prepared yet, start sync normally self.header_sync.start_sync(network, storage).await?; @@ -43,47 +43,6 @@ impl< .. } => { tracing::info!("📥 Starting masternode list download phase"); - // Get the effective chain height from header sync which accounts for checkpoint base - let effective_height = self.header_sync.get_chain_height(); - let sync_base_height = self.header_sync.get_sync_base_height(); - - // Also get the actual tip height to verify (blockchain height) - let storage_tip = storage - .get_tip_height() - .await - .map_err(|e| SyncError::Storage(format!("Failed to get storage tip: {}", e)))?; - - // Debug: Check chain state - let chain_state = storage.load_chain_state().await.map_err(|e| { - SyncError::Storage(format!("Failed to load chain state: {}", e)) - })?; - let chain_state_height = chain_state.as_ref().map(|s| s.get_height()).unwrap_or(0); - - tracing::info!( - "Starting masternode sync: effective_height={}, sync_base={}, storage_tip={:?}, chain_state_height={}, expected_storage_index={}", - effective_height, - sync_base_height, - storage_tip, - chain_state_height, - if sync_base_height > 0 { effective_height.saturating_sub(sync_base_height) } else { effective_height } - ); - - // Use the minimum of effective height and what's actually in storage - let _safe_height = if let Some(tip) = storage_tip { - let storage_based_height = tip; - if storage_based_height < effective_height { - tracing::warn!( - "Chain state height {} exceeds storage height {}, using storage height", - effective_height, - storage_based_height - ); - storage_based_height - } else { - effective_height - } - } else { - effective_height - }; // Start masternode sync (unified processing) match self.masternode_sync.start_sync(network, storage).await { diff --git a/dash-spv/src/sync/transitions.rs b/dash-spv/src/sync/transitions.rs index 505e2a541..e8ce58e93 100644 --- a/dash-spv/src/sync/transitions.rs +++ b/dash-spv/src/sync/transitions.rs @@ -177,11 +177,7 @@ impl TransitionManager { match current_phase { SyncPhase::Idle => { // Always start with headers - let start_height = storage - .get_tip_height() - .await - .map_err(|e| SyncError::Storage(format!("Failed to get tip height: {}", e)))? - .unwrap_or(0); + let start_height = storage.get_tip_height().await.unwrap_or(0); Ok(Some(SyncPhase::DownloadingHeaders { start_time: Instant::now(), @@ -199,13 +195,7 @@ impl TransitionManager { .. } => { if self.config.enable_masternodes { - let header_tip = storage - .get_tip_height() - .await - .map_err(|e| { - SyncError::Storage(format!("Failed to get header tip: {}", e)) - })? - .unwrap_or(0); + let header_tip = storage.get_tip_height().await.unwrap_or(0); let mn_height = match storage.load_masternode_state().await { Ok(Some(state)) => state.last_height, @@ -417,11 +407,7 @@ impl TransitionManager { &self, storage: &S, ) -> SyncResult> { - let header_tip = storage - .get_tip_height() - .await - .map_err(|e| SyncError::Storage(format!("Failed to get header tip: {}", e)))? - .unwrap_or(0); + let header_tip = storage.get_tip_height().await.unwrap_or(0); let filter_tip = storage .get_filter_tip_height() diff --git a/dash-spv/src/types.rs b/dash-spv/src/types.rs index ec217b421..47dff94cd 100644 --- a/dash-spv/src/types.rs +++ b/dash-spv/src/types.rs @@ -245,21 +245,8 @@ impl DetailedSyncProgress { /// /// ## Checkpoint Sync /// When syncing from a checkpoint (not genesis), `sync_base_height` is non-zero. -/// The `headers` vector contains headers starting from the checkpoint, not from genesis. -/// Use `tip_height()` to get the absolute blockchain height. -/// -/// ## Memory Considerations -/// - headers: ~80 bytes per header -/// - filter_headers: 32 bytes per filter header -/// - At 2M blocks: ~160MB for headers, ~64MB for filter headers #[derive(Clone, Default)] pub struct ChainState { - /// Block headers indexed by height. - pub headers: Vec, - - /// Filter headers indexed by height. - pub filter_headers: Vec, - /// Last ChainLock height. pub last_chainlock_height: Option, @@ -289,28 +276,6 @@ impl ChainState { pub fn new_for_network(network: Network) -> Self { let mut state = Self::default(); - // Initialize with genesis block - let genesis_header = match network { - Network::Dash => { - // Use known genesis for mainnet - dashcore::blockdata::constants::genesis_block(network).header - } - Network::Testnet => { - // Use known genesis for testnet - dashcore::blockdata::constants::genesis_block(network).header - } - _ => { - // For other networks, use the existing genesis block function - dashcore::blockdata::constants::genesis_block(network).header - } - }; - - // Add genesis header to the chain state - state.headers.push(genesis_header); - - tracing::debug!("Initialized ChainState with genesis block - network: {:?}, hash: {}, headers_count: {}", - network, genesis_header.block_hash(), state.headers.len()); - // Initialize masternode engine for the network let mut engine = MasternodeListEngine::default_for_network(network); if let Some(genesis_hash) = network.known_genesis_block_hash() { @@ -329,74 +294,6 @@ impl ChainState { self.sync_base_height > 0 } - /// Get the current tip height. - pub fn tip_height(&self) -> u32 { - if self.headers.is_empty() { - // When headers is empty, sync_base_height represents our current position - // This happens when we're syncing from a checkpoint but haven't received headers yet - self.sync_base_height - } else { - // Normal case: base + number of headers - 1 - self.sync_base_height + self.headers.len() as u32 - 1 - } - } - - /// Get the current tip hash. - pub fn tip_hash(&self) -> Option { - self.headers.last().map(|h| h.block_hash()) - } - - /// Get header at the given height. - pub fn header_at_height(&self, height: u32) -> Option<&BlockHeader> { - if height < self.sync_base_height { - return None; // Height is before our sync base - } - let index = (height - self.sync_base_height) as usize; - self.headers.get(index) - } - - /// Get filter header at the given height. - pub fn filter_header_at_height(&self, height: u32) -> Option<&FilterHeader> { - if height < self.sync_base_height { - return None; // Height is before our sync base - } - let index = (height - self.sync_base_height) as usize; - self.filter_headers.get(index) - } - - /// Add headers to the chain. - pub fn add_headers(&mut self, headers: Vec) { - self.headers.extend(headers); - } - - /// Add filter headers to the chain. - pub fn add_filter_headers(&mut self, filter_headers: Vec) { - if let Some(last) = filter_headers.last() { - self.current_filter_tip = Some(*last); - } - self.filter_headers.extend(filter_headers); - } - - /// Get the tip header - pub fn get_tip_header(&self) -> Option { - self.headers.last().copied() - } - - /// Get the height - pub fn get_height(&self) -> u32 { - self.tip_height() - } - - /// Add a single header - pub fn add_header(&mut self, header: BlockHeader) { - self.headers.push(header); - } - - /// Remove the tip header (for reorgs) - pub fn remove_tip(&mut self) -> Option { - self.headers.pop() - } - /// Update chain lock status pub fn update_chain_lock(&mut self, height: u32, hash: BlockHash) { // Only update if this is a newer chain lock @@ -429,26 +326,6 @@ impl ChainState { Some(Vec::new()) } - /// Calculate the total chain work up to the tip - pub fn calculate_chain_work(&self) -> Option { - use crate::chain::chain_work::ChainWork; - - // If we have no headers, return None - if self.headers.is_empty() { - return None; - } - - // Start with zero work - let mut total_work = ChainWork::zero(); - - // Add work from each header - for header in &self.headers { - total_work = total_work.add_header(header); - } - - Some(total_work) - } - /// Initialize chain state from a checkpoint. pub fn init_from_checkpoint( &mut self, @@ -456,16 +333,9 @@ impl ChainState { checkpoint_header: BlockHeader, network: Network, ) { - // Clear any existing headers - self.headers.clear(); - self.filter_headers.clear(); - // Set sync base height to checkpoint self.sync_base_height = checkpoint_height; - // Add the checkpoint header as our first header - self.headers.push(checkpoint_header); - tracing::info!( "Initialized ChainState from checkpoint - height: {}, hash: {}, network: {:?}", checkpoint_height, @@ -497,8 +367,6 @@ impl ChainState { impl std::fmt::Debug for ChainState { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("ChainState") - .field("headers", &format!("{} headers", self.headers.len())) - .field("filter_headers", &format!("{} filter headers", self.filter_headers.len())) .field("last_chainlock_height", &self.last_chainlock_height) .field("last_chainlock_hash", &self.last_chainlock_hash) .field("current_filter_tip", &self.current_filter_tip) diff --git a/dash-spv/tests/edge_case_filter_sync_test.rs b/dash-spv/tests/edge_case_filter_sync_test.rs index 370cf88d8..c5d4760b5 100644 --- a/dash-spv/tests/edge_case_filter_sync_test.rs +++ b/dash-spv/tests/edge_case_filter_sync_test.rs @@ -144,7 +144,7 @@ async fn test_filter_sync_at_tip_edge_case() { storage.store_filter_headers(&filter_headers).await.unwrap(); // Verify initial state - let tip_height = storage.get_tip_height().await.unwrap().unwrap(); + let tip_height = storage.get_tip_height().await.unwrap(); let filter_tip_height = storage.get_filter_tip_height().await.unwrap().unwrap(); assert_eq!(tip_height, height - 1); // 0-indexed assert_eq!(filter_tip_height, height - 1); // 0-indexed diff --git a/dash-spv/tests/filter_header_verification_test.rs b/dash-spv/tests/filter_header_verification_test.rs index 0cb6a5fa5..e8753411e 100644 --- a/dash-spv/tests/filter_header_verification_test.rs +++ b/dash-spv/tests/filter_header_verification_test.rs @@ -197,7 +197,7 @@ async fn test_filter_header_verification_failure_reproduction() { let initial_headers = create_test_headers_range(1000, 5000); // Headers 1000-4999 storage.store_headers(&initial_headers).await.expect("Failed to store initial headers"); - let tip_height = storage.get_tip_height().await.unwrap().unwrap(); + let tip_height = storage.get_tip_height().await.unwrap(); println!("Initial header chain stored: tip height = {}", tip_height); assert_eq!(tip_height, 4999); @@ -361,7 +361,7 @@ async fn test_overlapping_batches_from_different_peers() { let initial_headers = create_test_headers_range(1, 3000); // Headers 1-2999 storage.store_headers(&initial_headers).await.expect("Failed to store initial headers"); - let tip_height = storage.get_tip_height().await.unwrap().unwrap(); + let tip_height = storage.get_tip_height().await.unwrap(); println!("Header chain stored: tip height = {}", tip_height); assert_eq!(tip_height, 2999); diff --git a/dash-spv/tests/handshake_test.rs b/dash-spv/tests/handshake_test.rs index d8cb6579f..0f5125992 100644 --- a/dash-spv/tests/handshake_test.rs +++ b/dash-spv/tests/handshake_test.rs @@ -72,7 +72,8 @@ async fn test_handshake_timeout() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_network_manager_creation() { - let config = ClientConfig::new(Network::Dash); + let temp_dir = tempfile::TempDir::new().expect("Failed to create temporary directory"); + let config = ClientConfig::new(Network::Dash).with_storage_path(temp_dir.path().to_path_buf()); let network = PeerNetworkManager::new(&config).await; assert!(network.is_ok(), "Network manager creation should succeed"); diff --git a/dash-spv/tests/header_sync_test.rs b/dash-spv/tests/header_sync_test.rs index 2da0fdde4..668426dd5 100644 --- a/dash-spv/tests/header_sync_test.rs +++ b/dash-spv/tests/header_sync_test.rs @@ -5,7 +5,7 @@ use std::time::Duration; use dash_spv::{ client::{ClientConfig, DashSpvClient}, network::PeerNetworkManager, - storage::{DiskStorageManager, StorageManager}, + storage::{BlockHeaderStorage, ChainStateStorage, DiskStorageManager}, sync::{HeaderSyncManager, ReorgConfig}, types::{ChainState, ValidationMode}, }; @@ -25,12 +25,12 @@ async fn test_basic_header_sync_from_genesis() { // Create fresh storage starting from empty state let mut storage = - DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path().into()) + DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path()) .await .expect("Failed to create tmp storage"); // Verify empty initial state - assert_eq!(storage.get_tip_height().await.unwrap(), None); + assert_eq!(storage.get_tip_height().await, None); // Create test chain state for mainnet let chain_state = ChainState::new_for_network(Network::Dash); @@ -48,7 +48,7 @@ async fn test_header_sync_continuation() { let _ = env_logger::try_init(); let mut storage = - DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path().into()) + DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path()) .await .expect("Failed to create tmp storage"); @@ -57,7 +57,7 @@ async fn test_header_sync_continuation() { storage.store_headers(&existing_headers).await.expect("Failed to store existing headers"); // Verify we have the expected tip - assert_eq!(storage.get_tip_height().await.unwrap(), Some(99)); + assert_eq!(storage.get_tip_height().await, Some(99)); // Simulate adding more headers (continuation) let continuation_headers = create_test_header_chain_from(100, 50); @@ -67,7 +67,7 @@ async fn test_header_sync_continuation() { .expect("Failed to store continuation headers"); // Verify the chain extended properly - assert_eq!(storage.get_tip_height().await.unwrap(), Some(149)); + assert_eq!(storage.get_tip_height().await, Some(149)); // Verify continuity by checking some headers for height in 95..105 { @@ -83,7 +83,7 @@ async fn test_header_batch_processing() { let _ = env_logger::try_init(); let mut storage = - DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path().into()) + DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path()) .await .expect("Failed to create tmp storage"); @@ -102,7 +102,7 @@ async fn test_header_batch_processing() { let expected_tip = batch_end - 1; assert_eq!( - storage.get_tip_height().await.unwrap(), + storage.get_tip_height().await, Some(expected_tip as u32), "Tip height should be {} after batch {}-{}", expected_tip, @@ -112,7 +112,7 @@ async fn test_header_batch_processing() { } // Verify total count - let final_tip = storage.get_tip_height().await.unwrap(); + let final_tip = storage.get_tip_height().await; assert_eq!(final_tip, Some((total_headers - 1) as u32)); // Verify we can retrieve headers from different parts of the chain @@ -133,24 +133,24 @@ async fn test_header_sync_edge_cases() { let _ = env_logger::try_init(); let mut storage = - DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path().into()) + DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path()) .await .expect("Failed to create tmp storage"); // Test 1: Empty header batch let empty_headers: Vec = vec![]; storage.store_headers(&empty_headers).await.expect("Should handle empty header batch"); - assert_eq!(storage.get_tip_height().await.unwrap(), None); + assert_eq!(storage.get_tip_height().await, None); // Test 2: Single header let single_header = create_test_header_chain(1); storage.store_headers(&single_header).await.expect("Should handle single header"); - assert_eq!(storage.get_tip_height().await.unwrap(), Some(0)); + assert_eq!(storage.get_tip_height().await, Some(0)); // Test 3: Large batch let large_batch = create_test_header_chain_from(1, 5000); storage.store_headers(&large_batch).await.expect("Should handle large header batch"); - assert_eq!(storage.get_tip_height().await.unwrap(), Some(5000)); + assert_eq!(storage.get_tip_height().await, Some(5000)); // Test 4: Out-of-order access let header_4500 = storage.get_header(4500).await.unwrap(); @@ -171,7 +171,7 @@ async fn test_header_chain_validation() { let _ = env_logger::try_init(); let mut storage = - DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path().into()) + DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path()) .await .expect("Failed to create tmp storage"); @@ -191,7 +191,7 @@ async fn test_header_chain_validation() { storage.store_headers(&chain).await.expect("Failed to store header chain"); // Verify the chain is stored correctly - assert_eq!(storage.get_tip_height().await.unwrap(), Some(9)); + assert_eq!(storage.get_tip_height().await, Some(9)); // Verify we can retrieve the entire chain let retrieved_chain = storage.load_headers(0..10).await.unwrap(); @@ -209,7 +209,7 @@ async fn test_header_sync_performance() { let _ = env_logger::try_init(); let mut storage = - DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path().into()) + DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path()) .await .expect("Failed to create tmp storage"); @@ -229,7 +229,7 @@ async fn test_header_sync_performance() { let sync_duration = start_time.elapsed(); // Verify sync completed correctly - assert_eq!(storage.get_tip_height().await.unwrap(), Some((total_headers - 1) as u32)); + assert_eq!(storage.get_tip_height().await, Some((total_headers - 1) as u32)); // Performance assertions (these are rough benchmarks) assert!( @@ -261,9 +261,11 @@ async fn test_header_sync_performance() { #[tokio::test] async fn test_header_sync_with_client_integration() { let _ = env_logger::try_init(); + let temp_dir = tempfile::TempDir::new().expect("Failed to create temporary directory"); // Test header sync integration with the full client let config = ClientConfig::new(Network::Dash) + .with_storage_path(temp_dir.path().to_path_buf()) .with_validation_mode(ValidationMode::Basic) .with_connection_timeout(Duration::from_secs(10)); @@ -273,7 +275,7 @@ async fn test_header_sync_with_client_integration() { // Create storage manager let storage_manager = - DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path().into()) + DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path()) .await .expect("Failed to create tmp storage"); @@ -329,7 +331,7 @@ async fn test_header_storage_consistency() { let _ = env_logger::try_init(); let mut storage = - DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path().into()) + DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path()) .await .expect("Failed to create tmp storage"); @@ -338,7 +340,7 @@ async fn test_header_storage_consistency() { storage.store_headers(&headers).await.expect("Failed to store headers"); // Test consistency: get tip and verify it matches the last stored header - let tip_height = storage.get_tip_height().await.unwrap().unwrap(); + let tip_height = storage.get_tip_height().await.unwrap(); let tip_header = storage.get_header(tip_height).await.unwrap().unwrap(); let expected_tip = &headers[headers.len() - 1]; @@ -358,48 +360,6 @@ async fn test_header_storage_consistency() { info!("Header storage consistency test completed"); } -#[test_case(0, 0 ; "genesis_0_blocks")] -#[test_case(0, 1 ; "genesis_1_block")] -#[test_case(0, 60000 ; "genesis_60000_blocks")] -#[test_case(100, 0 ; "checkpoint_0_blocks")] -#[test_case(170000, 1 ; "checkpoint_1_block")] -#[test_case(12345, 60000 ; "checkpoint_60000_blocks")] -#[tokio::test] -async fn test_load_headers_from_storage(sync_base_height: u32, header_count: usize) { - // Setup: Create storage with 100 headers - let temp_dir = TempDir::new().expect("Failed to create temp dir"); - let mut storage = DiskStorageManager::new(temp_dir.path().to_path_buf()) - .await - .expect("Failed to create storage"); - - let test_headers = create_test_header_chain(header_count); - - // Store chain state - let mut chain_state = ChainState::new_for_network(Network::Dash); - chain_state.sync_base_height = sync_base_height; - chain_state.headers = test_headers.clone(); - storage.store_chain_state(&chain_state).await.expect("Failed to store chain state"); - - // Create HeaderSyncManager and load headers - let config = ClientConfig::new(Network::Dash); - let chain_state = Arc::new(RwLock::new(ChainState::new_for_network(Network::Dash))); - let mut header_sync = HeaderSyncManager::::new( - &config, - ReorgConfig::default(), - chain_state.clone(), - ) - .expect("Failed to create HeaderSyncManager"); - - // Load headers from storage - let loaded_count = - header_sync.load_headers_from_storage(&storage).await.expect("Failed to load headers"); - - let cs = chain_state.read().await; - - assert_eq!(loaded_count as usize, header_count, "Loaded count mismatch"); - assert_eq!(header_count, cs.headers.len(), "Chain state count mismatch"); -} - #[test_case(0, 1 ; "genesis_1_block")] #[test_case(0, 70000 ; "genesis_70000_blocks")] #[test_case(5000, 1 ; "checkpoint_1_block")] @@ -407,9 +367,8 @@ async fn test_load_headers_from_storage(sync_base_height: u32, header_count: usi #[tokio::test] async fn test_prepare_sync(sync_base_height: u32, header_count: usize) { let temp_dir = TempDir::new().expect("Failed to create temp dir"); - let mut storage = DiskStorageManager::new(temp_dir.path().to_path_buf()) - .await - .expect("Failed to create storage"); + let mut storage = + DiskStorageManager::new(temp_dir.path()).await.expect("Failed to create storage"); let headers = create_test_header_chain(header_count); let expected_tip_hash = headers.last().unwrap().block_hash(); @@ -417,8 +376,8 @@ async fn test_prepare_sync(sync_base_height: u32, header_count: usize) { // Create and store chain state let mut chain_state = ChainState::new_for_network(Network::Dash); chain_state.sync_base_height = sync_base_height; - chain_state.headers = headers; storage.store_chain_state(&chain_state).await.expect("Failed to store chain state"); + storage.store_headers(&headers).await.expect("Failed to store headers"); // Create HeaderSyncManager and load from storage let config = ClientConfig::new(Network::Dash); diff --git a/dash-spv/tests/integration_real_node_test.rs b/dash-spv/tests/integration_real_node_test.rs index 8979da6f6..e155a16f9 100644 --- a/dash-spv/tests/integration_real_node_test.rs +++ b/dash-spv/tests/integration_real_node_test.rs @@ -6,10 +6,11 @@ use std::net::SocketAddr; use std::time::{Duration, Instant}; +use dash_spv::storage::BlockHeaderStorage; use dash_spv::{ client::{ClientConfig, DashSpvClient}, network::{NetworkManager, PeerNetworkManager}, - storage::{DiskStorageManager, StorageManager}, + storage::DiskStorageManager, types::ValidationMode, }; use dashcore::Network; @@ -36,8 +37,7 @@ async fn create_test_client( // Create storage manager let storage_manager = - DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path().into()) - .await?; + DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path()).await?; // Create wallet manager let wallet = Arc::new(RwLock::new(WalletManager::::new())); @@ -200,13 +200,12 @@ async fn test_real_header_sync_up_to_10k() { config.peers.push(peer_addr); // Create fresh storage and client - let storage = - DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path().into()) - .await - .expect("Failed to create tmp storage"); + let storage = DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path()) + .await + .expect("Failed to create tmp storage"); // Verify starting from empty state - assert_eq!(storage.get_tip_height().await.unwrap(), None); + assert_eq!(storage.get_tip_height().await, None); let mut client = create_test_client(config.clone()).await.expect("Failed to create SPV client"); @@ -414,10 +413,9 @@ async fn test_real_header_chain_continuity() { config.peers.push(peer_addr); - let storage = - DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path().into()) - .await - .expect("Failed to create tmp storage"); + let storage = DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path()) + .await + .expect("Failed to create tmp storage"); let mut client = create_test_client(config).await.expect("Failed to create SPV client"); diff --git a/dash-spv/tests/peer_test.rs b/dash-spv/tests/peer_test.rs index 0ee6926ea..54da1d554 100644 --- a/dash-spv/tests/peer_test.rs +++ b/dash-spv/tests/peer_test.rs @@ -139,40 +139,6 @@ async fn test_peer_persistence() { } } -#[tokio::test] -async fn test_peer_disconnection() { - let _ = env_logger::builder().is_test(true).try_init(); - - let temp_dir = TempDir::new().unwrap(); - let temp_path = temp_dir.path().to_path_buf(); - let mut config = create_test_config(Network::Regtest, Some(temp_dir)); - - // Add manual test peers (would need actual regtest nodes running) - config.peers = vec!["127.0.0.1:19899".parse().unwrap(), "127.0.0.1:19898".parse().unwrap()]; - - // Create network manager - let network_manager = PeerNetworkManager::new(&config).await.unwrap(); - - // Create storage manager - let storage_manager = DiskStorageManager::new(temp_path).await.unwrap(); - - // Create wallet manager - let wallet = Arc::new(RwLock::new(WalletManager::::new())); - - let client = - DashSpvClient::new(config, network_manager, storage_manager, wallet).await.unwrap(); - - // Note: This test would require actual regtest nodes running - // For now, we just test that the API works - let test_addr: SocketAddr = "127.0.0.1:19899".parse().unwrap(); - - // Try to disconnect (will fail if not connected, but tests the API) - match client.disconnect_peer(&test_addr, "Test disconnection").await { - Ok(_) => println!("Disconnected peer {}", test_addr), - Err(e) => println!("Expected error disconnecting non-existent peer: {}", e), - } -} - #[tokio::test] async fn test_max_peer_limit() { use dash_spv::network::constants::MAX_PEERS; @@ -190,7 +156,7 @@ async fn test_max_peer_limit() { // Create storage manager let storage_manager = - DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path().into()) + DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path()) .await .expect("Failed to create tmp storage"); diff --git a/dash-spv/tests/reverse_index_test.rs b/dash-spv/tests/reverse_index_test.rs index 31bbb847f..e09d3097e 100644 --- a/dash-spv/tests/reverse_index_test.rs +++ b/dash-spv/tests/reverse_index_test.rs @@ -1,4 +1,4 @@ -use dash_spv::storage::{DiskStorageManager, StorageManager}; +use dash_spv::storage::{BlockHeaderStorage, DiskStorageManager, StorageManager}; use dashcore::block::Header as BlockHeader; use dashcore_hashes::Hash; use std::path::PathBuf; @@ -28,9 +28,6 @@ async fn test_reverse_index_disk_storage() { assert_eq!(height, Some(i as u32), "Height mismatch for header {}", i); } - // Add a small delay to ensure background worker processes save commands - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - storage.shutdown().await; } @@ -52,7 +49,7 @@ async fn test_reverse_index_disk_storage() { #[tokio::test] async fn test_clear_clears_index() { let mut storage = - DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path().into()) + DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path()) .await .expect("Failed to create tmp storage"); diff --git a/dash-spv/tests/rollback_test.rs b/dash-spv/tests/rollback_test.rs index d2424f972..7634648c6 100644 --- a/dash-spv/tests/rollback_test.rs +++ b/dash-spv/tests/rollback_test.rs @@ -42,7 +42,7 @@ async fn test_disk_storage_rollback() -> Result<(), Box> storage.store_headers(&headers).await?; // Verify we have 10 headers - let tip_height = storage.get_tip_height().await?; + let tip_height = storage.get_tip_height().await; assert_eq!(tip_height, Some(9)); // Load all headers to verify @@ -54,7 +54,7 @@ async fn test_disk_storage_rollback() -> Result<(), Box> // TODO: Test assertions commented out because rollback_to_height is not implemented // Verify tip height is now 5 - let _ = storage.get_tip_height().await?; + let _ = storage.get_tip_height().await; // assert_eq!(tip_height_after_rollback, Some(5)); // Verify we can only load headers up to height 5 diff --git a/dash-spv/tests/segmented_storage_debug.rs b/dash-spv/tests/segmented_storage_debug.rs index 611a5eaa0..1b10dd97d 100644 --- a/dash-spv/tests/segmented_storage_debug.rs +++ b/dash-spv/tests/segmented_storage_debug.rs @@ -1,6 +1,6 @@ //! Debug test for segmented storage. -use dash_spv::storage::{DiskStorageManager, StorageManager}; +use dash_spv::storage::{BlockHeaderStorage, DiskStorageManager, StorageManager}; use dashcore::block::{Header as BlockHeader, Version}; use dashcore::pow::CompactTarget; use dashcore::BlockHash; @@ -38,7 +38,7 @@ async fn test_basic_storage() { println!("Headers stored"); // Check tip height - let tip = storage.get_tip_height().await.unwrap(); + let tip = storage.get_tip_height().await; println!("Tip height: {:?}", tip); assert_eq!(tip, Some(9)); diff --git a/dash-spv/tests/segmented_storage_test.rs b/dash-spv/tests/segmented_storage_test.rs index 9b8995024..76d5e65f4 100644 --- a/dash-spv/tests/segmented_storage_test.rs +++ b/dash-spv/tests/segmented_storage_test.rs @@ -1,6 +1,9 @@ //! Tests for segmented disk storage implementation. -use dash_spv::storage::{DiskStorageManager, StorageManager}; +use dash_spv::storage::{ + BlockHeaderStorage, DiskStorageManager, FilterHeaderStorage, FilterStorage, MetadataStorage, + StorageManager, +}; use dashcore::block::{Header as BlockHeader, Version}; use dashcore::hash_types::FilterHeader; use dashcore::pow::CompactTarget; @@ -46,7 +49,7 @@ async fn test_segmented_storage_basic_operations() { } // Verify we can read them back - assert_eq!(storage.get_tip_height().await.unwrap(), Some(99_999)); + assert_eq!(storage.get_tip_height().await, Some(99_999)); // Check individual headers assert_eq!(storage.get_header(0).await.unwrap().unwrap().time, 0); @@ -76,7 +79,7 @@ async fn test_segmented_storage_persistence() { let mut storage = DiskStorageManager::new(path.clone()).await.unwrap(); // Verify storage starts empty - assert_eq!(storage.get_tip_height().await.unwrap(), None, "Storage should start empty"); + assert_eq!(storage.get_tip_height().await, None, "Storage should start empty"); let headers: Vec = (0..75_000).map(create_test_header).collect(); storage.store_headers(&headers).await.unwrap(); @@ -91,7 +94,7 @@ async fn test_segmented_storage_persistence() { { let storage = DiskStorageManager::new(path).await.unwrap(); - let actual_tip = storage.get_tip_height().await.unwrap(); + let actual_tip = storage.get_tip_height().await; if actual_tip != Some(74_999) { println!("Expected tip 74,999 but got {:?}", actual_tip); // Try to understand what's stored @@ -265,7 +268,7 @@ async fn test_background_save_timing() { // Verify data was saved { let storage = DiskStorageManager::new(path).await.unwrap(); - assert_eq!(storage.get_tip_height().await.unwrap(), Some(19_999)); + assert_eq!(storage.get_tip_height().await, Some(19_999)); assert_eq!(storage.get_header(15_000).await.unwrap().unwrap().time, 15_000); } } @@ -279,13 +282,13 @@ async fn test_clear_storage() { let headers: Vec = (0..10_000).map(create_test_header).collect(); storage.store_headers(&headers).await.unwrap(); - assert_eq!(storage.get_tip_height().await.unwrap(), Some(9_999)); + assert_eq!(storage.get_tip_height().await, Some(9_999)); // Clear storage storage.clear().await.unwrap(); // Verify everything is cleared - assert_eq!(storage.get_tip_height().await.unwrap(), None); + assert_eq!(storage.get_tip_height().await, None); assert_eq!(storage.get_header_height_by_hash(&headers[0].block_hash()).await.unwrap(), None); } @@ -311,7 +314,7 @@ async fn test_mixed_operations() { storage.store_metadata("test_key", b"test_value").await.unwrap(); // Verify everything - assert_eq!(storage.get_tip_height().await.unwrap(), Some(74_999)); + assert_eq!(storage.get_tip_height().await, Some(74_999)); assert_eq!(storage.get_filter_tip_height().await.unwrap(), Some(74_999)); let filters = storage.load_filters(1000..1001).await.unwrap(); diff --git a/dash-spv/tests/simple_header_test.rs b/dash-spv/tests/simple_header_test.rs index 40d0ce791..a21457188 100644 --- a/dash-spv/tests/simple_header_test.rs +++ b/dash-spv/tests/simple_header_test.rs @@ -3,7 +3,7 @@ use dash_spv::{ client::{ClientConfig, DashSpvClient}, network::PeerNetworkManager, - storage::{DiskStorageManager, StorageManager}, + storage::{BlockHeaderStorage, DiskStorageManager}, types::ValidationMode, }; use dashcore::Network; @@ -51,13 +51,12 @@ async fn test_simple_header_sync() { config.peers.push(peer_addr); // Create fresh storage - let storage = - DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path().into()) - .await - .expect("Failed to create tmp storage"); + let storage = DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path()) + .await + .expect("Failed to create tmp storage"); // Verify starting from empty state - assert_eq!(storage.get_tip_height().await.unwrap(), None); + assert_eq!(storage.get_tip_height().await, None); // Create network manager let network_manager = diff --git a/dash-spv/tests/simple_segmented_test.rs b/dash-spv/tests/simple_segmented_test.rs index 422bb78ed..9cea06a35 100644 --- a/dash-spv/tests/simple_segmented_test.rs +++ b/dash-spv/tests/simple_segmented_test.rs @@ -1,6 +1,6 @@ //! Simple test without background saving. -use dash_spv::storage::{DiskStorageManager, StorageManager}; +use dash_spv::storage::{BlockHeaderStorage, DiskStorageManager}; use dashcore::block::{Header as BlockHeader, Version}; use dashcore::pow::CompactTarget; use dashcore::BlockHash; @@ -28,7 +28,7 @@ async fn test_simple_storage() { let mut storage = DiskStorageManager::new(temp_dir.path().to_path_buf()).await.unwrap(); println!("Testing get_tip_height before storing anything..."); - let initial_tip = storage.get_tip_height().await.unwrap(); + let initial_tip = storage.get_tip_height().await; println!("Initial tip: {:?}", initial_tip); assert_eq!(initial_tip, None); @@ -40,7 +40,7 @@ async fn test_simple_storage() { println!("Single header stored"); println!("Checking tip height..."); - let tip = storage.get_tip_height().await.unwrap(); + let tip = storage.get_tip_height().await; println!("Tip height after storing one header: {:?}", tip); assert_eq!(tip, Some(0)); diff --git a/dash-spv/tests/storage_consistency_test.rs b/dash-spv/tests/storage_consistency_test.rs index 8bdd682b7..cdd166442 100644 --- a/dash-spv/tests/storage_consistency_test.rs +++ b/dash-spv/tests/storage_consistency_test.rs @@ -3,7 +3,7 @@ //! These tests are designed to expose the storage bug where get_tip_height() //! returns a value but get_header() at that height returns None. -use dash_spv::storage::{DiskStorageManager, StorageManager}; +use dash_spv::storage::{BlockHeaderStorage, DiskStorageManager, StorageManager}; use dashcore::block::{Header as BlockHeader, Version}; use dashcore::pow::CompactTarget; use dashcore::BlockHash; @@ -36,7 +36,7 @@ async fn test_tip_height_header_consistency_basic() { storage.store_headers(&headers).await.unwrap(); // Check consistency immediately - let tip_height = storage.get_tip_height().await.unwrap(); + let tip_height = storage.get_tip_height().await; println!("Tip height: {:?}", tip_height); if let Some(height) = tip_height { @@ -72,7 +72,7 @@ async fn test_tip_height_header_consistency_after_save() { // Wait for background save to complete sleep(Duration::from_secs(1)).await; - let tip_height = storage.get_tip_height().await.unwrap(); + let tip_height = storage.get_tip_height().await; println!("Phase 1 - Tip height: {:?}", tip_height); if let Some(height) = tip_height { @@ -87,7 +87,7 @@ async fn test_tip_height_header_consistency_after_save() { { let storage = DiskStorageManager::new(storage_path.clone()).await.unwrap(); - let tip_height = storage.get_tip_height().await.unwrap(); + let tip_height = storage.get_tip_height().await; println!("Phase 2 - Tip height after reload: {:?}", tip_height); if let Some(height) = tip_height { @@ -129,7 +129,7 @@ async fn test_tip_height_header_consistency_large_dataset() { storage.store_headers(&headers).await.unwrap(); // Check consistency after each batch - let tip_height = storage.get_tip_height().await.unwrap(); + let tip_height = storage.get_tip_height().await; if let Some(height) = tip_height { let header = storage.get_header(height).await.unwrap(); if header.is_none() { @@ -155,7 +155,7 @@ async fn test_tip_height_header_consistency_large_dataset() { } // Final consistency check - let final_tip = storage.get_tip_height().await.unwrap(); + let final_tip = storage.get_tip_height().await; println!("Final tip height: {:?}", final_tip); if let Some(height) = final_tip { @@ -206,7 +206,7 @@ async fn test_concurrent_tip_header_access() { let handle = tokio::spawn(async move { // Repeatedly check consistency for iteration in 0..100 { - let tip_height = storage.get_tip_height().await.unwrap(); + let tip_height = storage.get_tip_height().await; if let Some(height) = tip_height { let header = storage.get_header(height).await.unwrap(); @@ -278,7 +278,7 @@ async fn test_reproduce_filter_sync_bug() { storage.store_headers(&tip_header).await.unwrap(); // Now check what get_tip_height() returns - let reported_tip = storage.get_tip_height().await.unwrap(); + let reported_tip = storage.get_tip_height().await; println!("Storage reports tip height: {:?}", reported_tip); if let Some(tip_height) = reported_tip { @@ -346,7 +346,7 @@ async fn test_reproduce_filter_sync_bug_small() { storage.store_headers(&tip_header).await.unwrap(); // Now check what get_tip_height() returns - let reported_tip = storage.get_tip_height().await.unwrap(); + let reported_tip = storage.get_tip_height().await; println!("Storage reports tip height: {:?}", reported_tip); if let Some(tip_height) = reported_tip { @@ -406,7 +406,7 @@ async fn test_segment_boundary_consistency() { segment_size + 1, // Second in second segment ]; - let tip_height = storage.get_tip_height().await.unwrap().unwrap(); + let tip_height = storage.get_tip_height().await.unwrap(); println!("Tip height: {}", tip_height); for height in boundary_heights { @@ -461,7 +461,7 @@ async fn test_reproduce_tip_height_segment_eviction_race() { storage.store_headers(&headers).await.unwrap(); // Immediately check for race condition - let tip_height = storage.get_tip_height().await.unwrap(); + let tip_height = storage.get_tip_height().await; if let Some(height) = tip_height { // Try to access the tip header multiple times to catch race condition @@ -542,7 +542,7 @@ async fn test_concurrent_tip_height_access_with_eviction() { // Reduced from 50 to 20 iterations for iteration in 0..20 { // Get tip height - let tip_height = storage.get_tip_height().await.unwrap(); + let tip_height = storage.get_tip_height().await; if let Some(height) = tip_height { // Immediately try to access the tip header @@ -606,7 +606,7 @@ async fn test_concurrent_tip_height_access_with_eviction_heavy() { let handle = tokio::spawn(async move { for iteration in 0..50 { // Get tip height - let tip_height = storage.get_tip_height().await.unwrap(); + let tip_height = storage.get_tip_height().await; if let Some(height) = tip_height { // Immediately try to access the tip header @@ -659,7 +659,7 @@ async fn test_tip_height_segment_boundary_race() { storage.store_headers(&headers).await.unwrap(); // Verify tip is at segment boundary - let tip_height = storage.get_tip_height().await.unwrap(); + let tip_height = storage.get_tip_height().await; assert_eq!(tip_height, Some(segment_size - 1)); storage.shutdown().await; @@ -678,7 +678,7 @@ async fn test_tip_height_segment_boundary_race() { storage.store_headers(&headers).await.unwrap(); // After storing each segment, verify tip consistency - let reported_tip = storage.get_tip_height().await.unwrap(); + let reported_tip = storage.get_tip_height().await; if let Some(tip) = reported_tip { let header = storage.get_header(tip).await.unwrap(); if header.is_none() { @@ -698,7 +698,7 @@ async fn test_tip_height_segment_boundary_race() { } // But the current tip should always be accessible - let current_tip = storage.get_tip_height().await.unwrap(); + let current_tip = storage.get_tip_height().await; if let Some(tip) = current_tip { let header = storage.get_header(tip).await.unwrap(); assert!(header.is_some(), "Current tip header must always be accessible"); diff --git a/dash-spv/tests/storage_test.rs b/dash-spv/tests/storage_test.rs index 254a5162e..79833d09b 100644 --- a/dash-spv/tests/storage_test.rs +++ b/dash-spv/tests/storage_test.rs @@ -1,7 +1,7 @@ //! Integration tests for storage layer functionality. use dash_spv::error::StorageError; -use dash_spv::storage::{DiskStorageManager, StorageManager}; +use dash_spv::storage::{BlockHeaderStorage, DiskStorageManager, StorageManager}; use dashcore::{block::Header as BlockHeader, block::Version}; use dashcore_hashes::Hash; use tempfile::TempDir; @@ -57,7 +57,7 @@ async fn test_disk_storage_reopen_after_clean_shutdown() { assert!(storage.is_ok(), "Should reopen after clean shutdown"); let storage = storage.unwrap(); - let tip = storage.get_tip_height().await.unwrap(); + let tip = storage.get_tip_height().await; assert_eq!(tip, Some(4), "Data should persist across reopen"); } @@ -80,21 +80,25 @@ async fn test_disk_storage_concurrent_access_blocked() { } other => panic!("Expected DirectoryLocked error, got: {:?}", other), } - - // First storage manager should still be usable - assert!(_storage1.get_tip_height().await.is_ok()); } #[tokio::test] async fn test_disk_storage_lock_file_lifecycle() { let temp_dir = TempDir::new().expect("Failed to create temp directory"); let path = temp_dir.path().to_path_buf(); - let lock_path = path.join(".lock"); + let lock_path = { + let mut lock_file = path.clone(); + lock_file.set_extension("lock"); + lock_file + }; // Lock file created when storage opens { - let _storage = DiskStorageManager::new(path.clone()).await.unwrap(); + let mut storage = DiskStorageManager::new(path.clone()).await.unwrap(); assert!(lock_path.exists(), "Lock file should exist while storage is open"); + + storage.clear().await.expect("Failed to clear the storage"); + assert!(lock_path.exists(), "Lock file should exist after storage is cleared"); } // Lock file removed when storage drops diff --git a/dash-spv/tests/wallet_integration_test.rs b/dash-spv/tests/wallet_integration_test.rs index 3b00bcd36..d61f0f2cb 100644 --- a/dash-spv/tests/wallet_integration_test.rs +++ b/dash-spv/tests/wallet_integration_test.rs @@ -15,14 +15,18 @@ use key_wallet_manager::wallet_manager::WalletManager; /// Create a test SPV client with memory storage for integration testing. async fn create_test_client( ) -> DashSpvClient, PeerNetworkManager, DiskStorageManager> { - let config = ClientConfig::testnet().without_filters().without_masternodes(); + let temp_dir = tempfile::TempDir::new().expect("Failed to create temporary directory"); + let config = ClientConfig::testnet() + .without_filters() + .without_masternodes() + .with_storage_path(temp_dir.path().to_path_buf()); // Create network manager let network_manager = PeerNetworkManager::new(&config).await.unwrap(); // Create storage manager let storage_manager = - DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path().into()) + DiskStorageManager::new(TempDir::new().expect("Failed to create tmp dir").path()) .await .expect("Failed to create tmp storage");