fixed bootstrap process

This commit is contained in:
Brian Muller 2014-01-02 23:06:12 -05:00
parent d29a532abd
commit 75a97fb8f1
7 changed files with 129 additions and 88 deletions

2
.gitignore vendored
View File

@ -2,4 +2,4 @@ apidoc
*.pyc *.pyc
build build
dist dist
rpcudp.egg-info kademlia.egg-info

View File

@ -1,11 +1,11 @@
PYDOCTOR=pydoctor PYDOCTOR=pydoctor
docs: docs:
$(PYDOCTOR) --make-html --html-output apidoc --add-package rpcudp --project-name=rpcudp --project-url=http://github.com/bmuller/rpcudp --html-use-sorttable --html-use-splitlinks --html-shorten-lists $(PYDOCTOR) --make-html --html-output apidoc --add-package kademlia --project-name=kademlia --project-url=http://github.com/bmuller/kademlia --html-use-sorttable --html-use-splitlinks --html-shorten-lists
lint: lint:
pep8 --ignore=E303,E251,E201,E202 ./rpcudp --max-line-length=140 pep8 --ignore=E303,E251,E201,E202 ./rpcudp --max-line-length=140
find ./rpcudp -name '*.py' | xargs pyflakes find ./kademlia -name '*.py' | xargs pyflakes
install: install:
python setup.py install python setup.py install

16
example.py Normal file
View File

@ -0,0 +1,16 @@
from twisted.internet import reactor
from twisted.python import log
from kademlia.network import Server
import sys
log.startLogging(sys.stdout)
one = Server(1234)
def done(found):
print "Found nodes: ", found
reactor.stop()
two = Server(5678)
two.bootstrap([('127.0.0.1', 1234)]).addCallback(done)
reactor.run()

View File

@ -1,58 +1,13 @@
import hashlib import hashlib
import random import random
import heapq
from twisted.internet import log, defer from twisted.internet import defer
from twisted.python import log
from kademlia.protocol import KademliaProtocol from kademlia.protocol import KademliaProtocol
from kademlia.utils import deferredDict from kademlia.utils import deferredDict
from kademlia.storage import ForgetfulStorage from kademlia.storage import ForgetfulStorage
from kademlia.node import Node, NodeHeap
ALPHA = 3
class NodeHeap(object):
def __init__(self, maxsize):
self.heap = []
self.contacted = set()
self.maxsize = maxsize
def remove(self, peerIDs):
"""
Remove a list of peer ids from this heap. Note that while this
heap retains a constant visible size (based on the iterator), it's
actual size may be quite a bit larger than what's exposed. Therefore,
removal of nodes may not change the visible size as previously added
nodes suddenly become visible.
"""
peerIDs = set(peerIDs)
if len(peerIDs) == 0:
return
nheap = []
for distance, node in self.heap:
if not node.id in peerIDs:
heapq.heappush(nheap, (distance, node))
self.heap = nheap
def allBeenContacted(self):
return len(self.getUncontacted()) == 0
def getIDs(self):
return [n.id for n in self]
def markContacted(self, node):
self.contacted.add(node.id)
def push(self, distance, node):
heapq.heappush(self.heap, (distance, node))
def __len__(self):
return min(len(self.heap), self.maxsize)
def __iter__(self):
return iter(heapq.nsmallest(self.maxsize, self.heap))
def getUncontacted(self):
return [n for n in self if not n.id in self.contacted]
class SpiderCrawl(object): class SpiderCrawl(object):
@ -64,15 +19,16 @@ class SpiderCrawl(object):
# yet queried # yet queried
# repeat, unless nearest list has all been queried, then ur done # repeat, unless nearest list has all been queried, then ur done
def __init__(self, protocol, node, peers): def __init__(self, protocol, node, peers, ksize, alpha):
self.protocol = protocol self.protocol = protocol
self.nearest = NodeHeap(KSIZE) self.ksize = ksize
self.alpha = alpha
self.nearest = NodeHeap(self.ksize)
self.node = node self.node = node
self.lastIDsCrawled = [] self.lastIDsCrawled = []
for peer in peers: for peer in peers:
self.nearest.push(self.node.distanceTo(peer), peer) self.nearest.push(self.node.distanceTo(peer), peer)
def findNodes(self): def findNodes(self):
return self.find(self.protocol.callFindNode) return self.find(self.protocol.callFindNode)
@ -85,7 +41,7 @@ class SpiderCrawl(object):
return d.addCallback(handle) return d.addCallback(handle)
def find(self, rpcmethod): def find(self, rpcmethod):
count = ALPHA count = self.alpha
if self.nearest.getIDs() == self.lastIDsCrawled: if self.nearest.getIDs() == self.lastIDsCrawled:
count = len(self.nearest) count = len(self.nearest)
self.lastIDsCrawled = self.nearest.getIDs() self.lastIDsCrawled = self.nearest.getIDs()
@ -118,21 +74,32 @@ class SpiderCrawl(object):
class Server: class Server:
def __init__(self, port, ksize=20, alpha=3): def __init__(self, port, ksize=20, alpha=3):
self.ksize = ksize
self.alpha = alpha
# 160 bit random id # 160 bit random id
rid = hashlib.sha1(str(random.getrandbits(255))).digest() rid = hashlib.sha1(str(random.getrandbits(255))).digest()
storage = ForgetfulStorage() storage = ForgetfulStorage()
self.node = Node('127.0.0.1', port, rid) self.node = Node('127.0.0.1', port, rid)
self.prototcol = KademliaProtocol(self.node, storage, ksize, alpha) self.protocol = KademliaProtocol(self.node, storage, ksize, alpha)
def bootstrap(self, nodes): def bootstrap(self, addrs):
nodes = [ Node(*n) for n in nodes ] def initTable(results):
spider = NetworkSpider(self.protocol, self.node, nodes) nodes = []
return spider.findNodes() for addr, result in results.items():
if result[0]:
nodes.append(Node(addr[0], addr[1], result[1]))
spider = SpiderCrawl(self.protocol, self.node, nodes, self.ksize, self.alpha)
return spider.findNodes()
ds = {}
for addr in addrs:
ds[addr] = self.protocol.ping(addr, self.node.id)
return deferredDict(ds).addCallback(initTable)
def get(self, key): def get(self, key):
node = Node(None, None, key) node = Node(None, None, key)
nearest = self.router.findNeighbors(node, ALPHA) nearest = self.router.findNeighbors(node)
spider = NetworkSpider(self.protocol, node, nearest) spider = SpiderCrawl(self.protocol, node, nearest, self.ksize, self.alpha)
return spider.findValue() return spider.findValue()
def set(self, key, value): def set(self, key, value):
@ -141,6 +108,6 @@ class Server:
ds = [self.protocol.callStore(node) for node in nodes] ds = [self.protocol.callStore(node) for node in nodes]
return defer.gatherResults(ds) return defer.gatherResults(ds)
node = Node(None, None, key) node = Node(None, None, key)
nearest = self.router.findNeighbors(node, ALPHA) nearest = self.router.findNeighbors(node)
spider = NetworkSpider(self.protocol, nearest) spider = SpiderCrawl(self.protocol, nearest, self.ksize, self.alpha)
return spider.findNodes(node).addCallback(store) return spider.findNodes(node).addCallback(store)

