diff --git a/network/src/dht/peer_resolver.rs b/network/src/dht/peer_resolver.rs index a7ff254017..9477d71e47 100644 --- a/network/src/dht/peer_resolver.rs +++ b/network/src/dht/peer_resolver.rs @@ -1,4 +1,5 @@ use std::mem::ManuallyDrop; +use std::pin::pin; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::{Arc, Mutex, Weak}; use std::time::Duration; @@ -242,6 +243,8 @@ impl PeerResolverInner { data: &PeerResolverHandleData, prev_timings: &Option, ) -> Option<(Network, Arc)> { + use futures_util::future::Either; + struct Iter<'a> { backoff: Option>, data: &'a PeerResolverHandleData, @@ -290,6 +293,7 @@ impl PeerResolverInner { let is_stale = attempts > self.config.fast_retry_count as usize; // NOTE: Acquire network ref only during the operation. + let mut new_peer_added = self.dht_service.peer_added().notified(); { let network = self.weak_network.upgrade()?; if let Some(peer_info) = network.known_peers().get(&data.peer_id) @@ -345,7 +349,30 @@ impl PeerResolverInner { } let interval = iter.next().expect("retries iterator must be infinite"); - tokio::time::sleep(interval).await; + let mut sleep = pin!(tokio::time::sleep(interval)); + 'inner: loop { + match futures_util::future::select(&mut sleep, pin!(new_peer_added)).await { + // Backoff interval elapsed. + Either::Left(_) => break 'inner, + // A new peer has been discovered. + Either::Right(_) => { + new_peer_added = self.dht_service.peer_added().notified(); + + let network = self.weak_network.upgrade()?; + if let Some(peer_info) = network.known_peers().get(&data.peer_id) + && PeerResolverTimings::is_new_info(prev_timings, &peer_info) + { + tracing::trace!( + peer_id = %data.peer_id, + attempts, + is_stale, + "peer info exists", + ); + return Some((network, peer_info)); + } + } + } + } } } diff --git a/network/src/dht/query.rs b/network/src/dht/query.rs index d34354a7b0..71e2bd3cee 100644 --- a/network/src/dht/query.rs +++ b/network/src/dht/query.rs @@ -14,8 +14,8 @@ use tycho_util::time::now_sec; use tycho_util::{FastDashMap, FastHashMap, FastHashSet}; use crate::dht::config::DhtConfig; -use crate::dht::routing::{HandlesRoutingTable, SimpleRoutingTable}; -use crate::network::Network; +use crate::dht::routing::HandlesRoutingTable; +use crate::network::{KnownPeerHandle, KnownPeers, KnownPeersError, Network}; use crate::proto::dht::{NodeResponse, Value, ValueRef, ValueResponse, rpc}; use crate::types::{PeerId, PeerInfo, Request}; use crate::util::NetworkExt; @@ -119,7 +119,7 @@ pub enum DhtQueryMode { pub struct Query { network: Network, - candidates: SimpleRoutingTable, + candidates: HandlesRoutingTable, max_k: usize, timeout: Duration, } @@ -132,7 +132,7 @@ impl Query { config: &DhtConfig, mode: DhtQueryMode, ) -> Self { - let mut candidates = SimpleRoutingTable::new(PeerId(*target_id)); + let mut candidates = HandlesRoutingTable::new(PeerId(*target_id)); let random_id; let target_id_for_full = match mode { @@ -146,8 +146,10 @@ impl Query { let max_k = config.max_k; let timeout = config.request_timeout; - routing_table.visit_closest(target_id_for_full, max_k, |node| { - candidates.add(node.load_peer_info(), max_k, &Duration::MAX, Some); + routing_table.visit_closest(target_id_for_full, max_k, |handle| { + candidates.add(handle.load_peer_info(), max_k, &Duration::MAX, |_| { + Some(handle.clone()) + }); }); Self { @@ -171,29 +173,36 @@ impl Query { })); // Prepare request to initial candidates + let mut scheduled = FastHashSet::new(); let semaphore = Semaphore::new(MAX_PARALLEL_REQUESTS); let mut futures = FuturesUnordered::new(); + + let visit = |this: &Query, handle: &KnownPeerHandle| { + Self::visit::( + this.network.clone(), + handle.clone(), + request_body.clone(), + &semaphore, + this.timeout, + ) + }; + self.candidates - .visit_closest(self.local_id(), self.max_k, |node| { - futures.push(Self::visit::( - self.network.clone(), - node.clone(), - request_body.clone(), - &semaphore, - self.timeout, - )); + .visit_closest(self.local_id(), self.max_k, |handle| { + if scheduled.insert(handle.peer_info().id) { + futures.push(visit(&self, handle)); + } }); // Process responses and refill futures until the value is found or all peers are traversed - let mut visited = FastHashSet::new(); - while let Some((node, res)) = futures.next().await { + while let Some((handle, res)) = futures.next().await { match res { // Return the value if found Some(Ok(ValueResponse::Found(value))) => { let mut signature_checked = false; let is_valid = value.verify_ext(now_sec(), self.local_id(), &mut signature_checked); - tracing::debug!(peer_id = %node.id, is_valid, "found value"); + tracing::debug!(peer_id = %handle.peer_info().id, is_valid, "found value"); yield_on_complex(signature_checked).await; @@ -207,39 +216,42 @@ impl Query { // Refill futures from the nodes response Some(Ok(ValueResponse::NotFound(nodes))) => { let node_count = nodes.len(); - let has_new = self - .update_candidates(now_sec(), self.max_k, nodes, &mut visited) - .await; - tracing::debug!(peer_id = %node.id, count = node_count, has_new, "received nodes"); - - if !has_new { - // Do nothing if candidates were not changed - continue; - } + let known_peers = self.network.known_peers().clone(); + + // Update candidates. + let mut has_new = false; + process_only_valid(now_sec(), nodes, |peer_info| { + has_new |= self.candidates.add( + peer_info, + self.max_k, + &Duration::MAX, + |peer_info| Self::retain_candidate(&known_peers, peer_info), + ); + }) + .await; + + tracing::debug!( + peer_id = %handle.peer_info().id, + count = node_count, + has_new, + "received nodes", + ); // Add new nodes from the closest range self.candidates - .visit_closest(self.local_id(), self.max_k, |node| { - if visited.contains(&node.id) { - // Skip already visited nodes - return; + .visit_closest(self.local_id(), self.max_k, |handle| { + if scheduled.insert(handle.peer_info().id) { + futures.push(visit(&self, handle)); } - futures.push(Self::visit::( - self.network.clone(), - node.clone(), - request_body.clone(), - &semaphore, - self.timeout, - )); }); } // Do nothing on error Some(Err(e)) => { - tracing::warn!(peer_id = %node.id, "failed to query nodes: {e}"); + tracing::warn!(peer_id = %handle.peer_info().id, "failed to query nodes: {e}"); } // Do nothing on timeout None => { - tracing::warn!(peer_id = %node.id, "failed to query nodes: timeout"); + tracing::warn!(peer_id = %handle.peer_info().id, "failed to query nodes: timeout"); } } } @@ -257,65 +269,90 @@ impl Query { })); // Prepare request to initial candidates + let mut scheduled = FastHashSet::new(); + let mut candidate_depths = FastHashMap::new(); let semaphore = Semaphore::new(MAX_PARALLEL_REQUESTS); let mut futures = FuturesUnordered::new(); + + let visit = |this: &Query, node: &KnownPeerHandle, depth: usize| { + use futures_util::FutureExt; + + Self::visit::( + this.network.clone(), + node.clone(), + request_body.clone(), + &semaphore, + this.timeout, + ) + .map(move |res| (res, depth)) + }; + self.candidates .visit_closest(self.local_id(), self.max_k, |node| { - futures.push(Self::visit::( - self.network.clone(), - node.clone(), - request_body.clone(), - &semaphore, - self.timeout, - )); + let peer_id = node.peer_info().id; + if scheduled.insert(peer_id) { + candidate_depths.insert(peer_id, 0); + futures.push(visit(&self, node, 0)); + } }); // Process responses and refill futures until all peers are traversed - let mut current_depth = 0; let max_depth = depth.unwrap_or(usize::MAX); let mut result = FastHashMap::>::new(); - while let Some((node, res)) = futures.next().await { + while let Some(((node, res), query_depth)) = futures.next().await { match res { // Refill futures from the nodes response Some(Ok(NodeResponse { nodes })) => { - tracing::debug!(peer_id = %node.id, count = nodes.len(), "received nodes"); - if !self - .update_candidates_full(now_sec(), self.max_k, nodes, &mut result) - .await - { - // Do nothing if candidates were not changed - continue; - } + tracing::debug!(peer_id = %node.peer_info().id, count = nodes.len(), "received nodes"); + let known_peers = self.network.known_peers().clone(); + + // Update candidates. + process_only_valid(now_sec(), nodes, |peer_info| { + let discovered_depth = query_depth.saturating_add(1); + candidate_depths + .entry(peer_info.id) + .and_modify(|depth| *depth = (*depth).min(discovered_depth)) + .or_insert(discovered_depth); + + let peer_info = match result.entry(peer_info.id) { + // Insert a new entry + hash_map::Entry::Vacant(entry) => entry.insert(peer_info).clone(), + // Try to replace an old entry + hash_map::Entry::Occupied(mut entry) => { + if entry.get().created_at < peer_info.created_at { + *entry.get_mut() = peer_info; + } + entry.get().clone() + } + }; - current_depth += 1; - if current_depth >= max_depth { - // Stop on max depth - break; - } + self.candidates + .add(peer_info, self.max_k, &Duration::MAX, |peer_info| { + Self::retain_candidate(&known_peers, peer_info) + }); + }) + .await; // Add new nodes from the closest range self.candidates .visit_closest(self.local_id(), self.max_k, |node| { - if result.contains_key(&node.id) { - // Skip already visited nodes + let peer_id = node.peer_info().id; + let Some(&candidate_depth) = candidate_depths.get(&peer_id) else { return; + }; + + if candidate_depth <= max_depth && scheduled.insert(peer_id) { + futures.push(visit(&self, node, candidate_depth)); } - futures.push(Self::visit::( - self.network.clone(), - node.clone(), - request_body.clone(), - &semaphore, - self.timeout, - )); }); } // Do nothing on error Some(Err(e)) => { - tracing::warn!(peer_id = %node.id, "failed to query nodes: {e}"); + tracing::warn!(peer_id = %node.peer_info().id, "failed to query nodes: {e}"); } // Do nothing on timeout None => { - tracing::warn!(peer_id = %node.id, "failed to query nodes: timeout"); + tracing::warn!(peer_id = %node.peer_info().id, "failed to query nodes: timeout"); } } } @@ -324,70 +361,22 @@ impl Query { result } - async fn update_candidates( - &mut self, - now: u32, - max_k: usize, - nodes: Vec>, - visited: &mut FastHashSet, - ) -> bool { - let mut has_new = false; - process_only_valid(now, nodes, |node| { - // Insert a new entry - if visited.insert(node.id) { - self.candidates.add(node, max_k, &Duration::MAX, Some); - has_new = true; - } - }) - .await; - - has_new - } - - async fn update_candidates_full( - &mut self, - now: u32, - max_k: usize, - nodes: Vec>, - visited: &mut FastHashMap>, - ) -> bool { - let mut has_new = false; - process_only_valid(now, nodes, |node| { - match visited.entry(node.id) { - // Insert a new entry - hash_map::Entry::Vacant(entry) => { - let node = entry.insert(node).clone(); - self.candidates.add(node, max_k, &Duration::MAX, Some); - has_new = true; - } - // Try to replace an old entry - hash_map::Entry::Occupied(mut entry) => { - if entry.get().created_at < node.created_at { - *entry.get_mut() = node; - } - } - } - }) - .await; - - has_new - } - async fn visit( network: Network, - node: Arc, + handle: KnownPeerHandle, request_body: Bytes, semaphore: &Semaphore, timeout: Duration, - ) -> (Arc, Option>) + ) -> (KnownPeerHandle, Option>) where for<'a> T: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>, { let Ok(_permit) = semaphore.acquire().await else { - return (node, None); + return (handle, None); }; - let req = network.query(&node.id, Request { + let peer_id = handle.peer_info().id; + let req = network.query(&peer_id, Request { version: Default::default(), body: request_body.clone(), }); @@ -399,7 +388,25 @@ impl Query { Err(_) => None, }; - (node, res) + (handle, res) + } + + fn retain_candidate( + known_peers: &KnownPeers, + peer_info: Arc, + ) -> Option { + if let Some(handle) = known_peers.make_handle(&peer_info.id, false) { + match handle.update_peer_info(&peer_info) { + Ok(()) | Err(KnownPeersError::OutdatedInfo) => return Some(handle), + Err(KnownPeersError::PeerBanned(_)) => return None, + } + } + + match known_peers.insert(peer_info.clone(), false) { + Ok(handle) => Some(handle), + Err(KnownPeersError::OutdatedInfo) => known_peers.make_handle(&peer_info.id, false), + Err(KnownPeersError::PeerBanned(_)) => None, + } } } @@ -516,3 +523,168 @@ where } const MAX_PARALLEL_REQUESTS: usize = 10; + +#[cfg(test)] +mod tests { + use std::net::Ipv4Addr; + + use tycho_crypto::ed25519; + + use super::*; + use crate::dht::{DhtClient, DhtService}; + use crate::proto::dht::PeerValueKeyName; + use crate::util::Router; + + struct Node { + network: Network, + dht: DhtClient, + } + + impl Node { + fn new() -> Self { + let key = rand::random::(); + let local_id = ed25519::PublicKey::from(&key).into(); + + let (_, dht_service) = DhtService::builder(local_id).build(); + let router = Router::builder().route(dht_service.clone()).build(); + let network = Network::builder() + .with_private_key(key.to_bytes()) + .build((Ipv4Addr::LOCALHOST, 0), router) + .unwrap(); + let dht = dht_service.make_client(&network); + + Self { network, dht } + } + + fn add_peer(&self, peer: &Self) { + self.dht.add_peer(peer.peer_info()).unwrap(); + } + + fn peer_info(&self) -> Arc { + Arc::new(self.network.sign_peer_info(now_sec(), 60)) + } + + fn peer_id(&self) -> PeerId { + *self.network.peer_id() + } + + fn ban_peer(&self, peer: &Self) { + self.network.known_peers().ban(&peer.peer_id()); + } + + fn store_local_peer_info(&self) { + let peer_info = self.peer_info(); + self.dht + .entry(PeerValueKeyName::NodeInfo) + .with_data(peer_info.as_ref()) + .store_locally() + .unwrap(); + } + + fn make_query(&self, target_id: &[u8; 32]) -> Query { + let routing_table = self.dht.inner.routing_table.lock().unwrap(); + Query::new( + self.network.clone(), + &routing_table, + target_id, + &DhtConfig::default(), + DhtQueryMode::Closest, + ) + } + } + + #[tokio::test] + async fn find_peers_with_depth_works() { + let [a, b, c, d, e] = std::array::from_fn(|_| Node::new()); + + a.add_peer(&b); + b.add_peer(&c); + c.add_peer(&d); + d.add_peer(&e); + + let target_id = rand::random(); + + let result = a.make_query(&target_id).find_peers(Some(0)).await; + assert!(!result.contains_key(a.network.peer_id())); + assert!(!result.contains_key(b.network.peer_id())); + assert!(result.contains_key(&c.peer_id())); + assert!(!result.contains_key(&d.peer_id())); + + let result = a.make_query(&target_id).find_peers(Some(1)).await; + assert!(result.contains_key(&c.peer_id())); + assert!(result.contains_key(&d.peer_id())); + assert!(!result.contains_key(&e.peer_id())); + + let result = a.make_query(&target_id).find_peers(Some(2)).await; + assert!(result.contains_key(&c.peer_id())); + assert!(result.contains_key(&d.peer_id())); + assert!(result.contains_key(&e.peer_id())); + } + + #[tokio::test] + async fn query_reuses_local_dht_storage() { + let [a, b, c, d] = std::array::from_fn(|_| Node::new()); + + a.add_peer(&b); + b.add_peer(&c); + c.add_peer(&d); + d.store_local_peer_info(); + + let peer_info = a + .dht + .entry(PeerValueKeyName::NodeInfo) + .find_value::(&d.peer_id()) + .await + .unwrap(); + + assert_eq!(peer_info.id, d.peer_id()); + } + + #[tokio::test] + async fn query_skips_banned_nodes() { + let [a, b, c, d] = std::array::from_fn(|_| Node::new()); + + a.add_peer(&b); + b.add_peer(&c); + c.add_peer(&d); + a.ban_peer(&c); + + let target_id = rand::random(); + let result = a.make_query(&target_id).find_peers(Some(2)).await; + + assert!(result.contains_key(&c.peer_id())); + assert!(!result.contains_key(&d.peer_id())); + } + + #[tokio::test] + async fn query_overrides_outdated_peer_info() { + let [a, b, c, d] = std::array::from_fn(|_| Node::new()); + + a.add_peer(&b); + b.add_peer(&c); + c.add_peer(&d); + + let newer_peer_info = c.peer_info(); + let older_peer_info = Arc::new( + c.network + .sign_peer_info(newer_peer_info.created_at.saturating_sub(10), 60), + ); + + let known_handle = a + .network + .known_peers() + .insert(newer_peer_info.clone(), false) + .unwrap(); + b.dht.add_peer(older_peer_info).unwrap(); + + let target_id = rand::random(); + let result = a.make_query(&target_id).find_peers(Some(1)).await; + + assert!(result.contains_key(&c.peer_id())); + assert!(result.contains_key(&d.peer_id())); + assert_eq!( + known_handle.load_peer_info().created_at, + newer_peer_info.created_at + ); + } +} diff --git a/network/src/dht/routing.rs b/network/src/dht/routing.rs index 46710fe1a1..74d1ea4b46 100644 --- a/network/src/dht/routing.rs +++ b/network/src/dht/routing.rs @@ -8,7 +8,6 @@ use crate::dht::{MAX_XOR_DISTANCE, xor_distance}; use crate::network::KnownPeerHandle; use crate::types::{PeerId, PeerInfo}; -pub(crate) type SimpleRoutingTable = RoutingTable>; pub(crate) type HandlesRoutingTable = RoutingTable; pub(crate) struct RoutingTable { @@ -315,6 +314,8 @@ mod tests { use super::*; use crate::util::make_peer_info_stub; + type SimpleRoutingTable = RoutingTable>; + const MAX_K: usize = 20; #[derive(Debug, Default)] diff --git a/network/src/overlay/background_tasks.rs b/network/src/overlay/background_tasks.rs index 297b931ad6..e76a69fa7d 100644 --- a/network/src/overlay/background_tasks.rs +++ b/network/src/overlay/background_tasks.rs @@ -60,15 +60,17 @@ impl OverlayServiceInner { let mut public_overlays_changed = Box::pin(public_overlays_notify.notified()); let mut public_overlays_state = None::; - let dht_peer_added = dht_service + let dht_peer_added_notify = dht_service .as_ref() .map(|s| s.peer_added()) .cloned() .unwrap_or_default(); + let mut dht_peer_added = Box::pin(dht_peer_added_notify.notified()); let empty_overlays = OverlayIdsQueue::default(); + let mut drain_empty_overlays = false; - loop { + 'outer: loop { let action = match &mut public_overlays_state { // Initial update for public overlays list None => Action::UpdatePublicOverlaysList(public_overlays_state.insert( @@ -80,7 +82,21 @@ impl OverlayServiceInner { }, )), // Default actions - Some(public_overlays_state) => { + Some(public_overlays_state) => 'action: { + if drain_empty_overlays && let Some(overlay_id) = empty_overlays.pop() { + tracing::debug!( + %overlay_id, + "force discover public overlay peers on new DHT peer", + ); + break 'action Action::DiscoverPublicOverlayEntries { + overlay_id, + tasks: &mut public_overlays_state.discover, + force: true, + }; + } else { + drain_empty_overlays = false; + } + tokio::select! { _ = &mut public_overlays_changed => { public_overlays_changed = Box::pin(public_overlays_notify.notified()); @@ -91,7 +107,7 @@ impl OverlayServiceInner { overlay_id: id, tasks: &mut public_overlays_state.exchange, }, - None => continue, + None => continue 'outer, }, overlay_id = public_overlays_state.discover.next() => match overlay_id { Some(id) => Action::DiscoverPublicOverlayEntries { @@ -99,36 +115,28 @@ impl OverlayServiceInner { tasks: &mut public_overlays_state.discover, force: false, }, - None => continue, + None => continue 'outer, }, overlay_id = public_overlays_state.collect.next() => match overlay_id { Some(id) => Action::CollectPublicEntries { overlay_id: id, tasks: &mut public_overlays_state.collect, }, - None => continue, + None => continue 'outer, }, overlay_id = public_overlays_state.store.next() => match overlay_id { Some(id) => Action::StorePublicEntries { overlay_id: id, tasks: &mut public_overlays_state.store, }, - None => continue, - }, - _ = dht_peer_added.notified(), if !empty_overlays.is_empty() => { - let Some(id) = empty_overlays.pop() else { - continue; - }; - tracing::debug!( - overlay_id = %id, - "force discover public overlay peers on new DHT peer", - ); - Action::DiscoverPublicOverlayEntries { - overlay_id: id, - tasks: &mut public_overlays_state.discover, - force: true, - } + None => continue 'outer, }, + _ = &mut dht_peer_added, if !empty_overlays.is_empty() => { + dht_peer_added = Box::pin(dht_peer_added_notify.notified()); + // Trigger `empty_overlays.pop()` on next retry. + drain_empty_overlays = true; + continue 'outer; + } } } };