removed all camelcase silliness, bumped version to 2.0

This commit is contained in:
Brian Muller 2019-01-09 11:27:10 -05:00
parent dbe41e3b08
commit cf9d490d64
16 changed files with 306 additions and 279 deletions

View File

@ -497,7 +497,7 @@ valid-metaclass-classmethod-first-arg=mcs
[DESIGN] [DESIGN]
# Maximum number of arguments for function / method # Maximum number of arguments for function / method
max-args=5 max-args=6
# Maximum number of attributes for a class (see R0902). # Maximum number of attributes for a class (see R0902).
max-attributes=7 max-attributes=7

View File

@ -1,6 +1,6 @@
language: python language: python
python: python:
- "3.5"
- "3.6" - "3.6"
- "3.7"
install: pip install . && pip install -r dev-requirements.txt install: pip install . && pip install -r dev-requirements.txt
script: python -m unittest script: py.test

View File

@ -1,5 +1,6 @@
pycodestyle==2.3.1 pycodestyle>=2.4.0
pylint==1.8.1 pylint>=2.2.2
sphinx>=1.6.5 sphinx>=1.6.5
sphinxcontrib-napoleon>=0.6.1 sphinxcontrib-napoleon>=0.6.1
sphinxcontrib-zopeext>=0.2.1 sphinxcontrib-zopeext>=0.2.1
py.test>=4.1.0

View File

@ -2,4 +2,4 @@
Kademlia is a Python implementation of the Kademlia protocol which Kademlia is a Python implementation of the Kademlia protocol which
utilizes the asyncio library. utilizes the asyncio library.
""" """
__version__ = "1.1" __version__ = "2.0"

View File

@ -4,10 +4,12 @@ import logging
from kademlia.node import Node, NodeHeap from kademlia.node import Node, NodeHeap
from kademlia.utils import gather_dict from kademlia.utils import gather_dict
log = logging.getLogger(__name__)
log = logging.getLogger(__name__) # pylint: disable=invalid-name
class SpiderCrawl(object): # pylint: disable=too-few-public-methods
class SpiderCrawl:
""" """
Crawl the network and look for given 160-bit keys. Crawl the network and look for given 160-bit keys.
""" """
@ -29,7 +31,7 @@ class SpiderCrawl(object):
self.alpha = alpha self.alpha = alpha
self.node = node self.node = node
self.nearest = NodeHeap(self.node, self.ksize) self.nearest = NodeHeap(self.node, self.ksize)
self.lastIDsCrawled = [] self.last_ids_crawled = []
log.info("creating spider with peers: %s", peers) log.info("creating spider with peers: %s", peers)
self.nearest.push(peers) self.nearest.push(peers)
@ -38,7 +40,7 @@ class SpiderCrawl(object):
Get either a value or list of nodes. Get either a value or list of nodes.
Args: Args:
rpcmethod: The protocol's callfindValue or callFindNode. rpcmethod: The protocol's callfindValue or call_find_node.
The process: The process:
1. calls find_* to current ALPHA nearest not already queried nodes, 1. calls find_* to current ALPHA nearest not already queried nodes,
@ -51,18 +53,18 @@ class SpiderCrawl(object):
""" """
log.info("crawling network with nearest: %s", str(tuple(self.nearest))) log.info("crawling network with nearest: %s", str(tuple(self.nearest)))
count = self.alpha count = self.alpha
if self.nearest.getIDs() == self.lastIDsCrawled: if self.nearest.get_ids() == self.last_ids_crawled:
count = len(self.nearest) count = len(self.nearest)
self.lastIDsCrawled = self.nearest.getIDs() self.last_ids_crawled = self.nearest.get_ids()
ds = {} dicts = {}
for peer in self.nearest.getUncontacted()[:count]: for peer in self.nearest.get_uncontacted()[:count]:
ds[peer.id] = rpcmethod(peer, self.node) dicts[peer.id] = rpcmethod(peer, self.node)
self.nearest.markContacted(peer) self.nearest.mark_contacted(peer)
found = await gather_dict(ds) found = await gather_dict(dicts)
return await self._nodesFound(found) return await self._nodes_found(found)
async def _nodesFound(self, responses): async def _nodes_found(self, responses):
raise NotImplementedError raise NotImplementedError
@ -71,55 +73,55 @@ class ValueSpiderCrawl(SpiderCrawl):
SpiderCrawl.__init__(self, protocol, node, peers, ksize, alpha) SpiderCrawl.__init__(self, protocol, node, peers, ksize, alpha)
# keep track of the single nearest node without value - per # keep track of the single nearest node without value - per
# section 2.3 so we can set the key there if found # section 2.3 so we can set the key there if found
self.nearestWithoutValue = NodeHeap(self.node, 1) self.nearest_without_value = NodeHeap(self.node, 1)
async def find(self): async def find(self):
""" """
Find either the closest nodes or the value requested. Find either the closest nodes or the value requested.
""" """
return await self._find(self.protocol.callFindValue) return await self._find(self.protocol.call_find_value)
async def _nodesFound(self, responses): async def _nodes_found(self, responses):
""" """
Handle the result of an iteration in _find. Handle the result of an iteration in _find.
""" """
toremove = [] toremove = []
foundValues = [] found_values = []
for peerid, response in responses.items(): for peerid, response in responses.items():
response = RPCFindResponse(response) response = RPCFindResponse(response)
if not response.happened(): if not response.happened():
toremove.append(peerid) toremove.append(peerid)
elif response.hasValue(): elif response.has_value():
foundValues.append(response.getValue()) found_values.append(response.get_value())
else: else:
peer = self.nearest.getNodeById(peerid) peer = self.nearest.get_node(peerid)
self.nearestWithoutValue.push(peer) self.nearest_without_value.push(peer)
self.nearest.push(response.getNodeList()) self.nearest.push(response.get_node_list())
self.nearest.remove(toremove) self.nearest.remove(toremove)
if len(foundValues) > 0: if found_values:
return await self._handleFoundValues(foundValues) return await self._handle_found_values(found_values)
if self.nearest.allBeenContacted(): if self.nearest.have_contacted_all():
# not found! # not found!
return None return None
return await self.find() return await self.find()
async def _handleFoundValues(self, values): async def _handle_found_values(self, values):
""" """
We got some values! Exciting. But let's make sure We got some values! Exciting. But let's make sure
they're all the same or freak out a little bit. Also, they're all the same or freak out a little bit. Also,
make sure we tell the nearest node that *didn't* have make sure we tell the nearest node that *didn't* have
the value to store it. the value to store it.
""" """
valueCounts = Counter(values) value_counts = Counter(values)
if len(valueCounts) != 1: if len(value_counts) != 1:
log.warning("Got multiple values for key %i: %s", log.warning("Got multiple values for key %i: %s",
self.node.long_id, str(values)) self.node.long_id, str(values))
value = valueCounts.most_common(1)[0][0] value = value_counts.most_common(1)[0][0]
peerToSaveTo = self.nearestWithoutValue.popleft() peer = self.nearest_without_value.popleft()
if peerToSaveTo is not None: if peer:
await self.protocol.callStore(peerToSaveTo, self.node.id, value) await self.protocol.call_store(peer, self.node.id, value)
return value return value
@ -128,9 +130,9 @@ class NodeSpiderCrawl(SpiderCrawl):
""" """
Find the closest nodes. Find the closest nodes.
""" """
return await self._find(self.protocol.callFindNode) return await self._find(self.protocol.call_find_node)
async def _nodesFound(self, responses): async def _nodes_found(self, responses):
""" """
Handle the result of an iteration in _find. Handle the result of an iteration in _find.
""" """
@ -140,15 +142,15 @@ class NodeSpiderCrawl(SpiderCrawl):
if not response.happened(): if not response.happened():
toremove.append(peerid) toremove.append(peerid)
else: else:
self.nearest.push(response.getNodeList()) self.nearest.push(response.get_node_list())
self.nearest.remove(toremove) self.nearest.remove(toremove)
if self.nearest.allBeenContacted(): if self.nearest.have_contacted_all():
return list(self.nearest) return list(self.nearest)
return await self.find() return await self.find()
class RPCFindResponse(object): class RPCFindResponse:
def __init__(self, response): def __init__(self, response):
""" """
A wrapper for the result of a RPC find. A wrapper for the result of a RPC find.
@ -166,13 +168,13 @@ class RPCFindResponse(object):
""" """
return self.response[0] return self.response[0]
def hasValue(self): def has_value(self):
return isinstance(self.response[1], dict) return isinstance(self.response[1], dict)
def getValue(self): def get_value(self):
return self.response[1]['value'] return self.response[1]['value']
def getNodeList(self): def get_node_list(self):
""" """
Get the node list in the response. If there's no value, this should Get the node list in the response. If there's no value, this should
be set. be set.

