Source code for kademlia.routing

import heapq
import time
import operator
import asyncio

from itertools import chain
from collections import OrderedDict
from kademlia.utils import shared_prefix, bytes_to_bit_string


[docs]class KBucket: def __init__(self, rangeLower, rangeUpper, ksize, replacementNodeFactor=5): self.range = (rangeLower, rangeUpper) self.nodes = OrderedDict() self.replacement_nodes = OrderedDict() self.touch_last_updated() self.ksize = ksize self.max_replacement_nodes = self.ksize * replacementNodeFactor
[docs] def touch_last_updated(self): self.last_updated = time.monotonic()
[docs] def get_nodes(self): return list(self.nodes.values())
[docs] def split(self): midpoint = (self.range[0] + self.range[1]) // 2 one = KBucket(self.range[0], midpoint, self.ksize) two = KBucket(midpoint + 1, self.range[1], self.ksize) nodes = chain(self.nodes.values(), self.replacement_nodes.values()) for node in nodes: bucket = one if node.long_id <= midpoint else two bucket.add_node(node) return (one, two)
[docs] def remove_node(self, node): if node.id in self.replacement_nodes: del self.replacement_nodes[node.id] if node.id in self.nodes: del self.nodes[node.id] if self.replacement_nodes: newnode_id, newnode = self.replacement_nodes.popitem() self.nodes[newnode_id] = newnode
[docs] def has_in_range(self, node): return self.range[0] <= node.long_id <= self.range[1]
[docs] def is_new_node(self, node): return node.id not in self.nodes
[docs] def add_node(self, node): """ Add a C{Node} to the C{KBucket}. Return True if successful, False if the bucket is full. If the bucket is full, keep track of node in a replacement list, per section 4.1 of the paper. """ if node.id in self.nodes: del self.nodes[node.id] self.nodes[node.id] = node elif len(self) < self.ksize: self.nodes[node.id] = node else: if node.id in self.replacement_nodes: del self.replacement_nodes[node.id] self.replacement_nodes[node.id] = node while len(self.replacement_nodes) > self.max_replacement_nodes: self.replacement_nodes.popitem(last=False) return False return True
[docs] def depth(self): vals = self.nodes.values() sprefix = shared_prefix([bytes_to_bit_string(n.id) for n in vals]) return len(sprefix)
[docs] def head(self): return list(self.nodes.values())[0]
def __getitem__(self, node_id): return self.nodes.get(node_id, None) def __len__(self): return len(self.nodes)
[docs]class TableTraverser: def __init__(self, table, startNode): index = table.get_bucket_for(startNode) table.buckets[index].touch_last_updated() self.current_nodes = table.buckets[index].get_nodes() self.left_buckets = table.buckets[:index] self.right_buckets = table.buckets[(index + 1):] self.left = True def __iter__(self): return self def __next__(self): """ Pop an item from the left subtree, then right, then left, etc. """ if self.current_nodes: return self.current_nodes.pop() if self.left and self.left_buckets: self.current_nodes = self.left_buckets.pop().get_nodes() self.left = False return next(self) if self.right_buckets: self.current_nodes = self.right_buckets.pop(0).get_nodes() self.left = True return next(self) raise StopIteration
[docs]class RoutingTable: def __init__(self, protocol, ksize, node): """ @param node: The node that represents this server. It won't be added to the routing table, but will be needed later to determine which buckets to split or not. """ self.node = node self.protocol = protocol self.ksize = ksize self.flush()
[docs] def flush(self): self.buckets = [KBucket(0, 2 ** 160, self.ksize)]
[docs] def split_bucket(self, index): one, two = self.buckets[index].split() self.buckets[index] = one self.buckets.insert(index + 1, two)
[docs] def lonely_buckets(self): """ Get all of the buckets that haven't been updated in over an hour. """ hrago = time.monotonic() - 3600 return [b for b in self.buckets if b.last_updated < hrago]
[docs] def remove_contact(self, node): index = self.get_bucket_for(node) self.buckets[index].remove_node(node)
[docs] def is_new_node(self, node): index = self.get_bucket_for(node) return self.buckets[index].is_new_node(node)
[docs] def add_contact(self, node): index = self.get_bucket_for(node) bucket = self.buckets[index] # this will succeed unless the bucket is full if bucket.add_node(node): return # Per section 4.2 of paper, split if the bucket has the node # in its range or if the depth is not congruent to 0 mod 5 if bucket.has_in_range(self.node) or bucket.depth() % 5 != 0: self.split_bucket(index) self.add_contact(node) else: asyncio.ensure_future(self.protocol.call_ping(bucket.head()))
[docs] def get_bucket_for(self, node): """ Get the index of the bucket that the given node would fall into. """ for index, bucket in enumerate(self.buckets): if node.long_id < bucket.range[1]: return index # we should never be here, but make linter happy return None
[docs] def find_neighbors(self, node, k=None, exclude=None): k = k or self.ksize nodes = [] for neighbor in TableTraverser(self, node): notexcluded = exclude is None or not neighbor.same_home_as(exclude) if neighbor.id != node.id and notexcluded: heapq.heappush(nodes, (node.distance_to(neighbor), neighbor)) if len(nodes) == k: break return list(map(operator.itemgetter(1), heapq.nsmallest(k, nodes)))