Silence "No route to host" and "Connection timedout" errors when clients dissapear
[bitcoin:eloipool.git] / networkserver.py
1 # Eloipool - Python Bitcoin pool server
2 # Copyright (C) 2011-2013  Luke Dashjr <luke-jr+eloipool@utopios.org>
3 #
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as
6 # published by the Free Software Foundation, either version 3 of the
7 # License, or (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16
17 import asynchat
18 import logging
19 import os
20 import select
21 import socket
22 import threading
23 from time import time
24 import traceback
25 from util import ScheduleDict, WithNoop, tryErr
26 from errno import EHOSTUNREACH, ETIMEDOUT
27
28 EPOLL_READ = select.EPOLLIN | select.EPOLLPRI | select.EPOLLERR | select.EPOLLHUP
29 EPOLL_WRITE = select.EPOLLOUT
30
31 _DISCONNECTED = frozenset((EHOSTUNREACH,ETIMEDOUT))
32
33 class SocketHandler:
34         ac_in_buffer_size = 4096
35         ac_out_buffer_size = 4096
36         
37         def handle_close(self):
38                 self.wbuf = None
39                 self.close()
40         
41         def handle_error(self):
42                 self.logger.debug(traceback.format_exc())
43                 self.handle_close()
44         
45         # NOTE: This function checks for socket-closed condition and calls handle_close
46         recv = asynchat.async_chat.recv
47         
48         def handle_read(self):
49                 try:
50                         data = self.recv (self.ac_in_buffer_size)
51                 except socket.error as why:
52                         # This silences some additional expected socket errors
53                         # not automatically dealt with by asyncore.
54                         if why.args[0] not in _DISCONNECTED:
55                                 self.handle_error()
56                         else:
57                                 self.handle_close()
58                         return
59                 
60                 if self.closeme:
61                         # All input is ignored from sockets we have "closed"
62                         return
63                 
64                 if isinstance(data, str) and self.use_encoding:
65                         data = bytes(str, self.encoding)
66                 self.ac_in_buffer = self.ac_in_buffer + data
67                 
68                 self.server.lastReadbuf = self.ac_in_buffer
69                 
70                 self.handle_readbuf()
71         
72         collect_incoming_data = asynchat.async_chat._collect_incoming_data
73         get_terminator = asynchat.async_chat.get_terminator
74         set_terminator = asynchat.async_chat.set_terminator
75         
76         def handle_readbuf(self):
77                 while self.ac_in_buffer:
78                         lb = len(self.ac_in_buffer)
79                         terminator = self.get_terminator()
80                         if not terminator:
81                                 # no terminator, collect it all
82                                 self.collect_incoming_data (self.ac_in_buffer)
83                                 self.ac_in_buffer = b''
84                         elif isinstance(terminator, int):
85                                 # numeric terminator
86                                 n = terminator
87                                 if lb < n:
88                                         self.collect_incoming_data (self.ac_in_buffer)
89                                         self.ac_in_buffer = b''
90                                         self.terminator = self.terminator - lb
91                                 else:
92                                         self.collect_incoming_data (self.ac_in_buffer[:n])
93                                         self.ac_in_buffer = self.ac_in_buffer[n:]
94                                         self.terminator = 0
95                                         self.found_terminator()
96                         else:
97                                 # 3 cases:
98                                 # 1) end of buffer matches terminator exactly:
99                                 #    collect data, transition
100                                 # 2) end of buffer matches some prefix:
101                                 #    collect data to the prefix
102                                 # 3) end of buffer does not match any prefix:
103                                 #    collect data
104                                 # NOTE: this supports multiple different terminators, but
105                                 #       NOT ones that are prefixes of others...
106                                 if isinstance(self.ac_in_buffer, type(terminator)):
107                                         terminator = (terminator,)
108                                 termidx = tuple(map(self.ac_in_buffer.find, terminator))
109                                 try:
110                                         index = min(x for x in termidx if x >= 0)
111                                 except ValueError:
112                                         index = -1
113                                 if index != -1:
114                                         # we found the terminator
115                                         if index > 0:
116                                                 # don't bother reporting the empty string (source of subtle bugs)
117                                                 self.collect_incoming_data (self.ac_in_buffer[:index])
118                                         specific_terminator = terminator[termidx.index(index)]
119                                         terminator_len = len(specific_terminator)
120                                         self.ac_in_buffer = self.ac_in_buffer[index+terminator_len:]
121                                         # This does the Right Thing if the terminator is changed here.
122                                         self.found_terminator()
123                                 else:
124                                         # check for a prefix of the terminator
125                                         termidx = tuple(map(lambda a: asynchat.find_prefix_at_end (self.ac_in_buffer, a), terminator))
126                                         index = max(termidx)
127                                         if index:
128                                                 if index != lb:
129                                                         # we found a prefix, collect up to the prefix
130                                                         self.collect_incoming_data (self.ac_in_buffer[:-index])
131                                                         self.ac_in_buffer = self.ac_in_buffer[-index:]
132                                                 break
133                                         else:
134                                                 # no prefix, collect it all
135                                                 self.collect_incoming_data (self.ac_in_buffer)
136                                                 self.ac_in_buffer = b''
137         
138         def push(self, data):
139                 if not len(self.wbuf):
140                         # Try to send as much as we can immediately
141                         try:
142                                 bs = self.socket.send(data)
143                         except:
144                                 # Chances are we'll fail later, but anyway...
145                                 bs = 0
146                         data = data[bs:]
147                         if not len(data):
148                                 return
149                 self.wbuf += data
150                 self.server.register_socket_m(self.fd, EPOLL_READ | EPOLL_WRITE)
151         
152         def handle_timeout(self):
153                 self.close()
154         
155         def handle_write(self):
156                 if self.wbuf is None:
157                         # Socket was just closed by remote peer
158                         return
159                 bs = self.socket.send(self.wbuf)
160                 self.wbuf = self.wbuf[bs:]
161                 if not len(self.wbuf):
162                         if self.closeme:
163                                 self.close()
164                                 return
165                         self.server.register_socket_m(self.fd, EPOLL_READ)
166         
167         def close(self):
168                 if self.wbuf:
169                         self.closeme = True
170                         return
171                 if self.fd == -1:
172                         # Already closed
173                         return
174                 try:
175                         del self.server.connections[id(self)]
176                 except:
177                         pass
178                 self.server.unregister_socket(self.fd)
179                 self.changeTask(None)
180                 self.socket.close()
181                 self.fd = -1
182         
183         def boot(self):
184                 self.close()
185                 self.ac_in_buffer = b''
186         
187         def changeTask(self, f, t = None):
188                 tryErr(self.server.rmSchedule, self._Task, IgnoredExceptions=KeyError)
189                 if f:
190                         self._Task = self.server.schedule(f, t, errHandler=self)
191                 else:
192                         self._Task = None
193         
194         def __init__(self, server, sock, addr):
195                 self.ac_in_buffer = b''
196                 self.incoming = []
197                 self.wbuf = b''
198                 self.closeme = False
199                 self.server = server
200                 self.socket = sock
201                 self.addr = addr
202                 self._Task = None
203                 self.fd = sock.fileno()
204                 server.register_socket(self.fd, self)
205                 server.connections[id(self)] = self
206                 self.changeTask(self.handle_timeout, time() + 15)
207         
208         @classmethod
209         def _register(cls, scls):
210                 for a in dir(scls):
211                         if a == 'final_init':
212                                 f = lambda self, x=getattr(cls, a), y=getattr(scls, a): (x(self), y(self))
213                                 setattr(cls, a, f)
214                                 continue
215                         if a[0] == '_':
216                                 continue
217                         setattr(cls, a, getattr(scls, a))
218
219 class NetworkListener:
220         logger = logging.getLogger('SocketListener')
221         
222         def __init__(self, server, server_address, address_family = socket.AF_INET6):
223                 self.server = server
224                 self.server_address = server_address
225                 self.address_family = address_family
226                 tryErr(self.setup_socket, server_address, Logger=self.logger, ErrorMsg=server_address)
227         
228         def _makebind_py(self, server_address):
229                 sock = socket.socket(self.address_family, socket.SOCK_STREAM)
230                 sock.setblocking(0)
231                 try:
232                         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
233                 except socket.error:
234                         pass
235                 sock.bind(server_address)
236                 return sock
237         
238         def _makebind_su(self, server_address):
239                 if self.address_family != socket.AF_INET6:
240                         raise NotImplementedError
241                 
242                 from bindservice import bindservice
243                 (node, service) = server_address
244                 if not node: node = ''
245                 if not service: service = ''
246                 fd = bindservice(str(node), str(service))
247                 sock = socket.fromfd(fd, socket.AF_INET6, socket.SOCK_STREAM)
248                 sock.setblocking(0)
249                 return sock
250         
251         def _makebind(self, *a, **ka):
252                 try:
253                         return self._makebind_py(*a, **ka)
254                 except BaseException as e:
255                         try:
256                                 return self._makebind_su(*a, **ka)
257                         except:
258                                 pass
259                         raise
260         
261         def setup_socket(self, server_address):
262                 sock = self._makebind(server_address)
263                 sock.listen(100)
264                 self.server.register_socket(sock.fileno(), self)
265                 self.socket = sock
266         
267         def handle_read(self):
268                 server = self.server
269                 conn, addr = self.socket.accept()
270                 if server.rejecting:
271                         conn.close()
272                         return
273                 conn.setblocking(False)
274                 h = server.RequestHandlerClass(server, conn, addr)
275         
276         def handle_error(self):
277                 # Ignore errors... like socket closing on the queue
278                 pass
279
280 class _Waker:
281         def __init__(self, server, fd):
282                 self.server = server
283                 self.fd = fd
284                 self.logger = logging.getLogger('Waker for %s' % (server.__class__.__name__,))
285         
286         def handle_read(self):
287                 data = os.read(self.fd, 1)
288                 if not data:
289                         self.logger.error('Got EOF on socket')
290                 self.logger.debug('Read wakeup')
291
292 class AsyncSocketServer:
293         logger = logging.getLogger('SocketServer')
294         
295         waker = False
296         schMT = False
297         
298         def __init__(self, RequestHandlerClass):
299                 if not hasattr(self, 'ServerName'):
300                         self.ServerName = 'Eloipool'
301                 
302                 self.RequestHandlerClass = RequestHandlerClass
303                 
304                 self.running = False
305                 self.keepgoing = True
306                 self.rejecting = False
307                 self.lastidle = 0
308                 
309                 self._epoll = select.epoll()
310                 self._fd = {}
311                 self.connections = {}
312                 
313                 self._sch = ScheduleDict()
314                 self._schEH = {}
315                 if self.schMT:
316                         self._schLock = threading.Lock()
317                 else:
318                         self._schLock = WithNoop
319                 
320                 self.TrustedForwarders = ()
321                 
322                 if self.waker:
323                         (r, w) = os.pipe()
324                         o = _Waker(self, r)
325                         self.register_socket(r, o)
326                         self.waker = w
327         
328         def register_socket(self, fd, o, eventmask = EPOLL_READ):
329                 self._epoll.register(fd, eventmask)
330                 self._fd[fd] = o
331         
332         def register_socket_m(self, fd, eventmask):
333                 try:
334                         self._epoll.modify(fd, eventmask)
335                 except IOError:
336                         raise socket.error
337         
338         def unregister_socket(self, fd):
339                 del self._fd[fd]
340                 try:
341                         self._epoll.unregister(fd)
342                 except IOError:
343                         raise socket.error
344         
345         def schedule(self, task, startTime, errHandler=None):
346                 with self._schLock:
347                         self._sch[task] = startTime
348                         if errHandler:
349                                 self._schEH[id(task)] = errHandler
350                 return task
351         
352         def rmSchedule(self, task):
353                 with self._schLock:
354                         del self._sch[task]
355                         k = id(task)
356                         if k in self._schEH:
357                                 del self._schEH[k]
358         
359         def pre_schedule(self):
360                 pass
361         
362         def wakeup(self):
363                 if not self.waker:
364                         raise NotImplementedError('Class `%s\' did not enable waker' % (self.__class__.__name__))
365                 os.write(self.waker, b'\1')  # to break out of the epoll
366         
367         def final_init(self):
368                 pass
369         
370         def boot_all(self):
371                 conns = tuple(self.connections.values())
372                 for c in conns:
373                         tryErr(lambda: c.boot())
374         
375         def serve_forever(self):
376                 self.running = True
377                 self.final_init()
378                 while self.keepgoing:
379                         self.doing = 'pre-schedule'
380                         self.pre_schedule()
381                         self.doing = 'schedule'
382                         timeNow = time()
383                         if len(self._sch):
384                                 while True:
385                                         with self._schLock:
386                                                 if not len(self._sch):
387                                                         timeout = -1
388                                                         break
389                                                 timeNext = self._sch.nextTime()
390                                                 if timeNow < timeNext:
391                                                         timeout = timeNext - timeNow
392                                                         break
393                                                 f = self._sch.shift()
394                                         k = id(f)
395                                         EH = None
396                                         if k in self._schEH:
397                                                 EH = self._schEH[k]
398                                                 del self._schEH[k]
399                                         try:
400                                                 f()
401                                         except socket.error:
402                                                 if EH: tryErr(EH.handle_error)
403                                         except:
404                                                 self.logger.error(traceback.format_exc())
405                                                 if EH: tryErr(EH.handle_close)
406                         else:
407                                 timeout = -1
408                         if self.lastidle < timeNow - 1:
409                                 timeout = 0
410                         elif timeout < 0 or timeout > 1:
411                                 timeout = 1
412                         
413                         self.doing = 'poll'
414                         try:
415                                 events = self._epoll.poll(timeout=timeout)
416                         except (IOError, select.error):
417                                 continue
418                         except:
419                                 self.logger.error(traceback.format_exc())
420                                 continue
421                         self.doing = 'events'
422                         if not events:
423                                 self.lastidle = time()
424                         for (fd, e) in events:
425                                 o = self._fd.get(fd)
426                                 if o is None: continue
427                                 self.lastHandler = o
428                                 try:
429                                         if e & EPOLL_READ:
430                                                 o.handle_read()
431                                         if e & EPOLL_WRITE:
432                                                 o.handle_write()
433                                 except socket.error:
434                                         tryErr(o.handle_error)
435                                 except:
436                                         self.logger.error(traceback.format_exc())
437                                         tryErr(o.handle_error)
438                 self.doing = None
439                 self.running = False