View File

@ -1,12 +1,14 @@
from operator import itemgetter
import heapq
class Node: class Node:
def __init__(self, ip, port, id=None): def __init__(self, ip, port, id):
self.ip = ip self.ip = ip
self.port = port self.port = port
self.id = id self.id = id
if id is not None: self.long_id = long(id.encode('hex'), 16)
self.long_id = long(id.encode('hex'), 16)
def distnaceTo(self, node): def distanceTo(self, node):
return self.long_id ^ node.long_id return self.long_id ^ node.long_id
def __iter__(self): def __iter__(self):
@ -14,3 +16,52 @@ class Node:
Enables use of Node as a tuple - i.e., tuple(node) works. Enables use of Node as a tuple - i.e., tuple(node) works.
""" """
return iter([self.ip, self.port, self.id]) return iter([self.ip, self.port, self.id])
def __repr__(self):
return repr([self.ip, self.port, self.long_id])
class NodeHeap(object):
def __init__(self, maxsize):
self.heap = []
self.contacted = set()
self.maxsize = maxsize
def remove(self, peerIDs):
"""
Remove a list of peer ids from this heap. Note that while this
heap retains a constant visible size (based on the iterator), it's
actual size may be quite a bit larger than what's exposed. Therefore,
removal of nodes may not change the visible size as previously added
nodes suddenly become visible.
"""
peerIDs = set(peerIDs)
if len(peerIDs) == 0:
return
nheap = []
for distance, node in self.heap:
if not node.id in peerIDs:
heapq.heappush(nheap, (distance, node))
self.heap = nheap
def allBeenContacted(self):
return len(self.getUncontacted()) == 0
def getIDs(self):
return [n.id for n in self]
def markContacted(self, node):
self.contacted.add(node.id)
def push(self, distance, node):
heapq.heappush(self.heap, (distance, node))
def __len__(self):
return min(len(self.heap), self.maxsize)
def __iter__(self):
nodes = heapq.nsmallest(self.maxsize, self.heap)
return iter(map(itemgetter(1), nodes))
def getUncontacted(self):
return [n for n in self if not n.id in self.contacted]

