Document asynchat-inherited recv function behaviour
[bitcoin:eloipool.git] / networkserver.py
1 # Eloipool - Python Bitcoin pool server
2 # Copyright (C) 2011-2013  Luke Dashjr <luke-jr+eloipool@utopios.org>
3 #
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as
6 # published by the Free Software Foundation, either version 3 of the
7 # License, or (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16
17 import asynchat
18 import logging
19 import os
20 import select
21 import socket
22 import threading
23 from time import time
24 import traceback
25 from util import ScheduleDict, WithNoop, tryErr
26
27 EPOLL_READ = select.EPOLLIN | select.EPOLLPRI | select.EPOLLERR | select.EPOLLHUP
28 EPOLL_WRITE = select.EPOLLOUT
29
30 class SocketHandler:
31         ac_in_buffer_size = 4096
32         ac_out_buffer_size = 4096
33         
34         def handle_close(self):
35                 self.wbuf = None
36                 self.close()
37         
38         def handle_error(self):
39                 self.logger.debug(traceback.format_exc())
40                 self.handle_close()
41         
42         # NOTE: This function checks for socket-closed condition and calls handle_close
43         recv = asynchat.async_chat.recv
44         
45         def handle_read(self):
46                 try:
47                         data = self.recv (self.ac_in_buffer_size)
48                 except socket.error as why:
49                         self.handle_error()
50                         return
51                 
52                 if self.closeme:
53                         # All input is ignored from sockets we have "closed"
54                         return
55                 
56                 if isinstance(data, str) and self.use_encoding:
57                         data = bytes(str, self.encoding)
58                 self.ac_in_buffer = self.ac_in_buffer + data
59                 
60                 self.server.lastReadbuf = self.ac_in_buffer
61                 
62                 self.handle_readbuf()
63         
64         collect_incoming_data = asynchat.async_chat._collect_incoming_data
65         get_terminator = asynchat.async_chat.get_terminator
66         set_terminator = asynchat.async_chat.set_terminator
67         
68         def handle_readbuf(self):
69                 while self.ac_in_buffer:
70                         lb = len(self.ac_in_buffer)
71                         terminator = self.get_terminator()
72                         if not terminator:
73                                 # no terminator, collect it all
74                                 self.collect_incoming_data (self.ac_in_buffer)
75                                 self.ac_in_buffer = b''
76                         elif isinstance(terminator, int):
77                                 # numeric terminator
78                                 n = terminator
79                                 if lb < n:
80                                         self.collect_incoming_data (self.ac_in_buffer)
81                                         self.ac_in_buffer = b''
82                                         self.terminator = self.terminator - lb
83                                 else:
84                                         self.collect_incoming_data (self.ac_in_buffer[:n])
85                                         self.ac_in_buffer = self.ac_in_buffer[n:]
86                                         self.terminator = 0
87                                         self.found_terminator()
88                         else:
89                                 # 3 cases:
90                                 # 1) end of buffer matches terminator exactly:
91                                 #    collect data, transition
92                                 # 2) end of buffer matches some prefix:
93                                 #    collect data to the prefix
94                                 # 3) end of buffer does not match any prefix:
95                                 #    collect data
96                                 # NOTE: this supports multiple different terminators, but
97                                 #       NOT ones that are prefixes of others...
98                                 if isinstance(self.ac_in_buffer, type(terminator)):
99                                         terminator = (terminator,)
100                                 termidx = tuple(map(self.ac_in_buffer.find, terminator))
101                                 try:
102                                         index = min(x for x in termidx if x >= 0)
103                                 except ValueError:
104                                         index = -1
105                                 if index != -1:
106                                         # we found the terminator
107                                         if index > 0:
108                                                 # don't bother reporting the empty string (source of subtle bugs)
109                                                 self.collect_incoming_data (self.ac_in_buffer[:index])
110                                         specific_terminator = terminator[termidx.index(index)]
111                                         terminator_len = len(specific_terminator)
112                                         self.ac_in_buffer = self.ac_in_buffer[index+terminator_len:]
113                                         # This does the Right Thing if the terminator is changed here.
114                                         self.found_terminator()
115                                 else:
116                                         # check for a prefix of the terminator
117                                         termidx = tuple(map(lambda a: asynchat.find_prefix_at_end (self.ac_in_buffer, a), terminator))
118                                         index = max(termidx)
119                                         if index:
120                                                 if index != lb:
121                                                         # we found a prefix, collect up to the prefix
122                                                         self.collect_incoming_data (self.ac_in_buffer[:-index])
123                                                         self.ac_in_buffer = self.ac_in_buffer[-index:]
124                                                 break
125                                         else:
126                                                 # no prefix, collect it all
127                                                 self.collect_incoming_data (self.ac_in_buffer)
128                                                 self.ac_in_buffer = b''
129         
130         def push(self, data):
131                 if not len(self.wbuf):
132                         # Try to send as much as we can immediately
133                         try:
134                                 bs = self.socket.send(data)
135                         except:
136                                 # Chances are we'll fail later, but anyway...
137                                 bs = 0
138                         data = data[bs:]
139                         if not len(data):
140                                 return
141                 self.wbuf += data
142                 self.server.register_socket_m(self.fd, EPOLL_READ | EPOLL_WRITE)
143         
144         def handle_timeout(self):
145                 self.close()
146         
147         def handle_write(self):
148                 if self.wbuf is None:
149                         # Socket was just closed by remote peer
150                         return
151                 bs = self.socket.send(self.wbuf)
152                 self.wbuf = self.wbuf[bs:]
153                 if not len(self.wbuf):
154                         if self.closeme:
155                                 self.close()
156                                 return
157                         self.server.register_socket_m(self.fd, EPOLL_READ)
158         
159         def close(self):
160                 if self.wbuf:
161                         self.closeme = True
162                         return
163                 if self.fd == -1:
164                         # Already closed
165                         return
166                 try:
167                         del self.server.connections[id(self)]
168                 except:
169                         pass
170                 self.server.unregister_socket(self.fd)
171                 self.changeTask(None)
172                 self.socket.close()
173                 self.fd = -1
174         
175         def boot(self):
176                 self.close()
177                 self.ac_in_buffer = b''
178         
179         def changeTask(self, f, t = None):
180                 tryErr(self.server.rmSchedule, self._Task, IgnoredExceptions=KeyError)
181                 if f:
182                         self._Task = self.server.schedule(f, t, errHandler=self)
183                 else:
184                         self._Task = None
185         
186         def __init__(self, server, sock, addr):
187                 self.ac_in_buffer = b''
188                 self.incoming = []
189                 self.wbuf = b''
190                 self.closeme = False
191                 self.server = server
192                 self.socket = sock
193                 self.addr = addr
194                 self._Task = None
195                 self.fd = sock.fileno()
196                 server.register_socket(self.fd, self)
197                 server.connections[id(self)] = self
198                 self.changeTask(self.handle_timeout, time() + 15)
199         
200         @classmethod
201         def _register(cls, scls):
202                 for a in dir(scls):
203                         if a == 'final_init':
204                                 f = lambda self, x=getattr(cls, a), y=getattr(scls, a): (x(self), y(self))
205                                 setattr(cls, a, f)
206                                 continue
207                         if a[0] == '_':
208                                 continue
209                         setattr(cls, a, getattr(scls, a))
210
211 class NetworkListener:
212         logger = logging.getLogger('SocketListener')
213         
214         def __init__(self, server, server_address, address_family = socket.AF_INET6):
215                 self.server = server
216                 self.server_address = server_address
217                 self.address_family = address_family
218                 tryErr(self.setup_socket, server_address, Logger=self.logger, ErrorMsg=server_address)
219         
220         def _makebind_py(self, server_address):
221                 sock = socket.socket(self.address_family, socket.SOCK_STREAM)
222                 sock.setblocking(0)
223                 try:
224                         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
225                 except socket.error:
226                         pass
227                 sock.bind(server_address)
228                 return sock
229         
230         def _makebind_su(self, server_address):
231                 if self.address_family != socket.AF_INET6:
232                         raise NotImplementedError
233                 
234                 from bindservice import bindservice
235                 (node, service) = server_address
236                 if not node: node = ''
237                 if not service: service = ''
238                 fd = bindservice(str(node), str(service))
239                 sock = socket.fromfd(fd, socket.AF_INET6, socket.SOCK_STREAM)
240                 sock.setblocking(0)
241                 return sock
242         
243         def _makebind(self, *a, **ka):
244                 try:
245                         return self._makebind_py(*a, **ka)
246                 except BaseException as e:
247                         try:
248                                 return self._makebind_su(*a, **ka)
249                         except:
250                                 pass
251                         raise
252         
253         def setup_socket(self, server_address):
254                 sock = self._makebind(server_address)
255                 sock.listen(100)
256                 self.server.register_socket(sock.fileno(), self)
257                 self.socket = sock
258         
259         def handle_read(self):
260                 server = self.server
261                 conn, addr = self.socket.accept()
262                 if server.rejecting:
263                         conn.close()
264                         return
265                 conn.setblocking(False)
266                 h = server.RequestHandlerClass(server, conn, addr)
267         
268         def handle_error(self):
269                 # Ignore errors... like socket closing on the queue
270                 pass
271
272 class _Waker:
273         def __init__(self, server, fd):
274                 self.server = server
275                 self.fd = fd
276                 self.logger = logging.getLogger('Waker for %s' % (server.__class__.__name__,))
277         
278         def handle_read(self):
279                 data = os.read(self.fd, 1)
280                 if not data:
281                         self.logger.error('Got EOF on socket')
282                 self.logger.debug('Read wakeup')
283
284 class AsyncSocketServer:
285         logger = logging.getLogger('SocketServer')
286         
287         waker = False
288         schMT = False
289         
290         def __init__(self, RequestHandlerClass):
291                 if not hasattr(self, 'ServerName'):
292                         self.ServerName = 'Eloipool'
293                 
294                 self.RequestHandlerClass = RequestHandlerClass
295                 
296                 self.running = False
297                 self.keepgoing = True
298                 self.rejecting = False
299                 self.lastidle = 0
300                 
301                 self._epoll = select.epoll()
302                 self._fd = {}
303                 self.connections = {}
304                 
305                 self._sch = ScheduleDict()
306                 self._schEH = {}
307                 if self.schMT:
308                         self._schLock = threading.Lock()
309                 else:
310                         self._schLock = WithNoop
311                 
312                 self.TrustedForwarders = ()
313                 
314                 if self.waker:
315                         (r, w) = os.pipe()
316                         o = _Waker(self, r)
317                         self.register_socket(r, o)
318                         self.waker = w
319         
320         def register_socket(self, fd, o, eventmask = EPOLL_READ):
321                 self._epoll.register(fd, eventmask)
322                 self._fd[fd] = o
323         
324         def register_socket_m(self, fd, eventmask):
325                 try:
326                         self._epoll.modify(fd, eventmask)
327                 except IOError:
328                         raise socket.error
329         
330         def unregister_socket(self, fd):
331                 del self._fd[fd]
332                 try:
333                         self._epoll.unregister(fd)
334                 except IOError:
335                         raise socket.error
336         
337         def schedule(self, task, startTime, errHandler=None):
338                 with self._schLock:
339                         self._sch[task] = startTime
340                         if errHandler:
341                                 self._schEH[id(task)] = errHandler
342                 return task
343         
344         def rmSchedule(self, task):
345                 with self._schLock:
346                         del self._sch[task]
347                         k = id(task)
348                         if k in self._schEH:
349                                 del self._schEH[k]
350         
351         def pre_schedule(self):
352                 pass
353         
354         def wakeup(self):
355                 if not self.waker:
356                         raise NotImplementedError('Class `%s\' did not enable waker' % (self.__class__.__name__))
357                 os.write(self.waker, b'\1')  # to break out of the epoll
358         
359         def final_init(self):
360                 pass
361         
362         def boot_all(self):
363                 conns = tuple(self.connections.values())
364                 for c in conns:
365                         tryErr(lambda: c.boot())
366         
367         def serve_forever(self):
368                 self.running = True
369                 self.final_init()
370                 while self.keepgoing:
371                         self.doing = 'pre-schedule'
372                         self.pre_schedule()
373                         self.doing = 'schedule'
374                         timeNow = time()
375                         if len(self._sch):
376                                 while True:
377                                         with self._schLock:
378                                                 if not len(self._sch):
379                                                         timeout = -1
380                                                         break
381                                                 timeNext = self._sch.nextTime()
382                                                 if timeNow < timeNext:
383                                                         timeout = timeNext - timeNow
384                                                         break
385                                                 f = self._sch.shift()
386                                         k = id(f)
387                                         EH = None
388                                         if k in self._schEH:
389                                                 EH = self._schEH[k]
390                                                 del self._schEH[k]
391                                         try:
392                                                 f()
393                                         except socket.error:
394                                                 if EH: tryErr(EH.handle_error)
395                                         except:
396                                                 self.logger.error(traceback.format_exc())
397                                                 if EH: tryErr(EH.handle_close)
398                         else:
399                                 timeout = -1
400                         if self.lastidle < timeNow - 1:
401                                 timeout = 0
402                         elif timeout < 0 or timeout > 1:
403                                 timeout = 1
404                         
405                         self.doing = 'poll'
406                         try:
407                                 events = self._epoll.poll(timeout=timeout)
408                         except (IOError, select.error):
409                                 continue
410                         except:
411                                 self.logger.error(traceback.format_exc())
412                                 continue
413                         self.doing = 'events'
414                         if not events:
415                                 self.lastidle = time()
416                         for (fd, e) in events:
417                                 o = self._fd[fd]
418                                 self.lastHandler = o
419                                 try:
420                                         if e & EPOLL_READ:
421                                                 o.handle_read()
422                                         if e & EPOLL_WRITE:
423                                                 o.handle_write()
424                                 except socket.error:
425                                         tryErr(o.handle_error)
426                                 except:
427                                         self.logger.error(traceback.format_exc())
428                                         tryErr(o.handle_error)
429                 self.doing = None
430                 self.running = False