started writing tests

This commit is contained in:
Brian Muller 2014-01-03 17:50:49 -05:00
parent 75a97fb8f1
commit f7e8bd9442
14 changed files with 275 additions and 71 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
_trial_temp
apidoc
*.pyc
build

View File

@ -4,8 +4,11 @@ docs:
$(PYDOCTOR) --make-html --html-output apidoc --add-package kademlia --project-name=kademlia --project-url=http://github.com/bmuller/kademlia --html-use-sorttable --html-use-splitlinks --html-shorten-lists
lint:
pep8 --ignore=E303,E251,E201,E202 ./rpcudp --max-line-length=140
pep8 --ignore=E303,E251,E201,E202 ./kademlia --max-line-length=140
find ./kademlia -name '*.py' | xargs pyflakes
install:
python setup.py install
test:
trial kademlia

22
example/example.py Normal file
View File

@ -0,0 +1,22 @@
from twisted.internet import reactor
from twisted.python import log
from kademlia.network import Server
import sys
log.startLogging(sys.stdout)
def quit(result):
print "Key result:", result
reactor.stop()
def get(result, server):
reactor.stop()
#return server.get("a key").addCallback(quit)
def done(found, server):
log.msg("Found nodes: %s" % found)
return server.set("a key", "a value").addCallback(get, server)
two = Server(5678)
two.bootstrap([('127.0.0.1', 1234)]).addCallback(done, two)
reactor.run()

View File

@ -4,13 +4,5 @@ from kademlia.network import Server
import sys
log.startLogging(sys.stdout)
one = Server(1234)
def done(found):
print "Found nodes: ", found
reactor.stop()
two = Server(5678)
two.bootstrap([('127.0.0.1', 1234)]).addCallback(done)
reactor.run()

View File

@ -1,2 +1,7 @@
"""
Kademlia is a Python implementation of the Kademlia protocol for U{Twisted <http://twistedmatrix.com/trac/>}.
@author: Brian Muller U{bamuller@gmail.com}
"""
version_info = (0, 0)
version = '.'.join(map(str, version_info))

63
kademlia/log.py Normal file
View File

@ -0,0 +1,63 @@
import sys
from twisted.python import log
INFO = 5
DEBUG = 4
WARNING = 3
ERROR = 2
CRITICAL = 1
class FileLogObserver(log.FileLogObserver):
def __init__(self, f=None, level=WARNING, default=DEBUG):
log.FileLogObserver.__init__(self, f or sys.stdout)
self.level = level
self.default = default
def emit(self, eventDict):
ll = eventDict.get('loglevel', self.default)
if eventDict['isError'] or 'failure' in eventDict or self.level >= ll:
log.FileLogObserver.emit(self, eventDict)
class Logger:
def __init__(self, **kwargs):
self.kwargs = kwargs
def msg(self, message, **kw):
kw.update(self.kwargs)
if 'system' in kw and not isinstance(kw['system'], str):
kw['system'] = kw['system'].__class__.__name__
log.msg(message, **kw)
def info(self, message, **kw):
kw['loglevel'] = INFO
self.msg("[INFO] %s" % message, **kw)
def debug(self, message, **kw):
kw['loglevel'] = DEBUG
self.msg("[DEBUG] %s" % message, **kw)
def warning(self, message, **kw):
kw['loglevel'] = WARNING
self.msg("[WARNING] %s" % message, **kw)
def error(self, message, **kw):
kw['loglevel'] = ERROR
self.msg("[ERROR] %s" % message, **kw)
def critical(self, message, **kw):
kw['loglevel'] = CRITICAL
self.msg("[CRITICAL] %s" % message, **kw)
try:
theLogger
except NameError:
theLogger = Logger()
msg = theLogger.msg
info = theLogger.info
debug = theLogger.debug
warning = theLogger.warning
error = theLogger.error
critical = theLogger.critical

View File