View File

@ -13,10 +13,11 @@ from kademlia.node import Node
from kademlia.crawling import ValueSpiderCrawl from kademlia.crawling import ValueSpiderCrawl
from kademlia.crawling import NodeSpiderCrawl from kademlia.crawling import NodeSpiderCrawl
log = logging.getLogger(__name__) log = logging.getLogger(__name__) # pylint: disable=invalid-name
class Server(object): # pylint: disable=too-many-instance-attributes
class Server:
""" """
High level view of a node instance. This is the object that should be High level view of a node instance. This is the object that should be
created to start listening as an active node on the network. created to start listening as an active node on the network.
@ -83,22 +84,22 @@ class Server(object):
Refresh buckets that haven't had any lookups in the last hour Refresh buckets that haven't had any lookups in the last hour
(per section 2.3 of the paper). (per section 2.3 of the paper).
""" """
ds = [] results = []
for node_id in self.protocol.getRefreshIDs(): for node_id in self.protocol.get_refresh_ids():
node = Node(node_id) node = Node(node_id)
nearest = self.protocol.router.findNeighbors(node, self.alpha) nearest = self.protocol.router.find_neighbors(node, self.alpha)
spider = NodeSpiderCrawl(self.protocol, node, nearest, spider = NodeSpiderCrawl(self.protocol, node, nearest,
self.ksize, self.alpha) self.ksize, self.alpha)
ds.append(spider.find()) results.append(spider.find())
# do our crawling # do our crawling
await asyncio.gather(*ds) await asyncio.gather(*results)
# now republish keys older than one hour # now republish keys older than one hour
for dkey, value in self.storage.iteritemsOlderThan(3600): for dkey, value in self.storage.iter_older_than(3600):
await self.set_digest(dkey, value) await self.set_digest(dkey, value)
def bootstrappableNeighbors(self): def bootstrappable_neighbors(self):
""" """
Get a :class:`list` of (ip, port) :class:`tuple` pairs suitable for Get a :class:`list` of (ip, port) :class:`tuple` pairs suitable for
use as an argument to the bootstrap method. use as an argument to the bootstrap method.
@ -108,7 +109,7 @@ class Server(object):
storing them if this server is going down for a while. When it comes storing them if this server is going down for a while. When it comes
back up, the list of nodes can be used to bootstrap. back up, the list of nodes can be used to bootstrap.
""" """
neighbors = self.protocol.router.findNeighbors(self.node) neighbors = self.protocol.router.find_neighbors(self.node)
return [tuple(n)[-2:] for n in neighbors] return [tuple(n)[-2:] for n in neighbors]
async def bootstrap(self, addrs): async def bootstrap(self, addrs):
@ -145,8 +146,8 @@ class Server(object):
if self.storage.get(dkey) is not None: if self.storage.get(dkey) is not None:
return self.storage.get(dkey) return self.storage.get(dkey)
node = Node(dkey) node = Node(dkey)
nearest = self.protocol.router.findNeighbors(node) nearest = self.protocol.router.find_neighbors(node)
if len(nearest) == 0: if not nearest:
log.warning("There are no known neighbors to get key %s", key) log.warning("There are no known neighbors to get key %s", key)
return None return None
spider = ValueSpiderCrawl(self.protocol, node, nearest, spider = ValueSpiderCrawl(self.protocol, node, nearest,
@ -172,8 +173,8 @@ class Server(object):
""" """
node = Node(dkey) node = Node(dkey)
nearest = self.protocol.router.findNeighbors(node) nearest = self.protocol.router.find_neighbors(node)
if len(nearest) == 0: if not nearest:
log.warning("There are no known neighbors to set key %s", log.warning("There are no known neighbors to set key %s",
dkey.hex()) dkey.hex())
return False return False
@ -184,14 +185,14 @@ class Server(object):
log.info("setting '%s' on %s", dkey.hex(), list(map(str, nodes))) log.info("setting '%s' on %s", dkey.hex(), list(map(str, nodes)))
# if this node is close too, then store here as well # if this node is close too, then store here as well
biggest = max([n.distanceTo(node) for n in nodes]) biggest = max([n.distance_to(node) for n in nodes])
if self.node.distanceTo(node) < biggest: if self.node.distance_to(node) < biggest:
self.storage[dkey] = value self.storage[dkey] = value
ds = [self.protocol.callStore(n, dkey, value) for n in nodes] results = [self.protocol.call_store(n, dkey, value) for n in nodes]
# return true only if at least one store call succeeded # return true only if at least one store call succeeded
return any(await asyncio.gather(*ds)) return any(await asyncio.gather(*results))
def saveState(self, fname): def save_state(self, fname):
""" """
Save the state of this node (the alpha/ksize/id/immediate neighbors) Save the state of this node (the alpha/ksize/id/immediate neighbors)
to a cache file with the given fname. to a cache file with the given fname.
@ -201,29 +202,29 @@ class Server(object):
'ksize': self.ksize, 'ksize': self.ksize,
'alpha': self.alpha, 'alpha': self.alpha,
'id': self.node.id, 'id': self.node.id,
'neighbors': self.bootstrappableNeighbors() 'neighbors': self.bootstrappable_neighbors()
} }
if len(data['neighbors']) == 0: if not data['neighbors']:
log.warning("No known neighbors, so not writing to cache.") log.warning("No known neighbors, so not writing to cache.")
return return
with open(fname, 'wb') as f: with open(fname, 'wb') as file:
pickle.dump(data, f) pickle.dump(data, file)
@classmethod @classmethod
def loadState(self, fname): def load_state(cls, fname):
""" """
Load the state of this node (the alpha/ksize/id/immediate neighbors) Load the state of this node (the alpha/ksize/id/immediate neighbors)
from a cache file with the given fname. from a cache file with the given fname.
""" """
log.info("Loading state from %s", fname) log.info("Loading state from %s", fname)
with open(fname, 'rb') as f: with open(fname, 'rb') as file:
data = pickle.load(f) data = pickle.load(file)
s = Server(data['ksize'], data['alpha'], data['id']) svr = Server(data['ksize'], data['alpha'], data['id'])
if len(data['neighbors']) > 0: if data['neighbors']:
s.bootstrap(data['neighbors']) svr.bootstrap(data['neighbors'])
return s return svr
def saveStateRegularly(self, fname, frequency=600): def save_state_regularly(self, fname, frequency=600):
""" """
Save the state of node with a given regularity to the given Save the state of node with a given regularity to the given
filename. filename.
@ -233,10 +234,10 @@ class Server(object):
frequency: Frequency in seconds that the state should be saved. frequency: Frequency in seconds that the state should be saved.
By default, 10 minutes. By default, 10 minutes.
""" """
self.saveState(fname) self.save_state(fname)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
self.save_state_loop = loop.call_later(frequency, self.save_state_loop = loop.call_later(frequency,
self.saveStateRegularly, self.save_state_regularly,
fname, fname,
frequency) frequency)
@ -246,13 +247,11 @@ def check_dht_value_type(value):
Checks to see if the type of the value is a valid type for Checks to see if the type of the value is a valid type for
placing in the dht. placing in the dht.
""" """
typeset = set( typeset = [
[
int, int,
float, float,
bool, bool,
str, str,
bytes, bytes
] ]
) return type(value) in typeset # pylint: disable=unidiomatic-typecheck
return type(value) in typeset

