Bugfix: Stratum: Replies should not be sent if request id is null
[bitcoin:eloipool.git] / networkserver.py
1 # Eloipool - Python Bitcoin pool server
2 # Copyright (C) 2011-2012  Luke Dashjr <luke-jr+eloipool@utopios.org>
3 #
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as
6 # published by the Free Software Foundation, either version 3 of the
7 # License, or (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16
17 import asynchat
18 import logging
19 import os
20 import select
21 import socket
22 from time import time
23 import traceback
24 from util import ScheduleDict, tryErr
25
26 EPOLL_READ = select.EPOLLIN | select.EPOLLPRI | select.EPOLLERR | select.EPOLLHUP
27 EPOLL_WRITE = select.EPOLLOUT
28
29 class SocketHandler:
30         ac_in_buffer_size = 4096
31         ac_out_buffer_size = 4096
32         
33         def handle_close(self):
34                 self.changeTask(None)
35                 self.wbuf = None
36                 self.close()
37         
38         def handle_error(self):
39                 self.logger.debug(traceback.format_exc())
40                 self.handle_close()
41         
42         def handle_read(self):
43                 try:
44                         data = self.recv (self.ac_in_buffer_size)
45                 except socket.error as why:
46                         self.handle_error()
47                         return
48                 
49                 if self.closeme:
50                         # All input is ignored from sockets we have "closed"
51                         return
52                 
53                 if isinstance(data, str) and self.use_encoding:
54                         data = bytes(str, self.encoding)
55                 self.ac_in_buffer = self.ac_in_buffer + data
56                 
57                 self.server.lastReadbuf = self.ac_in_buffer
58                 
59                 self.handle_readbuf()
60         
61         collect_incoming_data = asynchat.async_chat._collect_incoming_data
62         get_terminator = asynchat.async_chat.get_terminator
63         set_terminator = asynchat.async_chat.set_terminator
64         
65         def handle_readbuf(self):
66                 while self.ac_in_buffer:
67                         lb = len(self.ac_in_buffer)
68                         terminator = self.get_terminator()
69                         if not terminator:
70                                 # no terminator, collect it all
71                                 self.collect_incoming_data (self.ac_in_buffer)
72                                 self.ac_in_buffer = b''
73                         elif isinstance(terminator, int):
74                                 # numeric terminator
75                                 n = terminator
76                                 if lb < n:
77                                         self.collect_incoming_data (self.ac_in_buffer)
78                                         self.ac_in_buffer = b''
79                                         self.terminator = self.terminator - lb
80                                 else:
81                                         self.collect_incoming_data (self.ac_in_buffer[:n])
82                                         self.ac_in_buffer = self.ac_in_buffer[n:]
83                                         self.terminator = 0
84                                         self.found_terminator()
85                         else:
86                                 # 3 cases:
87                                 # 1) end of buffer matches terminator exactly:
88                                 #    collect data, transition
89                                 # 2) end of buffer matches some prefix:
90                                 #    collect data to the prefix
91                                 # 3) end of buffer does not match any prefix:
92                                 #    collect data
93                                 # NOTE: this supports multiple different terminators, but
94                                 #       NOT ones that are prefixes of others...
95                                 if isinstance(self.ac_in_buffer, type(terminator)):
96                                         terminator = (terminator,)
97                                 termidx = tuple(map(self.ac_in_buffer.find, terminator))
98                                 try:
99                                         index = min(x for x in termidx if x >= 0)
100                                 except ValueError:
101                                         index = -1
102                                 if index != -1:
103                                         # we found the terminator
104                                         if index > 0:
105                                                 # don't bother reporting the empty string (source of subtle bugs)
106                                                 self.collect_incoming_data (self.ac_in_buffer[:index])
107                                         specific_terminator = terminator[termidx.index(index)]
108                                         terminator_len = len(specific_terminator)
109                                         self.ac_in_buffer = self.ac_in_buffer[index+terminator_len:]
110                                         # This does the Right Thing if the terminator is changed here.
111                                         self.found_terminator()
112                                 else:
113                                         # check for a prefix of the terminator
114                                         termidx = tuple(map(lambda a: asynchat.find_prefix_at_end (self.ac_in_buffer, a), terminator))
115                                         index = max(termidx)
116                                         if index:
117                                                 if index != lb:
118                                                         # we found a prefix, collect up to the prefix
119                                                         self.collect_incoming_data (self.ac_in_buffer[:-index])
120                                                         self.ac_in_buffer = self.ac_in_buffer[-index:]
121                                                 break
122                                         else:
123                                                 # no prefix, collect it all
124                                                 self.collect_incoming_data (self.ac_in_buffer)
125                                                 self.ac_in_buffer = b''
126         
127         def push(self, data):
128                 if not len(self.wbuf):
129                         # Try to send as much as we can immediately
130                         try:
131                                 bs = self.socket.send(data)
132                         except:
133                                 # Chances are we'll fail later, but anyway...
134                                 bs = 0
135                         data = data[bs:]
136                         if not len(data):
137                                 return
138                 self.wbuf += data
139                 self.server.register_socket_m(self.fd, EPOLL_READ | EPOLL_WRITE)
140         
141         def handle_timeout(self):
142                 self.close()
143         
144         def handle_write(self):
145                 if self.wbuf is None:
146                         # Socket was just closed by remote peer
147                         return
148                 bs = self.socket.send(self.wbuf)
149                 self.wbuf = self.wbuf[bs:]
150                 if not len(self.wbuf):
151                         if self.closeme:
152                                 self.close()
153                                 return
154                         self.server.register_socket_m(self.fd, EPOLL_READ)
155         
156         recv = asynchat.async_chat.recv
157         
158         def close(self):
159                 if self.wbuf:
160                         self.closeme = True
161                         return
162                 if self.fd == -1:
163                         # Already closed
164                         return
165                 self.server.unregister_socket(self.fd)
166                 self.socket.close()
167                 self.fd = -1
168         
169         def changeTask(self, f, t = None):
170                 tryErr(self.server.rmSchedule, self._Task, IgnoredExceptions=KeyError)
171                 if f:
172                         self._Task = self.server.schedule(f, t, errHandler=self)
173         
174         def __init__(self, server, sock, addr):
175                 self.ac_in_buffer = b''
176                 self.incoming = []
177                 self.wbuf = b''
178                 self.closeme = False
179                 self.server = server
180                 self.socket = sock
181                 self.addr = addr
182                 self._Task = None
183                 self.fd = sock.fileno()
184                 server.register_socket(self.fd, self)
185                 self.changeTask(self.handle_timeout, time() + 15)
186         
187         @classmethod
188         def _register(cls, scls):
189                 for a in dir(scls):
190                         if a == 'final_init':
191                                 f = lambda self, x=getattr(cls, a), y=getattr(scls, a): (x(self), y(self))
192                                 setattr(cls, a, f)
193                                 continue
194                         if a[0] == '_':
195                                 continue
196                         setattr(cls, a, getattr(scls, a))
197
198 class NetworkListener:
199         logger = logging.getLogger('SocketListener')
200         
201         def __init__(self, server, server_address, address_family = socket.AF_INET6):
202                 self.server = server
203                 self.server_address = server_address
204                 self.address_family = address_family
205                 tryErr(self.setup_socket, server_address, Logger=self.logger, ErrorMsg=server_address)
206         
207         def _makebind_py(self, server_address):
208                 sock = socket.socket(self.address_family, socket.SOCK_STREAM)
209                 sock.setblocking(0)
210                 try:
211                         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
212                 except socket.error:
213                         pass
214                 sock.bind(server_address)
215                 return sock
216         
217         def _makebind_su(self, server_address):
218                 if self.address_family != socket.AF_INET6:
219                         raise NotImplementedError
220                 
221                 from bindservice import bindservice
222                 (node, service) = server_address
223                 if not node: node = ''
224                 if not service: service = ''
225                 fd = bindservice(str(node), str(service))
226                 sock = socket.fromfd(fd, socket.AF_INET6, socket.SOCK_STREAM)
227                 sock.setblocking(0)
228                 return sock
229         
230         def _makebind(self, *a, **ka):
231                 try:
232                         return self._makebind_py(*a, **ka)
233                 except BaseException as e:
234                         try:
235                                 return self._makebind_su(*a, **ka)
236                         except:
237                                 pass
238                         raise
239         
240         def setup_socket(self, server_address):
241                 sock = self._makebind(server_address)
242                 sock.listen(100)
243                 self.server.register_socket(sock.fileno(), self)
244                 self.socket = sock
245         
246         def handle_read(self):
247                 server = self.server
248                 conn, addr = self.socket.accept()
249                 conn.setblocking(False)
250                 h = server.RequestHandlerClass(server, conn, addr)
251         
252         def handle_error(self):
253                 # Ignore errors... like socket closing on the queue
254                 pass
255
256 class _Waker:
257         def __init__(self, server, fd):
258                 self.server = server
259                 self.fd = fd
260                 self.logger = logging.getLogger('Waker for %s' % (server.__class__.__name__,))
261         
262         def handle_read(self):
263                 data = os.read(self.fd, 1)
264                 if not data:
265                         self.logger.error('Got EOF on socket')
266                 self.logger.debug('Read wakeup')
267
268 class AsyncSocketServer:
269         logger = logging.getLogger('SocketServer')
270         
271         waker = False
272         
273         def __init__(self, RequestHandlerClass):
274                 if not hasattr(self, 'ServerName'):
275                         self.ServerName = 'Eloipool'
276                 
277                 self.RequestHandlerClass = RequestHandlerClass
278                 
279                 self.running = False
280                 self.keepgoing = True
281                 
282                 self._epoll = select.epoll()
283                 self._fd = {}
284                 
285                 self._sch = ScheduleDict()
286                 self._schEH = {}
287                 
288                 self.TrustedForwarders = ()
289                 
290                 if self.waker:
291                         (r, w) = os.pipe()
292                         o = _Waker(self, r)
293                         self.register_socket(r, o)
294                         self.waker = w
295         
296         def register_socket(self, fd, o, eventmask = EPOLL_READ):
297                 self._epoll.register(fd, eventmask)
298                 self._fd[fd] = o
299         
300         def register_socket_m(self, fd, eventmask):
301                 try:
302                         self._epoll.modify(fd, eventmask)
303                 except IOError:
304                         raise socket.error
305         
306         def unregister_socket(self, fd):
307                 del self._fd[fd]
308                 try:
309                         self._epoll.unregister(fd)
310                 except IOError:
311                         raise socket.error
312         
313         def schedule(self, task, startTime, errHandler=None):
314                 self._sch[task] = startTime
315                 if errHandler:
316                         self._schEH[id(task)] = errHandler
317                 return task
318         
319         def rmSchedule(self, task):
320                 del self._sch[task]
321                 k = id(task)
322                 if k in self._schEH:
323                         del self._schEH[k]
324         
325         def pre_schedule(self):
326                 pass
327         
328         def wakeup(self):
329                 if not self.waker:
330                         raise NotImplementedError('Class `%s\' did not enable waker' % (self.__class__.__name__))
331                 os.write(self.waker, b'\1')  # to break out of the epoll
332         
333         def final_init(self):
334                 pass
335         
336         def serve_forever(self):
337                 self.running = True
338                 self.final_init()
339                 while self.keepgoing:
340                         self.doing = 'pre-schedule'
341                         self.pre_schedule()
342                         self.doing = 'schedule'
343                         if len(self._sch):
344                                 timeNow = time()
345                                 while True:
346                                         timeNext = self._sch.nextTime()
347                                         if timeNow < timeNext:
348                                                 timeout = timeNext - timeNow
349                                                 break
350                                         f = self._sch.shift()
351                                         k = id(f)
352                                         EH = None
353                                         if k in self._schEH:
354                                                 EH = self._schEH[k]
355                                                 del self._schEH[k]
356                                         try:
357                                                 f()
358                                         except socket.error:
359                                                 if EH: tryErr(EH.handle_error)
360                                         except:
361                                                 self.logger.error(traceback.format_exc())
362                                                 if EH: tryErr(EH.handle_close)
363                                         if not len(self._sch):
364                                                 timeout = -1
365                                                 break
366                         else:
367                                 timeout = -1
368                         
369                         self.doing = 'poll'
370                         try:
371                                 events = self._epoll.poll(timeout=timeout)
372                         except (IOError, select.error):
373                                 continue
374                         except:
375                                 self.logger.error(traceback.format_exc())
376                         self.doing = 'events'
377                         for (fd, e) in events:
378                                 o = self._fd[fd]
379                                 self.lastHandler = o
380                                 try:
381                                         if e & EPOLL_READ:
382                                                 o.handle_read()
383                                         if e & EPOLL_WRITE:
384                                                 o.handle_write()
385                                 except socket.error:
386                                         tryErr(o.handle_error)
387                                 except:
388                                         self.logger.error(traceback.format_exc())
389                                         tryErr(o.handle_error)
390                 self.doing = None
391                 self.running = False