@ -1,48 +1,78 @@
import hashlib
"""
Package for interacting on the network at a high level.
"""
import random
from twisted.internet.task import LoopingCall
from twisted.internet import defer
from twisted.python import log
from kademlia.log import Logger
from kademlia.protocol import KademliaProtocol
from kademlia.utils import deferredDict
from kademlia.utils import deferredDict, digest
from kademlia.storage import ForgetfulStorage
from kademlia.node import Node, NodeHeap
class SpiderCrawl(object):
# call find_node to current ALPHA nearest not already queried,
# ...adding results to current nearest
# current nearest list needs to keep track of who has been queried already
# sort by nearest, keep KSIZE
# if list is same as last time, next call should be to everyone not
# yet queried
# repeat, unless nearest list has all been queried, then ur done
"""
Crawl the network and look for given 160-bit keys.
"""
def __init__(self, protocol, node, peers, ksize, alpha):
"""
Create a new C{SpiderCrawl}er.
@param protocol: a C{KademliaProtocol} instance.
@param node: A C{Node} representing the key we're looking for
@param peers: A list of C{Node}s that provide the entry point for the network
@param ksize: The value for k based on the paper
@param alpha: The value for alpha based on the paper
"""
self.protocol = protocol
self.ksize = ksize
self.alpha = alpha
self.nearest = NodeHeap(self.ksize)
self.node = node
self.lastIDsCrawled = []
self.log = Logger(system=self)
self.log.info("creating spider with peers: %s" % peers)
for peer in peers:
self.nearest.push(self.node.distanceTo(peer), peer)
def findNodes(self):
return self.find(self.protocol.callFindNode)
"""
Find the closest nodes.
"""
return self._find(self.protocol.callFindNode)
def findValue(self):
"""
Find either the closest nodes or the value requested.
"""
def handle(result):
if isinstance(result, dict):
return result['value']
return None
d = self.find(self.protocol.callFindValue)
d = self._find(self.protocol.callFindValue)
return d.addCallback(handle)
def find(self, rpcmethod):
def _find(self, rpcmethod):
"""
Get either a value or list of nodes.
@param rpcmethod: The protocol's C{callfindValue} or C{callFindNode}.
The process:
1. calls find_* to current ALPHA nearest not already queried nodes,
adding results to current nearest list of k nodes.
2. current nearest list needs to keep track of who has been queried already
sort by nearest, keep KSIZE
3. if list is same as last time, next call should be to everyone not
yet queried
4. repeat, unless nearest list has all been queried, then ur done
"""
count = self.alpha
if self.nearest.getIDs() == self.lastIDsCrawled:
self.log.info("last iteration same as current - checking all in list now")
count = len(self.nearest)
self.lastIDsCrawled = self.nearest.getIDs()
@ -50,9 +80,13 @@ class SpiderCrawl(object):
for peer in self.nearest.getUncontacted()[:count]:
ds[peer.id] = rpcmethod(peer, self.node)
self.nearest.markContacted(peer)
return deferredDict(ds).addCallback(self.nodesFound)
return deferredDict(ds).addCallback(self._nodesFound)
def nodesFound(self, responses):
def _nodesFound(self, responses):
"""
Handle the result of an iteration in C{_find}.
"""
print "got some responses: ", responses
toremove = []
for peerid, response in responses.items():
# response will be a tuple of (<response received>, <value>)
@ -61,6 +95,7 @@ class SpiderCrawl(object):
if not response[0]:
toremove.push(peerid)
elif isinstance(response[1], dict):
self.log.debug("found value for %i" % self.node.long_id)
return response[1]
for nodeple in (response[1] or []):
peer = Node(*nodeple)
@ -72,17 +107,47 @@ class SpiderCrawl(object):
return self.findNodes()
class Server:
class Server(object):
"""
High level view of a node instance. This is the object that should be created
to start listening as an active node on the network.
"""
def __init__(self, port, ksize=20, alpha=3):
"""
Create a server instance. This will start listening on the given port.
@param port: UDP port to listen on
@param k: The k parameter from the paper
@param alpha: The alpha parameter from the paper
"""
self.ksize = ksize
self.alpha = alpha
# 160 bit random id
rid = hashlib.sha1(str(random.getrandbits(255))).digest()
self.log = Logger(system=self)
storage = ForgetfulStorage()
self.node = Node('127.0.0.1', port, rid)
self.protocol = KademliaProtocol(self.node, storage, ksize, alpha)
self.node = Node('127.0.0.1', port, digest(random.getrandbits(255)))
self.protocol = KademliaProtocol(self.node, storage, ksize)
self.refreshLoop = LoopingCall(self.refreshTable).start(3600)
def refreshTable(self):
"""
Refresh buckets that haven't had any lookups in the last hour
(per section 2.3 of the paper).
"""
ds = []
for id in self.protocol.getRefreshIDs():
node = Node(None, None, id)
nearest = self.protocol.router.findNeighbors(node, self.alpha)
spider = SpiderCrawl(self.protocol, node, nearest)
ds.append(spider.findNodes())
return defer.gatherResults(ds)
def bootstrap(self, addrs):
"""
Bootstrap the server by connecting to other known nodes in the network.
@param addrs: A C{list} of (ip, port) C{tuple}s
"""
def initTable(results):
nodes = []
for addr, result in results.items():
@ -97,17 +162,29 @@ class Server:
return deferredDict(ds).addCallback(initTable)
def get(self, key):
node = Node(None, None, key)
nearest = self.router.findNeighbors(node)
"""
Get a key if the network has it.
@return: C{None} if not found, the value otherwise.
"""
node = Node(None, None, digest(key))
nearest = self.protocol.router.findNeighbors(node)
spider = SpiderCrawl(self.protocol, node, nearest, self.ksize, self.alpha)
return spider.findValue()
def set(self, key, value):
# TODO - if no one responds, freak out
"""
Set the given key to the given value in the network.
TODO - if no one responds, freak out
"""
self.log.debug("setting '%s' = '%s' on network" % (key, value))
dkey = digest(key)
def store(nodes):
ds = [self.protocol.callStore(node) for node in nodes]
self.log.info("setting '%s' on %s" % (key, map(str, nodes)))
ds = [self.protocol.callStore(node, dkey, value) for node in nodes]
return defer.gatherResults(ds)
node = Node(None, None, key)
nearest = self.router.findNeighbors(node)
spider = SpiderCrawl(self.protocol, nearest, self.ksize, self.alpha)
return spider.findNodes(node).addCallback(store)
node = Node(None, None, dkey)
nearest = self.protocol.router.findNeighbors(node)
spider = SpiderCrawl(self.protocol, node, nearest, self.ksize, self.alpha)
return spider.findNodes().addCallback(store)