View File

@ -1,4 +1,4 @@
from twisted.internet import log from twisted.python import log
from rpcudp.protocol import RPCProtocol from rpcudp.protocol import RPCProtocol
from kademlia.node import Node from kademlia.node import Node
@ -8,14 +8,14 @@ from kademlia.routing import RoutingTable
class KademliaProtocol(RPCProtocol): class KademliaProtocol(RPCProtocol):
def __init__(self, node, storage, ksize, alpha): def __init__(self, node, storage, ksize, alpha):
RPCProtocol.__init__(self, node.port) RPCProtocol.__init__(self, node.port)
self.router = RoutingTable(self) self.router = RoutingTable(self, ksize, alpha)
self.storage = storage self.storage = storage
self.sourceID = node.id self.sourceID = node.id
def rpc_ping(self, sender, nodeid): def rpc_ping(self, sender, nodeid):
source = Node(sender[0], sender[1], nodeid) source = Node(sender[0], sender[1], nodeid)
self.router.addContact(source) self.router.addContact(source)
return "pong" return self.sourceID
def rpc_store(self, sender, nodeid, key, value): def rpc_store(self, sender, nodeid, key, value):
source = Node(sender[0], sender[1], nodeid) source = Node(sender[0], sender[1], nodeid)
@ -25,7 +25,8 @@ class KademliaProtocol(RPCProtocol):
def rpc_find_node(self, sender, nodeid, key): def rpc_find_node(self, sender, nodeid, key):
source = Node(sender[0], sender[1], nodeid) source = Node(sender[0], sender[1], nodeid)
self.router.addContact(source) self.router.addContact(source)
return map(tuple, self.table.findNeighbors(Node(None, None, key)) node = Node(None, None, key)
return map(tuple, self.router.findNeighbors(node))
def rpc_find_value(self, sender, nodeid, key): def rpc_find_value(self, sender, nodeid, key):
source = Node(sender[0], sender[1], nodeid) source = Node(sender[0], sender[1], nodeid)
@ -38,24 +39,24 @@ class KademliaProtocol(RPCProtocol):
def callFindNode(self, nodeToAsk, nodeToFind): def callFindNode(self, nodeToAsk, nodeToFind):
address = (nodeToAsk.ip, nodeToAsk.port) address = (nodeToAsk.ip, nodeToAsk.port)
d = self.find_node(address, self.sourceID, nodeToFind.id) d = self.find_node(address, self.sourceID, nodeToFind.id)
return d.addCallback(handleCallResponse, nodetoAsk) return d.addCallback(self.handleCallResponse, nodeToAsk)
def callFindValue(self, nodeToAsk, nodeToFind): def callFindValue(self, nodeToAsk, nodeToFind):
address = (nodeToAsk.ip, nodeToAsk.port) address = (nodeToAsk.ip, nodeToAsk.port)
d = self.find_value(address, self.sourceID, nodeToFind.id) d = self.find_value(address, self.sourceID, nodeToFind.id)
return d.addCallback(handleCallResponse, nodetoAsk) return d.addCallback(self.handleCallResponse, nodeToAsk)
def callPing(self, nodeToAsk): def callPing(self, nodeToAsk):
address = (nodeToAsk.ip, nodeToAsk.port) address = (nodeToAsk.ip, nodeToAsk.port)
d = self.ping(address, self.sourceID) d = self.ping(address, self.sourceID)
return d.addCallback(handleCallResponse, nodetoAsk) return d.addCallback(self.handleCallResponse, nodeToAsk)
def callStore(self, nodeToAsk, key, value): def callStore(self, nodeToAsk, key, value):
address = (nodeToAsk.ip, nodeToAsk.port) address = (nodeToAsk.ip, nodeToAsk.port)
d = self.store(address, self.sourceID, key, value) d = self.store(address, self.sourceID, key, value)
return d.addCallback(handleCallResponse, nodetoAsk) return d.addCallback(self.handleCallResponse, nodeToAsk)
def handleCallResponse(self, result): def handleCallResponse(self, result, node):
""" """
If we get a response, add the node to the routing table. If If we get a response, add the node to the routing table. If
we get no response, make sure it's removed from the routing table. we get no response, make sure it's removed from the routing table.

View File

@ -1,13 +1,16 @@
import heapq import heapq
import time
from collections import OrderedDict from collections import OrderedDict
from twisted.internet.task import LoopingCall from twisted.internet.task import LoopingCall
from twisted.internet import defer
class KBucket(object): class KBucket(object):
def __init__(self, rangeLower, rangeUpper): def __init__(self, rangeLower, rangeUpper, ksize):
self.range = (rangeLower, rangeUpper) self.range = (rangeLower, rangeUpper)
self.nodes = OrderedDict() self.nodes = OrderedDict()
self.touchLastUpdated() self.touchLastUpdated()
self.ksize = ksize
def touchLastUpdated(self): def touchLastUpdated(self):
self.lastUpdated = time.time() self.lastUpdated = time.time()
@ -17,8 +20,8 @@ class KBucket(object):
def split(self): def split(self):
midpoint = self.range[1] - ((self.range[1] - self.range[0]) / 2) midpoint = self.range[1] - ((self.range[1] - self.range[0]) / 2)
one = KBucket(self.range[0], midpoint) one = KBucket(self.range[0], midpoint, self.ksize)
two = KBucket(midpoint+1, self.range[1]) two = KBucket(midpoint+1, self.range[1], self.ksize)
for node in self.nodes.values(): for node in self.nodes.values():
bucket = one if node.long_id <= midpoint else two bucket = one if node.long_id <= midpoint else two
bucket.nodes[node.id] = node bucket.nodes[node.id] = node
@ -39,7 +42,7 @@ class KBucket(object):
if node.id in self.nodes: if node.id in self.nodes:
del self.nodes[node.id] del self.nodes[node.id]
self.nodes[node.id] = node self.nodes[node.id] = node
elif len(self) < KSIZE: elif len(self) < self.ksize:
self.nodes[node.id] = node self.nodes[node.id] = node
else: else:
return False return False
@ -58,7 +61,7 @@ class KBucket(object):
class TableTraverser(object): class TableTraverser(object):
def __init__(self, table, startNode): def __init__(self, table, startNode):
index = table.getBucketFor(startNode) index = table.getBucketFor(startNode)
bucket[index].touchLastUpdated() table.buckets[index].touchLastUpdated()
self.currentNodes = table.buckets[index].getNodes() self.currentNodes = table.buckets[index].getNodes()
self.leftBuckets = table.buckets[:index] self.leftBuckets = table.buckets[:index]
self.rightBuckets = table.buckets[(index+1):] self.rightBuckets = table.buckets[(index+1):]
@ -88,9 +91,11 @@ class TableTraverser(object):
class RoutingTable(object): class RoutingTable(object):
def __init__(self, protocol): def __init__(self, protocol, ksize, alpha):
self.protocol = protocol self.protocol = protocol
self.buckets = [KBucket(0, 2**160)] self.ksize = ksize
self.alpha = alpha
self.buckets = [KBucket(0, 2**160, ksize)]
LoopingCall(self.refresh).start(3600) LoopingCall(self.refresh).start(3600)
def splitBucket(self, index): def splitBucket(self, index):
@ -105,7 +110,7 @@ class RoutingTable(object):
for bucket in self.buckets: for bucket in self.buckets:
if bucket.lastUpdated < (time.time() - 3600): if bucket.lastUpdated < (time.time() - 3600):
node = Node(None, None, random.randint(*bucket.range)) node = Node(None, None, random.randint(*bucket.range))
nearest = self.findNeighbors(node, ALPHA) nearest = self.findNeighbors(node, self.alpha)
spider = NetworkSpider(self.protocol, node, nearest) spider = NetworkSpider(self.protocol, node, nearest)
ds.append(spider.findNodes()) ds.append(spider.findNodes())
return defer.gatherResults(ds) return defer.gatherResults(ds)
@ -115,7 +120,7 @@ class RoutingTable(object):
self.buckets[index].removeNode(node) self.buckets[index].removeNode(node)
def addContact(self, node): def addContact(self, node):
index = self.getBucketFor(self, node) index = self.getBucketFor(node)
bucket = self.buckets[index] bucket = self.buckets[index]
# this will succeed unless the bucket is full # this will succeed unless the bucket is full
@ -136,7 +141,8 @@ class RoutingTable(object):
if node.long_id < bucket.range[1]: if node.long_id < bucket.range[1]:
return index return index
def findNeighbors(self, node, k=KSIZE): def findNeighbors(self, node, k=None):
k = k or self.ksize
nodes = [] nodes = []
for neighbor in TableTraverser(self, node): for neighbor in TableTraverser(self, node):
if neighbor.id != node.id: if neighbor.id != node.id: