started writing tests

This commit is contained in:
Brian Muller 2014-01-03 17:50:49 -05:00
parent 75a97fb8f1
commit f7e8bd9442
14 changed files with 275 additions and 71 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
_trial_temp
apidoc apidoc
*.pyc *.pyc
build build

View File

@ -4,8 +4,11 @@ docs:
$(PYDOCTOR) --make-html --html-output apidoc --add-package kademlia --project-name=kademlia --project-url=http://github.com/bmuller/kademlia --html-use-sorttable --html-use-splitlinks --html-shorten-lists $(PYDOCTOR) --make-html --html-output apidoc --add-package kademlia --project-name=kademlia --project-url=http://github.com/bmuller/kademlia --html-use-sorttable --html-use-splitlinks --html-shorten-lists
lint: lint:
pep8 --ignore=E303,E251,E201,E202 ./rpcudp --max-line-length=140 pep8 --ignore=E303,E251,E201,E202 ./kademlia --max-line-length=140
find ./kademlia -name '*.py' | xargs pyflakes find ./kademlia -name '*.py' | xargs pyflakes
install: install:
python setup.py install python setup.py install
test:
trial kademlia

22
example/example.py Normal file
View File

@ -0,0 +1,22 @@
from twisted.internet import reactor
from twisted.python import log
from kademlia.network import Server
import sys
log.startLogging(sys.stdout)
def quit(result):
print "Key result:", result
reactor.stop()
def get(result, server):
reactor.stop()
#return server.get("a key").addCallback(quit)
def done(found, server):
log.msg("Found nodes: %s" % found)
return server.set("a key", "a value").addCallback(get, server)
two = Server(5678)
two.bootstrap([('127.0.0.1', 1234)]).addCallback(done, two)
reactor.run()

View File

@ -4,13 +4,5 @@ from kademlia.network import Server
import sys import sys
log.startLogging(sys.stdout) log.startLogging(sys.stdout)
one = Server(1234) one = Server(1234)
def done(found):
print "Found nodes: ", found
reactor.stop()
two = Server(5678)
two.bootstrap([('127.0.0.1', 1234)]).addCallback(done)
reactor.run() reactor.run()

View File

@ -1,2 +1,7 @@
"""
Kademlia is a Python implementation of the Kademlia protocol for U{Twisted <http://twistedmatrix.com/trac/>}.
@author: Brian Muller U{bamuller@gmail.com}
"""
version_info = (0, 0) version_info = (0, 0)
version = '.'.join(map(str, version_info)) version = '.'.join(map(str, version_info))

63
kademlia/log.py Normal file
View File

@ -0,0 +1,63 @@
import sys
from twisted.python import log
INFO = 5
DEBUG = 4
WARNING = 3
ERROR = 2
CRITICAL = 1
class FileLogObserver(log.FileLogObserver):
def __init__(self, f=None, level=WARNING, default=DEBUG):
log.FileLogObserver.__init__(self, f or sys.stdout)
self.level = level
self.default = default
def emit(self, eventDict):
ll = eventDict.get('loglevel', self.default)
if eventDict['isError'] or 'failure' in eventDict or self.level >= ll:
log.FileLogObserver.emit(self, eventDict)
class Logger:
def __init__(self, **kwargs):
self.kwargs = kwargs
def msg(self, message, **kw):
kw.update(self.kwargs)
if 'system' in kw and not isinstance(kw['system'], str):
kw['system'] = kw['system'].__class__.__name__
log.msg(message, **kw)
def info(self, message, **kw):
kw['loglevel'] = INFO
self.msg("[INFO] %s" % message, **kw)
def debug(self, message, **kw):
kw['loglevel'] = DEBUG
self.msg("[DEBUG] %s" % message, **kw)
def warning(self, message, **kw):
kw['loglevel'] = WARNING
self.msg("[WARNING] %s" % message, **kw)
def error(self, message, **kw):
kw['loglevel'] = ERROR
self.msg("[ERROR] %s" % message, **kw)
def critical(self, message, **kw):
kw['loglevel'] = CRITICAL
self.msg("[CRITICAL] %s" % message, **kw)
try:
theLogger
except NameError:
theLogger = Logger()
msg = theLogger.msg
info = theLogger.info
debug = theLogger.debug
warning = theLogger.warning
error = theLogger.error
critical = theLogger.critical

View File