View File

@ -1,6 +1,7 @@
from operator import itemgetter
import heapq
class Node:
def __init__(self, ip, port, id):
self.ip = ip
@ -8,6 +9,9 @@ class Node:
self.id = id
self.long_id = long(id.encode('hex'), 16)
def sameHomeAs(self, node):
return self.ip == node.ip and self.port == node.port
def distanceTo(self, node):
return self.long_id ^ node.long_id
@ -20,6 +24,9 @@ class Node:
def __repr__(self):
return repr([self.ip, self.port, self.long_id])
def __str__(self):
return "%s:%s" % (self.ip, str(self.port))
class NodeHeap(object):
def __init__(self, maxsize):

View File

@ -1,16 +1,27 @@
from twisted.python import log
import random
from rpcudp.protocol import RPCProtocol
from kademlia.node import Node
from kademlia.routing import RoutingTable
from kademlia.log import Logger
class KademliaProtocol(RPCProtocol):
def __init__(self, node, storage, ksize, alpha):
def __init__(self, node, storage, ksize):
RPCProtocol.__init__(self, node.port)
self.router = RoutingTable(self, ksize, alpha)
self.router = RoutingTable(self, ksize)
self.storage = storage
self.sourceID = node.id
self.log = Logger(system=self)
def getRefreshIDs(self):
"""
Get ids to search for to keep old buckets up to date.
"""
ids = []
for bucket in self.router.getLonelyBuckets():
ids.append(random.randint(*bucket.range))
return ids
def rpc_ping(self, sender, nodeid):
source = Node(sender[0], sender[1], nodeid)
@ -20,13 +31,15 @@ class KademliaProtocol(RPCProtocol):
def rpc_store(self, sender, nodeid, key, value):
source = Node(sender[0], sender[1], nodeid)
self.router.addContact(source)
self.log.debug("got a store request from %s, storing value" % str(sender))
self.storage[key] = value
def rpc_find_node(self, sender, nodeid, key):
self.log.info("finding neighbors of %i in local table" % long(nodeid.encode('hex'), 16))
source = Node(sender[0], sender[1], nodeid)
self.router.addContact(source)
node = Node(None, None, key)
return map(tuple, self.router.findNeighbors(node))
return map(tuple, self.router.findNeighbors(node, exclude=source))
def rpc_find_value(self, sender, nodeid, key):
source = Node(sender[0], sender[1], nodeid)
@ -62,7 +75,9 @@ class KademliaProtocol(RPCProtocol):
we get no response, make sure it's removed from the routing table.
"""
if result[0]:
self.log.info("got response from %s, adding to router" % node)
self.router.addContact(node)
else:
self.log.debug("no response from %s, removing from router" % node)
self.router.removeContact(node)
return result

View File

@ -1,9 +1,8 @@
import heapq
import time
import operator
from collections import OrderedDict
from twisted.internet.task import LoopingCall
from twisted.internet import defer
class KBucket(object):
def __init__(self, rangeLower, rangeUpper, ksize):
@ -32,7 +31,7 @@ class KBucket(object):
del self.nodes[node.id]
def hasInRange(self, node):
return rangeLower <= node.long_id <= rangeUpper
return self.range[0] <= node.long_id <= self.range[1]
def addNode(self, node):
"""
@ -91,12 +90,10 @@ class TableTraverser(object):
class RoutingTable(object):
def __init__(self, protocol, ksize, alpha):
def __init__(self, protocol, ksize):
self.protocol = protocol
self.ksize = ksize
self.alpha = alpha
self.buckets = [KBucket(0, 2 ** 160, ksize)]
LoopingCall(self.refresh).start(3600)
def splitBucket(self, index):
one, two = self.buckets[index].split()
@ -105,15 +102,12 @@ class RoutingTable(object):
# todo split one/two if needed based on section 4.2
def refresh(self):
ds = []
for bucket in self.buckets:
if bucket.lastUpdated < (time.time() - 3600):
node = Node(None, None, random.randint(*bucket.range))
nearest = self.findNeighbors(node, self.alpha)
spider = NetworkSpider(self.protocol, node, nearest)
ds.append(spider.findNodes())
return defer.gatherResults(ds)
def getLonelyBuckets(self):
"""
Get all of the buckets that haven't been updated in over
an hour.
"""
return [b for b in self.buckets if b.lastUpdated < (time.time() - 3600)]
def removeContact(self, node):
index = self.getBucketFor(self, node)
@ -141,12 +135,13 @@ class RoutingTable(object):
if node.long_id < bucket.range[1]:
return index
def findNeighbors(self, node, k=None):
def findNeighbors(self, node, k=None, exclude=None):
k = k or self.ksize
nodes = []
for neighbor in TableTraverser(self, node):
if neighbor.id != node.id:
heapq.heappush(nodes, (node.distanceFrom(neighbor), neighbor))
if neighbor.id != node.id and (exclude is None or not neighbor.sameHomeAs(exclude)):
heapq.heappush(nodes, (node.distanceTo(neighbor), neighbor))
if len(nodes) == k:
break
return heapq.nsmallest(k, nodes)
return map(operator.itemgetter(1), heapq.nsmallest(k, nodes))

View File

@ -1,6 +1,7 @@
import time
from collections import OrderedDict
class ForgetfulStorage(object):
def __init__(self, ttl=7200):
"""
@ -24,6 +25,12 @@ class ForgetfulStorage(object):
for _ in xrange(pop):
self.data.popitem(first=True)
def get(self, key, default=None):
self.cull()
if key in self.data:
return self.data[key][1]
return default
def __getitem__(self, key):
self.cull()
return self.data[key][1]

View File

View File

@ -0,0 +1,11 @@
import random
import hashlib
from twisted.trial import unittest
from kademlia.node import Node, NodeHeap
class NodeTest(unittest.TestCase):
def test_longID(self):
rid = hashlib.sha1(str(random.getrandbits(255))).digest()
n = Node(None, None, rid)
self.assertEqual(n.long_id, long(rid.encode('hex'), 16))

View File

@ -1,9 +1,15 @@
"""
General catchall for functions that don't make sense as methods.
"""
import hashlib
from twisted.internet import defer
def digest(s):
if not isinstance(s, str):
s = str(s)
return hashlib.sha1(s).digest()
def deferredDict(d):
"""
Just like a C{defer.DeferredList} but instead accepts and returns a C{dict}.