View File

@ -4,15 +4,15 @@ import heapq
class Node: class Node:
def __init__(self, node_id, ip=None, port=None): def __init__(self, node_id, ip=None, port=None):
self.id = node_id self.id = node_id # pylint: disable=invalid-name
self.ip = ip self.ip = ip # pylint: disable=invalid-name
self.port = port self.port = port
self.long_id = int(node_id.hex(), 16) self.long_id = int(node_id.hex(), 16)
def sameHomeAs(self, node): def same_home_as(self, node):
return self.ip == node.ip and self.port == node.port return self.ip == node.ip and self.port == node.port
def distanceTo(self, node): def distance_to(self, node):
""" """
Get the distance between this node and another. Get the distance between this node and another.
""" """
@ -31,7 +31,7 @@ class Node:
return "%s:%s" % (self.ip, str(self.port)) return "%s:%s" % (self.ip, str(self.port))
class NodeHeap(object): class NodeHeap:
""" """
A heap of nodes ordered by distance to a given node. A heap of nodes ordered by distance to a given node.
""" """
@ -47,7 +47,7 @@ class NodeHeap(object):
self.contacted = set() self.contacted = set()
self.maxsize = maxsize self.maxsize = maxsize
def remove(self, peerIDs): def remove(self, peers):
""" """
Remove a list of peer ids from this heap. Note that while this Remove a list of peer ids from this heap. Note that while this
heap retains a constant visible size (based on the iterator), it's heap retains a constant visible size (based on the iterator), it's
@ -55,34 +55,32 @@ class NodeHeap(object):
removal of nodes may not change the visible size as previously added removal of nodes may not change the visible size as previously added
nodes suddenly become visible. nodes suddenly become visible.
""" """
peerIDs = set(peerIDs) peers = set(peers)
if len(peerIDs) == 0: if not peers:
return return
nheap = [] nheap = []
for distance, node in self.heap: for distance, node in self.heap:
if node.id not in peerIDs: if node.id not in peers:
heapq.heappush(nheap, (distance, node)) heapq.heappush(nheap, (distance, node))
self.heap = nheap self.heap = nheap
def getNodeById(self, node_id): def get_node(self, node_id):
for _, node in self.heap: for _, node in self.heap:
if node.id == node_id: if node.id == node_id:
return node return node
return None return None
def allBeenContacted(self): def have_contacted_all(self):
return len(self.getUncontacted()) == 0 return len(self.get_uncontacted()) == 0
def getIDs(self): def get_ids(self):
return [n.id for n in self] return [n.id for n in self]
def markContacted(self, node): def mark_contacted(self, node):
self.contacted.add(node.id) self.contacted.add(node.id)
def popleft(self): def popleft(self):
if len(self) > 0: return heapq.heappop(self.heap)[1] if self else None
return heapq.heappop(self.heap)[1]
return None
def push(self, nodes): def push(self, nodes):
""" """
@ -95,7 +93,7 @@ class NodeHeap(object):
for node in nodes: for node in nodes:
if node not in self: if node not in self:
distance = self.node.distanceTo(node) distance = self.node.distance_to(node)
heapq.heappush(self.heap, (distance, node)) heapq.heappush(self.heap, (distance, node))
def __len__(self): def __len__(self):
@ -106,10 +104,10 @@ class NodeHeap(object):
return iter(map(itemgetter(1), nodes)) return iter(map(itemgetter(1), nodes))
def __contains__(self, node): def __contains__(self, node):
for _, n in self.heap: for _, other in self.heap:
if node.id == n.id: if node.id == other.id:
return True return True
return False return False
def getUncontacted(self): def get_uncontacted(self):
return [n for n in self if n.id not in self.contacted] return [n for n in self if n.id not in self.contacted]