@ -1,48 +1,78 @@
import hashlib """
Package for interacting on the network at a high level.
"""
import random import random
from twisted.internet.task import LoopingCall
from twisted.internet import defer from twisted.internet import defer
from twisted.python import log
from kademlia.log import Logger
from kademlia.protocol import KademliaProtocol from kademlia.protocol import KademliaProtocol
from kademlia.utils import deferredDict from kademlia.utils import deferredDict, digest
from kademlia.storage import ForgetfulStorage from kademlia.storage import ForgetfulStorage
from kademlia.node import Node, NodeHeap from kademlia.node import Node, NodeHeap
class SpiderCrawl(object): class SpiderCrawl(object):
# call find_node to current ALPHA nearest not already queried, """
# ...adding results to current nearest Crawl the network and look for given 160-bit keys.
# current nearest list needs to keep track of who has been queried already """
# sort by nearest, keep KSIZE
# if list is same as last time, next call should be to everyone not
# yet queried
# repeat, unless nearest list has all been queried, then ur done
def __init__(self, protocol, node, peers, ksize, alpha): def __init__(self, protocol, node, peers, ksize, alpha):
"""
Create a new C{SpiderCrawl}er.
@param protocol: a C{KademliaProtocol} instance.
@param node: A C{Node} representing the key we're looking for
@param peers: A list of C{Node}s that provide the entry point for the network
@param ksize: The value for k based on the paper
@param alpha: The value for alpha based on the paper
"""
self.protocol = protocol self.protocol = protocol
self.ksize = ksize self.ksize = ksize
self.alpha = alpha self.alpha = alpha
self.nearest = NodeHeap(self.ksize) self.nearest = NodeHeap(self.ksize)
self.node = node self.node = node
self.lastIDsCrawled = [] self.lastIDsCrawled = []
self.log = Logger(system=self)
self.log.info("creating spider with peers: %s" % peers)
for peer in peers: for peer in peers:
self.nearest.push(self.node.distanceTo(peer), peer) self.nearest.push(self.node.distanceTo(peer), peer)
def findNodes(self): def findNodes(self):
return self.find(self.protocol.callFindNode) """
Find the closest nodes.
"""
return self._find(self.protocol.callFindNode)
def findValue(self): def findValue(self):
"""
Find either the closest nodes or the value requested.
"""
def handle(result): def handle(result):
if isinstance(result, dict): if isinstance(result, dict):
return result['value'] return result['value']
return None return None
d = self.find(self.protocol.callFindValue) d = self._find(self.protocol.callFindValue)
return d.addCallback(handle) return d.addCallback(handle)
def find(self, rpcmethod): def _find(self, rpcmethod):
"""
Get either a value or list of nodes.
@param rpcmethod: The protocol's C{callfindValue} or C{callFindNode}.
The process:
1. calls find_* to current ALPHA nearest not already queried nodes,
adding results to current nearest list of k nodes.
2. current nearest list needs to keep track of who has been queried already
sort by nearest, keep KSIZE
3. if list is same as last time, next call should be to everyone not
yet queried
4. repeat, unless nearest list has all been queried, then ur done
"""
count = self.alpha count = self.alpha
if self.nearest.getIDs() == self.lastIDsCrawled: if self.nearest.getIDs() == self.lastIDsCrawled:
self.log.info("last iteration same as current - checking all in list now")
count = len(self.nearest) count = len(self.nearest)
self.lastIDsCrawled = self.nearest.getIDs() self.lastIDsCrawled = self.nearest.getIDs()
@ -50,9 +80,13 @@ class SpiderCrawl(object):
for peer in self.nearest.getUncontacted()[:count]: for peer in self.nearest.getUncontacted()[:count]:
ds[peer.id] = rpcmethod(peer, self.node) ds[peer.id] = rpcmethod(peer, self.node)
self.nearest.markContacted(peer) self.nearest.markContacted(peer)
return deferredDict(ds).addCallback(self.nodesFound) return deferredDict(ds).addCallback(self._nodesFound)
def nodesFound(self, responses): def _nodesFound(self, responses):
"""
Handle the result of an iteration in C{_find}.
"""
print "got some responses: ", responses
toremove = [] toremove = []
for peerid, response in responses.items(): for peerid, response in responses.items():
# response will be a tuple of (<response received>, <value>) # response will be a tuple of (<response received>, <value>)
@ -61,6 +95,7 @@ class SpiderCrawl(object):
if not response[0]: if not response[0]:
toremove.push(peerid) toremove.push(peerid)
elif isinstance(response[1], dict): elif isinstance(response[1], dict):
self.log.debug("found value for %i" % self.node.long_id)
return response[1] return response[1]
for nodeple in (response[1] or []): for nodeple in (response[1] or []):
peer = Node(*nodeple) peer = Node(*nodeple)
@ -72,17 +107,47 @@ class SpiderCrawl(object):
return self.findNodes() return self.findNodes()
class Server: class Server(object):
"""
High level view of a node instance. This is the object that should be created
to start listening as an active node on the network.
"""
def __init__(self, port, ksize=20, alpha=3): def __init__(self, port, ksize=20, alpha=3):
"""
Create a server instance. This will start listening on the given port.
@param port: UDP port to listen on
@param k: The k parameter from the paper
@param alpha: The alpha parameter from the paper
"""
self.ksize = ksize self.ksize = ksize
self.alpha = alpha self.alpha = alpha
# 160 bit random id self.log = Logger(system=self)
rid = hashlib.sha1(str(random.getrandbits(255))).digest()
storage = ForgetfulStorage() storage = ForgetfulStorage()
self.node = Node('127.0.0.1', port, rid) self.node = Node('127.0.0.1', port, digest(random.getrandbits(255)))
self.protocol = KademliaProtocol(self.node, storage, ksize, alpha) self.protocol = KademliaProtocol(self.node, storage, ksize)
self.refreshLoop = LoopingCall(self.refreshTable).start(3600)
def refreshTable(self):
"""
Refresh buckets that haven't had any lookups in the last hour
(per section 2.3 of the paper).
"""
ds = []
for id in self.protocol.getRefreshIDs():
node = Node(None, None, id)
nearest = self.protocol.router.findNeighbors(node, self.alpha)
spider = SpiderCrawl(self.protocol, node, nearest)
ds.append(spider.findNodes())
return defer.gatherResults(ds)
def bootstrap(self, addrs): def bootstrap(self, addrs):
"""
Bootstrap the server by connecting to other known nodes in the network.
@param addrs: A C{list} of (ip, port) C{tuple}s
"""
def initTable(results): def initTable(results):
nodes = [] nodes = []
for addr, result in results.items(): for addr, result in results.items():
@ -97,17 +162,29 @@ class Server:
return deferredDict(ds).addCallback(initTable) return deferredDict(ds).addCallback(initTable)
def get(self, key): def get(self, key):
node = Node(None, None, key) """
nearest = self.router.findNeighbors(node) Get a key if the network has it.
@return: C{None} if not found, the value otherwise.
"""
node = Node(None, None, digest(key))
nearest = self.protocol.router.findNeighbors(node)
spider = SpiderCrawl(self.protocol, node, nearest, self.ksize, self.alpha) spider = SpiderCrawl(self.protocol, node, nearest, self.ksize, self.alpha)
return spider.findValue() return spider.findValue()
def set(self, key, value): def set(self, key, value):
# TODO - if no one responds, freak out """
Set the given key to the given value in the network.
TODO - if no one responds, freak out
"""
self.log.debug("setting '%s' = '%s' on network" % (key, value))
dkey = digest(key)
def store(nodes): def store(nodes):
ds = [self.protocol.callStore(node) for node in nodes] self.log.info("setting '%s' on %s" % (key, map(str, nodes)))
ds = [self.protocol.callStore(node, dkey, value) for node in nodes]
return defer.gatherResults(ds) return defer.gatherResults(ds)
node = Node(None, None, key) node = Node(None, None, dkey)
nearest = self.router.findNeighbors(node) nearest = self.protocol.router.findNeighbors(node)
spider = SpiderCrawl(self.protocol, nearest, self.ksize, self.alpha) spider = SpiderCrawl(self.protocol, node, nearest, self.ksize, self.alpha)
return spider.findNodes(node).addCallback(store) return spider.findNodes().addCallback(store)

View File

@ -1,6 +1,7 @@
from operator import itemgetter from operator import itemgetter
import heapq import heapq
class Node: class Node:
def __init__(self, ip, port, id): def __init__(self, ip, port, id):
self.ip = ip self.ip = ip
@ -8,6 +9,9 @@ class Node:
self.id = id self.id = id
self.long_id = long(id.encode('hex'), 16) self.long_id = long(id.encode('hex'), 16)
def sameHomeAs(self, node):
return self.ip == node.ip and self.port == node.port
def distanceTo(self, node): def distanceTo(self, node):
return self.long_id ^ node.long_id return self.long_id ^ node.long_id
@ -20,6 +24,9 @@ class Node:
def __repr__(self): def __repr__(self):
return repr([self.ip, self.port, self.long_id]) return repr([self.ip, self.port, self.long_id])
def __str__(self):
return "%s:%s" % (self.ip, str(self.port))
class NodeHeap(object): class NodeHeap(object):
def __init__(self, maxsize): def __init__(self, maxsize):

View File

@ -1,16 +1,27 @@
from twisted.python import log import random
from rpcudp.protocol import RPCProtocol from rpcudp.protocol import RPCProtocol
from kademlia.node import Node from kademlia.node import Node
from kademlia.routing import RoutingTable from kademlia.routing import RoutingTable
from kademlia.log import Logger
class KademliaProtocol(RPCProtocol): class KademliaProtocol(RPCProtocol):
def __init__(self, node, storage, ksize, alpha): def __init__(self, node, storage, ksize):
RPCProtocol.__init__(self, node.port) RPCProtocol.__init__(self, node.port)
self.router = RoutingTable(self, ksize, alpha) self.router = RoutingTable(self, ksize)
self.storage = storage self.storage = storage
self.sourceID = node.id self.sourceID = node.id
self.log = Logger(system=self)
def getRefreshIDs(self):
"""
Get ids to search for to keep old buckets up to date.
"""
ids = []
for bucket in self.router.getLonelyBuckets():
ids.append(random.randint(*bucket.range))
return ids
def rpc_ping(self, sender, nodeid): def rpc_ping(self, sender, nodeid):
source = Node(sender[0], sender[1], nodeid) source = Node(sender[0], sender[1], nodeid)
@ -20,13 +31,15 @@ class KademliaProtocol(RPCProtocol):
def rpc_store(self, sender, nodeid, key, value): def rpc_store(self, sender, nodeid, key, value):
source = Node(sender[0], sender[1], nodeid) source = Node(sender[0], sender[1], nodeid)
self.router.addContact(source) self.router.addContact(source)
self.log.debug("got a store request from %s, storing value" % str(sender))
self.storage[key] = value self.storage[key] = value
def rpc_find_node(self, sender, nodeid, key): def rpc_find_node(self, sender, nodeid, key):
self.log.info("finding neighbors of %i in local table" % long(nodeid.encode('hex'), 16))
source = Node(sender[0], sender[1], nodeid) source = Node(sender[0], sender[1], nodeid)
self.router.addContact(source) self.router.addContact(source)
node = Node(None, None, key) node = Node(None, None, key)
return map(tuple, self.router.findNeighbors(node)) return map(tuple, self.router.findNeighbors(node, exclude=source))
def rpc_find_value(self, sender, nodeid, key): def rpc_find_value(self, sender, nodeid, key):
source = Node(sender[0], sender[1], nodeid) source = Node(sender[0], sender[1], nodeid)
@ -62,7 +75,9 @@ class KademliaProtocol(RPCProtocol):
we get no response, make sure it's removed from the routing table. we get no response, make sure it's removed from the routing table.
""" """
if result[0]: if result[0]:
self.log.info("got response from %s, adding to router" % node)
self.router.addContact(node) self.router.addContact(node)
else: else:
self.log.debug("no response from %s, removing from router" % node)
self.router.removeContact(node) self.router.removeContact(node)
return result return result

View File

@ -1,9 +1,8 @@
import heapq import heapq
import time import time
import operator
from collections import OrderedDict from collections import OrderedDict
from twisted.internet.task import LoopingCall
from twisted.internet import defer
class KBucket(object): class KBucket(object):
def __init__(self, rangeLower, rangeUpper, ksize): def __init__(self, rangeLower, rangeUpper, ksize):
@ -21,7 +20,7 @@ class KBucket(object):
def split(self): def split(self):
midpoint = self.range[1] - ((self.range[1] - self.range[0]) / 2) midpoint = self.range[1] - ((self.range[1] - self.range[0]) / 2)
one = KBucket(self.range[0], midpoint, self.ksize) one = KBucket(self.range[0], midpoint, self.ksize)
two = KBucket(midpoint+1, self.range[1], self.ksize) two = KBucket(midpoint + 1, self.range[1], self.ksize)
for node in self.nodes.values(): for node in self.nodes.values():
bucket = one if node.long_id <= midpoint else two bucket = one if node.long_id <= midpoint else two
bucket.nodes[node.id] = node bucket.nodes[node.id] = node
@ -32,7 +31,7 @@ class KBucket(object):
del self.nodes[node.id] del self.nodes[node.id]
def hasInRange(self, node): def hasInRange(self, node):
return rangeLower <= node.long_id <= rangeUpper return self.range[0] <= node.long_id <= self.range[1]
def addNode(self, node): def addNode(self, node):
""" """
@ -64,7 +63,7 @@ class TableTraverser(object):
table.buckets[index].touchLastUpdated() table.buckets[index].touchLastUpdated()
self.currentNodes = table.buckets[index].getNodes() self.currentNodes = table.buckets[index].getNodes()
self.leftBuckets = table.buckets[:index] self.leftBuckets = table.buckets[:index]
self.rightBuckets = table.buckets[(index+1):] self.rightBuckets = table.buckets[(index + 1):]
self.left = True self.left = True
def __iter__(self): def __iter__(self):
@ -91,29 +90,24 @@ class TableTraverser(object):
class RoutingTable(object): class RoutingTable(object):
def __init__(self, protocol, ksize, alpha): def __init__(self, protocol, ksize):
self.protocol = protocol self.protocol = protocol
self.ksize = ksize self.ksize = ksize
self.alpha = alpha self.buckets = [KBucket(0, 2 ** 160, ksize)]
self.buckets = [KBucket(0, 2**160, ksize)]
LoopingCall(self.refresh).start(3600)
def splitBucket(self, index): def splitBucket(self, index):
one, two = self.buckets[index].split() one, two = self.buckets[index].split()
self.buckets[index] = one self.buckets[index] = one
self.buckets.insert(index+1, two) self.buckets.insert(index + 1, two)
# todo split one/two if needed based on section 4.2 # todo split one/two if needed based on section 4.2
def refresh(self): def getLonelyBuckets(self):
ds = [] """
for bucket in self.buckets: Get all of the buckets that haven't been updated in over
if bucket.lastUpdated < (time.time() - 3600): an hour.
node = Node(None, None, random.randint(*bucket.range)) """
nearest = self.findNeighbors(node, self.alpha) return [b for b in self.buckets if b.lastUpdated < (time.time() - 3600)]
spider = NetworkSpider(self.protocol, node, nearest)
ds.append(spider.findNodes())
return defer.gatherResults(ds)
def removeContact(self, node): def removeContact(self, node):
index = self.getBucketFor(self, node) index = self.getBucketFor(self, node)
@ -141,12 +135,13 @@ class RoutingTable(object):
if node.long_id < bucket.range[1]: if node.long_id < bucket.range[1]:
return index return index
def findNeighbors(self, node, k=None): def findNeighbors(self, node, k=None, exclude=None):
k = k or self.ksize k = k or self.ksize
nodes = [] nodes = []
for neighbor in TableTraverser(self, node): for neighbor in TableTraverser(self, node):
if neighbor.id != node.id: if neighbor.id != node.id and (exclude is None or not neighbor.sameHomeAs(exclude)):
heapq.heappush(nodes, (node.distanceFrom(neighbor), neighbor)) heapq.heappush(nodes, (node.distanceTo(neighbor), neighbor))
if len(nodes) == k: if len(nodes) == k:
break break
return heapq.nsmallest(k, nodes)
return map(operator.itemgetter(1), heapq.nsmallest(k, nodes))

View File

@ -1,6 +1,7 @@
import time import time
from collections import OrderedDict from collections import OrderedDict
class ForgetfulStorage(object): class ForgetfulStorage(object):
def __init__(self, ttl=7200): def __init__(self, ttl=7200):
""" """
@ -24,6 +25,12 @@ class ForgetfulStorage(object):
for _ in xrange(pop): for _ in xrange(pop):
self.data.popitem(first=True) self.data.popitem(first=True)
def get(self, key, default=None):
self.cull()
if key in self.data:
return self.data[key][1]
return default
def __getitem__(self, key): def __getitem__(self, key):
self.cull() self.cull()
return self.data[key][1] return self.data[key][1]

View File

View File

@ -0,0 +1,11 @@
import random
import hashlib
from twisted.trial import unittest
from kademlia.node import Node, NodeHeap
class NodeTest(unittest.TestCase):
def test_longID(self):
rid = hashlib.sha1(str(random.getrandbits(255))).digest()
n = Node(None, None, rid)
self.assertEqual(n.long_id, long(rid.encode('hex'), 16))

View File

@ -1,9 +1,15 @@
""" """
General catchall for functions that don't make sense as methods. General catchall for functions that don't make sense as methods.
""" """
import hashlib
from twisted.internet import defer from twisted.internet import defer
def digest(s):
if not isinstance(s, str):
s = str(s)
return hashlib.sha1(s).digest()
def deferredDict(d): def deferredDict(d):
""" """
Just like a C{defer.DeferredList} but instead accepts and returns a C{dict}. Just like a C{defer.DeferredList} but instead accepts and returns a C{dict}.