removed all camelcase silliness, bumped version to 2.0

This commit is contained in:
Brian Muller 2019-01-09 11:27:10 -05:00
parent dbe41e3b08
commit cf9d490d64
16 changed files with 306 additions and 279 deletions

View File

@ -497,7 +497,7 @@ valid-metaclass-classmethod-first-arg=mcs
[DESIGN]
# Maximum number of arguments for function / method
max-args=5
max-args=6
# Maximum number of attributes for a class (see R0902).
max-attributes=7

View File

@ -1,6 +1,6 @@
language: python
python:
- "3.5"
- "3.6"
- "3.7"
install: pip install . && pip install -r dev-requirements.txt
script: python -m unittest
script: py.test

View File

@ -1,5 +1,6 @@
pycodestyle==2.3.1
pylint==1.8.1
pycodestyle>=2.4.0
pylint>=2.2.2
sphinx>=1.6.5
sphinxcontrib-napoleon>=0.6.1
sphinxcontrib-zopeext>=0.2.1
py.test>=4.1.0

View File

@ -2,4 +2,4 @@
Kademlia is a Python implementation of the Kademlia protocol which
utilizes the asyncio library.
"""
__version__ = "1.1"
__version__ = "2.0"

View File

@ -4,10 +4,12 @@ import logging
from kademlia.node import Node, NodeHeap
from kademlia.utils import gather_dict
log = logging.getLogger(__name__)
log = logging.getLogger(__name__) # pylint: disable=invalid-name
class SpiderCrawl(object):
# pylint: disable=too-few-public-methods
class SpiderCrawl:
"""
Crawl the network and look for given 160-bit keys.
"""
@ -29,7 +31,7 @@ class SpiderCrawl(object):
self.alpha = alpha
self.node = node
self.nearest = NodeHeap(self.node, self.ksize)
self.lastIDsCrawled = []
self.last_ids_crawled = []
log.info("creating spider with peers: %s", peers)
self.nearest.push(peers)
@ -38,7 +40,7 @@ class SpiderCrawl(object):
Get either a value or list of nodes.
Args:
rpcmethod: The protocol's callfindValue or callFindNode.
rpcmethod: The protocol's callfindValue or call_find_node.
The process:
1. calls find_* to current ALPHA nearest not already queried nodes,
@ -51,18 +53,18 @@ class SpiderCrawl(object):
"""
log.info("crawling network with nearest: %s", str(tuple(self.nearest)))
count = self.alpha
if self.nearest.getIDs() == self.lastIDsCrawled:
if self.nearest.get_ids() == self.last_ids_crawled:
count = len(self.nearest)
self.lastIDsCrawled = self.nearest.getIDs()
self.last_ids_crawled = self.nearest.get_ids()
ds = {}
for peer in self.nearest.getUncontacted()[:count]:
ds[peer.id] = rpcmethod(peer, self.node)
self.nearest.markContacted(peer)
found = await gather_dict(ds)
return await self._nodesFound(found)
dicts = {}
for peer in self.nearest.get_uncontacted()[:count]:
dicts[peer.id] = rpcmethod(peer, self.node)
self.nearest.mark_contacted(peer)
found = await gather_dict(dicts)
return await self._nodes_found(found)
async def _nodesFound(self, responses):
async def _nodes_found(self, responses):
raise NotImplementedError
@ -71,55 +73,55 @@ class ValueSpiderCrawl(SpiderCrawl):
SpiderCrawl.__init__(self, protocol, node, peers, ksize, alpha)
# keep track of the single nearest node without value - per
# section 2.3 so we can set the key there if found
self.nearestWithoutValue = NodeHeap(self.node, 1)
self.nearest_without_value = NodeHeap(self.node, 1)
async def find(self):
"""
Find either the closest nodes or the value requested.
"""
return await self._find(self.protocol.callFindValue)
return await self._find(self.protocol.call_find_value)
async def _nodesFound(self, responses):
async def _nodes_found(self, responses):
"""
Handle the result of an iteration in _find.
"""
toremove = []
foundValues = []
found_values = []
for peerid, response in responses.items():
response = RPCFindResponse(response)
if not response.happened():
toremove.append(peerid)
elif response.hasValue():
foundValues.append(response.getValue())
elif response.has_value():
found_values.append(response.get_value())
else:
peer = self.nearest.getNodeById(peerid)
self.nearestWithoutValue.push(peer)
self.nearest.push(response.getNodeList())
peer = self.nearest.get_node(peerid)
self.nearest_without_value.push(peer)
self.nearest.push(response.get_node_list())
self.nearest.remove(toremove)
if len(foundValues) > 0:
return await self._handleFoundValues(foundValues)
if self.nearest.allBeenContacted():
if found_values:
return await self._handle_found_values(found_values)
if self.nearest.have_contacted_all():
# not found!
return None
return await self.find()
async def _handleFoundValues(self, values):
async def _handle_found_values(self, values):
"""
We got some values! Exciting. But let's make sure
they're all the same or freak out a little bit. Also,
make sure we tell the nearest node that *didn't* have
the value to store it.
"""
valueCounts = Counter(values)
if len(valueCounts) != 1:
value_counts = Counter(values)
if len(value_counts) != 1:
log.warning("Got multiple values for key %i: %s",
self.node.long_id, str(values))
value = valueCounts.most_common(1)[0][0]
value = value_counts.most_common(1)[0][0]
peerToSaveTo = self.nearestWithoutValue.popleft()
if peerToSaveTo is not None:
await self.protocol.callStore(peerToSaveTo, self.node.id, value)
peer = self.nearest_without_value.popleft()
if peer:
await self.protocol.call_store(peer, self.node.id, value)
return value
@ -128,9 +130,9 @@ class NodeSpiderCrawl(SpiderCrawl):
"""
Find the closest nodes.
"""
return await self._find(self.protocol.callFindNode)
return await self._find(self.protocol.call_find_node)
async def _nodesFound(self, responses):
async def _nodes_found(self, responses):
"""
Handle the result of an iteration in _find.
"""
@ -140,15 +142,15 @@ class NodeSpiderCrawl(SpiderCrawl):
if not response.happened():
toremove.append(peerid)
else:
self.nearest.push(response.getNodeList())
self.nearest.push(response.get_node_list())
self.nearest.remove(toremove)
if self.nearest.allBeenContacted():
if self.nearest.have_contacted_all():
return list(self.nearest)
return await self.find()
class RPCFindResponse(object):
class RPCFindResponse:
def __init__(self, response):
"""
A wrapper for the result of a RPC find.
@ -166,13 +168,13 @@ class RPCFindResponse(object):
"""
return self.response[0]
def hasValue(self):
def has_value(self):
return isinstance(self.response[1], dict)
def getValue(self):
def get_value(self):
return self.response[1]['value']
def getNodeList(self):
def get_node_list(self):
"""
Get the node list in the response. If there's no value, this should
be set.

View File

@ -13,10 +13,11 @@ from kademlia.node import Node
from kademlia.crawling import ValueSpiderCrawl
from kademlia.crawling import NodeSpiderCrawl
log = logging.getLogger(__name__)
log = logging.getLogger(__name__) # pylint: disable=invalid-name
class Server(object):
# pylint: disable=too-many-instance-attributes
class Server:
"""
High level view of a node instance. This is the object that should be
created to start listening as an active node on the network.
@ -83,22 +84,22 @@ class Server(object):
Refresh buckets that haven't had any lookups in the last hour
(per section 2.3 of the paper).
"""
ds = []
for node_id in self.protocol.getRefreshIDs():
results = []
for node_id in self.protocol.get_refresh_ids():
node = Node(node_id)
nearest = self.protocol.router.findNeighbors(node, self.alpha)
nearest = self.protocol.router.find_neighbors(node, self.alpha)
spider = NodeSpiderCrawl(self.protocol, node, nearest,
self.ksize, self.alpha)
ds.append(spider.find())
results.append(spider.find())
# do our crawling
await asyncio.gather(*ds)
await asyncio.gather(*results)
# now republish keys older than one hour
for dkey, value in self.storage.iteritemsOlderThan(3600):
for dkey, value in self.storage.iter_older_than(3600):
await self.set_digest(dkey, value)
def bootstrappableNeighbors(self):
def bootstrappable_neighbors(self):
"""
Get a :class:`list` of (ip, port) :class:`tuple` pairs suitable for
use as an argument to the bootstrap method.
@ -108,7 +109,7 @@ class Server(object):
storing them if this server is going down for a while. When it comes
back up, the list of nodes can be used to bootstrap.
"""
neighbors = self.protocol.router.findNeighbors(self.node)
neighbors = self.protocol.router.find_neighbors(self.node)
return [tuple(n)[-2:] for n in neighbors]
async def bootstrap(self, addrs):
@ -145,8 +146,8 @@ class Server(object):
if self.storage.get(dkey) is not None:
return self.storage.get(dkey)
node = Node(dkey)
nearest = self.protocol.router.findNeighbors(node)
if len(nearest) == 0:
nearest = self.protocol.router.find_neighbors(node)
if not nearest:
log.warning("There are no known neighbors to get key %s", key)
return None
spider = ValueSpiderCrawl(self.protocol, node, nearest,
@ -172,8 +173,8 @@ class Server(object):
"""
node = Node(dkey)
nearest = self.protocol.router.findNeighbors(node)
if len(nearest) == 0:
nearest = self.protocol.router.find_neighbors(node)
if not nearest:
log.warning("There are no known neighbors to set key %s",
dkey.hex())
return False
@ -184,14 +185,14 @@ class Server(object):
log.info("setting '%s' on %s", dkey.hex(), list(map(str, nodes)))
# if this node is close too, then store here as well
biggest = max([n.distanceTo(node) for n in nodes])
if self.node.distanceTo(node) < biggest:
biggest = max([n.distance_to(node) for n in nodes])
if self.node.distance_to(node) < biggest:
self.storage[dkey] = value
ds = [self.protocol.callStore(n, dkey, value) for n in nodes]
results = [self.protocol.call_store(n, dkey, value) for n in nodes]
# return true only if at least one store call succeeded
return any(await asyncio.gather(*ds))
return any(await asyncio.gather(*results))
def saveState(self, fname):
def save_state(self, fname):
"""
Save the state of this node (the alpha/ksize/id/immediate neighbors)
to a cache file with the given fname.
@ -201,29 +202,29 @@ class Server(object):
'ksize': self.ksize,
'alpha': self.alpha,
'id': self.node.id,
'neighbors': self.bootstrappableNeighbors()
'neighbors': self.bootstrappable_neighbors()
}
if len(data['neighbors']) == 0:
if not data['neighbors']:
log.warning("No known neighbors, so not writing to cache.")
return
with open(fname, 'wb') as f:
pickle.dump(data, f)
with open(fname, 'wb') as file:
pickle.dump(data, file)
@classmethod
def loadState(self, fname):
def load_state(cls, fname):
"""
Load the state of this node (the alpha/ksize/id/immediate neighbors)
from a cache file with the given fname.
"""
log.info("Loading state from %s", fname)
with open(fname, 'rb') as f:
data = pickle.load(f)
s = Server(data['ksize'], data['alpha'], data['id'])
if len(data['neighbors']) > 0:
s.bootstrap(data['neighbors'])
return s
with open(fname, 'rb') as file:
data = pickle.load(file)
svr = Server(data['ksize'], data['alpha'], data['id'])
if data['neighbors']:
svr.bootstrap(data['neighbors'])
return svr
def saveStateRegularly(self, fname, frequency=600):
def save_state_regularly(self, fname, frequency=600):
"""
Save the state of node with a given regularity to the given
filename.
@ -233,10 +234,10 @@ class Server(object):
frequency: Frequency in seconds that the state should be saved.
By default, 10 minutes.
"""
self.saveState(fname)
self.save_state(fname)
loop = asyncio.get_event_loop()
self.save_state_loop = loop.call_later(frequency,
self.saveStateRegularly,
self.save_state_regularly,
fname,
frequency)
@ -246,13 +247,11 @@ def check_dht_value_type(value):
Checks to see if the type of the value is a valid type for
placing in the dht.
"""
typeset = set(
[
int,
float,
bool,
str,
bytes,
]
)
return type(value) in typeset
typeset = [
int,
float,
bool,
str,
bytes
]
return type(value) in typeset # pylint: disable=unidiomatic-typecheck