View File

@ -8,37 +8,37 @@ from kademlia.node import Node
from kademlia.routing import RoutingTable from kademlia.routing import RoutingTable
from kademlia.utils import digest from kademlia.utils import digest
log = logging.getLogger(__name__) log = logging.getLogger(__name__) # pylint: disable=invalid-name
class KademliaProtocol(RPCProtocol): class KademliaProtocol(RPCProtocol):
def __init__(self, sourceNode, storage, ksize): def __init__(self, source_node, storage, ksize):
RPCProtocol.__init__(self) RPCProtocol.__init__(self)
self.router = RoutingTable(self, ksize, sourceNode) self.router = RoutingTable(self, ksize, source_node)
self.storage = storage self.storage = storage
self.sourceNode = sourceNode self.source_node = source_node
def getRefreshIDs(self): def get_refresh_ids(self):
""" """
Get ids to search for to keep old buckets up to date. Get ids to search for to keep old buckets up to date.
""" """
ids = [] ids = []
for bucket in self.router.getLonelyBuckets(): for bucket in self.router.lonely_buckets():
rid = random.randint(*bucket.range).to_bytes(20, byteorder='big') rid = random.randint(*bucket.range).to_bytes(20, byteorder='big')
ids.append(rid) ids.append(rid)
return ids return ids
def rpc_stun(self, sender): def rpc_stun(self, sender): # pylint: disable=no-self-use
return sender return sender
def rpc_ping(self, sender, nodeid): def rpc_ping(self, sender, nodeid):
source = Node(nodeid, sender[0], sender[1]) source = Node(nodeid, sender[0], sender[1])
self.welcomeIfNewNode(source) self.welcome_if_new(source)
return self.sourceNode.id return self.source_node.id
def rpc_store(self, sender, nodeid, key, value): def rpc_store(self, sender, nodeid, key, value):
source = Node(nodeid, sender[0], sender[1]) source = Node(nodeid, sender[0], sender[1])
self.welcomeIfNewNode(source) self.welcome_if_new(source)
log.debug("got a store request from %s, storing '%s'='%s'", log.debug("got a store request from %s, storing '%s'='%s'",
sender, key.hex(), value) sender, key.hex(), value)
self.storage[key] = value self.storage[key] = value
@ -48,42 +48,42 @@ class KademliaProtocol(RPCProtocol):
log.info("finding neighbors of %i in local table", log.info("finding neighbors of %i in local table",
int(nodeid.hex(), 16)) int(nodeid.hex(), 16))
source = Node(nodeid, sender[0], sender[1]) source = Node(nodeid, sender[0], sender[1])
self.welcomeIfNewNode(source) self.welcome_if_new(source)
node = Node(key) node = Node(key)
neighbors = self.router.findNeighbors(node, exclude=source) neighbors = self.router.find_neighbors(node, exclude=source)
return list(map(tuple, neighbors)) return list(map(tuple, neighbors))
def rpc_find_value(self, sender, nodeid, key): def rpc_find_value(self, sender, nodeid, key):
source = Node(nodeid, sender[0], sender[1]) source = Node(nodeid, sender[0], sender[1])
self.welcomeIfNewNode(source) self.welcome_if_new(source)
value = self.storage.get(key, None) value = self.storage.get(key, None)
if value is None: if value is None:
return self.rpc_find_node(sender, nodeid, key) return self.rpc_find_node(sender, nodeid, key)
return {'value': value} return {'value': value}
async def callFindNode(self, nodeToAsk, nodeToFind): async def call_find_node(self, node_to_ask, node_to_find):
address = (nodeToAsk.ip, nodeToAsk.port) address = (node_to_ask.ip, node_to_ask.port)
result = await self.find_node(address, self.sourceNode.id, result = await self.find_node(address, self.source_node.id,
nodeToFind.id) node_to_find.id)
return self.handleCallResponse(result, nodeToAsk) return self.handle_call_response(result, node_to_ask)
async def callFindValue(self, nodeToAsk, nodeToFind): async def call_find_value(self, node_to_ask, node_to_find):
address = (nodeToAsk.ip, nodeToAsk.port) address = (node_to_ask.ip, node_to_ask.port)
result = await self.find_value(address, self.sourceNode.id, result = await self.find_value(address, self.source_node.id,
nodeToFind.id) node_to_find.id)
return self.handleCallResponse(result, nodeToAsk) return self.handle_call_response(result, node_to_ask)
async def callPing(self, nodeToAsk): async def call_ping(self, node_to_ask):
address = (nodeToAsk.ip, nodeToAsk.port) address = (node_to_ask.ip, node_to_ask.port)
result = await self.ping(address, self.sourceNode.id) result = await self.ping(address, self.source_node.id)
return self.handleCallResponse(result, nodeToAsk) return self.handle_call_response(result, node_to_ask)
async def callStore(self, nodeToAsk, key, value): async def call_store(self, node_to_ask, key, value):
address = (nodeToAsk.ip, nodeToAsk.port) address = (node_to_ask.ip, node_to_ask.port)
result = await self.store(address, self.sourceNode.id, key, value) result = await self.store(address, self.source_node.id, key, value)
return self.handleCallResponse(result, nodeToAsk) return self.handle_call_response(result, node_to_ask)
def welcomeIfNewNode(self, node): def welcome_if_new(self, node):
""" """
Given a new node, send it all the keys/values it should be storing, Given a new node, send it all the keys/values it should be storing,
then add it to the routing table. then add it to the routing table.
@ -97,32 +97,32 @@ class KademliaProtocol(RPCProtocol):
is closer than the closest in that list, then store the key/value is closer than the closest in that list, then store the key/value
on the new node (per section 2.5 of the paper) on the new node (per section 2.5 of the paper)
""" """
if not self.router.isNewNode(node): if not self.router.is_new_node(node):
return return
log.info("never seen %s before, adding to router", node) log.info("never seen %s before, adding to router", node)
for key, value in self.storage.items(): for key, value in self.storage:
keynode = Node(digest(key)) keynode = Node(digest(key))
neighbors = self.router.findNeighbors(keynode) neighbors = self.router.find_neighbors(keynode)
if len(neighbors) > 0: if neighbors:
last = neighbors[-1].distanceTo(keynode) last = neighbors[-1].distance_to(keynode)
newNodeClose = node.distanceTo(keynode) < last new_node_close = node.distance_to(keynode) < last
first = neighbors[0].distanceTo(keynode) first = neighbors[0].distance_to(keynode)
thisNodeClosest = self.sourceNode.distanceTo(keynode) < first this_closest = self.source_node.distance_to(keynode) < first
if len(neighbors) == 0 or (newNodeClose and thisNodeClosest): if not neighbors or (new_node_close and this_closest):
asyncio.ensure_future(self.callStore(node, key, value)) asyncio.ensure_future(self.call_store(node, key, value))
self.router.addContact(node) self.router.add_contact(node)
def handleCallResponse(self, result, node): def handle_call_response(self, result, node):
""" """
If we get a response, add the node to the routing table. If If we get a response, add the node to the routing table. If
we get no response, make sure it's removed from the routing table. we get no response, make sure it's removed from the routing table.
""" """
if not result[0]: if not result[0]:
log.warning("no response from %s, removing from router", node) log.warning("no response from %s, removing from router", node)
self.router.removeContact(node) self.router.remove_contact(node)
return result return result
log.info("got successful response from %s", node) log.info("got successful response from %s", node)
self.welcomeIfNewNode(node) self.welcome_if_new(node)
return result return result

