Bugfix: Check fd still has a handler object, in case it has been destroyed in the...
[bitcoin:eloipool.git] / jsonrpcserver.py
1 import asynchat
2 from base64 import b64decode
3 from binascii import a2b_hex, b2a_hex
4 from datetime import datetime
5 from email.utils import formatdate
6 import json
7 import logging
8 try:
9         import midstate
10         assert midstate.SHA256(b'This is just a test, ignore it. I am making it over 64-bytes long.')[:8] == (0x755f1a94, 0x999b270c, 0xf358c014, 0xfd39caeb, 0x0dcc9ebc, 0x4694cd1a, 0x8e95678e, 0x75fac450)
11 except:
12         midstate = None
13 import os
14 import re
15 import select
16 import socket
17 from struct import pack
18 import threading
19 from time import mktime, time, sleep
20 import traceback
21 from util import RejectedShare, swap32, tryErr
22
23 class WithinLongpoll(BaseException):
24         pass
25
26 EPOLL_READ = select.EPOLLIN | select.EPOLLPRI | select.EPOLLERR | select.EPOLLHUP
27 EPOLL_WRITE = select.EPOLLOUT
28
29 # TODO: keepalive/close
30 _CheckForDupesHACK = {}
31 class JSONRPCHandler:
32         HTTPStatus = {
33                 200: 'OK',
34                 401: 'Unauthorized',
35                 404: 'Not Found',
36                 405: 'Method Not Allowed',
37                 500: 'Internal Server Error',
38         }
39         
40         logger = logging.getLogger('JSONRPCHandler')
41         
42         ac_in_buffer_size = 4096
43         ac_out_buffer_size = 4096
44         
45         def sendReply(self, status=200, body=b'', headers=None):
46                 buf = "HTTP/1.1 %d %s\r\n" % (status, self.HTTPStatus.get(status, 'Eligius'))
47                 headers = dict(headers) if headers else {}
48                 headers['Date'] = formatdate(timeval=mktime(datetime.now().timetuple()), localtime=False, usegmt=True)
49                 if body is None:
50                         headers.setdefault('Transfer-Encoding', 'chunked')
51                         body = b''
52                 else:
53                         headers['Content-Length'] = len(body)
54                 if status == 200:
55                         headers.setdefault('Content-Type', 'application/json')
56                         headers.setdefault('X-Long-Polling', '/LP')
57                         headers.setdefault('X-Roll-NTime', 'expire=120')
58                 for k, v in headers.items():
59                         if v is None: continue
60                         buf += "%s: %s\r\n" % (k, v)
61                 buf += "\r\n"
62                 buf = buf.encode('utf8')
63                 buf += body
64                 self.push(buf)
65         
66         def doError(self, reason = ''):
67                 return self.sendReply(500, reason.encode('utf8'))
68         
69         def doHeader_authorization(self, value):
70                 value = value.split(b' ')
71                 if len(value) != 2 or value[0] != b'Basic':
72                         return self.doError('Bad Authorization header')
73                 value = b64decode(value[1])
74                 value = value.split(b':')[0]
75                 self.Username = value.decode('utf8')
76         
77         def doHeader_content_length(self, value):
78                 self.CL = int(value)
79         
80         def doHeader_x_minimum_wait(self, value):
81                 self.reqinfo['MinWait'] = int(value)
82         
83         def doHeader_x_mining_extensions(self, value):
84                 self.extensions = value.decode('ascii').lower().split(' ')
85         
86         def doAuthenticate(self):
87                 self.sendReply(401, headers={'WWW-Authenticate': 'Basic realm="Eligius"'})
88         
89         def doLongpoll(self):
90                 self.sendReply(200, body=None)
91                 self.push(b"1\r\n{\r\n")
92                 waitTime = self.reqinfo.get('MinWait', 15)  # TODO: make default configurable
93                 timeNow = time()
94                 self.waitTime = waitTime + timeNow
95                 
96                 totfromme = self.LPTrack()
97                 with self.server._LPLock:
98                         self.server._LPClients[id(self)] = self
99                         self.server.schedule(self._chunkedKA, timeNow + 45)
100                         self.logger.debug("New LP client; %d total; %d from %s" % (len(self.server._LPClients), totfromme, self.addr[0]))
101                 
102                 raise WithinLongpoll
103         
104         def _chunkedKA(self):
105                 # Keepalive via chunked transfer encoding
106                 self.push(b"1\r\n \r\n")
107                 self.server.schedule(self._chunkedKA, time() + 45)
108         
109         def LPTrack(self):
110                 myip = self.addr[0]
111                 if myip not in self.server.LPTracking:
112                         self.server.LPTracking[myip] = 0
113                 self.server.LPTracking[myip] += 1
114                 return self.server.LPTracking[myip]
115         
116         def LPUntrack(self):
117                 self.server.LPTracking[self.addr[0]] -= 1
118         
119         def cleanupLP(self):
120                 # Called when the connection is closed
121                 with self.server._LPLock:
122                         try:
123                                 del self.server._LPClients[id(self)]
124                                 self.LPUntrack()
125                         except KeyError:
126                                 pass
127                 self.server.rmSchedule(self._chunkedKA, self.wakeLongpoll)
128         
129         def wakeLongpoll(self):
130                 now = time()
131                 if now < self.waitTime:
132                         self.server.schedule(self.wakeLongpoll, self.waitTime)
133                         return
134                 
135                 self.LPUntrack()
136                 
137                 self.server.rmSchedule(self._chunkedKA)
138                 
139                 rv = self.doJSON_getwork()
140                 rv = {'id': 1, 'error': None, 'result': rv}
141                 rv = json.dumps(rv)
142                 rv = rv.encode('utf8')
143                 rv = rv[1:]  # strip the '{' we already sent
144                 self.push(('%x' % len(rv)).encode('utf8') + b"\r\n" + rv + b"\r\n0\r\n\r\n")
145                 
146                 self.reset_request()
147         
148         def doJSON(self, data):
149                 # TODO: handle JSON errors
150                 data = data.decode('utf8')
151                 data = json.loads(data)
152                 method = 'doJSON_' + str(data['method']).lower()
153                 if not hasattr(self, method):
154                         return self.doError('No such method')
155                 # TODO: handle errors as JSON-RPC
156                 self._JSONHeaders = {}
157                 rv = getattr(self, method)(*tuple(data.get('params', ())))
158                 if rv is None:
159                         return
160                 rv = {'id': data['id'], 'error': None, 'result': rv}
161                 rv = json.dumps(rv)
162                 rv = rv.encode('utf8')
163                 return self.sendReply(200, rv, headers=self._JSONHeaders)
164         
165         getwork_rv_template = {
166                 'data': '000000800000000000000000000000000000000000000000000000000000000000000000000000000000000080020000',
167                 'target': 'ffffffffffffffffffffffffffffffffffffffffffffffffffffffff00000000',
168                 'hash1': '00000000000000000000000000000000000000000000000000000000000000000000008000000000000000000000000000000000000000000000000000010000',
169         }
170         def doJSON_getwork(self, data=None):
171                 if not data is None:
172                         return self.doJSON_submitwork(data)
173                 rv = dict(self.getwork_rv_template)
174                 hdr = self.server.getBlockHeader(self.Username)
175                 
176                 # FIXME: this assumption breaks with internal rollntime
177                 # NOTE: noncerange needs to set nonce to start value at least
178                 global _CheckForDupesHACK
179                 uhdr = hdr[:68] + hdr[72:]
180                 if uhdr in _CheckForDupesHACK:
181                         raise self.server.RaiseRedFlags(RuntimeError('issuing duplicate work'))
182                 _CheckForDupesHACK[uhdr] = None
183                 
184                 data = b2a_hex(swap32(hdr)).decode('utf8') + rv['data']
185                 # TODO: endian shuffle etc
186                 rv['data'] = data
187                 if midstate and 'midstate' not in self.extensions:
188                         h = midstate.SHA256(hdr)[:8]
189                         rv['midstate'] = b2a_hex(pack('<LLLLLLLL', *h)).decode('ascii')
190                 return rv
191         
192         def doJSON_submitwork(self, datax):
193                 data = swap32(a2b_hex(datax))[:80]
194                 share = {
195                         'data': data,
196                         '_origdata' : datax,
197                         'username': self.Username,
198                         'remoteHost': self.addr[0],
199                 }
200                 try:
201                         self.server.receiveShare(share)
202                 except RejectedShare as rej:
203                         self._JSONHeaders['X-Reject-Reason'] = str(rej)
204                         return False
205                 return True
206         
207         def doJSON_setworkaux(self, k, hexv = None):
208                 if self.Username != self.server.SecretUser:
209                         self.doAuthenticate()
210                         return None
211                 if hexv:
212                         self.server.aux[k] = a2b_hex(hexv)
213                 else:
214                         del self.server.aux[k]
215                 return True
216         
217         def handle_close(self):
218                 self.cleanupLP()
219                 self.close()
220         
221         def handle_request(self):
222                 if not self.Username:
223                         return self.doAuthenticate()
224                 if not self.method in (b'GET', b'POST'):
225                         return self.sendReply(405)
226                 if not self.path in (b'/', b'/LP', b'/LP/'):
227                         return self.sendReply(404)
228                 try:
229                         if self.path[:3] == b'/LP':
230                                 return self.doLongpoll()
231                         data = b''.join(self.incoming)
232                         return self.doJSON(data)
233                 except socket.error:
234                         raise
235                 except WithinLongpoll:
236                         raise
237                 except:
238                         self.logger.error(traceback.format_exc())
239                         return self.doError('uncaught error')
240         
241         def parse_headers(self, hs):
242                 hs = re.split(br'\r?\n', hs)
243                 data = hs.pop(0).split(b' ')
244                 if len(data) < 2:
245                         self.close()
246                         return
247                 self.method = data[0]
248                 self.path = data[1]
249                 self.CL = None
250                 self.extensions = []
251                 self.Username = None
252                 self.reqinfo = {}
253                 while True:
254                         try:
255                                 data = hs.pop(0)
256                         except IndexError:
257                                 break
258                         data = tuple(map(lambda a: a.strip(), data.split(b':', 1)))
259                         method = 'doHeader_' + data[0].decode('ascii').lower()
260                         if hasattr(self, method):
261                                 getattr(self, method)(data[1])
262         
263         def found_terminator(self):
264                 if self.reading_headers:
265                         self.reading_headers = False
266                         self.parse_headers(b"".join(self.incoming))
267                         self.incoming = []
268                         if self.CL:
269                                 self.set_terminator(self.CL)
270                                 return
271                 
272                 self.set_terminator(None)
273                 try:
274                         self.handle_request()
275                         self.reset_request()
276                 except WithinLongpoll:
277                         pass
278         
279         def handle_error(self):
280                 self.logger.error(traceback.format_exc())
281                 self.handle_close()
282         
283         get_terminator = asynchat.async_chat.get_terminator
284         set_terminator = asynchat.async_chat.set_terminator
285         
286         def handle_read (self):
287                 try:
288                         data = self.recv (self.ac_in_buffer_size)
289                 except socket.error as why:
290                         self.handle_error()
291                         return
292                 
293                 if isinstance(data, str) and self.use_encoding:
294                         data = bytes(str, self.encoding)
295                 self.ac_in_buffer = self.ac_in_buffer + data
296                 
297                 # Continue to search for self.terminator in self.ac_in_buffer,
298                 # while calling self.collect_incoming_data.  The while loop
299                 # is necessary because we might read several data+terminator
300                 # combos with a single recv(4096).
301                 
302                 while self.ac_in_buffer:
303                         lb = len(self.ac_in_buffer)
304                         terminator = self.get_terminator()
305                         if not terminator:
306                                 # no terminator, collect it all
307                                 self.collect_incoming_data (self.ac_in_buffer)
308                                 self.ac_in_buffer = b''
309                         elif isinstance(terminator, int):
310                                 # numeric terminator
311                                 n = terminator
312                                 if lb < n:
313                                         self.collect_incoming_data (self.ac_in_buffer)
314                                         self.ac_in_buffer = b''
315                                         self.terminator = self.terminator - lb
316                                 else:
317                                         self.collect_incoming_data (self.ac_in_buffer[:n])
318                                         self.ac_in_buffer = self.ac_in_buffer[n:]
319                                         self.terminator = 0
320                                         self.found_terminator()
321                         else:
322                                 # 3 cases:
323                                 # 1) end of buffer matches terminator exactly:
324                                 #    collect data, transition
325                                 # 2) end of buffer matches some prefix:
326                                 #    collect data to the prefix
327                                 # 3) end of buffer does not match any prefix:
328                                 #    collect data
329                                 # NOTE: this supports multiple different terminators, but
330                                 #       NOT ones that are prefixes of others...
331                                 if isinstance(self.ac_in_buffer, type(terminator)):
332                                         terminator = (terminator,)
333                                 termidx = tuple(map(self.ac_in_buffer.find, terminator))
334                                 try:
335                                         index = min(x for x in termidx if x >= 0)
336                                 except ValueError:
337                                         index = -1
338                                 if index != -1:
339                                         # we found the terminator
340                                         if index > 0:
341                                                 # don't bother reporting the empty string (source of subtle bugs)
342                                                 self.collect_incoming_data (self.ac_in_buffer[:index])
343                                         specific_terminator = terminator[termidx.index(index)]
344                                         terminator_len = len(specific_terminator)
345                                         self.ac_in_buffer = self.ac_in_buffer[index+terminator_len:]
346                                         # This does the Right Thing if the terminator is changed here.
347                                         self.found_terminator()
348                                 else:
349                                         # check for a prefix of the terminator
350                                         termidx = tuple(map(lambda a: asynchat.find_prefix_at_end (self.ac_in_buffer, a), terminator))
351                                         index = min(x for x in termidx if x >= 0)
352                                         if index:
353                                                 if index != lb:
354                                                         # we found a prefix, collect up to the prefix
355                                                         self.collect_incoming_data (self.ac_in_buffer[:-index])
356                                                         self.ac_in_buffer = self.ac_in_buffer[-index:]
357                                                 break
358                                         else:
359                                                 # no prefix, collect it all
360                                                 self.collect_incoming_data (self.ac_in_buffer)
361                                                 self.ac_in_buffer = b''
362         
363         def reset_request(self):
364                 self.incoming = []
365                 self.set_terminator( (b"\n\n", b"\r\n\r\n") )
366                 self.reading_headers = True
367         
368         def collect_incoming_data(self, data):
369                 asynchat.async_chat._collect_incoming_data(self, data)
370         
371         def push(self, data):
372                 self.wbuf += data
373                 self.server.register_socket_m(self.socket, EPOLL_READ | EPOLL_WRITE)
374         
375         def handle_write(self):
376                 bs = self.socket.send(self.wbuf)
377                 self.wbuf = self.wbuf[bs:]
378                 if not len(self.wbuf):
379                         self.server.register_socket_m(self.socket, EPOLL_READ)
380         
381         recv = asynchat.async_chat.recv
382         
383         def close(self):
384                 self.server.unregister_socket(self.socket)
385                 self.socket.close()
386         
387         def __init__(self, server, sock, addr):
388                 self.ac_in_buffer = b''
389                 self.wbuf = b''
390                 self.server = server
391                 self.socket = sock
392                 self.addr = addr
393                 self.reset_request()
394                 server.register_socket(sock, self)
395         
396 setattr(JSONRPCHandler, 'doHeader_content-length', JSONRPCHandler.doHeader_content_length);
397 setattr(JSONRPCHandler, 'doHeader_x-minimum-wait', JSONRPCHandler.doHeader_x_minimum_wait);
398 setattr(JSONRPCHandler, 'doHeader_x-mining-extensions', JSONRPCHandler.doHeader_x_mining_extensions);
399
400 class JSONRPCServer:
401         def __init__(self, server_address, RequestHandlerClass=JSONRPCHandler):
402                 self.logger = logging.getLogger('JSONRPCServer')
403                 
404                 self.SecretUser = None
405                 
406                 self._epoll = select.epoll()
407                 self._fd = {}
408                 
409                 self._schLock = threading.RLock()
410                 self._sch = []
411                 
412                 self._LPClients = {}
413                 self._LPLock = threading.RLock()
414                 self._LPWaitTime = time() + 15
415                 self._LPILock = threading.Lock()
416                 self._LPI = False
417                 self._LPWLock = threading.Lock()
418                 
419                 self.LPTracking = {}
420                 
421                 if server_address:
422                         self.setup_socket(server_address)
423         
424         def setup_socket(self, server_address):
425                 sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
426                 sock.setblocking(0)
427                 try:
428                         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
429                 except socket.error:
430                         pass
431                 sock.bind(server_address)
432                 sock.listen(100)
433                 self.register_socket(sock, self)
434                 self.socket = sock
435         
436         def register_socket(self, sock, o, eventmask = EPOLL_READ):
437                 fd = sock.fileno()
438                 self._epoll.register(fd, eventmask)
439                 self._fd[fd] = o
440         
441         def register_socket_m(self, sock, eventmask):
442                 fd = sock.fileno()
443                 self._epoll.modify(fd, eventmask)
444         
445         def unregister_socket(self, sock):
446                 fd = sock.fileno()
447                 self._epoll.unregister(fd)
448                 del self._fd[fd]
449         
450         def handle_read(self):
451                 conn, addr = self.socket.accept()
452                 h = JSONRPCHandler(self, conn, addr)
453         
454         def schedule(self, task, startTime):
455                 with self._schLock:
456                         IL = len(self._sch)
457                         while IL and startTime < self._sch[IL-1][0]:
458                                 IL -= 1
459                         self._sch.insert(IL, (startTime, task))
460         
461         def rmSchedule(self, *tasks):
462                 tasks = list(tasks)
463                 with self._schLock:
464                         i = 0
465                         SL = len(self._sch)
466                         while i < SL:
467                                 if self._sch[i][1] in tasks:
468                                         tasks.remove(self._sch[i][1])
469                                         del self._sch[i]
470                                         if not tasks:
471                                                 break
472                                         SL -= 1
473                                 else:
474                                         i += 1
475         
476         def serve_forever(self):
477                 while True:
478                         with self._schLock:
479                                 if len(self._sch):
480                                         timeNow = time()
481                                         while True:
482                                                 timeNext = self._sch[0][0]
483                                                 if timeNow < timeNext:
484                                                         timeout = timeNext - timeNow
485                                                         break
486                                                 f = self._sch.pop(0)[1]
487                                                 f()
488                                                 if not len(self._sch):
489                                                         timeout = -1
490                                                         break
491                                 else:
492                                         timeout = -1
493                         try:
494                                 events = self._epoll.poll(timeout=timeout)
495                         except select.error:
496                                 continue
497                         for (fd, e) in events:
498                                 o = self._fd.get(fd)
499                                 if o is None: continue
500                                 if e & EPOLL_READ:
501                                         tryErr(self.logger, o.handle_read)
502                                 if e & EPOLL_WRITE:
503                                         tryErr(self.logger, o.handle_write)
504         
505         def wakeLongpoll(self):
506                 self.logger.debug("(LPILock)")
507                 with self._LPILock:
508                         if self._LPI:
509                                 self.logger.info('Ignoring longpoll attempt while another is waiting')
510                                 return
511                         self._LPI = True
512                 
513                 th = threading.Thread(target=self._LPthread)
514                 th.daemon = True
515                 th.start()
516         
517         def _LPthread(self):
518                 self.logger.debug("(LPWLock)")
519                 with self._LPWLock:
520                         now = time()
521                         if self._LPWaitTime > now:
522                                 delay = self._LPWaitTime - now
523                                 self.logger.info('Waiting %.3g seconds to longpoll' % (delay,))
524                                 sleep(delay)
525                         
526                         self._LPI = False
527                         
528                         self._actualLP()
529         
530         def _actualLP(self):
531                 self.logger.debug("(LPLock)")
532                 with self._LPLock:
533                         C = tuple(self._LPClients.values())
534                         self._LPClients = {}
535                         if not C:
536                                 self.logger.info('Nobody to longpoll')
537                                 return
538                         OC = len(C)
539                         self.logger.debug("%d clients to wake up..." % (OC,))
540                         
541                         now = time()
542                         
543                         for ic in C:
544                                 ic.wakeLongpoll()
545                         
546                         self._LPWaitTime = time()
547                         self.logger.info('Longpoll woke up %d clients in %.3f seconds' % (OC, self._LPWaitTime - now))
548                         self._LPWaitTime += 5  # TODO: make configurable: minimum time between longpolls