Subversion Repositories svnkaklik

Rev

Go to most recent revision | Details | Last modification | View Log

Rev Author Line No. Line
36 kaklik 1
# Written by Bram Cohen
2
# see LICENSE.txt for license information
3
 
4
import socket
5
from errno import EWOULDBLOCK, ECONNREFUSED, EHOSTUNREACH
6
try:
7
    from select import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP
8
    timemult = 1000
9
except ImportError:
10
    from selectpoll import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP
11
    timemult = 1
12
from time import sleep
13
from clock import clock
14
import sys
15
from random import shuffle, randrange
16
from natpunch import UPnP_open_port, UPnP_close_port
17
# from BT1.StreamCheck import StreamCheck
18
# import inspect
19
try:
20
    True
21
except:
22
    True = 1
23
    False = 0
24
 
25
all = POLLIN | POLLOUT
26
 
27
UPnP_ERROR = "unable to forward port via UPnP"
28
 
29
class SingleSocket:
30
    def __init__(self, socket_handler, sock, handler, ip = None):
31
        self.socket_handler = socket_handler
32
        self.socket = sock
33
        self.handler = handler
34
        self.buffer = []
35
        self.last_hit = clock()
36
        self.fileno = sock.fileno()
37
        self.connected = False
38
        self.skipped = 0
39
#        self.check = StreamCheck()
40
        try:
41
            self.ip = self.socket.getpeername()[0]
42
        except:
43
            if ip is None:
44
                self.ip = 'unknown'
45
            else:
46
                self.ip = ip
47
 
48
    def get_ip(self, real=False):
49
        if real:
50
            try:
51
                self.ip = self.socket.getpeername()[0]
52
            except:
53
                pass
54
        return self.ip
55
 
56
    def close(self):
57
        '''
58
        for x in xrange(5,0,-1):
59
            try:
60
                f = inspect.currentframe(x).f_code
61
                print (f.co_filename,f.co_firstlineno,f.co_name)
62
                del f
63
            except:
64
                pass
65
        print ''
66
        '''
67
        assert self.socket
68
        self.connected = False
69
        sock = self.socket
70
        self.socket = None
71
        self.buffer = []
72
        del self.socket_handler.single_sockets[self.fileno]
73
        self.socket_handler.poll.unregister(sock)
74
        sock.close()
75
 
76
    def shutdown(self, val):
77
        self.socket.shutdown(val)
78
 
79
    def is_flushed(self):
80
        return not self.buffer
81
 
82
    def write(self, s):
83
#        self.check.write(s)
84
        assert self.socket is not None
85
        self.buffer.append(s)
86
        if len(self.buffer) == 1:
87
            self.try_write()
88
 
89
    def try_write(self):
90
        if self.connected:
91
            dead = False
92
            try:
93
                while self.buffer:
94
                    buf = self.buffer[0]
95
                    amount = self.socket.send(buf)
96
                    if amount == 0:
97
                        self.skipped += 1
98
                        break
99
                    self.skipped = 0
100
                    if amount != len(buf):
101
                        self.buffer[0] = buf[amount:]
102
                        break
103
                    del self.buffer[0]
104
            except socket.error, e:
105
                try:
106
                    dead = e[0] != EWOULDBLOCK
107
                except:
108
                    dead = True
109
                self.skipped += 1
110
            if self.skipped >= 3:
111
                dead = True
112
            if dead:
113
                self.socket_handler.dead_from_write.append(self)
114
                return
115
        if self.buffer:
116
            self.socket_handler.poll.register(self.socket, all)
117
        else:
118
            self.socket_handler.poll.register(self.socket, POLLIN)
119
 
120
    def set_handler(self, handler):
121
        self.handler = handler
122
 
123
class SocketHandler:
124
    def __init__(self, timeout, ipv6_enable, readsize = 100000):
125
        self.timeout = timeout
126
        self.ipv6_enable = ipv6_enable
127
        self.readsize = readsize
128
        self.poll = poll()
129
        # {socket: SingleSocket}
130
        self.single_sockets = {}
131
        self.dead_from_write = []
132
        self.max_connects = 1000
133
        self.port_forwarded = None
134
        self.servers = {}
135
 
136
    def scan_for_timeouts(self):
