0,0 → 1,375 |
# Written by Bram Cohen |
# see LICENSE.txt for license information |
|
import socket |
from errno import EWOULDBLOCK, ECONNREFUSED, EHOSTUNREACH |
try: |
from select import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP |
timemult = 1000 |
except ImportError: |
from selectpoll import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP |
timemult = 1 |
from time import sleep |
from clock import clock |
import sys |
from random import shuffle, randrange |
from natpunch import UPnP_open_port, UPnP_close_port |
# from BT1.StreamCheck import StreamCheck |
# import inspect |
try: |
True |
except: |
True = 1 |
False = 0 |
|
all = POLLIN | POLLOUT |
|
UPnP_ERROR = "unable to forward port via UPnP" |
|
class SingleSocket: |
def __init__(self, socket_handler, sock, handler, ip = None): |
self.socket_handler = socket_handler |
self.socket = sock |
self.handler = handler |
self.buffer = [] |
self.last_hit = clock() |
self.fileno = sock.fileno() |
self.connected = False |
self.skipped = 0 |
# self.check = StreamCheck() |
try: |
self.ip = self.socket.getpeername()[0] |
except: |
if ip is None: |
self.ip = 'unknown' |
else: |
self.ip = ip |
|
def get_ip(self, real=False): |
if real: |
try: |
self.ip = self.socket.getpeername()[0] |
except: |
pass |
return self.ip |
|
def close(self): |
''' |
for x in xrange(5,0,-1): |
try: |
f = inspect.currentframe(x).f_code |
print (f.co_filename,f.co_firstlineno,f.co_name) |
del f |
except: |
pass |
print '' |
''' |
assert self.socket |
self.connected = False |
sock = self.socket |
self.socket = None |
self.buffer = [] |
del self.socket_handler.single_sockets[self.fileno] |
self.socket_handler.poll.unregister(sock) |
sock.close() |
|
def shutdown(self, val): |
self.socket.shutdown(val) |
|
def is_flushed(self): |
return not self.buffer |
|
def write(self, s): |
# self.check.write(s) |
assert self.socket is not None |
self.buffer.append(s) |
if len(self.buffer) == 1: |
self.try_write() |
|
def try_write(self): |
if self.connected: |
dead = False |
try: |
while self.buffer: |
buf = self.buffer[0] |
amount = self.socket.send(buf) |
if amount == 0: |
self.skipped += 1 |
break |
self.skipped = 0 |
if amount != len(buf): |
self.buffer[0] = buf[amount:] |
break |
del self.buffer[0] |
except socket.error, e: |
try: |
dead = e[0] != EWOULDBLOCK |
except: |
dead = True |
self.skipped += 1 |
if self.skipped >= 3: |
dead = True |
if dead: |
self.socket_handler.dead_from_write.append(self) |
return |
if self.buffer: |
self.socket_handler.poll.register(self.socket, all) |
else: |
self.socket_handler.poll.register(self.socket, POLLIN) |
|
def set_handler(self, handler): |
self.handler = handler |
|
class SocketHandler: |
def __init__(self, timeout, ipv6_enable, readsize = 100000): |
self.timeout = timeout |
self.ipv6_enable = ipv6_enable |
self.readsize = readsize |
self.poll = poll() |
# {socket: SingleSocket} |
self.single_sockets = {} |
self.dead_from_write = [] |
self.max_connects = 1000 |
self.port_forwarded = None |
self.servers = {} |
|
def scan_for_timeouts(self): |
t = clock() - self.timeout |
tokill = [] |
for s in self.single_sockets.values(): |
if s.last_hit < t: |
tokill.append(s) |
for k in tokill: |
if k.socket is not None: |
self._close_socket(k) |
|
def bind(self, port, bind = '', reuse = False, ipv6_socket_style = 1, upnp = 0): |
port = int(port) |
addrinfos = [] |
self.servers = {} |
self.interfaces = [] |
# if bind != "" thread it as a comma seperated list and bind to all |
# addresses (can be ips or hostnames) else bind to default ipv6 and |
# ipv4 address |
if bind: |
if self.ipv6_enable: |
socktype = socket.AF_UNSPEC |
else: |
socktype = socket.AF_INET |
bind = bind.split(',') |
for addr in bind: |
if sys.version_info < (2,2): |
addrinfos.append((socket.AF_INET, None, None, None, (addr, port))) |
else: |
addrinfos.extend(socket.getaddrinfo(addr, port, |
socktype, socket.SOCK_STREAM)) |
else: |
if self.ipv6_enable: |
addrinfos.append([socket.AF_INET6, None, None, None, ('', port)]) |
if not addrinfos or ipv6_socket_style != 0: |
addrinfos.append([socket.AF_INET, None, None, None, ('', port)]) |
for addrinfo in addrinfos: |
try: |
server = socket.socket(addrinfo[0], socket.SOCK_STREAM) |
if reuse: |
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
server.setblocking(0) |
server.bind(addrinfo[4]) |
self.servers[server.fileno()] = server |
if bind: |
self.interfaces.append(server.getsockname()[0]) |
server.listen(64) |
self.poll.register(server, POLLIN) |
except socket.error, e: |
for server in self.servers.values(): |
try: |
server.close() |
except: |
pass |
if self.ipv6_enable and ipv6_socket_style == 0 and self.servers: |
raise socket.error('blocked port (may require ipv6_binds_v4 to be set)') |
raise socket.error(str(e)) |
if not self.servers: |
raise socket.error('unable to open server port') |
if upnp: |
if not UPnP_open_port(port): |
for server in self.servers.values(): |
try: |
server.close() |
except: |
pass |
self.servers = None |
self.interfaces = None |
raise socket.error(UPnP_ERROR) |
self.port_forwarded = port |
self.port = port |
|
def find_and_bind(self, minport, maxport, bind = '', reuse = False, |
ipv6_socket_style = 1, upnp = 0, randomizer = False): |
e = 'maxport less than minport - no ports to check' |
if maxport-minport < 50 or not randomizer: |
portrange = range(minport, maxport+1) |
if randomizer: |
shuffle(portrange) |
portrange = portrange[:20] # check a maximum of 20 ports |
else: |
portrange = [] |
while len(portrange) < 20: |
listen_port = randrange(minport, maxport+1) |
if not listen_port in portrange: |
portrange.append(listen_port) |
for listen_port in portrange: |
try: |
self.bind(listen_port, bind, |
ipv6_socket_style = ipv6_socket_style, upnp = upnp) |
return listen_port |
except socket.error, e: |
pass |
raise socket.error(str(e)) |
|
|
def set_handler(self, handler): |
self.handler = handler |
|
|
def start_connection_raw(self, dns, socktype = socket.AF_INET, handler = None): |
if handler is None: |
handler = self.handler |
sock = socket.socket(socktype, socket.SOCK_STREAM) |
sock.setblocking(0) |
try: |
sock.connect_ex(dns) |
except socket.error: |
raise |
except Exception, e: |
raise socket.error(str(e)) |
self.poll.register(sock, POLLIN) |
s = SingleSocket(self, sock, handler, dns[0]) |
self.single_sockets[sock.fileno()] = s |
return s |
|
|
def start_connection(self, dns, handler = None, randomize = False): |
if handler is None: |
handler = self.handler |
if sys.version_info < (2,2): |
s = self.start_connection_raw(dns,socket.AF_INET,handler) |
else: |
if self.ipv6_enable: |
socktype = socket.AF_UNSPEC |
else: |
socktype = socket.AF_INET |
try: |
addrinfos = socket.getaddrinfo(dns[0], int(dns[1]), |
socktype, socket.SOCK_STREAM) |
except socket.error, e: |
raise |
except Exception, e: |
raise socket.error(str(e)) |
if randomize: |
shuffle(addrinfos) |
for addrinfo in addrinfos: |
try: |
s = self.start_connection_raw(addrinfo[4],addrinfo[0],handler) |
break |
except: |
pass |
else: |
raise socket.error('unable to connect') |
return s |
|
|
def _sleep(self): |
sleep(1) |
|
def handle_events(self, events): |
for sock, event in events: |
s = self.servers.get(sock) |
if s: |
if event & (POLLHUP | POLLERR) != 0: |
self.poll.unregister(s) |
s.close() |
del self.servers[sock] |
print "lost server socket" |
elif len(self.single_sockets) < self.max_connects: |
try: |
newsock, addr = s.accept() |
newsock.setblocking(0) |
nss = SingleSocket(self, newsock, self.handler) |
self.single_sockets[newsock.fileno()] = nss |
self.poll.register(newsock, POLLIN) |
self.handler.external_connection_made(nss) |
except socket.error: |
self._sleep() |
else: |
s = self.single_sockets.get(sock) |
if not s: |
continue |
s.connected = True |
if (event & (POLLHUP | POLLERR)): |
self._close_socket(s) |
continue |
if (event & POLLIN): |
try: |
s.last_hit = clock() |
data = s.socket.recv(100000) |
if not data: |
self._close_socket(s) |
else: |
s.handler.data_came_in(s, data) |
except socket.error, e: |
code, msg = e |
if code != EWOULDBLOCK: |
self._close_socket(s) |
continue |
if (event & POLLOUT) and s.socket and not s.is_flushed(): |
s.try_write() |
if s.is_flushed(): |
s.handler.connection_flushed(s) |
|
def close_dead(self): |
while self.dead_from_write: |
old = self.dead_from_write |
self.dead_from_write = [] |
for s in old: |
if s.socket: |
self._close_socket(s) |
|
def _close_socket(self, s): |
s.close() |
s.handler.connection_lost(s) |
|
def do_poll(self, t): |
r = self.poll.poll(t*timemult) |
if r is None: |
connects = len(self.single_sockets) |
to_close = int(connects*0.05)+1 # close 5% of sockets |
self.max_connects = connects-to_close |
closelist = self.single_sockets.values() |
shuffle(closelist) |
closelist = closelist[:to_close] |
for sock in closelist: |
self._close_socket(sock) |
return [] |
return r |
|
def get_stats(self): |
return { 'interfaces': self.interfaces, |
'port': self.port, |
'upnp': self.port_forwarded is not None } |
|
|
def shutdown(self): |
for ss in self.single_sockets.values(): |
try: |
ss.close() |
except: |
pass |
for server in self.servers.values(): |
try: |
server.close() |
except: |
pass |
if self.port_forwarded is not None: |
UPnP_close_port(self.port_forwarded) |
|