NetworkServer: Keep track of when the server was last idle
[bitcoin:eloipool.git] / networkserver.py
1 # Eloipool - Python Bitcoin pool server
2 # Copyright (C) 2011-2013  Luke Dashjr <luke-jr+eloipool@utopios.org>
3 #
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as
6 # published by the Free Software Foundation, either version 3 of the
7 # License, or (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16
17 import asynchat
18 import logging
19 import os
20 import select
21 import socket
22 import threading
23 from time import time
24 import traceback
25 from util import ScheduleDict, WithNoop, tryErr
26
27 EPOLL_READ = select.EPOLLIN | select.EPOLLPRI | select.EPOLLERR | select.EPOLLHUP
28 EPOLL_WRITE = select.EPOLLOUT
29
30 class SocketHandler:
31         ac_in_buffer_size = 4096
32         ac_out_buffer_size = 4096
33         
34         def handle_close(self):
35                 self.wbuf = None
36                 self.close()
37         
38         def handle_error(self):
39                 self.logger.debug(traceback.format_exc())
40                 self.handle_close()
41         
42         def handle_read(self):
43                 try:
44                         data = self.recv (self.ac_in_buffer_size)
45                 except socket.error as why:
46                         self.handle_error()
47                         return
48                 
49                 if self.closeme:
50                         # All input is ignored from sockets we have "closed"
51                         return
52                 
53                 if isinstance(data, str) and self.use_encoding:
54                         data = bytes(str, self.encoding)
55                 self.ac_in_buffer = self.ac_in_buffer + data
56                 
57                 self.server.lastReadbuf = self.ac_in_buffer
58                 
59                 self.handle_readbuf()
60         
61         collect_incoming_data = asynchat.async_chat._collect_incoming_data
62         get_terminator = asynchat.async_chat.get_terminator
63         set_terminator = asynchat.async_chat.set_terminator
64         
65         def handle_readbuf(self):
66                 while self.ac_in_buffer:
67                         lb = len(self.ac_in_buffer)
68                         terminator = self.get_terminator()
69                         if not terminator:
70                                 # no terminator, collect it all
71                                 self.collect_incoming_data (self.ac_in_buffer)
72                                 self.ac_in_buffer = b''
73                         elif isinstance(terminator, int):
74                                 # numeric terminator
75                                 n = terminator
76                                 if lb < n:
77                                         self.collect_incoming_data (self.ac_in_buffer)
78                                         self.ac_in_buffer = b''
79                                         self.terminator = self.terminator - lb
80                                 else:
81                                         self.collect_incoming_data (self.ac_in_buffer[:n])
82                                         self.ac_in_buffer = self.ac_in_buffer[n:]
83                                         self.terminator = 0
84                                         self.found_terminator()
85                         else:
86                                 # 3 cases:
87                                 # 1) end of buffer matches terminator exactly:
88                                 #    collect data, transition
89                                 # 2) end of buffer matches some prefix:
90                                 #    collect data to the prefix
91                                 # 3) end of buffer does not match any prefix:
92                                 #    collect data
93                                 # NOTE: this supports multiple different terminators, but
94                                 #       NOT ones that are prefixes of others...
95                                 if isinstance(self.ac_in_buffer, type(terminator)):
96                                         terminator = (terminator,)
97                                 termidx = tuple(map(self.ac_in_buffer.find, terminator))
98                                 try:
99                                         index = min(x for x in termidx if x >= 0)
100                                 except ValueError:
101                                         index = -1
102                                 if index != -1:
103                                         # we found the terminator
104                                         if index > 0:
105                                                 # don't bother reporting the empty string (source of subtle bugs)
106                                                 self.collect_incoming_data (self.ac_in_buffer[:index])
107                                         specific_terminator = terminator[termidx.index(index)]
108                                         terminator_len = len(specific_terminator)
109                                         self.ac_in_buffer = self.ac_in_buffer[index+terminator_len:]
110                                         # This does the Right Thing if the terminator is changed here.
111                                         self.found_terminator()
112                                 else:
113                                         # check for a prefix of the terminator
114                                         termidx = tuple(map(lambda a: asynchat.find_prefix_at_end (self.ac_in_buffer, a), terminator))
115                                         index = max(termidx)
116                                         if index:
117                                                 if index != lb:
118                                                         # we found a prefix, collect up to the prefix
119                                                         self.collect_incoming_data (self.ac_in_buffer[:-index])
120                                                         self.ac_in_buffer = self.ac_in_buffer[-index:]
121                                                 break
122                                         else:
123                                                 # no prefix, collect it all
124                                                 self.collect_incoming_data (self.ac_in_buffer)
125                                                 self.ac_in_buffer = b''
126         
127         def push(self, data):
128                 if not len(self.wbuf):
129                         # Try to send as much as we can immediately
130                         try:
131                                 bs = self.socket.send(data)
132                         except:
133                                 # Chances are we'll fail later, but anyway...
134                                 bs = 0
135                         data = data[bs:]
136                         if not len(data):
137                                 return
138                 self.wbuf += data
139                 self.server.register_socket_m(self.fd, EPOLL_READ | EPOLL_WRITE)
140         
141         def handle_timeout(self):
142                 self.close()
143         
144         def handle_write(self):
145                 if self.wbuf is None:
146                         # Socket was just closed by remote peer
147                         return
148                 bs = self.socket.send(self.wbuf)
149                 self.wbuf = self.wbuf[bs:]
150                 if not len(self.wbuf):
151                         if self.closeme:
152                                 self.close()
153                                 return
154                         self.server.register_socket_m(self.fd, EPOLL_READ)
155         
156         recv = asynchat.async_chat.recv
157         
158         def close(self):
159                 if self.wbuf:
160                         self.closeme = True
161                         return
162                 if self.fd == -1:
163                         # Already closed
164                         return
165                 try:
166                         del self.server.connections[id(self)]
167                 except:
168                         pass
169                 self.server.unregister_socket(self.fd)
170                 self.changeTask(None)
171                 self.socket.close()
172                 self.fd = -1
173         
174         def boot(self):
175                 self.close()
176                 self.ac_in_buffer = b''
177         
178         def changeTask(self, f, t = None):
179                 tryErr(self.server.rmSchedule, self._Task, IgnoredExceptions=KeyError)
180                 if f:
181                         self._Task = self.server.schedule(f, t, errHandler=self)
182                 else:
183                         self._Task = None
184         
185         def __init__(self, server, sock, addr):
186                 self.ac_in_buffer = b''
187                 self.incoming = []
188                 self.wbuf = b''
189                 self.closeme = False
190                 self.server = server
191                 self.socket = sock
192                 self.addr = addr
193                 self._Task = None
194                 self.fd = sock.fileno()
195                 server.register_socket(self.fd, self)
196                 server.connections[id(self)] = self
197                 self.changeTask(self.handle_timeout, time() + 15)
198         
199         @classmethod
200         def _register(cls, scls):
201                 for a in dir(scls):
202                         if a == 'final_init':
203                                 f = lambda self, x=getattr(cls, a), y=getattr(scls, a): (x(self), y(self))
204                                 setattr(cls, a, f)
205                                 continue
206                         if a[0] == '_':
207                                 continue
208                         setattr(cls, a, getattr(scls, a))
209
210 class NetworkListener:
211         logger = logging.getLogger('SocketListener')
212         
213         def __init__(self, server, server_address, address_family = socket.AF_INET6):
214                 self.server = server
215                 self.server_address = server_address
216                 self.address_family = address_family
217                 tryErr(self.setup_socket, server_address, Logger=self.logger, ErrorMsg=server_address)
218         
219         def _makebind_py(self, server_address):
220                 sock = socket.socket(self.address_family, socket.SOCK_STREAM)
221                 sock.setblocking(0)
222                 try:
223                         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
224                 except socket.error:
225                         pass
226                 sock.bind(server_address)
227                 return sock
228         
229         def _makebind_su(self, server_address):
230                 if self.address_family != socket.AF_INET6:
231                         raise NotImplementedError
232                 
233                 from bindservice import bindservice
234                 (node, service) = server_address
235                 if not node: node = ''
236                 if not service: service = ''
237                 fd = bindservice(str(node), str(service))
238                 sock = socket.fromfd(fd, socket.AF_INET6, socket.SOCK_STREAM)
239                 sock.setblocking(0)
240                 return sock
241         
242         def _makebind(self, *a, **ka):
243                 try:
244                         return self._makebind_py(*a, **ka)
245                 except BaseException as e:
246                         try:
247                                 return self._makebind_su(*a, **ka)
248                         except:
249                                 pass
250                         raise
251         
252         def setup_socket(self, server_address):
253                 sock = self._makebind(server_address)
254                 sock.listen(100)
255                 self.server.register_socket(sock.fileno(), self)
256                 self.socket = sock
257         
258         def handle_read(self):
259                 server = self.server
260                 conn, addr = self.socket.accept()
261                 if server.rejecting:
262                         conn.close()
263                         return
264                 conn.setblocking(False)
265                 h = server.RequestHandlerClass(server, conn, addr)
266         
267         def handle_error(self):
268                 # Ignore errors... like socket closing on the queue
269                 pass
270
271 class _Waker:
272         def __init__(self, server, fd):
273                 self.server = server
274                 self.fd = fd
275                 self.logger = logging.getLogger('Waker for %s' % (server.__class__.__name__,))
276         
277         def handle_read(self):
278                 data = os.read(self.fd, 1)
279                 if not data:
280                         self.logger.error('Got EOF on socket')
281                 self.logger.debug('Read wakeup')
282
283 class AsyncSocketServer:
284         logger = logging.getLogger('SocketServer')
285         
286         waker = False
287         schMT = False
288         
289         def __init__(self, RequestHandlerClass):
290                 if not hasattr(self, 'ServerName'):
291                         self.ServerName = 'Eloipool'
292                 
293                 self.RequestHandlerClass = RequestHandlerClass
294                 
295                 self.running = False
296                 self.keepgoing = True
297                 self.rejecting = False
298                 self.lastidle = 0
299                 
300                 self._epoll = select.epoll()
301                 self._fd = {}
302                 self.connections = {}
303                 
304                 self._sch = ScheduleDict()
305                 self._schEH = {}
306                 if self.schMT:
307                         self._schLock = threading.Lock()
308                 else:
309                         self._schLock = WithNoop
310                 
311                 self.TrustedForwarders = ()
312                 
313                 if self.waker:
314                         (r, w) = os.pipe()
315                         o = _Waker(self, r)
316                         self.register_socket(r, o)
317                         self.waker = w
318         
319         def register_socket(self, fd, o, eventmask = EPOLL_READ):
320                 self._epoll.register(fd, eventmask)
321                 self._fd[fd] = o
322         
323         def register_socket_m(self, fd, eventmask):
324                 try:
325                         self._epoll.modify(fd, eventmask)
326                 except IOError:
327                         raise socket.error
328         
329         def unregister_socket(self, fd):
330                 del self._fd[fd]
331                 try:
332                         self._epoll.unregister(fd)
333                 except IOError:
334                         raise socket.error
335         
336         def schedule(self, task, startTime, errHandler=None):
337                 with self._schLock:
338                         self._sch[task] = startTime
339                         if errHandler:
340                                 self._schEH[id(task)] = errHandler
341                 return task
342         
343         def rmSchedule(self, task):
344                 with self._schLock:
345                         del self._sch[task]
346                         k = id(task)
347                         if k in self._schEH:
348                                 del self._schEH[k]
349         
350         def pre_schedule(self):
351                 pass
352         
353         def wakeup(self):
354                 if not self.waker:
355                         raise NotImplementedError('Class `%s\' did not enable waker' % (self.__class__.__name__))
356                 os.write(self.waker, b'\1')  # to break out of the epoll
357         
358         def final_init(self):
359                 pass
360         
361         def boot_all(self):
362                 conns = tuple(self.connections.values())
363                 for c in conns:
364                         tryErr(lambda: c.boot())
365         
366         def serve_forever(self):
367                 self.running = True
368                 self.final_init()
369                 while self.keepgoing:
370                         self.doing = 'pre-schedule'
371                         self.pre_schedule()
372                         self.doing = 'schedule'
373                         timeNow = time()
374                         if len(self._sch):
375                                 while True:
376                                         with self._schLock:
377                                                 if not len(self._sch):
378                                                         timeout = -1
379                                                         break
380                                                 timeNext = self._sch.nextTime()
381                                                 if timeNow < timeNext:
382                                                         timeout = timeNext - timeNow
383                                                         break
384                                                 f = self._sch.shift()
385                                         k = id(f)
386                                         EH = None
387                                         if k in self._schEH:
388                                                 EH = self._schEH[k]
389                                                 del self._schEH[k]
390                                         try:
391                                                 f()
392                                         except socket.error:
393                                                 if EH: tryErr(EH.handle_error)
394                                         except:
395                                                 self.logger.error(traceback.format_exc())
396                                                 if EH: tryErr(EH.handle_close)
397                         else:
398                                 timeout = -1
399                         if self.lastidle < timeNow - 1:
400                                 timeout = 0
401                         elif timeout < 0 or timeout > 1:
402                                 timeout = 1
403                         
404                         self.doing = 'poll'
405                         try:
406                                 events = self._epoll.poll(timeout=timeout)
407                         except (IOError, select.error):
408                                 continue
409                         except:
410                                 self.logger.error(traceback.format_exc())
411                                 continue
412                         self.doing = 'events'
413                         if not events:
414                                 self.lastidle = time()
415                         for (fd, e) in events:
416                                 o = self._fd[fd]
417                                 self.lastHandler = o
418                                 try:
419                                         if e & EPOLL_READ:
420                                                 o.handle_read()
421                                         if e & EPOLL_WRITE:
422                                                 o.handle_write()
423                                 except socket.error:
424                                         tryErr(o.handle_error)
425                                 except:
426                                         self.logger.error(traceback.format_exc())
427                                         tryErr(o.handle_error)
428                 self.doing = None
429                 self.running = False