View File

@ -4,15 +4,15 @@ import heapq
class Node:
def __init__(self, node_id, ip=None, port=None):
self.id = node_id
self.ip = ip
self.id = node_id # pylint: disable=invalid-name
self.ip = ip # pylint: disable=invalid-name
self.port = port
self.long_id = int(node_id.hex(), 16)
def sameHomeAs(self, node):
def same_home_as(self, node):
return self.ip == node.ip and self.port == node.port
def distanceTo(self, node):
def distance_to(self, node):
"""
Get the distance between this node and another.
"""
@ -31,7 +31,7 @@ class Node:
return "%s:%s" % (self.ip, str(self.port))
class NodeHeap(object):
class NodeHeap:
"""
A heap of nodes ordered by distance to a given node.
"""
@ -47,7 +47,7 @@ class NodeHeap(object):
self.contacted = set()
self.maxsize = maxsize
def remove(self, peerIDs):
def remove(self, peers):
"""
Remove a list of peer ids from this heap. Note that while this
heap retains a constant visible size (based on the iterator), it's
@ -55,34 +55,32 @@ class NodeHeap(object):
removal of nodes may not change the visible size as previously added
nodes suddenly become visible.
"""
peerIDs = set(peerIDs)
if len(peerIDs) == 0:
peers = set(peers)
if not peers:
return
nheap = []
for distance, node in self.heap:
if node.id not in peerIDs:
if node.id not in peers:
heapq.heappush(nheap, (distance, node))
self.heap = nheap
def getNodeById(self, node_id):
def get_node(self, node_id):
for _, node in self.heap:
if node.id == node_id:
return node
return None
def allBeenContacted(self):
return len(self.getUncontacted()) == 0
def have_contacted_all(self):
return len(self.get_uncontacted()) == 0
def getIDs(self):
def get_ids(self):
return [n.id for n in self]
def markContacted(self, node):
def mark_contacted(self, node):
self.contacted.add(node.id)
def popleft(self):
if len(self) > 0:
return heapq.heappop(self.heap)[1]
return None
return heapq.heappop(self.heap)[1] if self else None
def push(self, nodes):
"""
@ -95,7 +93,7 @@ class NodeHeap(object):
for node in nodes:
if node not in self:
distance = self.node.distanceTo(node)
distance = self.node.distance_to(node)
heapq.heappush(self.heap, (distance, node))
def __len__(self):
@ -106,10 +104,10 @@ class NodeHeap(object):
return iter(map(itemgetter(1), nodes))
def __contains__(self, node):
for _, n in self.heap:
if node.id == n.id:
for _, other in self.heap:
if node.id == other.id:
return True
return False
def getUncontacted(self):
def get_uncontacted(self):
return [n for n in self if n.id not in self.contacted]

View File

@ -8,37 +8,37 @@ from kademlia.node import Node
from kademlia.routing import RoutingTable
from kademlia.utils import digest
log = logging.getLogger(__name__)
log = logging.getLogger(__name__) # pylint: disable=invalid-name
class KademliaProtocol(RPCProtocol):
def __init__(self, sourceNode, storage, ksize):
def __init__(self, source_node, storage, ksize):
RPCProtocol.__init__(self)
self.router = RoutingTable(self, ksize, sourceNode)
self.router = RoutingTable(self, ksize, source_node)
self.storage = storage
self.sourceNode = sourceNode
self.source_node = source_node
def getRefreshIDs(self):
def get_refresh_ids(self):
"""
Get ids to search for to keep old buckets up to date.
"""
ids = []
for bucket in self.router.getLonelyBuckets():
for bucket in self.router.lonely_buckets():
rid = random.randint(*bucket.range).to_bytes(20, byteorder='big')
ids.append(rid)
return ids
def rpc_stun(self, sender):
def rpc_stun(self, sender): # pylint: disable=no-self-use
return sender
def rpc_ping(self, sender, nodeid):
source = Node(nodeid, sender[0], sender[1])
self.welcomeIfNewNode(source)
return self.sourceNode.id
self.welcome_if_new(source)
return self.source_node.id
def rpc_store(self, sender, nodeid, key, value):
source = Node(nodeid, sender[0], sender[1])
self.welcomeIfNewNode(source)
self.welcome_if_new(source)
log.debug("got a store request from %s, storing '%s'='%s'",
sender, key.hex(), value)
self.storage[key] = value
@ -48,42 +48,42 @@ class KademliaProtocol(RPCProtocol):
log.info("finding neighbors of %i in local table",
int(nodeid.hex(), 16))
source = Node(nodeid, sender[0], sender[1])
self.welcomeIfNewNode(source)
self.welcome_if_new(source)
node = Node(key)
neighbors = self.router.findNeighbors(node, exclude=source)
neighbors = self.router.find_neighbors(node, exclude=source)
return list(map(tuple, neighbors))
def rpc_find_value(self, sender, nodeid, key):
source = Node(nodeid, sender[0], sender[1])
self.welcomeIfNewNode(source)
self.welcome_if_new(source)
value = self.storage.get(key, None)
if value is None:
return self.rpc_find_node(sender, nodeid, key)
return {'value': value}
async def callFindNode(self, nodeToAsk, nodeToFind):
address = (nodeToAsk.ip, nodeToAsk.port)
result = await self.find_node(address, self.sourceNode.id,
nodeToFind.id)
return self.handleCallResponse(result, nodeToAsk)
async def call_find_node(self, node_to_ask, node_to_find):
address = (node_to_ask.ip, node_to_ask.port)
result = await self.find_node(address, self.source_node.id,
node_to_find.id)
return self.handle_call_response(result, node_to_ask)
async def callFindValue(self, nodeToAsk, nodeToFind):
address = (nodeToAsk.ip, nodeToAsk.port)
result = await self.find_value(address, self.sourceNode.id,
nodeToFind.id)
return self.handleCallResponse(result, nodeToAsk)
async def call_find_value(self, node_to_ask, node_to_find):
address = (node_to_ask.ip, node_to_ask.port)
result = await self.find_value(address, self.source_node.id,
node_to_find.id)
return self.handle_call_response(result, node_to_ask)
async def callPing(self, nodeToAsk):
address = (nodeToAsk.ip, nodeToAsk.port)
result = await self.ping(address, self.sourceNode.id)
return self.handleCallResponse(result, nodeToAsk)
async def call_ping(self, node_to_ask):
address = (node_to_ask.ip, node_to_ask.port)
result = await self.ping(address, self.source_node.id)
return self.handle_call_response(result, node_to_ask)
async def callStore(self, nodeToAsk, key, value):
address = (nodeToAsk.ip, nodeToAsk.port)
result = await self.store(address, self.sourceNode.id, key, value)
return self.handleCallResponse(result, nodeToAsk)
async def call_store(self, node_to_ask, key, value):
address = (node_to_ask.ip, node_to_ask.port)
result = await self.store(address, self.source_node.id, key, value)
return self.handle_call_response(result, node_to_ask)
def welcomeIfNewNode(self, node):
def welcome_if_new(self, node):
"""
Given a new node, send it all the keys/values it should be storing,
then add it to the routing table.
@ -97,32 +97,32 @@ class KademliaProtocol(RPCProtocol):
is closer than the closest in that list, then store the key/value
on the new node (per section 2.5 of the paper)
"""
if not self.router.isNewNode(node):
if not self.router.is_new_node(node):
return
log.info("never seen %s before, adding to router", node)
for key, value in self.storage.items():
for key, value in self.storage:
keynode = Node(digest(key))
neighbors = self.router.findNeighbors(keynode)
if len(neighbors) > 0:
last = neighbors[-1].distanceTo(keynode)
newNodeClose = node.distanceTo(keynode) < last
first = neighbors[0].distanceTo(keynode)
thisNodeClosest = self.sourceNode.distanceTo(keynode) < first
if len(neighbors) == 0 or (newNodeClose and thisNodeClosest):
asyncio.ensure_future(self.callStore(node, key, value))
self.router.addContact(node)
neighbors = self.router.find_neighbors(keynode)
if neighbors:
last = neighbors[-1].distance_to(keynode)
new_node_close = node.distance_to(keynode) < last
first = neighbors[0].distance_to(keynode)
this_closest = self.source_node.distance_to(keynode) < first
if not neighbors or (new_node_close and this_closest):
asyncio.ensure_future(self.call_store(node, key, value))
self.router.add_contact(node)
def handleCallResponse(self, result, node):
def handle_call_response(self, result, node):
"""
If we get a response, add the node to the routing table. If
we get no response, make sure it's removed from the routing table.
"""
if not result[0]:
log.warning("no response from %s, removing from router", node)
self.router.removeContact(node)
self.router.remove_contact(node)
return result
log.info("got successful response from %s", node)
self.welcomeIfNewNode(node)
self.welcome_if_new(node)
return result

View File

@ -4,21 +4,21 @@ import operator
import asyncio
from collections import OrderedDict
from kademlia.utils import OrderedSet, sharedPrefix, bytesToBitString
from kademlia.utils import OrderedSet, shared_prefix, bytes_to_bit_string
class KBucket(object):
class KBucket:
def __init__(self, rangeLower, rangeUpper, ksize):
self.range = (rangeLower, rangeUpper)
self.nodes = OrderedDict()
self.replacementNodes = OrderedSet()
self.touchLastUpdated()
self.replacement_nodes = OrderedSet()
self.touch_last_updated()
self.ksize = ksize
def touchLastUpdated(self):
self.lastUpdated = time.monotonic()
def touch_last_updated(self):
self.last_updated = time.monotonic()
def getNodes(self):
def get_nodes(self):
return list(self.nodes.values())
def split(self):
@ -30,23 +30,23 @@ class KBucket(object):
bucket.nodes[node.id] = node
return (one, two)
def removeNode(self, node):
def remove_node(self, node):
if node.id not in self.nodes:
return
# delete node, and see if we can add a replacement
del self.nodes[node.id]
if len(self.replacementNodes) > 0:
newnode = self.replacementNodes.pop()
if self.replacement_nodes:
newnode = self.replacement_nodes.pop()
self.nodes[newnode.id] = newnode
def hasInRange(self, node):
def has_in_range(self, node):
return self.range[0] <= node.long_id <= self.range[1]
def isNewNode(self, node):
def is_new_node(self, node):
return node.id not in self.nodes
def addNode(self, node):
def add_node(self, node):
"""
Add a C{Node} to the C{KBucket}. Return True if successful,
False if the bucket is full.
@ -60,14 +60,14 @@ class KBucket(object):
elif len(self) < self.ksize:
self.nodes[node.id] = node
else:
self.replacementNodes.push(node)
self.replacement_nodes.push(node)
return False
return True
def depth(self):
vals = self.nodes.values()
sp = sharedPrefix([bytesToBitString(n.id) for n in vals])
return len(sp)
sprefix = shared_prefix([bytes_to_bit_string(n.id) for n in vals])
return len(sprefix)
def head(self):
return list(self.nodes.values())[0]
@ -79,13 +79,13 @@ class KBucket(object):
return len(self.nodes)
class TableTraverser(object):
class TableTraverser:
def __init__(self, table, startNode):
index = table.getBucketFor(startNode)
table.buckets[index].touchLastUpdated()
self.currentNodes = table.buckets[index].getNodes()
self.leftBuckets = table.buckets[:index]
self.rightBuckets = table.buckets[(index + 1):]
index = table.get_bucket_for(startNode)
table.buckets[index].touch_last_updated()
self.current_nodes = table.buckets[index].get_nodes()
self.left_buckets = table.buckets[:index]
self.right_buckets = table.buckets[(index + 1):]
self.left = True
def __iter__(self):
@ -95,23 +95,23 @@ class TableTraverser(object):
"""
Pop an item from the left subtree, then right, then left, etc.
"""
if len(self.currentNodes) > 0:
return self.currentNodes.pop()
if self.current_nodes:
return self.current_nodes.pop()
if self.left and len(self.leftBuckets) > 0:
self.currentNodes = self.leftBuckets.pop().getNodes()
if self.left and self.left_buckets:
self.current_nodes = self.left_buckets.pop().get_nodes()
self.left = False
return next(self)
if len(self.rightBuckets) > 0:
self.currentNodes = self.rightBuckets.pop(0).getNodes()
if self.right_buckets:
self.current_nodes = self.right_buckets.pop(0).get_nodes()
self.left = True
return next(self)
raise StopIteration
class RoutingTable(object):
class RoutingTable:
def __init__(self, protocol, ksize, node):
"""
@param node: The node that represents this server. It won't
@ -126,58 +126,60 @@ class RoutingTable(object):
def flush(self):
self.buckets = [KBucket(0, 2 ** 160, self.ksize)]
def splitBucket(self, index):
def split_bucket(self, index):
one, two = self.buckets[index].split()
self.buckets[index] = one
self.buckets.insert(index + 1, two)
def getLonelyBuckets(self):
def lonely_buckets(self):
"""
Get all of the buckets that haven't been updated in over
an hour.
"""
hrago = time.monotonic() - 3600
return [b for b in self.buckets if b.lastUpdated < hrago]
return [b for b in self.buckets if b.last_updated < hrago]
def removeContact(self, node):
index = self.getBucketFor(node)
self.buckets[index].removeNode(node)
def remove_contact(self, node):
index = self.get_bucket_for(node)
self.buckets[index].remove_node(node)
def isNewNode(self, node):
index = self.getBucketFor(node)
return self.buckets[index].isNewNode(node)
def is_new_node(self, node):
index = self.get_bucket_for(node)
return self.buckets[index].is_new_node(node)
def addContact(self, node):
index = self.getBucketFor(node)
def add_contact(self, node):
index = self.get_bucket_for(node)
bucket = self.buckets[index]
# this will succeed unless the bucket is full
if bucket.addNode(node):
if bucket.add_node(node):
return
# Per section 4.2 of paper, split if the bucket has the node
# in its range or if the depth is not congruent to 0 mod 5
if bucket.hasInRange(self.node) or bucket.depth() % 5 != 0:
self.splitBucket(index)
self.addContact(node)
if bucket.has_in_range(self.node) or bucket.depth() % 5 != 0:
self.split_bucket(index)
self.add_contact(node)
else:
asyncio.ensure_future(self.protocol.callPing(bucket.head()))
asyncio.ensure_future(self.protocol.call_ping(bucket.head()))
def getBucketFor(self, node):
def get_bucket_for(self, node):
"""
Get the index of the bucket that the given node would fall into.
"""
for index, bucket in enumerate(self.buckets):
if node.long_id < bucket.range[1]:
return index
# we should never be here, but make linter happy
return None
def findNeighbors(self, node, k=None, exclude=None):
def find_neighbors(self, node, k=None, exclude=None):
k = k or self.ksize
nodes = []
for neighbor in TableTraverser(self, node):
notexcluded = exclude is None or not neighbor.sameHomeAs(exclude)
notexcluded = exclude is None or not neighbor.same_home_as(exclude)
if neighbor.id != node.id and notexcluded:
heapq.heappush(nodes, (node.distanceTo(neighbor), neighbor))
heapq.heappush(nodes, (node.distance_to(neighbor), neighbor))
if len(nodes) == k:
break

View File

@ -28,7 +28,7 @@ class IStorage:
"""
raise NotImplementedError
def iteritemsOlderThan(self, secondsOld):
def iter_older_than(self, seconds_old):
"""
Return the an iterator over (key, value) tuples for items older
than the given secondsOld.
@ -57,7 +57,7 @@ class ForgetfulStorage(IStorage):
self.cull()
def cull(self):
for _, _ in self.iteritemsOlderThan(self.ttl):
for _, _ in self.iter_older_than(self.ttl):
self.data.popitem(last=False)
def get(self, key, default=None):
@ -70,27 +70,23 @@ class ForgetfulStorage(IStorage):
self.cull()
return self.data[key][1]
def __iter__(self):
self.cull()
return iter(self.data)
def __repr__(self):
self.cull()
return repr(self.data)
def iteritemsOlderThan(self, secondsOld):
minBirthday = time.monotonic() - secondsOld
zipped = self._tripleIterable()
matches = takewhile(lambda r: minBirthday >= r[1], zipped)
def iter_older_than(self, seconds_old):
min_birthday = time.monotonic() - seconds_old
zipped = self._triple_iter()
matches = takewhile(lambda r: min_birthday >= r[1], zipped)
return list(map(operator.itemgetter(0, 2), matches))
def _tripleIterable(self):
def _triple_iter(self):
ikeys = self.data.keys()
ibirthday = map(operator.itemgetter(0), self.data.values())
ivalues = map(operator.itemgetter(1), self.data.values())
return zip(ikeys, ibirthday, ivalues)
def items(self):
def __iter__(self):
self.cull()
ikeys = self.data.keys()
ivalues = map(operator.itemgetter(1), self.data.values())

View File

@ -8,31 +8,31 @@ from kademlia.tests.utils import mknode
class NodeTest(unittest.TestCase):
def test_longID(self):
def test_long_id(self):
rid = hashlib.sha1(str(random.getrandbits(255)).encode()).digest()
n = Node(rid)
self.assertEqual(n.long_id, int(rid.hex(), 16))
node = Node(rid)
self.assertEqual(node.long_id, int(rid.hex(), 16))
def test_distanceCalculation(self):
def test_distance_calculation(self):
ridone = hashlib.sha1(str(random.getrandbits(255)).encode())
ridtwo = hashlib.sha1(str(random.getrandbits(255)).encode())
shouldbe = int(ridone.hexdigest(), 16) ^ int(ridtwo.hexdigest(), 16)
none = Node(ridone.digest())
ntwo = Node(ridtwo.digest())
self.assertEqual(none.distanceTo(ntwo), shouldbe)
self.assertEqual(none.distance_to(ntwo), shouldbe)
class NodeHeapTest(unittest.TestCase):
def test_maxSize(self):
n = NodeHeap(mknode(intid=0), 3)
self.assertEqual(0, len(n))
def test_max_size(self):
node = NodeHeap(mknode(intid=0), 3)
self.assertEqual(0, len(node))
for d in range(10):
n.push(mknode(intid=d))
self.assertEqual(3, len(n))
for digit in range(10):
node.push(mknode(intid=digit))
self.assertEqual(3, len(node))
self.assertEqual(3, len(list(n)))
self.assertEqual(3, len(list(node)))
def test_iteration(self):
heap = NodeHeap(mknode(intid=0), 5)

View File

@ -7,53 +7,53 @@ from kademlia.tests.utils import mknode, FakeProtocol
class KBucketTest(unittest.TestCase):
def test_split(self):
bucket = KBucket(0, 10, 5)
bucket.addNode(mknode(intid=5))
bucket.addNode(mknode(intid=6))
bucket.add_node(mknode(intid=5))
bucket.add_node(mknode(intid=6))
one, two = bucket.split()
self.assertEqual(len(one), 1)
self.assertEqual(one.range, (0, 5))
self.assertEqual(len(two), 1)
self.assertEqual(two.range, (6, 10))
def test_addNode(self):
def test_add_node(self):
# when full, return false
bucket = KBucket(0, 10, 2)
self.assertTrue(bucket.addNode(mknode()))
self.assertTrue(bucket.addNode(mknode()))
self.assertFalse(bucket.addNode(mknode()))
self.assertTrue(bucket.add_node(mknode()))
self.assertTrue(bucket.add_node(mknode()))
self.assertFalse(bucket.add_node(mknode()))
self.assertEqual(len(bucket), 2)
# make sure when a node is double added it's put at the end
bucket = KBucket(0, 10, 3)
nodes = [mknode(), mknode(), mknode()]
for node in nodes:
bucket.addNode(node)
for index, node in enumerate(bucket.getNodes()):
bucket.add_node(node)
for index, node in enumerate(bucket.get_nodes()):
self.assertEqual(node, nodes[index])
def test_inRange(self):
def test_in_range(self):
bucket = KBucket(0, 10, 10)
self.assertTrue(bucket.hasInRange(mknode(intid=5)))
self.assertFalse(bucket.hasInRange(mknode(intid=11)))
self.assertTrue(bucket.hasInRange(mknode(intid=10)))
self.assertTrue(bucket.hasInRange(mknode(intid=0)))
self.assertTrue(bucket.has_in_range(mknode(intid=5)))
self.assertFalse(bucket.has_in_range(mknode(intid=11)))
self.assertTrue(bucket.has_in_range(mknode(intid=10)))
self.assertTrue(bucket.has_in_range(mknode(intid=0)))
class RoutingTableTest(unittest.TestCase):
def setUp(self):
self.id = mknode().id
self.id = mknode().id # pylint: disable=invalid-name
self.protocol = FakeProtocol(self.id)
self.router = self.protocol.router
def test_addContact(self):
self.router.addContact(mknode())
def test_add_contact(self):
self.router.add_contact(mknode())
self.assertTrue(len(self.router.buckets), 1)
self.assertTrue(len(self.router.buckets[0].nodes), 1)
class TableTraverserTest(unittest.TestCase):
def setUp(self):
self.id = mknode().id
self.id = mknode().id # pylint: disable=invalid-name
self.protocol = FakeProtocol(self.id)
self.router = self.protocol.router
@ -70,8 +70,8 @@ class TableTraverserTest(unittest.TestCase):
buckets = []
for i in range(5):
bucket = KBucket(2 * i, 2 * i + 1, 2)
bucket.addNode(nodes[2 * i])
bucket.addNode(nodes[2 * i + 1])
bucket.add_node(nodes[2 * i])
bucket.add_node(nodes[2 * i + 1])
buckets.append(bucket)
# replace router's bucket with our test buckets

View File

@ -0,0 +1,29 @@
import unittest
from kademlia.storage import ForgetfulStorage
class ForgetfulStorageTest(unittest.TestCase):
def test_storing(self):
storage = ForgetfulStorage(10)
storage['one'] = 'two'
self.assertEqual(storage['one'], 'two')
def test_forgetting(self):
storage = ForgetfulStorage(0)
storage['one'] = 'two'
self.assertEqual(storage.get('one'), None)
def test_iter(self):
storage = ForgetfulStorage(10)
storage['one'] = 'two'
for key, value in storage:
self.assertEqual(key, 'one')
self.assertEqual(value, 'two')
def test_iter_old(self):
storage = ForgetfulStorage(10)
storage['one'] = 'two'
for key, value in storage.iter_older_than(0):
self.assertEqual(key, 'one')
self.assertEqual(value, 'two')

View File

@ -1,36 +1,36 @@
import hashlib
import unittest
from kademlia.utils import digest, sharedPrefix, OrderedSet
from kademlia.utils import digest, shared_prefix, OrderedSet
class UtilsTest(unittest.TestCase):
def test_digest(self):
d = hashlib.sha1(b'1').digest()
self.assertEqual(d, digest(1))
dig = hashlib.sha1(b'1').digest()
self.assertEqual(dig, digest(1))
d = hashlib.sha1(b'another').digest()
self.assertEqual(d, digest('another'))
dig = hashlib.sha1(b'another').digest()
self.assertEqual(dig, digest('another'))
def test_sharedPrefix(self):
def test_shared_prefix(self):
args = ['prefix', 'prefixasdf', 'prefix', 'prefixxxx']
self.assertEqual(sharedPrefix(args), 'prefix')
self.assertEqual(shared_prefix(args), 'prefix')
args = ['p', 'prefixasdf', 'prefix', 'prefixxxx']
self.assertEqual(sharedPrefix(args), 'p')
self.assertEqual(shared_prefix(args), 'p')
args = ['one', 'two']
self.assertEqual(sharedPrefix(args), '')
self.assertEqual(shared_prefix(args), '')
args = ['hi']
self.assertEqual(sharedPrefix(args), 'hi')
self.assertEqual(shared_prefix(args), 'hi')
class OrderedSetTest(unittest.TestCase):
def test_order(self):
o = OrderedSet()
o.push('1')
o.push('1')
o.push('2')
o.push('1')
self.assertEqual(o, ['2', '1'])
oset = OrderedSet()
oset.push('1')
oset.push('1')
oset.push('2')
oset.push('1')
self.assertEqual(oset, ['2', '1'])

View File

@ -9,7 +9,7 @@ from kademlia.node import Node
from kademlia.routing import RoutingTable
def mknode(node_id=None, ip=None, port=None, intid=None):
def mknode(node_id=None, ip_addy=None, port=None, intid=None):
"""
Make a node. Created a random id if not specified.
"""
@ -18,11 +18,11 @@ def mknode(node_id=None, ip=None, port=None, intid=None):
if not node_id:
randbits = str(random.getrandbits(255))
node_id = hashlib.sha1(randbits.encode()).digest()
return Node(node_id, ip, port)
return Node(node_id, ip_addy, port)
class FakeProtocol:
def __init__(self, sourceID, ksize=20):
self.router = RoutingTable(self, ksize, Node(sourceID))
class FakeProtocol: # pylint: disable=too-few-public-methods
def __init__(self, source_id, ksize=20):
self.router = RoutingTable(self, ksize, Node(source_id))
self.storage = {}
self.sourceID = sourceID
self.source_id = source_id

View File

@ -6,16 +6,16 @@ import operator
import asyncio
async def gather_dict(d):
cors = list(d.values())
async def gather_dict(dic):
cors = list(dic.values())
results = await asyncio.gather(*cors)
return dict(zip(d.keys(), results))
return dict(zip(dic.keys(), results))
def digest(s):
if not isinstance(s, bytes):
s = str(s).encode('utf8')
return hashlib.sha1(s).digest()
def digest(string):
if not isinstance(string, bytes):
string = str(string).encode('utf8')
return hashlib.sha1(string).digest()
class OrderedSet(list):
@ -34,7 +34,7 @@ class OrderedSet(list):
self.append(thing)
def sharedPrefix(args):
def shared_prefix(args):
"""
Find the shared prefix between the strings.
@ -52,6 +52,6 @@ def sharedPrefix(args):
return args[0][:i]
def bytesToBitString(bites):
def bytes_to_bit_string(bites):
bits = [bin(bite)[2:].rjust(8, '0') for bite in bites]
return "".join(bits)