View File

@ -4,21 +4,21 @@ import operator
import asyncio import asyncio
from collections import OrderedDict from collections import OrderedDict
from kademlia.utils import OrderedSet, sharedPrefix, bytesToBitString from kademlia.utils import OrderedSet, shared_prefix, bytes_to_bit_string
class KBucket(object): class KBucket:
def __init__(self, rangeLower, rangeUpper, ksize): def __init__(self, rangeLower, rangeUpper, ksize):
self.range = (rangeLower, rangeUpper) self.range = (rangeLower, rangeUpper)
self.nodes = OrderedDict() self.nodes = OrderedDict()
self.replacementNodes = OrderedSet() self.replacement_nodes = OrderedSet()
self.touchLastUpdated() self.touch_last_updated()
self.ksize = ksize self.ksize = ksize
def touchLastUpdated(self): def touch_last_updated(self):
self.lastUpdated = time.monotonic() self.last_updated = time.monotonic()
def getNodes(self): def get_nodes(self):
return list(self.nodes.values()) return list(self.nodes.values())
def split(self): def split(self):
@ -30,23 +30,23 @@ class KBucket(object):
bucket.nodes[node.id] = node bucket.nodes[node.id] = node
return (one, two) return (one, two)
def removeNode(self, node): def remove_node(self, node):
if node.id not in self.nodes: if node.id not in self.nodes:
return return
# delete node, and see if we can add a replacement # delete node, and see if we can add a replacement
del self.nodes[node.id] del self.nodes[node.id]
if len(self.replacementNodes) > 0: if self.replacement_nodes:
newnode = self.replacementNodes.pop() newnode = self.replacement_nodes.pop()
self.nodes[newnode.id] = newnode self.nodes[newnode.id] = newnode
def hasInRange(self, node): def has_in_range(self, node):
return self.range[0] <= node.long_id <= self.range[1] return self.range[0] <= node.long_id <= self.range[1]
def isNewNode(self, node): def is_new_node(self, node):
return node.id not in self.nodes return node.id not in self.nodes
def addNode(self, node): def add_node(self, node):
""" """
Add a C{Node} to the C{KBucket}. Return True if successful, Add a C{Node} to the C{KBucket}. Return True if successful,
False if the bucket is full. False if the bucket is full.
@ -60,14 +60,14 @@ class KBucket(object):
elif len(self) < self.ksize: elif len(self) < self.ksize:
self.nodes[node.id] = node self.nodes[node.id] = node
else: else:
self.replacementNodes.push(node) self.replacement_nodes.push(node)
return False return False
return True return True
def depth(self): def depth(self):
vals = self.nodes.values() vals = self.nodes.values()
sp = sharedPrefix([bytesToBitString(n.id) for n in vals]) sprefix = shared_prefix([bytes_to_bit_string(n.id) for n in vals])
return len(sp) return len(sprefix)
def head(self): def head(self):
return list(self.nodes.values())[0] return list(self.nodes.values())[0]
@ -79,13 +79,13 @@ class KBucket(object):
return len(self.nodes) return len(self.nodes)
class TableTraverser(object): class TableTraverser:
def __init__(self, table, startNode): def __init__(self, table, startNode):
index = table.getBucketFor(startNode) index = table.get_bucket_for(startNode)
table.buckets[index].touchLastUpdated() table.buckets[index].touch_last_updated()
self.currentNodes = table.buckets[index].getNodes() self.current_nodes = table.buckets[index].get_nodes()
self.leftBuckets = table.buckets[:index] self.left_buckets = table.buckets[:index]
self.rightBuckets = table.buckets[(index + 1):] self.right_buckets = table.buckets[(index + 1):]
self.left = True self.left = True
def __iter__(self): def __iter__(self):
@ -95,23 +95,23 @@ class TableTraverser(object):
""" """
Pop an item from the left subtree, then right, then left, etc. Pop an item from the left subtree, then right, then left, etc.
""" """
if len(self.currentNodes) > 0: if self.current_nodes:
return self.currentNodes.pop() return self.current_nodes.pop()
if self.left and len(self.leftBuckets) > 0: if self.left and self.left_buckets:
self.currentNodes = self.leftBuckets.pop().getNodes() self.current_nodes = self.left_buckets.pop().get_nodes()
self.left = False self.left = False
return next(self) return next(self)
if len(self.rightBuckets) > 0: if self.right_buckets:
self.currentNodes = self.rightBuckets.pop(0).getNodes() self.current_nodes = self.right_buckets.pop(0).get_nodes()
self.left = True self.left = True
return next(self) return next(self)
raise StopIteration raise StopIteration
class RoutingTable(object): class RoutingTable:
def __init__(self, protocol, ksize, node): def __init__(self, protocol, ksize, node):
""" """
@param node: The node that represents this server. It won't @param node: The node that represents this server. It won't
@ -126,58 +126,60 @@ class RoutingTable(object):
def flush(self): def flush(self):
self.buckets = [KBucket(0, 2 ** 160, self.ksize)] self.buckets = [KBucket(0, 2 ** 160, self.ksize)]
def splitBucket(self, index): def split_bucket(self, index):
one, two = self.buckets[index].split() one, two = self.buckets[index].split()
self.buckets[index] = one self.buckets[index] = one
self.buckets.insert(index + 1, two) self.buckets.insert(index + 1, two)
def getLonelyBuckets(self): def lonely_buckets(self):
""" """
Get all of the buckets that haven't been updated in over Get all of the buckets that haven't been updated in over
an hour. an hour.
""" """
hrago = time.monotonic() - 3600 hrago = time.monotonic() - 3600
return [b for b in self.buckets if b.lastUpdated < hrago] return [b for b in self.buckets if b.last_updated < hrago]
def removeContact(self, node): def remove_contact(self, node):
index = self.getBucketFor(node) index = self.get_bucket_for(node)
self.buckets[index].removeNode(node) self.buckets[index].remove_node(node)
def isNewNode(self, node): def is_new_node(self, node):
index = self.getBucketFor(node) index = self.get_bucket_for(node)
return self.buckets[index].isNewNode(node) return self.buckets[index].is_new_node(node)
def addContact(self, node): def add_contact(self, node):
index = self.getBucketFor(node) index = self.get_bucket_for(node)
bucket = self.buckets[index] bucket = self.buckets[index]
# this will succeed unless the bucket is full # this will succeed unless the bucket is full
if bucket.addNode(node): if bucket.add_node(node):
return return
# Per section 4.2 of paper, split if the bucket has the node # Per section 4.2 of paper, split if the bucket has the node
# in its range or if the depth is not congruent to 0 mod 5 # in its range or if the depth is not congruent to 0 mod 5
if bucket.hasInRange(self.node) or bucket.depth() % 5 != 0: if bucket.has_in_range(self.node) or bucket.depth() % 5 != 0:
self.splitBucket(index) self.split_bucket(index)
self.addContact(node) self.add_contact(node)
else: else:
asyncio.ensure_future(self.protocol.callPing(bucket.head())) asyncio.ensure_future(self.protocol.call_ping(bucket.head()))
def getBucketFor(self, node): def get_bucket_for(self, node):
""" """
Get the index of the bucket that the given node would fall into. Get the index of the bucket that the given node would fall into.
""" """
for index, bucket in enumerate(self.buckets): for index, bucket in enumerate(self.buckets):
if node.long_id < bucket.range[1]: if node.long_id < bucket.range[1]:
return index return index
# we should never be here, but make linter happy
return None
def findNeighbors(self, node, k=None, exclude=None): def find_neighbors(self, node, k=None, exclude=None):
k = k or self.ksize k = k or self.ksize
nodes = [] nodes = []
for neighbor in TableTraverser(self, node): for neighbor in TableTraverser(self, node):
notexcluded = exclude is None or not neighbor.sameHomeAs(exclude) notexcluded = exclude is None or not neighbor.same_home_as(exclude)
if neighbor.id != node.id and notexcluded: if neighbor.id != node.id and notexcluded:
heapq.heappush(nodes, (node.distanceTo(neighbor), neighbor)) heapq.heappush(nodes, (node.distance_to(neighbor), neighbor))
if len(nodes) == k: if len(nodes) == k:
break break

View File

@ -28,7 +28,7 @@ class IStorage:
""" """
raise NotImplementedError raise NotImplementedError
def iteritemsOlderThan(self, secondsOld): def iter_older_than(self, seconds_old):
""" """
Return the an iterator over (key, value) tuples for items older Return the an iterator over (key, value) tuples for items older
than the given secondsOld. than the given secondsOld.
@ -57,7 +57,7 @@ class ForgetfulStorage(IStorage):
self.cull() self.cull()
def cull(self): def cull(self):
for _, _ in self.iteritemsOlderThan(self.ttl): for _, _ in self.iter_older_than(self.ttl):
self.data.popitem(last=False) self.data.popitem(last=False)
def get(self, key, default=None): def get(self, key, default=None):
@ -70,27 +70,23 @@ class ForgetfulStorage(IStorage):
self.cull() self.cull()
return self.data[key][1] return self.data[key][1]
def __iter__(self):
self.cull()
return iter(self.data)
def __repr__(self): def __repr__(self):
self.cull() self.cull()
return repr(self.data) return repr(self.data)
def iteritemsOlderThan(self, secondsOld): def iter_older_than(self, seconds_old):
minBirthday = time.monotonic() - secondsOld min_birthday = time.monotonic() - seconds_old
zipped = self._tripleIterable() zipped = self._triple_iter()
matches = takewhile(lambda r: minBirthday >= r[1], zipped) matches = takewhile(lambda r: min_birthday >= r[1], zipped)
return list(map(operator.itemgetter(0, 2), matches)) return list(map(operator.itemgetter(0, 2), matches))
def _tripleIterable(self): def _triple_iter(self):
ikeys = self.data.keys() ikeys = self.data.keys()
ibirthday = map(operator.itemgetter(0), self.data.values()) ibirthday = map(operator.itemgetter(0), self.data.values())
ivalues = map(operator.itemgetter(1), self.data.values()) ivalues = map(operator.itemgetter(1), self.data.values())
return zip(ikeys, ibirthday, ivalues) return zip(ikeys, ibirthday, ivalues)
def items(self): def __iter__(self):
self.cull() self.cull()
ikeys = self.data.keys() ikeys = self.data.keys()
ivalues = map(operator.itemgetter(1), self.data.values()) ivalues = map(operator.itemgetter(1), self.data.values())

