UniqueSessionIdManager: Support for delaying releases of session ids, and picking...
[bitcoin:eloipool.git] / networkserver.py
1 # Eloipool - Python Bitcoin pool server
2 # Copyright (C) 2011-2012  Luke Dashjr <luke-jr+eloipool@utopios.org>
3 #
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as
6 # published by the Free Software Foundation, either version 3 of the
7 # License, or (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16
17 import asynchat
18 import logging
19 import os
20 import select
21 import socket
22 import threading
23 from time import time
24 import traceback
25 from util import ScheduleDict, WithNoop, tryErr
26
27 EPOLL_READ = select.EPOLLIN | select.EPOLLPRI | select.EPOLLERR | select.EPOLLHUP
28 EPOLL_WRITE = select.EPOLLOUT
29
30 class SocketHandler:
31         ac_in_buffer_size = 4096
32         ac_out_buffer_size = 4096
33         
34         def handle_close(self):
35                 self.wbuf = None
36                 self.close()
37         
38         def handle_error(self):
39                 self.logger.debug(traceback.format_exc())
40                 self.handle_close()
41         
42         def handle_read(self):
43                 try:
44                         data = self.recv (self.ac_in_buffer_size)
45                 except socket.error as why:
46                         self.handle_error()
47                         return
48                 
49                 if self.closeme:
50                         # All input is ignored from sockets we have "closed"
51                         return
52                 
53                 if isinstance(data, str) and self.use_encoding:
54                         data = bytes(str, self.encoding)
55                 self.ac_in_buffer = self.ac_in_buffer + data
56                 
57                 self.server.lastReadbuf = self.ac_in_buffer
58                 
59                 self.handle_readbuf()
60         
61         collect_incoming_data = asynchat.async_chat._collect_incoming_data
62         get_terminator = asynchat.async_chat.get_terminator
63         set_terminator = asynchat.async_chat.set_terminator
64         
65         def handle_readbuf(self):
66                 while self.ac_in_buffer:
67                         lb = len(self.ac_in_buffer)
68                         terminator = self.get_terminator()
69                         if not terminator:
70                                 # no terminator, collect it all
71                                 self.collect_incoming_data (self.ac_in_buffer)
72                                 self.ac_in_buffer = b''
73                         elif isinstance(terminator, int):
74                                 # numeric terminator
75                                 n = terminator
76                                 if lb < n:
77                                         self.collect_incoming_data (self.ac_in_buffer)
78                                         self.ac_in_buffer = b''
79                                         self.terminator = self.terminator - lb
80                                 else:
81                                         self.collect_incoming_data (self.ac_in_buffer[:n])
82                                         self.ac_in_buffer = self.ac_in_buffer[n:]
83                                         self.terminator = 0
84                                         self.found_terminator()
85                         else:
86                                 # 3 cases:
87                                 # 1) end of buffer matches terminator exactly:
88                                 #    collect data, transition
89                                 # 2) end of buffer matches some prefix:
90                                 #    collect data to the prefix
91                                 # 3) end of buffer does not match any prefix:
92                                 #    collect data
93                                 # NOTE: this supports multiple different terminators, but
94                                 #       NOT ones that are prefixes of others...
95                                 if isinstance(self.ac_in_buffer, type(terminator)):
96                                         terminator = (terminator,)
97                                 termidx = tuple(map(self.ac_in_buffer.find, terminator))
98                                 try:
99                                         index = min(x for x in termidx if x >= 0)
100                                 except ValueError:
101                                         index = -1
102                                 if index != -1:
103                                         # we found the terminator
104                                         if index > 0:
105                                                 # don't bother reporting the empty string (source of subtle bugs)
106                                                 self.collect_incoming_data (self.ac_in_buffer[:index])
107                                         specific_terminator = terminator[termidx.index(index)]
108                                         terminator_len = len(specific_terminator)
109                                         self.ac_in_buffer = self.ac_in_buffer[index+terminator_len:]
110                                         # This does the Right Thing if the terminator is changed here.
111                                         self.found_terminator()
112                                 else:
113                                         # check for a prefix of the terminator
114                                         termidx = tuple(map(lambda a: asynchat.find_prefix_at_end (self.ac_in_buffer, a), terminator))
115                                         index = max(termidx)
116                                         if index:
117                                                 if index != lb:
118                                                         # we found a prefix, collect up to the prefix
119                                                         self.collect_incoming_data (self.ac_in_buffer[:-index])
120                                                         self.ac_in_buffer = self.ac_in_buffer[-index:]
121                                                 break
122                                         else:
123                                                 # no prefix, collect it all
124                                                 self.collect_incoming_data (self.ac_in_buffer)
125                                                 self.ac_in_buffer = b''
126         
127         def push(self, data):
128                 if not len(self.wbuf):
129                         # Try to send as much as we can immediately
130                         try:
131                                 bs = self.socket.send(data)
132                         except:
133                                 # Chances are we'll fail later, but anyway...
134                                 bs = 0
135                         data = data[bs:]
136                         if not len(data):
137                                 return
138                 self.wbuf += data
139                 self.server.register_socket_m(self.fd, EPOLL_READ | EPOLL_WRITE)
140         
141         def handle_timeout(self):
142                 self.close()
143         
144         def handle_write(self):
145                 if self.wbuf is None:
146                         # Socket was just closed by remote peer
147                         return
148                 bs = self.socket.send(self.wbuf)
149                 self.wbuf = self.wbuf[bs:]
150                 if not len(self.wbuf):
151                         if self.closeme:
152                                 self.close()
153                                 return
154                         self.server.register_socket_m(self.fd, EPOLL_READ)
155         
156         recv = asynchat.async_chat.recv
157         
158         def close(self):
159                 if self.wbuf:
160                         self.closeme = True
161                         return
162                 if self.fd == -1:
163                         # Already closed
164                         return
165                 try:
166                         del self.server.connections[id(self)]
167                 except:
168                         pass
169                 self.server.unregister_socket(self.fd)
170                 self.changeTask(None)
171                 self.socket.close()
172                 self.fd = -1
173         
174         def boot(self):
175                 self.close()
176                 self.ac_in_buffer = b''
177         
178         def changeTask(self, f, t = None):
179                 tryErr(self.server.rmSchedule, self._Task, IgnoredExceptions=KeyError)
180                 if f:
181                         self._Task = self.server.schedule(f, t, errHandler=self)
182                 else:
183                         self._Task = None
184         
185         def __init__(self, server, sock, addr):
186                 self.ac_in_buffer = b''
187                 self.incoming = []
188                 self.wbuf = b''
189                 self.closeme = False
190                 self.server = server
191                 self.socket = sock
192                 self.addr = addr
193                 self._Task = None
194                 self.fd = sock.fileno()
195                 server.register_socket(self.fd, self)
196                 server.connections[id(self)] = self
197                 self.changeTask(self.handle_timeout, time() + 15)
198         
199         @classmethod
200         def _register(cls, scls):
201                 for a in dir(scls):
202                         if a == 'final_init':
203                                 f = lambda self, x=getattr(cls, a), y=getattr(scls, a): (x(self), y(self))
204                                 setattr(cls, a, f)
205                                 continue
206                         if a[0] == '_':
207                                 continue
208                         setattr(cls, a, getattr(scls, a))
209
210 class NetworkListener:
211         logger = logging.getLogger('SocketListener')
212         
213         def __init__(self, server, server_address, address_family = socket.AF_INET6):
214                 self.server = server
215                 self.server_address = server_address
216                 self.address_family = address_family
217                 tryErr(self.setup_socket, server_address, Logger=self.logger, ErrorMsg=server_address)
218         
219         def _makebind_py(self, server_address):
220                 sock = socket.socket(self.address_family, socket.SOCK_STREAM)
221                 sock.setblocking(0)
222                 try:
223                         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
224                 except socket.error:
225                         pass
226                 sock.bind(server_address)
227                 return sock
228         
229         def _makebind_su(self, server_address):
230                 if self.address_family != socket.AF_INET6:
231                         raise NotImplementedError
232                 
233                 from bindservice import bindservice
234                 (node, service) = server_address
235                 if not node: node = ''
236                 if not service: service = ''
237                 fd = bindservice(str(node), str(service))
238                 sock = socket.fromfd(fd, socket.AF_INET6, socket.SOCK_STREAM)
239                 sock.setblocking(0)
240                 return sock
241         
242         def _makebind(self, *a, **ka):
243                 try:
244                         return self._makebind_py(*a, **ka)
245                 except BaseException as e:
246                         try:
247                                 return self._makebind_su(*a, **ka)
248                         except:
249                                 pass
250                         raise
251         
252         def setup_socket(self, server_address):
253                 sock = self._makebind(server_address)
254                 sock.listen(100)
255                 self.server.register_socket(sock.fileno(), self)
256                 self.socket = sock
257         
258         def handle_read(self):
259                 server = self.server
260                 conn, addr = self.socket.accept()
261                 if server.rejecting:
262                         conn.close()
263                         return
264                 conn.setblocking(False)
265                 h = server.RequestHandlerClass(server, conn, addr)
266         
267         def handle_error(self):
268                 # Ignore errors... like socket closing on the queue
269                 pass
270
271 class _Waker:
272         def __init__(self, server, fd):
273                 self.server = server
274                 self.fd = fd
275                 self.logger = logging.getLogger('Waker for %s' % (server.__class__.__name__,))
276         
277         def handle_read(self):
278                 data = os.read(self.fd, 1)
279                 if not data:
280                         self.logger.error('Got EOF on socket')
281                 self.logger.debug('Read wakeup')
282
283 class AsyncSocketServer:
284         logger = logging.getLogger('SocketServer')
285         
286         waker = False
287         schMT = False
288         
289         def __init__(self, RequestHandlerClass):
290                 if not hasattr(self, 'ServerName'):
291                         self.ServerName = 'Eloipool'
292                 
293                 self.RequestHandlerClass = RequestHandlerClass
294                 
295                 self.running = False
296                 self.keepgoing = True
297                 self.rejecting = False
298                 
299                 self._epoll = select.epoll()
300                 self._fd = {}
301                 self.connections = {}
302                 
303                 self._sch = ScheduleDict()
304                 self._schEH = {}
305                 if self.schMT:
306                         self._schLock = threading.Lock()
307                 else:
308                         self._schLock = WithNoop
309                 
310                 self.TrustedForwarders = ()
311                 
312                 if self.waker:
313                         (r, w) = os.pipe()
314                         o = _Waker(self, r)
315                         self.register_socket(r, o)
316                         self.waker = w
317         
318         def register_socket(self, fd, o, eventmask = EPOLL_READ):
319                 self._epoll.register(fd, eventmask)
320                 self._fd[fd] = o
321         
322         def register_socket_m(self, fd, eventmask):
323                 try:
324                         self._epoll.modify(fd, eventmask)
325                 except IOError:
326                         raise socket.error
327         
328         def unregister_socket(self, fd):
329                 del self._fd[fd]
330                 try:
331                         self._epoll.unregister(fd)
332                 except IOError:
333                         raise socket.error
334         
335         def schedule(self, task, startTime, errHandler=None):
336                 with self._schLock:
337                         self._sch[task] = startTime
338                         if errHandler:
339                                 self._schEH[id(task)] = errHandler
340                 return task
341         
342         def rmSchedule(self, task):
343                 with self._schLock:
344                         del self._sch[task]
345                         k = id(task)
346                         if k in self._schEH:
347                                 del self._schEH[k]
348         
349         def pre_schedule(self):
350                 pass
351         
352         def wakeup(self):
353                 if not self.waker:
354                         raise NotImplementedError('Class `%s\' did not enable waker' % (self.__class__.__name__))
355                 os.write(self.waker, b'\1')  # to break out of the epoll
356         
357         def final_init(self):
358                 pass
359         
360         def boot_all(self):
361                 conns = tuple(self.connections.values())
362                 for c in conns:
363                         tryErr(lambda: c.boot())
364         
365         def serve_forever(self):
366                 self.running = True
367                 self.final_init()
368                 while self.keepgoing:
369                         self.doing = 'pre-schedule'
370                         self.pre_schedule()
371                         self.doing = 'schedule'
372                         if len(self._sch):
373                                 timeNow = time()
374                                 while True:
375                                         with self._schLock:
376                                                 if not len(self._sch):
377                                                         timeout = -1
378                                                         break
379                                                 timeNext = self._sch.nextTime()
380                                                 if timeNow < timeNext:
381                                                         timeout = timeNext - timeNow
382                                                         break
383                                                 f = self._sch.shift()
384                                         k = id(f)
385                                         EH = None
386                                         if k in self._schEH:
387                                                 EH = self._schEH[k]
388                                                 del self._schEH[k]
389                                         try:
390                                                 f()
391                                         except socket.error:
392                                                 if EH: tryErr(EH.handle_error)
393                                         except:
394                                                 self.logger.error(traceback.format_exc())
395                                                 if EH: tryErr(EH.handle_close)
396                         else:
397                                 timeout = -1
398                         
399                         self.doing = 'poll'
400                         try:
401                                 events = self._epoll.poll(timeout=timeout)
402                         except (IOError, select.error):
403                                 continue
404                         except:
405                                 self.logger.error(traceback.format_exc())
406                         self.doing = 'events'
407                         for (fd, e) in events:
408                                 o = self._fd[fd]
409                                 self.lastHandler = o
410                                 try:
411                                         if e & EPOLL_READ:
412                                                 o.handle_read()
413                                         if e & EPOLL_WRITE:
414                                                 o.handle_write()
415                                 except socket.error:
416                                         tryErr(o.handle_error)
417                                 except:
418                                         self.logger.error(traceback.format_exc())
419                                         tryErr(o.handle_error)
420                 self.doing = None
421                 self.running = False