137
        t = clock() - self.timeout
138
        tokill = []
139
        for s in self.single_sockets.values():
140
            if s.last_hit < t:
141
                tokill.append(s)
142
        for k in tokill:
143
            if k.socket is not None:
144
                self._close_socket(k)
145
 
146
    def bind(self, port, bind = '', reuse = False, ipv6_socket_style = 1, upnp = 0):
147
        port = int(port)
148
        addrinfos = []
149
        self.servers = {}
150
        self.interfaces = []
151
        # if bind != "" thread it as a comma seperated list and bind to all
152
        # addresses (can be ips or hostnames) else bind to default ipv6 and
153
        # ipv4 address
154
        if bind:
155
            if self.ipv6_enable:
156
                socktype = socket.AF_UNSPEC
157
            else:
158
                socktype = socket.AF_INET
159
            bind = bind.split(',')
160
            for addr in bind:
161
                if sys.version_info < (2,2):
162
                    addrinfos.append((socket.AF_INET, None, None, None, (addr, port)))
163
                else:
164
                    addrinfos.extend(socket.getaddrinfo(addr, port,
165
                                               socktype, socket.SOCK_STREAM))
166
        else:
167
            if self.ipv6_enable:
168
                addrinfos.append([socket.AF_INET6, None, None, None, ('', port)])
169
            if not addrinfos or ipv6_socket_style != 0:
170
                addrinfos.append([socket.AF_INET, None, None, None, ('', port)])
171
        for addrinfo in addrinfos:
172
            try:
173
                server = socket.socket(addrinfo[0], socket.SOCK_STREAM)
174
                if reuse:
175
                    server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
176
                server.setblocking(0)
177
                server.bind(addrinfo[4])
178
                self.servers[server.fileno()] = server
179
                if bind:
180
                    self.interfaces.append(server.getsockname()[0])
181
                server.listen(64)
182
                self.poll.register(server, POLLIN)
183
            except socket.error, e:
184
                for server in self.servers.values():
185
                    try:
186
                        server.close()
187
                    except:
188
                        pass
189
                if self.ipv6_enable and ipv6_socket_style == 0 and self.servers:
190
                    raise socket.error('blocked port (may require ipv6_binds_v4 to be set)')
191
                raise socket.error(str(e))
192
        if not self.servers:
193
            raise socket.error('unable to open server port')
194
        if upnp:
195
            if not UPnP_open_port(port):
196
                for server in self.servers.values():
197
                    try:
198
                        server.close()
199
                    except:
200
                        pass
201
                    self.servers = None
202
                    self.interfaces = None
203
                raise socket.error(UPnP_ERROR)
204
            self.port_forwarded = port
205
        self.port = port
206
 
207
    def find_and_bind(self, minport, maxport, bind = '', reuse = False,
208
                      ipv6_socket_style = 1, upnp = 0, randomizer = False):
209
        e = 'maxport less than minport - no ports to check'
210
        if maxport-minport < 50 or not randomizer:
211
            portrange = range(minport, maxport+1)
212
            if randomizer:
213
                shuffle(portrange)
214
                portrange = portrange[:20]  # check a maximum of 20 ports
215
        else:
216
            portrange = []
217
            while len(portrange) < 20:
218
                listen_port = randrange(minport, maxport+1)
219
                if not listen_port in portrange:
220
                    portrange.append(listen_port)
221
        for listen_port in portrange:
222
            try:
223
                self.bind(listen_port, bind,
224
                               ipv6_socket_style = ipv6_socket_style, upnp = upnp)
225
                return listen_port
226
            except socket.error, e:
227
                pass
228
        raise socket.error(str(e))
229
 
230
 
231
    def set_handler(self, handler):
232
        self.handler = handler
233
 
234
 
235
    def start_connection_raw(self, dns, socktype = socket.AF_INET, handler = None):
236
        if handler is None:
237
            handler = self.handler
238
        sock = socket.socket(socktype, socket.SOCK_STREAM)
239
        sock.setblocking(0)
240
        try:
241
            sock.connect_ex(dns)
242
        except socket.error:
243
            raise
244
        except Exception, e:
245
            raise socket.error(str(e))
246
        self.poll.register(sock, POLLIN)