View File

@ -8,31 +8,31 @@ from kademlia.tests.utils import mknode
class NodeTest(unittest.TestCase): class NodeTest(unittest.TestCase):
def test_longID(self): def test_long_id(self):
rid = hashlib.sha1(str(random.getrandbits(255)).encode()).digest() rid = hashlib.sha1(str(random.getrandbits(255)).encode()).digest()
n = Node(rid) node = Node(rid)
self.assertEqual(n.long_id, int(rid.hex(), 16)) self.assertEqual(node.long_id, int(rid.hex(), 16))
def test_distanceCalculation(self): def test_distance_calculation(self):
ridone = hashlib.sha1(str(random.getrandbits(255)).encode()) ridone = hashlib.sha1(str(random.getrandbits(255)).encode())
ridtwo = hashlib.sha1(str(random.getrandbits(255)).encode()) ridtwo = hashlib.sha1(str(random.getrandbits(255)).encode())
shouldbe = int(ridone.hexdigest(), 16) ^ int(ridtwo.hexdigest(), 16) shouldbe = int(ridone.hexdigest(), 16) ^ int(ridtwo.hexdigest(), 16)
none = Node(ridone.digest()) none = Node(ridone.digest())
ntwo = Node(ridtwo.digest()) ntwo = Node(ridtwo.digest())
self.assertEqual(none.distanceTo(ntwo), shouldbe) self.assertEqual(none.distance_to(ntwo), shouldbe)
class NodeHeapTest(unittest.TestCase): class NodeHeapTest(unittest.TestCase):
def test_maxSize(self): def test_max_size(self):
n = NodeHeap(mknode(intid=0), 3) node = NodeHeap(mknode(intid=0), 3)
self.assertEqual(0, len(n)) self.assertEqual(0, len(node))
for d in range(10): for digit in range(10):
n.push(mknode(intid=d)) node.push(mknode(intid=digit))
self.assertEqual(3, len(n)) self.assertEqual(3, len(node))
self.assertEqual(3, len(list(n))) self.assertEqual(3, len(list(node)))
def test_iteration(self): def test_iteration(self):
heap = NodeHeap(mknode(intid=0), 5) heap = NodeHeap(mknode(intid=0), 5)

View File

@ -7,53 +7,53 @@ from kademlia.tests.utils import mknode, FakeProtocol
class KBucketTest(unittest.TestCase): class KBucketTest(unittest.TestCase):
def test_split(self): def test_split(self):
bucket = KBucket(0, 10, 5) bucket = KBucket(0, 10, 5)
bucket.addNode(mknode(intid=5)) bucket.add_node(mknode(intid=5))
bucket.addNode(mknode(intid=6)) bucket.add_node(mknode(intid=6))
one, two = bucket.split() one, two = bucket.split()
self.assertEqual(len(one), 1) self.assertEqual(len(one), 1)
self.assertEqual(one.range, (0, 5)) self.assertEqual(one.range, (0, 5))
self.assertEqual(len(two), 1) self.assertEqual(len(two), 1)
self.assertEqual(two.range, (6, 10)) self.assertEqual(two.range, (6, 10))
def test_addNode(self): def test_add_node(self):
# when full, return false # when full, return false
bucket = KBucket(0, 10, 2) bucket = KBucket(0, 10, 2)
self.assertTrue(bucket.addNode(mknode())) self.assertTrue(bucket.add_node(mknode()))
self.assertTrue(bucket.addNode(mknode())) self.assertTrue(bucket.add_node(mknode()))
self.assertFalse(bucket.addNode(mknode())) self.assertFalse(bucket.add_node(mknode()))
self.assertEqual(len(bucket), 2) self.assertEqual(len(bucket), 2)
# make sure when a node is double added it's put at the end # make sure when a node is double added it's put at the end
bucket = KBucket(0, 10, 3) bucket = KBucket(0, 10, 3)
nodes = [mknode(), mknode(), mknode()] nodes = [mknode(), mknode(), mknode()]
for node in nodes: for node in nodes:
bucket.addNode(node) bucket.add_node(node)
for index, node in enumerate(bucket.getNodes()): for index, node in enumerate(bucket.get_nodes()):
self.assertEqual(node, nodes[index]) self.assertEqual(node, nodes[index])
def test_inRange(self): def test_in_range(self):
bucket = KBucket(0, 10, 10) bucket = KBucket(0, 10, 10)
self.assertTrue(bucket.hasInRange(mknode(intid=5))) self.assertTrue(bucket.has_in_range(mknode(intid=5)))
self.assertFalse(bucket.hasInRange(mknode(intid=11))) self.assertFalse(bucket.has_in_range(mknode(intid=11)))
self.assertTrue(bucket.hasInRange(mknode(intid=10))) self.assertTrue(bucket.has_in_range(mknode(intid=10)))
self.assertTrue(bucket.hasInRange(mknode(intid=0))) self.assertTrue(bucket.has_in_range(mknode(intid=0)))
class RoutingTableTest(unittest.TestCase): class RoutingTableTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.id = mknode().id self.id = mknode().id # pylint: disable=invalid-name
self.protocol = FakeProtocol(self.id) self.protocol = FakeProtocol(self.id)
self.router = self.protocol.router self.router = self.protocol.router
def test_addContact(self): def test_add_contact(self):
self.router.addContact(mknode()) self.router.add_contact(mknode())
self.assertTrue(len(self.router.buckets), 1) self.assertTrue(len(self.router.buckets), 1)
self.assertTrue(len(self.router.buckets[0].nodes), 1) self.assertTrue(len(self.router.buckets[0].nodes), 1)
class TableTraverserTest(unittest.TestCase): class TableTraverserTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.id = mknode().id self.id = mknode().id # pylint: disable=invalid-name
self.protocol = FakeProtocol(self.id) self.protocol = FakeProtocol(self.id)
self.router = self.protocol.router self.router = self.protocol.router
@ -70,8 +70,8 @@ class TableTraverserTest(unittest.TestCase):
buckets = [] buckets = []
for i in range(5): for i in range(5):
bucket = KBucket(2 * i, 2 * i + 1, 2) bucket = KBucket(2 * i, 2 * i + 1, 2)
bucket.addNode(nodes[2 * i]) bucket.add_node(nodes[2 * i])
bucket.addNode(nodes[2 * i + 1]) bucket.add_node(nodes[2 * i + 1])
buckets.append(bucket) buckets.append(bucket)
# replace router's bucket with our test buckets # replace router's bucket with our test buckets

