Blame | Last modification | View Log | Download
# Written by Bram Cohen# see LICENSE.txt for license informationimport socketfrom errno import EWOULDBLOCK, ECONNREFUSED, EHOSTUNREACHtry:from select import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUPtimemult = 1000except ImportError:from selectpoll import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUPtimemult = 1from time import sleepfrom clock import clockimport sysfrom random import shuffle, randrangefrom natpunch import UPnP_open_port, UPnP_close_port# from BT1.StreamCheck import StreamCheck# import inspecttry:Trueexcept:True = 1False = 0all = POLLIN | POLLOUTUPnP_ERROR = "unable to forward port via UPnP"class SingleSocket:def __init__(self, socket_handler, sock, handler, ip = None):self.socket_handler = socket_handlerself.socket = sockself.handler = handlerself.buffer = []self.last_hit = clock()self.fileno = sock.fileno()self.connected = Falseself.skipped = 0# self.check = StreamCheck()try:self.ip = self.socket.getpeername()[0]except:if ip is None:self.ip = 'unknown'else:self.ip = ipdef get_ip(self, real=False):if real:try:self.ip = self.socket.getpeername()[0]except:passreturn self.ipdef close(self):'''for x in xrange(5,0,-1):try:f = inspect.currentframe(x).f_codeprint (f.co_filename,f.co_firstlineno,f.co_name)del fexcept:passprint '''''assert self.socketself.connected = Falsesock = self.socketself.socket = Noneself.buffer = []del self.socket_handler.single_sockets[self.fileno]self.socket_handler.poll.unregister(sock)sock.close()def shutdown(self, val):self.socket.shutdown(val)def is_flushed(self):return not self.bufferdef write(self, s):# self.check.write(s)assert self.socket is not Noneself.buffer.append(s)if len(self.buffer) == 1:self.try_write()def try_write(self):if self.connected:dead = Falsetry:while self.buffer:buf = self.buffer[0]amount = self.socket.send(buf)if amount == 0:self.skipped += 1breakself.skipped = 0if amount != len(buf):self.buffer[0] = buf[amount:]breakdel self.buffer[0]except socket.error, e:try:dead = e[0] != EWOULDBLOCKexcept:dead = Trueself.skipped += 1if self.skipped >= 3:dead = Trueif dead:self.socket_handler.dead_from_write.append(self)returnif self.buffer:self.socket_handler.poll.register(self.socket, all)else:self.socket_handler.poll.register(self.socket, POLLIN)def set_handler(self, handler):self.handler = handlerclass SocketHandler:def __init__(self, timeout, ipv6_enable, readsize = 100000):self.timeout = timeoutself.ipv6_enable = ipv6_enableself.readsize = readsizeself.poll = poll()# {socket: SingleSocket}self.single_sockets = {}self.dead_from_write = []self.max_connects = 1000self.port_forwarded = Noneself.servers = {}def scan_for_timeouts(self):t = clock() - self.timeouttokill = []for s in self.single_sockets.values():if s.last_hit < t:tokill.append(s)for k in tokill:if k.socket is not None:self._close_socket(k)def bind(self, port, bind = '', reuse = False, ipv6_socket_style = 1, upnp = 0):port = int(port)addrinfos = []self.servers = {}self.interfaces = []# if bind != "" thread it as a comma seperated list and bind to all# addresses (can be ips or hostnames) else bind to default ipv6 and# ipv4 addressif bind:if self.ipv6_enable:socktype = socket.AF_UNSPECelse:socktype = socket.AF_INETbind = bind.split(',')for addr in bind:if sys.version_info < (2,2):addrinfos.append((socket.AF_INET, None, None, None, (addr, port)))else:addrinfos.extend(socket.getaddrinfo(addr, port,socktype, socket.SOCK_STREAM))else:if self.ipv6_enable:addrinfos.append([socket.AF_INET6, None, None, None, ('', port)])if not addrinfos or ipv6_socket_style != 0:addrinfos.append([socket.AF_INET, None, None, None, ('', port)])for addrinfo in addrinfos:try:server = socket.socket(addrinfo[0], socket.SOCK_STREAM)if reuse:server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)server.setblocking(0)server.bind(addrinfo[4])self.servers[server.fileno()] = serverif bind:self.interfaces.append(server.getsockname()[0])server.listen(64)self.poll.register(server, POLLIN)except socket.error, e:for server in self.servers.values():try:server.close()except:passif self.ipv6_enable and ipv6_socket_style == 0 and self.servers:raise socket.error('blocked port (may require ipv6_binds_v4 to be set)')raise socket.error(str(e))if not self.servers:raise socket.error('unable to open server port')if upnp:if not UPnP_open_port(port):for server in self.servers.values():try:server.close()except:passself.servers = Noneself.interfaces = Noneraise socket.error(UPnP_ERROR)self.port_forwarded = portself.port = portdef find_and_bind(self, minport, maxport, bind = '', reuse = False,ipv6_socket_style = 1, upnp = 0, randomizer = False):e = 'maxport less than minport - no ports to check'if maxport-minport < 50 or not randomizer:portrange = range(minport, maxport+1)if randomizer:shuffle(portrange)portrange = portrange[:20] # check a maximum of 20 portselse:portrange = []while len(portrange) < 20:listen_port = randrange(minport, maxport+1)if not listen_port in portrange:portrange.append(listen_port)for listen_port in portrange:try:self.bind(listen_port, bind,ipv6_socket_style = ipv6_socket_style, upnp = upnp)return listen_portexcept socket.error, e:passraise socket.error(str(e))def set_handler(self, handler):self.handler = handlerdef start_connection_raw(self, dns, socktype = socket.AF_INET, handler = None):if handler is None:handler = self.handlersock = socket.socket(socktype, socket.SOCK_STREAM)sock.setblocking(0)try:sock.connect_ex(dns)except socket.error:raiseexcept Exception, e:raise socket.error(str(e))self.poll.register(sock, POLLIN)s = SingleSocket(self, sock, handler, dns[0])self.single_sockets[sock.fileno()] = sreturn sdef start_connection(self, dns, handler = None, randomize = False):if handler is None:handler = self.handlerif sys.version_info < (2,2):s = self.start_connection_raw(dns,socket.AF_INET,handler)else:if self.ipv6_enable:socktype = socket.AF_UNSPECelse:socktype = socket.AF_INETtry:addrinfos = socket.getaddrinfo(dns[0], int(dns[1]),socktype, socket.SOCK_STREAM)except socket.error, e:raiseexcept Exception, e:raise socket.error(str(e))if randomize:shuffle(addrinfos)for addrinfo in addrinfos:try:s = self.start_connection_raw(addrinfo[4],addrinfo[0],handler)breakexcept:passelse:raise socket.error('unable to connect')return sdef _sleep(self):sleep(1)def handle_events(self, events):for sock, event in events:s = self.servers.get(sock)if s:if event & (POLLHUP | POLLERR) != 0:self.poll.unregister(s)s.close()del self.servers[sock]print "lost server socket"elif len(self.single_sockets) < self.max_connects:try:newsock, addr = s.accept()newsock.setblocking(0)nss = SingleSocket(self, newsock, self.handler)self.single_sockets[newsock.fileno()] = nssself.poll.register(newsock, POLLIN)self.handler.external_connection_made(nss)except socket.error:self._sleep()else:s = self.single_sockets.get(sock)if not s:continues.connected = Trueif (event & (POLLHUP | POLLERR)):self._close_socket(s)continueif (event & POLLIN):try:s.last_hit = clock()data = s.socket.recv(100000)if not data:self._close_socket(s)else:s.handler.data_came_in(s, data)except socket.error, e:code, msg = eif code != EWOULDBLOCK:self._close_socket(s)continueif (event & POLLOUT) and s.socket and not s.is_flushed():s.try_write()if s.is_flushed():s.handler.connection_flushed(s)def close_dead(self):while self.dead_from_write:old = self.dead_from_writeself.dead_from_write = []for s in old:if s.socket:self._close_socket(s)def _close_socket(self, s):s.close()s.handler.connection_lost(s)def do_poll(self, t):r = self.poll.poll(t*timemult)if r is None:connects = len(self.single_sockets)to_close = int(connects*0.05)+1 # close 5% of socketsself.max_connects = connects-to_closecloselist = self.single_sockets.values()shuffle(closelist)closelist = closelist[:to_close]for sock in closelist:self._close_socket(sock)return []return rdef get_stats(self):return { 'interfaces': self.interfaces,'port': self.port,'upnp': self.port_forwarded is not None }def shutdown(self):for ss in self.single_sockets.values():try:ss.close()except:passfor server in self.servers.values():try:server.close()except:passif self.port_forwarded is not None:UPnP_close_port(self.port_forwarded)