Blame | Last modification | View Log | Download
# Written by Bram Cohen
# see LICENSE.txt for license information
import socket
from errno import EWOULDBLOCK, ECONNREFUSED, EHOSTUNREACH
try:
from select import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP
timemult = 1000
except ImportError:
from selectpoll import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP
timemult = 1
from time import sleep
from clock import clock
import sys
from random import shuffle, randrange
from natpunch import UPnP_open_port, UPnP_close_port
# from BT1.StreamCheck import StreamCheck
# import inspect
try:
True
except:
True = 1
False = 0
all = POLLIN | POLLOUT
UPnP_ERROR = "unable to forward port via UPnP"
class SingleSocket:
def __init__(self, socket_handler, sock, handler, ip = None):
self.socket_handler = socket_handler
self.socket = sock
self.handler = handler
self.buffer = []
self.last_hit = clock()
self.fileno = sock.fileno()
self.connected = False
self.skipped = 0
# self.check = StreamCheck()
try:
self.ip = self.socket.getpeername()[0]
except:
if ip is None:
self.ip = 'unknown'
else:
self.ip = ip
def get_ip(self, real=False):
if real:
try:
self.ip = self.socket.getpeername()[0]
except:
pass
return self.ip
def close(self):
'''
for x in xrange(5,0,-1):
try:
f = inspect.currentframe(x).f_code
print (f.co_filename,f.co_firstlineno,f.co_name)
del f
except:
pass
print ''
'''
assert self.socket
self.connected = False
sock = self.socket
self.socket = None
self.buffer = []
del self.socket_handler.single_sockets[self.fileno]
self.socket_handler.poll.unregister(sock)
sock.close()
def shutdown(self, val):
self.socket.shutdown(val)
def is_flushed(self):
return not self.buffer
def write(self, s):
# self.check.write(s)
assert self.socket is not None
self.buffer.append(s)
if len(self.buffer) == 1:
self.try_write()
def try_write(self):
if self.connected:
dead = False
try:
while self.buffer:
buf = self.buffer[0]
amount = self.socket.send(buf)
if amount == 0:
self.skipped += 1
break
self.skipped = 0
if amount != len(buf):
self.buffer[0] = buf[amount:]
break
del self.buffer[0]
except socket.error, e:
try:
dead = e[0] != EWOULDBLOCK
except:
dead = True
self.skipped += 1
if self.skipped >= 3:
dead = True
if dead:
self.socket_handler.dead_from_write.append(self)
return
if self.buffer:
self.socket_handler.poll.register(self.socket, all)
else:
self.socket_handler.poll.register(self.socket, POLLIN)
def set_handler(self, handler):
self.handler = handler
class SocketHandler:
def __init__(self, timeout, ipv6_enable, readsize = 100000):
self.timeout = timeout
self.ipv6_enable = ipv6_enable
self.readsize = readsize
self.poll = poll()
# {socket: SingleSocket}
self.single_sockets = {}
self.dead_from_write = []
self.max_connects = 1000
self.port_forwarded = None
self.servers = {}
def scan_for_timeouts(self):
t = clock() - self.timeout
tokill = []
for s in self.single_sockets.values():
if s.last_hit < t:
tokill.append(s)
for k in tokill:
if k.socket is not None:
self._close_socket(k)
def bind(self, port, bind = '', reuse = False, ipv6_socket_style = 1, upnp = 0):
port = int(port)
addrinfos = []
self.servers = {}
self.interfaces = []
# if bind != "" thread it as a comma seperated list and bind to all
# addresses (can be ips or hostnames) else bind to default ipv6 and
# ipv4 address
if bind:
if self.ipv6_enable:
socktype = socket.AF_UNSPEC
else:
socktype = socket.AF_INET
bind = bind.split(',')
for addr in bind:
if sys.version_info < (2,2):
addrinfos.append((socket.AF_INET, None, None, None, (addr, port)))
else:
addrinfos.extend(socket.getaddrinfo(addr, port,
socktype, socket.SOCK_STREAM))
else:
if self.ipv6_enable:
addrinfos.append([socket.AF_INET6, None, None, None, ('', port)])
if not addrinfos or ipv6_socket_style != 0:
addrinfos.append([socket.AF_INET, None, None, None, ('', port)])
for addrinfo in addrinfos:
try:
server = socket.socket(addrinfo[0], socket.SOCK_STREAM)
if reuse:
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server.setblocking(0)
server.bind(addrinfo[4])
self.servers[server.fileno()] = server
if bind:
self.interfaces.append(server.getsockname()[0])
server.listen(64)
self.poll.register(server, POLLIN)
except socket.error, e:
for server in self.servers.values():
try:
server.close()
except:
pass
if self.ipv6_enable and ipv6_socket_style == 0 and self.servers:
raise socket.error('blocked port (may require ipv6_binds_v4 to be set)')
raise socket.error(str(e))
if not self.servers:
raise socket.error('unable to open server port')
if upnp:
if not UPnP_open_port(port):
for server in self.servers.values():
try:
server.close()
except:
pass
self.servers = None
self.interfaces = None
raise socket.error(UPnP_ERROR)
self.port_forwarded = port
self.port = port
def find_and_bind(self, minport, maxport, bind = '', reuse = False,
ipv6_socket_style = 1, upnp = 0, randomizer = False):
e = 'maxport less than minport - no ports to check'
if maxport-minport < 50 or not randomizer:
portrange = range(minport, maxport+1)
if randomizer:
shuffle(portrange)
portrange = portrange[:20] # check a maximum of 20 ports
else:
portrange = []
while len(portrange) < 20:
listen_port = randrange(minport, maxport+1)
if not listen_port in portrange:
portrange.append(listen_port)
for listen_port in portrange:
try:
self.bind(listen_port, bind,
ipv6_socket_style = ipv6_socket_style, upnp = upnp)
return listen_port
except socket.error, e:
pass
raise socket.error(str(e))
def set_handler(self, handler):
self.handler = handler
def start_connection_raw(self, dns, socktype = socket.AF_INET, handler = None):
if handler is None:
handler = self.handler
sock = socket.socket(socktype, socket.SOCK_STREAM)
sock.setblocking(0)
try:
sock.connect_ex(dns)
except socket.error:
raise
except Exception, e:
raise socket.error(str(e))
self.poll.register(sock, POLLIN)
s = SingleSocket(self, sock, handler, dns[0])
self.single_sockets[sock.fileno()] = s
return s
def start_connection(self, dns, handler = None, randomize = False):
if handler is None:
handler = self.handler
if sys.version_info < (2,2):
s = self.start_connection_raw(dns,socket.AF_INET,handler)
else:
if self.ipv6_enable:
socktype = socket.AF_UNSPEC
else:
socktype = socket.AF_INET
try:
addrinfos = socket.getaddrinfo(dns[0], int(dns[1]),
socktype, socket.SOCK_STREAM)
except socket.error, e:
raise
except Exception, e:
raise socket.error(str(e))
if randomize:
shuffle(addrinfos)
for addrinfo in addrinfos:
try:
s = self.start_connection_raw(addrinfo[4],addrinfo[0],handler)
break
except:
pass
else:
raise socket.error('unable to connect')
return s
def _sleep(self):
sleep(1)
def handle_events(self, events):
for sock, event in events:
s = self.servers.get(sock)
if s:
if event & (POLLHUP | POLLERR) != 0:
self.poll.unregister(s)
s.close()
del self.servers[sock]
print "lost server socket"
elif len(self.single_sockets) < self.max_connects:
try:
newsock, addr = s.accept()
newsock.setblocking(0)
nss = SingleSocket(self, newsock, self.handler)
self.single_sockets[newsock.fileno()] = nss
self.poll.register(newsock, POLLIN)
self.handler.external_connection_made(nss)
except socket.error:
self._sleep()
else:
s = self.single_sockets.get(sock)
if not s:
continue
s.connected = True
if (event & (POLLHUP | POLLERR)):
self._close_socket(s)
continue
if (event & POLLIN):
try:
s.last_hit = clock()
data = s.socket.recv(100000)
if not data:
self._close_socket(s)
else:
s.handler.data_came_in(s, data)
except socket.error, e:
code, msg = e
if code != EWOULDBLOCK:
self._close_socket(s)
continue
if (event & POLLOUT) and s.socket and not s.is_flushed():
s.try_write()
if s.is_flushed():
s.handler.connection_flushed(s)
def close_dead(self):
while self.dead_from_write:
old = self.dead_from_write
self.dead_from_write = []
for s in old:
if s.socket:
self._close_socket(s)
def _close_socket(self, s):
s.close()
s.handler.connection_lost(s)
def do_poll(self, t):
r = self.poll.poll(t*timemult)
if r is None:
connects = len(self.single_sockets)
to_close = int(connects*0.05)+1 # close 5% of sockets
self.max_connects = connects-to_close
closelist = self.single_sockets.values()
shuffle(closelist)
closelist = closelist[:to_close]
for sock in closelist:
self._close_socket(sock)
return []
return r
def get_stats(self):
return { 'interfaces': self.interfaces,
'port': self.port,
'upnp': self.port_forwarded is not None }
def shutdown(self):
for ss in self.single_sockets.values():
try:
ss.close()
except:
pass
for server in self.servers.values():
try:
server.close()
except:
pass
if self.port_forwarded is not None:
UPnP_close_port(self.port_forwarded)