View File

@ -0,0 +1,29 @@
import unittest
from kademlia.storage import ForgetfulStorage
class ForgetfulStorageTest(unittest.TestCase):
def test_storing(self):
storage = ForgetfulStorage(10)
storage['one'] = 'two'
self.assertEqual(storage['one'], 'two')
def test_forgetting(self):
storage = ForgetfulStorage(0)
storage['one'] = 'two'
self.assertEqual(storage.get('one'), None)
def test_iter(self):
storage = ForgetfulStorage(10)
storage['one'] = 'two'
for key, value in storage:
self.assertEqual(key, 'one')
self.assertEqual(value, 'two')
def test_iter_old(self):
storage = ForgetfulStorage(10)
storage['one'] = 'two'
for key, value in storage.iter_older_than(0):
self.assertEqual(key, 'one')
self.assertEqual(value, 'two')

View File

@ -1,36 +1,36 @@
import hashlib import hashlib
import unittest import unittest
from kademlia.utils import digest, sharedPrefix, OrderedSet from kademlia.utils import digest, shared_prefix, OrderedSet
class UtilsTest(unittest.TestCase): class UtilsTest(unittest.TestCase):
def test_digest(self): def test_digest(self):
d = hashlib.sha1(b'1').digest() dig = hashlib.sha1(b'1').digest()
self.assertEqual(d, digest(1)) self.assertEqual(dig, digest(1))
d = hashlib.sha1(b'another').digest() dig = hashlib.sha1(b'another').digest()
self.assertEqual(d, digest('another')) self.assertEqual(dig, digest('another'))
def test_sharedPrefix(self): def test_shared_prefix(self):
args = ['prefix', 'prefixasdf', 'prefix', 'prefixxxx'] args = ['prefix', 'prefixasdf', 'prefix', 'prefixxxx']
self.assertEqual(sharedPrefix(args), 'prefix') self.assertEqual(shared_prefix(args), 'prefix')
args = ['p', 'prefixasdf', 'prefix', 'prefixxxx'] args = ['p', 'prefixasdf', 'prefix', 'prefixxxx']
self.assertEqual(sharedPrefix(args), 'p') self.assertEqual(shared_prefix(args), 'p')
args = ['one', 'two'] args = ['one', 'two']
self.assertEqual(sharedPrefix(args), '') self.assertEqual(shared_prefix(args), '')
args = ['hi'] args = ['hi']
self.assertEqual(sharedPrefix(args), 'hi') self.assertEqual(shared_prefix(args), 'hi')
class OrderedSetTest(unittest.TestCase): class OrderedSetTest(unittest.TestCase):
def test_order(self): def test_order(self):
o = OrderedSet() oset = OrderedSet()
o.push('1') oset.push('1')
o.push('1') oset.push('1')
o.push('2') oset.push('2')
o.push('1') oset.push('1')
self.assertEqual(o, ['2', '1']) self.assertEqual(oset, ['2', '1'])

View File

@ -9,7 +9,7 @@ from kademlia.node import Node
from kademlia.routing import RoutingTable from kademlia.routing import RoutingTable
def mknode(node_id=None, ip=None, port=None, intid=None): def mknode(node_id=None, ip_addy=None, port=None, intid=None):
""" """
Make a node. Created a random id if not specified. Make a node. Created a random id if not specified.
""" """
@ -18,11 +18,11 @@ def mknode(node_id=None, ip=None, port=None, intid=None):
if not node_id: if not node_id:
randbits = str(random.getrandbits(255)) randbits = str(random.getrandbits(255))
node_id = hashlib.sha1(randbits.encode()).digest() node_id = hashlib.sha1(randbits.encode()).digest()
return Node(node_id, ip, port) return Node(node_id, ip_addy, port)
class FakeProtocol: class FakeProtocol: # pylint: disable=too-few-public-methods
def __init__(self, sourceID, ksize=20): def __init__(self, source_id, ksize=20):
self.router = RoutingTable(self, ksize, Node(sourceID)) self.router = RoutingTable(self, ksize, Node(source_id))
self.storage = {} self.storage = {}
self.sourceID = sourceID self.source_id = source_id

View File

@ -6,16 +6,16 @@ import operator
import asyncio import asyncio
async def gather_dict(d): async def gather_dict(dic):
cors = list(d.values()) cors = list(dic.values())
results = await asyncio.gather(*cors) results = await asyncio.gather(*cors)
return dict(zip(d.keys(), results)) return dict(zip(dic.keys(), results))
def digest(s): def digest(string):
if not isinstance(s, bytes): if not isinstance(string, bytes):
s = str(s).encode('utf8') string = str(string).encode('utf8')
return hashlib.sha1(s).digest() return hashlib.sha1(string).digest()
class OrderedSet(list): class OrderedSet(list):
@ -34,7 +34,7 @@ class OrderedSet(list):
self.append(thing) self.append(thing)
def sharedPrefix(args): def shared_prefix(args):
""" """
Find the shared prefix between the strings. Find the shared prefix between the strings.
@ -52,6 +52,6 @@ def sharedPrefix(args):
return args[0][:i] return args[0][:i]
def bytesToBitString(bites): def bytes_to_bit_string(bites):
bits = [bin(bite)[2:].rjust(8, '0') for bite in bites] bits = [bin(bite)[2:].rjust(8, '0') for bite in bites]
return "".join(bits) return "".join(bits)