Bugfix: Clear SocketHandler._Task when removing it
[bitcoin:eloipool.git] / jsonrpcserver.py
1 # Eloipool - Python Bitcoin pool server
2 # Copyright (C) 2011-2012  Luke Dashjr <luke-jr+eloipool@utopios.org>
3 #
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as
6 # published by the Free Software Foundation, either version 3 of the
7 # License, or (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16
17 import asynchat
18 from base64 import b64decode
19 from binascii import a2b_hex, b2a_hex
20 from datetime import datetime
21 from email.utils import formatdate
22 import json
23 import logging
24 try:
25         import midstate
26         assert midstate.SHA256(b'This is just a test, ignore it. I am making it over 64-bytes long.')[:8] == (0x755f1a94, 0x999b270c, 0xf358c014, 0xfd39caeb, 0x0dcc9ebc, 0x4694cd1a, 0x8e95678e, 0x75fac450)
27 except:
28         logging.getLogger('jsonrpcserver').warning('Error importing \'midstate\' module; work will not provide midstates')
29         midstate = None
30 import os
31 import re
32 import select
33 import socket
34 from struct import pack
35 import threading
36 from time import mktime, time, sleep
37 import traceback
38 from util import RejectedShare, ScheduleDict, swap32, tryErr
39
40 class WithinLongpoll(BaseException):
41         pass
42
43 EPOLL_READ = select.EPOLLIN | select.EPOLLPRI | select.EPOLLERR | select.EPOLLHUP
44 EPOLL_WRITE = select.EPOLLOUT
45
46 # TODO: keepalive/close
47 _CheckForDupesHACK = {}
48 class JSONRPCHandler:
49         HTTPStatus = {
50                 200: 'OK',
51                 401: 'Unauthorized',
52                 404: 'Not Found',
53                 405: 'Method Not Allowed',
54                 500: 'Internal Server Error',
55         }
56         
57         LPHeaders = {
58                 'X-Long-Polling': None,
59         }
60         
61         logger = logging.getLogger('JSONRPCHandler')
62         
63         ac_in_buffer_size = 4096
64         ac_out_buffer_size = 4096
65         
66         def sendReply(self, status=200, body=b'', headers=None):
67                 buf = "HTTP/1.1 %d %s\r\n" % (status, self.HTTPStatus.get(status, 'Eligius'))
68                 headers = dict(headers) if headers else {}
69                 headers['Date'] = formatdate(timeval=mktime(datetime.now().timetuple()), localtime=False, usegmt=True)
70                 headers.setdefault('Server', 'Eloipool')
71                 if body is None:
72                         headers.setdefault('Transfer-Encoding', 'chunked')
73                         body = b''
74                 else:
75                         headers['Content-Length'] = len(body)
76                 if status == 200:
77                         headers.setdefault('Content-Type', 'application/json')
78                         headers.setdefault('X-Long-Polling', '/LP')
79                         headers.setdefault('X-Roll-NTime', 'expire=120')
80                 for k, v in headers.items():
81                         if v is None: continue
82                         buf += "%s: %s\r\n" % (k, v)
83                 buf += "\r\n"
84                 buf = buf.encode('utf8')
85                 buf += body
86                 self.push(buf)
87         
88         def doError(self, reason = ''):
89                 return self.sendReply(500, reason.encode('utf8'))
90         
91         def doHeader_authorization(self, value):
92                 value = value.split(b' ')
93                 if len(value) != 2 or value[0] != b'Basic':
94                         return self.doError('Bad Authorization header')
95                 value = b64decode(value[1])
96                 value = value.split(b':')[0]
97                 self.Username = value.decode('utf8')
98         
99         def doHeader_content_length(self, value):
100                 self.CL = int(value)
101         
102         def doHeader_user_agent(self, value):
103                 self.reqinfo['UA'] = value
104                 quirks = self.quirks
105                 try:
106                         if value[:9] == b'phoenix/v':
107                                 v = tuple(map(int, value[9:].split(b'.')))
108                                 if v[0] < 2 and v[1] < 8 and v[2] < 1:
109                                         quirks['NELH'] = None
110                 except:
111                         pass
112                 self.quirks = quirks
113         
114         def doHeader_x_minimum_wait(self, value):
115                 self.reqinfo['MinWait'] = int(value)
116         
117         def doHeader_x_mining_extensions(self, value):
118                 self.extensions = value.decode('ascii').lower().split(' ')
119         
120         def doAuthenticate(self):
121                 self.sendReply(401, headers={'WWW-Authenticate': 'Basic realm="Eligius"'})
122         
123         def doLongpoll(self):
124                 timeNow = time()
125                 
126                 self._LP = True
127                 if 'NELH' not in self.quirks:
128                         # [NOT No] Early Longpoll Headers
129                         self.sendReply(200, body=None, headers=self.LPHeaders)
130                         self.push(b"1\r\n{\r\n")
131                         self.changeTask(self._chunkedKA, timeNow + 45)
132                 else:
133                         self.changeTask(None)
134                 
135                 waitTime = self.reqinfo.get('MinWait', 15)  # TODO: make default configurable
136                 self.waitTime = waitTime + timeNow
137                 
138                 totfromme = self.LPTrack()
139                 self.server._LPClients[id(self)] = self
140                 self.logger.debug("New LP client; %d total; %d from %s" % (len(self.server._LPClients), totfromme, self.addr[0]))
141                 
142                 raise WithinLongpoll
143         
144         def _chunkedKA(self):
145                 # Keepalive via chunked transfer encoding
146                 self.push(b"1\r\n \r\n")
147                 self.changeTask(self._chunkedKA, time() + 45)
148         
149         def LPTrack(self):
150                 myip = self.addr[0]
151                 if myip not in self.server.LPTracking:
152                         self.server.LPTracking[myip] = 0
153                 self.server.LPTracking[myip] += 1
154                 return self.server.LPTracking[myip]
155         
156         def LPUntrack(self):
157                 self.server.LPTracking[self.addr[0]] -= 1
158         
159         def cleanupLP(self):
160                 # Called when the connection is closed
161                 if not self._LP:
162                         return
163                 self.changeTask(None)
164                 try:
165                         del self.server._LPClients[id(self)]
166                 except KeyError:
167                         pass
168                 self.LPUntrack()
169         
170         def wakeLongpoll(self):
171                 now = time()
172                 if now < self.waitTime:
173                         self.changeTask(self.wakeLongpoll, self.waitTime)
174                         return
175                 else:
176                         self.changeTask(None)
177                 
178                 self.LPUntrack()
179                 
180                 rv = self.doJSON_getwork()
181                 rv['submitold'] = True
182                 rv = {'id': 1, 'error': None, 'result': rv}
183                 rv = json.dumps(rv)
184                 rv = rv.encode('utf8')
185                 if 'NELH' not in self.quirks:
186                         rv = rv[1:]  # strip the '{' we already sent
187                         self.push(('%x' % len(rv)).encode('utf8') + b"\r\n" + rv + b"\r\n0\r\n\r\n")
188                 else:
189                         self.sendReply(200, body=rv, headers=self.LPHeaders)
190                 
191                 self.reset_request()
192         
193         def doJSON(self, data):
194                 # TODO: handle JSON errors
195                 data = data.decode('utf8')
196                 data = json.loads(data)
197                 method = 'doJSON_' + str(data['method']).lower()
198                 if not hasattr(self, method):
199                         return self.doError('No such method')
200                 # TODO: handle errors as JSON-RPC
201                 self._JSONHeaders = {}
202                 rv = getattr(self, method)(*tuple(data.get('params', ())))
203                 if rv is None:
204                         return
205                 rv = {'id': data['id'], 'error': None, 'result': rv}
206                 rv = json.dumps(rv)
207                 rv = rv.encode('utf8')
208                 return self.sendReply(200, rv, headers=self._JSONHeaders)
209         
210         getwork_rv_template = {
211                 'data': '000000800000000000000000000000000000000000000000000000000000000000000000000000000000000080020000',
212                 'target': 'ffffffffffffffffffffffffffffffffffffffffffffffffffffffff00000000',
213                 'hash1': '00000000000000000000000000000000000000000000000000000000000000000000008000000000000000000000000000000000000000000000000000010000',
214         }
215         def doJSON_getwork(self, data=None):
216                 if not data is None:
217                         return self.doJSON_submitwork(data)
218                 rv = dict(self.getwork_rv_template)
219                 hdr = self.server.getBlockHeader(self.Username)
220                 
221                 # FIXME: this assumption breaks with internal rollntime
222                 # NOTE: noncerange needs to set nonce to start value at least
223                 global _CheckForDupesHACK
224                 uhdr = hdr[:68] + hdr[72:]
225                 if uhdr in _CheckForDupesHACK:
226                         raise self.server.RaiseRedFlags(RuntimeError('issuing duplicate work'))
227                 _CheckForDupesHACK[uhdr] = None
228                 
229                 data = b2a_hex(swap32(hdr)).decode('utf8') + rv['data']
230                 # TODO: endian shuffle etc
231                 rv['data'] = data
232                 if midstate and 'midstate' not in self.extensions:
233                         h = midstate.SHA256(hdr)[:8]
234                         rv['midstate'] = b2a_hex(pack('<LLLLLLLL', *h)).decode('ascii')
235                 return rv
236         
237         def doJSON_submitwork(self, datax):
238                 data = swap32(a2b_hex(datax))[:80]
239                 share = {
240                         'data': data,
241                         '_origdata' : datax,
242                         'username': self.Username,
243                         'remoteHost': self.addr[0],
244                 }
245                 try:
246                         self.server.receiveShare(share)
247                 except RejectedShare as rej:
248                         self._JSONHeaders['X-Reject-Reason'] = str(rej)
249                         return False
250                 return True
251         
252         def doJSON_setworkaux(self, k, hexv = None):
253                 if self.Username != self.server.SecretUser:
254                         self.doAuthenticate()
255                         return None
256                 if hexv:
257                         self.server.aux[k] = a2b_hex(hexv)
258                 else:
259                         del self.server.aux[k]
260                 return True
261         
262         def handle_close(self):
263                 self.cleanupLP()
264                 self.wbuf = None
265                 self.close()
266         
267         def handle_request(self):
268                 if not self.Username:
269                         return self.doAuthenticate()
270                 if not self.method in (b'GET', b'POST'):
271                         return self.sendReply(405)
272                 if not self.path in (b'/', b'/LP', b'/LP/'):
273                         return self.sendReply(404)
274                 try:
275                         if self.path[:3] == b'/LP':
276                                 return self.doLongpoll()
277                         data = b''.join(self.incoming)
278                         return self.doJSON(data)
279                 except socket.error:
280                         raise
281                 except WithinLongpoll:
282                         raise
283                 except:
284                         self.logger.error(traceback.format_exc())
285                         return self.doError('uncaught error')
286         
287         def parse_headers(self, hs):
288                 self.CL = None
289                 self.Username = None
290                 self.method = None
291                 self.path = None
292                 hs = re.split(br'\r?\n', hs)
293                 data = hs.pop(0).split(b' ')
294                 try:
295                         self.method = data[0]
296                         self.path = data[1]
297                 except IndexError:
298                         self.close()
299                         return
300                 self.extensions = []
301                 self.reqinfo = {}
302                 self.quirks = {}
303                 while True:
304                         try:
305                                 data = hs.pop(0)
306                         except IndexError:
307                                 break
308                         data = tuple(map(lambda a: a.strip(), data.split(b':', 1)))
309                         method = 'doHeader_' + data[0].decode('ascii').lower()
310                         if hasattr(self, method):
311                                 getattr(self, method)(data[1])
312         
313         def found_terminator(self):
314                 if self.reading_headers:
315                         inbuf = b"".join(self.incoming)
316                         self.incoming = []
317                         m = re.match(br'^[\r\n]+', inbuf)
318                         if m:
319                                 inbuf = inbuf[len(m.group(0)):]
320                         if not inbuf:
321                                 return
322                         
323                         self.reading_headers = False
324                         self.parse_headers(inbuf)
325                         if self.CL:
326                                 self.set_terminator(self.CL)
327                                 return
328                 
329                 self.set_terminator(None)
330                 try:
331                         self.handle_request()
332                         self.reset_request()
333                 except WithinLongpoll:
334                         pass
335         
336         def handle_error(self):
337                 self.logger.debug(traceback.format_exc())
338                 self.handle_close()
339         
340         get_terminator = asynchat.async_chat.get_terminator
341         set_terminator = asynchat.async_chat.set_terminator
342         
343         def handle_read (self):
344                 try:
345                         data = self.recv (self.ac_in_buffer_size)
346                 except socket.error as why:
347                         self.handle_error()
348                         return
349                 
350                 if self.closeme:
351                         # All input is ignored from sockets we have "closed"
352                         return
353                 
354                 if isinstance(data, str) and self.use_encoding:
355                         data = bytes(str, self.encoding)
356                 self.ac_in_buffer = self.ac_in_buffer + data
357                 
358                 # Continue to search for self.terminator in self.ac_in_buffer,
359                 # while calling self.collect_incoming_data.  The while loop
360                 # is necessary because we might read several data+terminator
361                 # combos with a single recv(4096).
362                 
363                 while self.ac_in_buffer:
364                         lb = len(self.ac_in_buffer)
365                         terminator = self.get_terminator()
366                         if not terminator:
367                                 # no terminator, collect it all
368                                 self.collect_incoming_data (self.ac_in_buffer)
369                                 self.ac_in_buffer = b''
370                         elif isinstance(terminator, int):
371                                 # numeric terminator
372                                 n = terminator
373                                 if lb < n:
374                                         self.collect_incoming_data (self.ac_in_buffer)
375                                         self.ac_in_buffer = b''
376                                         self.terminator = self.terminator - lb
377                                 else:
378                                         self.collect_incoming_data (self.ac_in_buffer[:n])
379                                         self.ac_in_buffer = self.ac_in_buffer[n:]
380                                         self.terminator = 0
381                                         self.found_terminator()
382                         else:
383                                 # 3 cases:
384                                 # 1) end of buffer matches terminator exactly:
385                                 #    collect data, transition
386                                 # 2) end of buffer matches some prefix:
387                                 #    collect data to the prefix
388                                 # 3) end of buffer does not match any prefix:
389                                 #    collect data
390                                 # NOTE: this supports multiple different terminators, but
391                                 #       NOT ones that are prefixes of others...
392                                 if isinstance(self.ac_in_buffer, type(terminator)):
393                                         terminator = (terminator,)
394                                 termidx = tuple(map(self.ac_in_buffer.find, terminator))
395                                 try:
396                                         index = min(x for x in termidx if x >= 0)
397                                 except ValueError:
398                                         index = -1
399                                 if index != -1:
400                                         # we found the terminator
401                                         if index > 0:
402                                                 # don't bother reporting the empty string (source of subtle bugs)
403                                                 self.collect_incoming_data (self.ac_in_buffer[:index])
404                                         specific_terminator = terminator[termidx.index(index)]
405                                         terminator_len = len(specific_terminator)
406                                         self.ac_in_buffer = self.ac_in_buffer[index+terminator_len:]
407                                         # This does the Right Thing if the terminator is changed here.
408                                         self.found_terminator()
409                                 else:
410                                         # check for a prefix of the terminator
411                                         termidx = tuple(map(lambda a: asynchat.find_prefix_at_end (self.ac_in_buffer, a), terminator))
412                                         index = max(termidx)
413                                         if index:
414                                                 if index != lb:
415                                                         # we found a prefix, collect up to the prefix
416                                                         self.collect_incoming_data (self.ac_in_buffer[:-index])
417                                                         self.ac_in_buffer = self.ac_in_buffer[-index:]
418                                                 break
419                                         else:
420                                                 # no prefix, collect it all
421                                                 self.collect_incoming_data (self.ac_in_buffer)
422                                                 self.ac_in_buffer = b''
423         
424         def reset_request(self):
425                 self.incoming = []
426                 self.set_terminator( (b"\n\n", b"\r\n\r\n") )
427                 self.reading_headers = True
428                 self._LP = False
429                 self.changeTask(self.handle_timeout, time() + 15)
430         
431         def collect_incoming_data(self, data):
432                 asynchat.async_chat._collect_incoming_data(self, data)
433         
434         def push(self, data):
435                 self.wbuf += data
436                 self.server.register_socket_m(self.fd, EPOLL_READ | EPOLL_WRITE)
437         
438         def handle_timeout(self):
439                 self.close()
440         
441         def handle_write(self):
442                 if self.wbuf is None:
443                         # Socket was just closed by remote peer
444                         return
445                 bs = self.socket.send(self.wbuf)
446                 self.wbuf = self.wbuf[bs:]
447                 if not len(self.wbuf):
448                         if self.closeme:
449                                 self.close()
450                                 return
451                         self.server.register_socket_m(self.fd, EPOLL_READ)
452         
453         recv = asynchat.async_chat.recv
454         
455         def close(self):
456                 if self.wbuf:
457                         self.closeme = True
458                         return
459                 self.server.unregister_socket(self.fd)
460                 self.changeTask(None)
461                 self.socket.close()
462         
463         def changeTask(self, f, t = None):
464                 tryErr(self.server.rmSchedule, self._Task, IgnoredExceptions=KeyError)
465                 if f:
466                         self._Task = self.server.schedule(f, t, errHandler=self)
467                 else:
468                         self._Task = None
469         
470         def __init__(self, server, sock, addr):
471                 self.ac_in_buffer = b''
472                 self.wbuf = b''
473                 self.closeme = False
474                 self.server = server
475                 self.socket = sock
476                 self.addr = addr
477                 self._Task = None
478                 self.reset_request()
479                 self.fd = sock.fileno()
480                 server.register_socket(self.fd, self)
481                 self.changeTask(self.handle_timeout, time() + 15)
482         
483 setattr(JSONRPCHandler, 'doHeader_content-length', JSONRPCHandler.doHeader_content_length);
484 setattr(JSONRPCHandler, 'doHeader_user-agent', JSONRPCHandler.doHeader_user_agent);
485 setattr(JSONRPCHandler, 'doHeader_x-minimum-wait', JSONRPCHandler.doHeader_x_minimum_wait);
486 setattr(JSONRPCHandler, 'doHeader_x-mining-extensions', JSONRPCHandler.doHeader_x_mining_extensions);
487
488 class JSONRPCListener:
489         logger = logging.getLogger('JSONRPCListener')
490         
491         def __init__(self, server, server_address):
492                 self.server = server
493                 tryErr(self.setup_socket, server_address, Logger=self.logger, ErrorMsg=server_address)
494         
495         def setup_socket(self, server_address):
496                 sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
497                 sock.setblocking(0)
498                 try:
499                         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
500                 except socket.error:
501                         pass
502                 sock.bind(server_address)
503                 sock.listen(100)
504                 self.server.register_socket(sock.fileno(), self)
505                 self.socket = sock
506         
507         def handle_read(self):
508                 server = self.server
509                 conn, addr = self.socket.accept()
510                 h = server.RequestHandlerClass(server, conn, addr)
511
512 class _JSONRPCLongpoll:
513         logger = logging.getLogger('JSONRPCLongpoll')
514         
515         def __init__(self, server, fd):
516                 self.server = server
517                 self.fd = fd
518         
519         def handle_read(self):
520                 # Woken up by longpoll request
521                 data = os.read(self.fd, 1)
522                 if not data:
523                         self.logger.error('Got EOF on socket')
524                 self.logger.debug('Read wakeup on longpoll pipe')
525
526 class JSONRPCServer:
527         def __init__(self, server_address=None, RequestHandlerClass=JSONRPCHandler):
528                 self.logger = logging.getLogger('JSONRPCServer')
529                 
530                 self.RequestHandlerClass = RequestHandlerClass
531                 
532                 self.SecretUser = None
533                 
534                 self._epoll = select.epoll()
535                 self._fd = {}
536                 
537                 self._sch = ScheduleDict()
538                 self._schEH = {}
539                 
540                 self.LPRequest = False
541                 self._LPClients = {}
542                 self._LPWaitTime = time() + 15
543                 (r, w) = os.pipe()
544                 o = _JSONRPCLongpoll(self, r)
545                 self.register_socket(r, o)
546                 self._LPSock = w
547                 
548                 self.LPTracking = {}
549                 
550                 self._lo = []
551                 if server_address:
552                         JSONRPCListener(self, server_address)
553         
554         def register_socket(self, fd, o, eventmask = EPOLL_READ):
555                 self._epoll.register(fd, eventmask)
556                 self._fd[fd] = o
557         
558         def register_socket_m(self, fd, eventmask):
559                 try:
560                         self._epoll.modify(fd, eventmask)
561                 except IOError:
562                         raise socket.error
563         
564         def unregister_socket(self, fd):
565                 del self._fd[fd]
566                 try:
567                         self._epoll.unregister(fd)
568                 except IOError:
569                         raise socket.error
570         
571         def schedule(self, task, startTime, errHandler=None):
572                 self._sch[task] = startTime
573                 if errHandler:
574                         self._schEH[id(task)] = errHandler
575                 return task
576         
577         def rmSchedule(self, task):
578                 del self._sch[task]
579                 k = id(task)
580                 if k in self._schEH:
581                         del self._schEH[k]
582         
583         def serve_forever(self):
584                 while True:
585                         if len(self._sch):
586                                 timeNow = time()
587                                 while True:
588                                         timeNext = self._sch.nextTime()
589                                         if timeNow < timeNext:
590                                                 timeout = timeNext - timeNow
591                                                 break
592                                         f = self._sch.shift()
593                                         k = id(f)
594                                         EH = None
595                                         if k in self._schEH:
596                                                 EH = self._schEH[k]
597                                                 del self._schEH[k]
598                                         try:
599                                                 f()
600                                         except socket.error:
601                                                 if EH: tryErr(EH.handle_error)
602                                         except:
603                                                 self.logger.error(traceback.format_exc())
604                                                 if EH: tryErr(EH.handle_close)
605                                         if not len(self._sch):
606                                                 timeout = -1
607                                                 break
608                         else:
609                                 timeout = -1
610                         
611                         try:
612                                 events = self._epoll.poll(timeout=timeout)
613                         except select.error:
614                                 continue
615                         for (fd, e) in events:
616                                 o = self._fd[fd]
617                                 try:
618                                         if e & EPOLL_READ:
619                                                 o.handle_read()
620                                         if e & EPOLL_WRITE:
621                                                 o.handle_write()
622                                 except socket.error:
623                                         tryErr(o.handle_error)
624                                 except:
625                                         self.logger.error(traceback.format_exc())
626                                         tryErr(o.handle_close)
627                         if self.LPRequest == 1:
628                                 self._LPsch()
629         
630         def wakeLongpoll(self):
631                 if self.LPRequest:
632                         self.logger.info('Ignoring longpoll attempt while another is waiting')
633                         return
634                 self.LPRequest = 1
635                 os.write(self._LPSock, b'\1')  # to break out of the epoll
636         
637         def _LPsch(self):
638                 now = time()
639                 if self._LPWaitTime > now:
640                         delay = self._LPWaitTime - now
641                         self.logger.info('Waiting %.3g seconds to longpoll' % (delay,))
642                         self.schedule(self._actualLP, self._LPWaitTime)
643                         self.LPRequest = 2
644                 else:
645                         self._actualLP()
646         
647         def _actualLP(self):
648                 self.LPRequest = False
649                 C = tuple(self._LPClients.values())
650                 self._LPClients = {}
651                 if not C:
652                         self.logger.info('Nobody to longpoll')
653                         return
654                 OC = len(C)
655                 self.logger.debug("%d clients to wake up..." % (OC,))
656                 
657                 now = time()
658                 
659                 for ic in C:
660                         ic.wakeLongpoll()
661                 
662                 self._LPWaitTime = time()
663                 self.logger.info('Longpoll woke up %d clients in %.3f seconds' % (OC, self._LPWaitTime - now))
664                 self._LPWaitTime += 5  # TODO: make configurable: minimum time between longpolls
665         
666         def TopLPers(self, n = 0x10):
667                 tmp = list(self.LPTracking.keys())
668                 tmp.sort(key=lambda k: self.LPTracking[k])
669                 for jerk in map(lambda k: (k, self.LPTracking[k]), tmp[-n:]):
670                         print(jerk)