Subversion Repositories svnkaklik

Rev

Details | Last modification | View Log

Rev Author Line No. Line
36 kaklik 1
# Written by Bram Cohen
2
# see LICENSE.txt for license information
3
 
4
from cStringIO import StringIO
5
from binascii import b2a_hex
6
from socket import error as socketerror
7
from urllib import quote
8
from traceback import print_exc
9
try:
10
    True
11
except:
12
    True = 1
13
    False = 0
14
 
15
MAX_INCOMPLETE = 8
16
 
17
protocol_name = 'BitTorrent protocol'
18
option_pattern = chr(0)*8
19
 
20
def toint(s):
21
    return long(b2a_hex(s), 16)
22
 
23
def tobinary(i):
24
    return (chr(i >> 24) + chr((i >> 16) & 0xFF) + 
25
        chr((i >> 8) & 0xFF) + chr(i & 0xFF))
26
 
27
hexchars = '0123456789ABCDEF'
28
hexmap = []
29
for i in xrange(256):
30
    hexmap.append(hexchars[(i&0xF0)/16]+hexchars[i&0x0F])
31
 
32
def tohex(s):
33
    r = []
34
    for c in s:
35
        r.append(hexmap[ord(c)])
36
    return ''.join(r)
37
 
38
def make_readable(s):
39
    if not s:
40
        return ''
41
    if quote(s).find('%') >= 0:
42
        return tohex(s)
43
    return '"'+s+'"'
44
 
45
 
46
class IncompleteCounter:
47
    def __init__(self):
48
        self.c = 0
49
    def increment(self):
50
        self.c += 1
51
    def decrement(self):
52
        self.c -= 1
53
    def toomany(self):
54
        return self.c >= MAX_INCOMPLETE
55
 
56
incompletecounter = IncompleteCounter()
57
 
58
 
59
# header, reserved, download id, my id, [length, message]
60
 
61
class Connection:
62
    def __init__(self, Encoder, connection, id, ext_handshake=False):
63
        self.Encoder = Encoder
64
        self.connection = connection
65
        self.connecter = Encoder.connecter
66
        self.id = id
67
        self.readable_id = make_readable(id)
68
        self.locally_initiated = (id != None)
69
        self.complete = False
70
        self.keepalive = lambda: None
71
        self.closed = False
72
        self.buffer = StringIO()
73
        if self.locally_initiated:
74
            incompletecounter.increment()
75
        if self.locally_initiated or ext_handshake:
76
            self.connection.write(chr(len(protocol_name)) + protocol_name + 
77
                option_pattern + self.Encoder.download_id)
78
        if ext_handshake:
79
            self.Encoder.connecter.external_connection_made += 1
80
            self.connection.write(self.Encoder.my_id)
81
            self.next_len, self.next_func = 20, self.read_peer_id
82
        else:
83
            self.next_len, self.next_func = 1, self.read_header_len
84
        self.Encoder.raw_server.add_task(self._auto_close, 15)
85
 
86
    def get_ip(self, real=False):
87
        return self.connection.get_ip(real)
88
 
89
    def get_id(self):
90
        return self.id
91
 
92
    def get_readable_id(self):
93
        return self.readable_id
94
 
95
    def is_locally_initiated(self):
96
        return self.locally_initiated
97
 
98
    def is_flushed(self):
99
        return self.connection.is_flushed()
100
 
101
    def read_header_len(self, s):
102
        if ord(s) != len(protocol_name):
103
            return None
104
        return len(protocol_name), self.read_header
105
 
106
    def read_header(self, s):
107
        if s != protocol_name:
108
            return None
109
        return 8, self.read_reserved
110
 
111
    def read_reserved(self, s):
112
        return 20, self.read_download_id
113
 
114
    def read_download_id(self, s):
115
        if s != self.Encoder.download_id:
116
            return None
117
        if not self.locally_initiated:
118
            self.Encoder.connecter.external_connection_made += 1
119
            self.connection.write(chr(len(protocol_name)) + protocol_name + 
120
                option_pattern + self.Encoder.download_id + self.Encoder.my_id)
121
        return 20, self.read_peer_id
122
 
123
    def read_peer_id(self, s):
124
        if not self.id:
125
            self.id = s
126
            self.readable_id = make_readable(s)
127
        else:
128
            if s != self.id:
129
                return None
130
        self.complete = self.Encoder.got_id(self)
131
        if not self.complete:
132
            return None
133
        if self.locally_initiated:
134
            self.connection.write(self.Encoder.my_id)
135
            incompletecounter.decrement()
136
        c = self.Encoder.connecter.connection_made(self)
137
        self.keepalive = c.send_keepalive
138
        return 4, self.read_len
139
 
140
    def read_len(self, s):
141
        l = toint(s)
142
        if l > self.Encoder.max_len:
143
            return None
144
        return l, self.read_message
145
 
146
    def read_message(self, s):
147
        if s != '':
148
            self.connecter.got_message(self, s)
149
        return 4, self.read_len
150
 
151
    def read_dead(self, s):
152
        return None
153
 
154
    def _auto_close(self):
155
        if not self.complete:
156
            self.close()
157
 
158
    def close(self):
159
        if not self.closed:
160
            self.connection.close()
161
            self.sever()
162
 
163
    def sever(self):
164
        self.closed = True
165
        del self.Encoder.connections[self.connection]
166
        if self.complete:
167
            self.connecter.connection_lost(self)
168
        elif self.locally_initiated:
169
            incompletecounter.decrement()
170
 
171
    def send_message_raw(self, message):
172
        if not self.closed:
173
            self.connection.write(message)
174
 
175
    def data_came_in(self, connection, s):
176
        self.Encoder.measurefunc(len(s))
177
        while True:
178
            if self.closed:
179
                return
180
            i = self.next_len - self.buffer.tell()
181
            if i > len(s):
182
                self.buffer.write(s)
183
                return
184
            self.buffer.write(s[:i])
185
            s = s[i:]
186
            m = self.buffer.getvalue()
187
            self.buffer.reset()
188
            self.buffer.truncate()
189
            try:
190
                x = self.next_func(m)
191
            except:
192
                self.next_len, self.next_func = 1, self.read_dead
193
                raise
194
            if x is None:
195
                self.close()
196
                return
197
            self.next_len, self.next_func = x
198
 
199
    def connection_flushed(self, connection):
200
        if self.complete:
201
            self.connecter.connection_flushed(self)
202
 
203
    def connection_lost(self, connection):
204
        if self.Encoder.connections.has_key(connection):
205
            self.sever()
206
 
207
 
208
class Encoder:
209
    def __init__(self, connecter, raw_server, my_id, max_len,
210
            schedulefunc, keepalive_delay, download_id, 
211
            measurefunc, config):
212
        self.raw_server = raw_server
213
        self.connecter = connecter
214
        self.my_id = my_id
215
        self.max_len = max_len
216
        self.schedulefunc = schedulefunc
217
        self.keepalive_delay = keepalive_delay
218
        self.download_id = download_id
219
        self.measurefunc = measurefunc
220
        self.config = config
221
        self.connections = {}
222
        self.banned = {}
223
        self.to_connect = []
224
        self.paused = False
225
        if self.config['max_connections'] == 0:
226
            self.max_connections = 2 ** 30
227
        else:
228
            self.max_connections = self.config['max_connections']
229
        schedulefunc(self.send_keepalives, keepalive_delay)
230
 
231
    def send_keepalives(self):
232
        self.schedulefunc(self.send_keepalives, self.keepalive_delay)
233
        if self.paused:
234
            return
235
        for c in self.connections.values():
236
            c.keepalive()
237
 
238
    def start_connections(self, list):
239
        if not self.to_connect:
240
            self.raw_server.add_task(self._start_connection_from_queue)
241
        self.to_connect = list
242
 
243
    def _start_connection_from_queue(self):
244
        if self.connecter.external_connection_made:
245
            max_initiate = self.config['max_initiate']
246
        else:
247
            max_initiate = int(self.config['max_initiate']*1.5)
248
        cons = len(self.connections)
249
        if cons >= self.max_connections or cons >= max_initiate:
250
            delay = 60
251
        elif self.paused or incompletecounter.toomany():
252
            delay = 1
253
        else:
254
            delay = 0
255
            dns, id = self.to_connect.pop(0)
256
            self.start_connection(dns, id)
257
        if self.to_connect:
258
            self.raw_server.add_task(self._start_connection_from_queue, delay)
259
 
260
    def start_connection(self, dns, id):
261
        if ( self.paused
262
             or len(self.connections) >= self.max_connections
263
             or id == self.my_id
264
             or self.banned.has_key(dns[0]) ):
265
            return True
266
        for v in self.connections.values():
267
            if v is None:
268
                continue
269
            if id and v.id == id:
270
                return True
271
            ip = v.get_ip(True)
272
            if self.config['security'] and ip != 'unknown' and ip == dns[0]:
273
                return True
274
        try:
275
            c = self.raw_server.start_connection(dns)
276
            con = Connection(self, c, id)
277
            self.connections[c] = con
278
            c.set_handler(con)
279
        except socketerror:
280
            return False
281
        return True
282
 
283
    def _start_connection(self, dns, id):
284
        def foo(self=self, dns=dns, id=id):
285
            self.start_connection(dns, id)
286
 
287
        self.schedulefunc(foo, 0)
288
 
289
    def got_id(self, connection):
290
        if connection.id == self.my_id:
291
            self.connecter.external_connection_made -= 1
292
            return False
293
        ip = connection.get_ip(True)
294
        if self.config['security'] and self.banned.has_key(ip):
295
            return False
296
        for v in self.connections.values():
297
            if connection is not v:
298
                if connection.id == v.id:
299
                    return False
300
                if self.config['security'] and ip != 'unknown' and ip == v.get_ip(True):
301
                    v.close()
302
        return True
303
 
304
    def external_connection_made(self, connection):
305
        if self.paused or len(self.connections) >= self.max_connections:
306
            connection.close()
307
            return False
308
        con = Connection(self, connection, None)
309
        self.connections[connection] = con
310
        connection.set_handler(con)
311
        return True
312
 
313
    def externally_handshaked_connection_made(self, connection, options, already_read):
314
        if self.paused or len(self.connections) >= self.max_connections:
315
            connection.close()
316
            return False
317
        con = Connection(self, connection, None, True)
318
        self.connections[connection] = con
319
        connection.set_handler(con)
320
        if already_read:
321
            con.data_came_in(con, already_read)
322
        return True
323
 
324
    def close_all(self):
325
        for c in self.connections.values():
326
            c.close()
327
        self.connections = {}
328
 
329
    def ban(self, ip):
330
        self.banned[ip] = 1
331
 
332
    def pause(self, flag):
333
        self.paused = flag