diff --git a/kademlia/routing.py b/kademlia/routing.py index 4a7279f..2db2185 100644 --- a/kademlia/routing.py +++ b/kademlia/routing.py @@ -5,14 +5,14 @@ import asyncio from itertools import chain from collections import OrderedDict -from kademlia.utils import OrderedSet, shared_prefix, bytes_to_bit_string +from kademlia.utils import shared_prefix, bytes_to_bit_string class KBucket: def __init__(self, rangeLower, rangeUpper, ksize): self.range = (rangeLower, rangeUpper) self.nodes = OrderedDict() - self.replacement_nodes = OrderedSet() + self.replacement_nodes = OrderedDict() self.touch_last_updated() self.ksize = ksize @@ -26,21 +26,23 @@ class KBucket: midpoint = (self.range[0] + self.range[1]) / 2 one = KBucket(self.range[0], midpoint, self.ksize) two = KBucket(midpoint + 1, self.range[1], self.ksize) - for node in chain(self.nodes.values(), self.replacement_nodes): + nodes = chain(self.nodes.values(), self.replacement_nodes.values()) + for node in nodes: bucket = one if node.long_id <= midpoint else two bucket.add_node(node) return (one, two) def remove_node(self, node): - if node.id not in self.nodes: - return + if node.id in self.replacement_nodes: + del self.replacement_nodes[node.id] - # delete node, and see if we can add a replacement - del self.nodes[node.id] - if self.replacement_nodes: - newnode = self.replacement_nodes.pop() - self.nodes[newnode.id] = newnode + if node.id in self.nodes: + del self.nodes[node.id] + + if self.replacement_nodes: + newnode_id, newnode = self.replacement_nodes.popitem() + self.nodes[newnode_id] = newnode def has_in_range(self, node): return self.range[0] <= node.long_id <= self.range[1] @@ -62,7 +64,9 @@ class KBucket: elif len(self) < self.ksize: self.nodes[node.id] = node else: - self.replacement_nodes.push(node) + if node.id in self.replacement_nodes: + del self.replacement_nodes[node.id] + self.replacement_nodes[node.id] = node return False return True diff --git a/kademlia/tests/test_routing.py b/kademlia/tests/test_routing.py index 0523127..4cf84eb 100644 --- a/kademlia/tests/test_routing.py +++ b/kademlia/tests/test_routing.py @@ -1,5 +1,6 @@ import unittest +from random import shuffle from kademlia.routing import KBucket, TableTraverser from kademlia.tests.utils import mknode, FakeProtocol @@ -31,6 +32,31 @@ class KBucketTest(unittest.TestCase): for index, node in enumerate(bucket.get_nodes()): self.assertEqual(node, nodes[index]) + def test_remove_node(self): + k = 3 + bucket = KBucket(0, 10, k) + nodes = [mknode() for _ in range(10)] + for node in nodes: + bucket.add_node(node) + + replacement_nodes = bucket.replacement_nodes + self.assertEqual(list(bucket.nodes.values()), nodes[:k]) + self.assertEqual(list(replacement_nodes.values()), nodes[k:]) + + bucket.remove_node(nodes.pop()) + self.assertEqual(list(bucket.nodes.values()), nodes[:k]) + self.assertEqual(list(replacement_nodes.values()), nodes[k:]) + + bucket.remove_node(nodes.pop(0)) + self.assertEqual(list(bucket.nodes.values()), nodes[:k-1] + nodes[-1:]) + self.assertEqual(list(replacement_nodes.values()), nodes[k-1:-1]) + + shuffle(nodes) + for node in nodes: + bucket.remove_node(node) + self.assertEqual(len(bucket), 0) + self.assertEqual(len(replacement_nodes), 0) + def test_in_range(self): bucket = KBucket(0, 10, 10) self.assertTrue(bucket.has_in_range(mknode(intid=5))) diff --git a/kademlia/tests/test_utils.py b/kademlia/tests/test_utils.py index 6caa25c..9fcdd86 100644 --- a/kademlia/tests/test_utils.py +++ b/kademlia/tests/test_utils.py @@ -1,7 +1,7 @@ import hashlib import unittest -from kademlia.utils import digest, shared_prefix, OrderedSet +from kademlia.utils import digest, shared_prefix class UtilsTest(unittest.TestCase): @@ -24,13 +24,3 @@ class UtilsTest(unittest.TestCase): args = ['hi'] self.assertEqual(shared_prefix(args), 'hi') - - -class OrderedSetTest(unittest.TestCase): - def test_order(self): - oset = OrderedSet() - oset.push('1') - oset.push('1') - oset.push('2') - oset.push('1') - self.assertEqual(oset, ['2', '1']) diff --git a/kademlia/utils.py b/kademlia/utils.py index da90085..319c53e 100644 --- a/kademlia/utils.py +++ b/kademlia/utils.py @@ -18,22 +18,6 @@ def digest(string): return hashlib.sha1(string).digest() -class OrderedSet(list): - """ - Acts like a list in all ways, except in the behavior of the - :meth:`push` method. - """ - - def push(self, thing): - """ - 1. If the item exists in the list, it's removed - 2. The item is pushed to the end of the list - """ - if thing in self: - self.remove(thing) - self.append(thing) - - def shared_prefix(args): """ Find the shared prefix between the strings.