247
        s = SingleSocket(self, sock, handler, dns[0])
248
        self.single_sockets[sock.fileno()] = s
249
        return s
250
 
251
 
252
    def start_connection(self, dns, handler = None, randomize = False):
253
        if handler is None:
254
            handler = self.handler
255
        if sys.version_info < (2,2):
256
            s = self.start_connection_raw(dns,socket.AF_INET,handler)
257
        else:
258
            if self.ipv6_enable:
259
                socktype = socket.AF_UNSPEC
260
            else:
261
                socktype = socket.AF_INET
262
            try:
263
                addrinfos = socket.getaddrinfo(dns[0], int(dns[1]),
264
                                               socktype, socket.SOCK_STREAM)
265
            except socket.error, e:
266
                raise
267
            except Exception, e:
268
                raise socket.error(str(e))
269
            if randomize:
270
                shuffle(addrinfos)
271
            for addrinfo in addrinfos:
272
                try:
273
                    s = self.start_connection_raw(addrinfo[4],addrinfo[0],handler)
274
                    break
275
                except:
276
                    pass
277
            else:
278
                raise socket.error('unable to connect')
279
        return s
280
 
281
 
282
    def _sleep(self):
283
        sleep(1)
284
 
285
    def handle_events(self, events):
286
        for sock, event in events:
287
            s = self.servers.get(sock)
288
            if s:
289
                if event & (POLLHUP | POLLERR) != 0:
290
                    self.poll.unregister(s)
291
                    s.close()
292
                    del self.servers[sock]
293
                    print "lost server socket"
294
                elif len(self.single_sockets) < self.max_connects:
295
                    try:
296
                        newsock, addr = s.accept()
297
                        newsock.setblocking(0)
298
                        nss = SingleSocket(self, newsock, self.handler)
299
                        self.single_sockets[newsock.fileno()] = nss
300
                        self.poll.register(newsock, POLLIN)
301
                        self.handler.external_connection_made(nss)
302
                    except socket.error:
303
                        self._sleep()
304
            else:
305
                s = self.single_sockets.get(sock)
306
                if not s:
307
                    continue
308
                s.connected = True
309
                if (event & (POLLHUP | POLLERR)):
310
                    self._close_socket(s)
311
                    continue
312
                if (event & POLLIN):
313
                    try:
314
                        s.last_hit = clock()
315
                        data = s.socket.recv(100000)
316
                        if not data:
317
                            self._close_socket(s)
318
                        else:
319
                            s.handler.data_came_in(s, data)
320
                    except socket.error, e:
321
                        code, msg = e
322
                        if code != EWOULDBLOCK:
323
                            self._close_socket(s)
324
                            continue
325
                if (event & POLLOUT) and s.socket and not s.is_flushed():
326
                    s.try_write()
327
                    if s.is_flushed():
328
                        s.handler.connection_flushed(s)
329
 
330
    def close_dead(self):
331
        while self.dead_from_write:
332
            old = self.dead_from_write
333
            self.dead_from_write = []
334
            for s in old:
335
                if s.socket:
336
                    self._close_socket(s)
337
 
338
    def _close_socket(self, s):
339
        s.close()
340
        s.handler.connection_lost(s)
341
 
342
    def do_poll(self, t):
343
        r = self.poll.poll(t*timemult)
344
        if r is None:
345
            connects = len(self.single_sockets)
346
            to_close = int(connects*0.05)+1 # close 5% of sockets
347
            self.max_connects = connects-to_close
348
            closelist = self.single_sockets.values()
349
            shuffle(closelist)
350
            closelist = closelist[:to_close]
351
            for sock in closelist:
352
                self._close_socket(sock)
353
            return []
354
        return r     
355
 
356
    def get_stats(self):
357
        return { 'interfaces': self.interfaces,
358
                 'port': self.port,
359
                 'upnp': self.port_forwarded is not None }
360
 
361
 
362
    def shutdown(self):
363
        for ss in self.single_sockets.values():
364
            try:
365
                ss.close()
366
            except:
367
                pass
368
        for server in self.servers.values():
369
            try:
370
                server.close()
371
            except:
372
                pass
373
        if self.port_forwarded is not None:
374
            UPnP_close_port(self.port_forwarded)
375