Bugfix: Refactor sendReply (and derivatives) to use a RequestHandled exception when...
[bitcoin:eloipool.git] / jsonrpcserver.py
1 # Eloipool - Python Bitcoin pool server
2 # Copyright (C) 2011-2012  Luke Dashjr <luke-jr+eloipool@utopios.org>
3 #
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as
6 # published by the Free Software Foundation, either version 3 of the
7 # License, or (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16
17 import asynchat
18 from base64 import b64decode
19 from binascii import a2b_hex, b2a_hex
20 from copy import deepcopy
21 from datetime import datetime
22 from email.utils import formatdate
23 import json
24 import logging
25 try:
26         import midstate
27         assert midstate.SHA256(b'This is just a test, ignore it. I am making it over 64-bytes long.')[:8] == (0x755f1a94, 0x999b270c, 0xf358c014, 0xfd39caeb, 0x0dcc9ebc, 0x4694cd1a, 0x8e95678e, 0x75fac450)
28 except:
29         logging.getLogger('jsonrpcserver').warning('Error importing \'midstate\' module; work will not provide midstates')
30         midstate = None
31 import networkserver
32 import os
33 import re
34 import socket
35 from struct import pack
36 import threading
37 from time import mktime, time, sleep
38 import traceback
39 from util import RejectedShare, swap32
40
41 class WithinLongpoll(BaseException):
42         pass
43
44 class RequestAlreadyHandled(BaseException):
45         pass
46
47 class RequestHandled(RequestAlreadyHandled):
48         pass
49
50 class RequestNotHandled(BaseException):
51         pass
52
53 # TODO: keepalive/close
54 _CheckForDupesHACK = {}
55 class JSONRPCHandler(networkserver.SocketHandler):
56         HTTPStatus = {
57                 200: 'OK',
58                 401: 'Unauthorized',
59                 404: 'Not Found',
60                 405: 'Method Not Allowed',
61                 500: 'Internal Server Error',
62         }
63         
64         LPHeaders = {
65                 'X-Long-Polling': None,
66         }
67         
68         logger = logging.getLogger('JSONRPCHandler')
69         
70         def sendReply(self, status=200, body=b'', headers=None):
71                 if self.replySent:
72                         raise RequestAlreadyHandled
73                 buf = "HTTP/1.1 %d %s\r\n" % (status, self.HTTPStatus.get(status, 'Eligius'))
74                 headers = dict(headers) if headers else {}
75                 headers['Date'] = formatdate(timeval=mktime(datetime.now().timetuple()), localtime=False, usegmt=True)
76                 headers.setdefault('Server', 'Eloipool')
77                 if body is None:
78                         headers.setdefault('Transfer-Encoding', 'chunked')
79                 else:
80                         headers['Content-Length'] = len(body)
81                 if status == 200:
82                         headers.setdefault('Content-Type', 'application/json')
83                         headers.setdefault('X-Long-Polling', '/LP')
84                         headers.setdefault('X-Roll-NTime', 'expire=120')
85                 elif body and body[0] == 123:  # b'{'
86                         headers.setdefault('Content-Type', 'application/json')
87                 for k, v in headers.items():
88                         if v is None: continue
89                         buf += "%s: %s\r\n" % (k, v)
90                 buf += "\r\n"
91                 buf = buf.encode('utf8')
92                 self.replySent = True
93                 if body is None:
94                         self.push(buf)
95                         return
96                 buf += body
97                 self.push(buf)
98                 raise RequestHandled
99         
100         def doError(self, reason = '', code = 100):
101                 reason = json.dumps(reason)
102                 reason = r'{"result":null,"id":null,"error":{"name":"JSONRPCError","code":%d,"message":%s}}' % (code, reason)
103                 return self.sendReply(500, reason.encode('utf8'))
104         
105         def doHeader_authorization(self, value):
106                 value = value.split(b' ')
107                 if len(value) != 2 or value[0] != b'Basic':
108                         return self.doError('Bad Authorization header')
109                 value = b64decode(value[1])
110                 value = value.split(b':')[0]
111                 self.Username = value.decode('utf8')
112         
113         def doHeader_content_length(self, value):
114                 self.CL = int(value)
115         
116         def doHeader_user_agent(self, value):
117                 self.reqinfo['UA'] = value
118                 quirks = self.quirks
119                 try:
120                         if value[:9] == b'phoenix/v':
121                                 v = tuple(map(int, value[9:].split(b'.')))
122                                 if v[0] < 2 and v[1] < 8 and v[2] < 1:
123                                         quirks['NELH'] = None
124                 except:
125                         pass
126                 self.quirks = quirks
127         
128         def doHeader_x_minimum_wait(self, value):
129                 self.reqinfo['MinWait'] = int(value)
130         
131         def doHeader_x_mining_extensions(self, value):
132                 self.extensions = value.decode('ascii').lower().split(' ')
133         
134         def doAuthenticate(self):
135                 self.sendReply(401, headers={'WWW-Authenticate': 'Basic realm="Eligius"'})
136         
137         def doLongpoll(self):
138                 timeNow = time()
139                 
140                 self._LP = True
141                 if 'NELH' not in self.quirks:
142                         # [NOT No] Early Longpoll Headers
143                         self.sendReply(200, body=None, headers=self.LPHeaders)
144                         self.push(b"1\r\n{\r\n")
145                         self.changeTask(self._chunkedKA, timeNow + 45)
146                 else:
147                         self.changeTask(None)
148                 
149                 waitTime = self.reqinfo.get('MinWait', 15)  # TODO: make default configurable
150                 self.waitTime = waitTime + timeNow
151                 
152                 totfromme = self.LPTrack()
153                 self.server._LPClients[id(self)] = self
154                 self.logger.debug("New LP client; %d total; %d from %s" % (len(self.server._LPClients), totfromme, self.addr[0]))
155                 
156                 raise WithinLongpoll
157         
158         def _chunkedKA(self):
159                 # Keepalive via chunked transfer encoding
160                 self.push(b"1\r\n \r\n")
161                 self.changeTask(self._chunkedKA, time() + 45)
162         
163         def LPTrack(self):
164                 myip = self.addr[0]
165                 if myip not in self.server.LPTracking:
166                         self.server.LPTracking[myip] = 0
167                 self.server.LPTracking[myip] += 1
168                 return self.server.LPTracking[myip]
169         
170         def LPUntrack(self):
171                 self.server.LPTracking[self.addr[0]] -= 1
172         
173         def cleanupLP(self):
174                 # Called when the connection is closed
175                 if not self._LP:
176                         return
177                 self.changeTask(None)
178                 try:
179                         del self.server._LPClients[id(self)]
180                 except KeyError:
181                         pass
182                 self.LPUntrack()
183         
184         def wakeLongpoll(self):
185                 now = time()
186                 if now < self.waitTime:
187                         self.changeTask(self.wakeLongpoll, self.waitTime)
188                         return
189                 else:
190                         self.changeTask(None)
191                 
192                 self.LPUntrack()
193                 
194                 rv = self.doJSON_getwork()
195                 rv['submitold'] = True
196                 rv = {'id': 1, 'error': None, 'result': rv}
197                 rv = json.dumps(rv)
198                 rv = rv.encode('utf8')
199                 if 'NELH' not in self.quirks:
200                         rv = rv[1:]  # strip the '{' we already sent
201                         self.push(('%x' % len(rv)).encode('utf8') + b"\r\n" + rv + b"\r\n0\r\n\r\n")
202                         self.reset_request()
203                         return
204                 
205                 try:
206                         self.sendReply(200, body=rv, headers=self.LPHeaders)
207                         raise RequestNotHandled
208                 except RequestHandled:
209                         # Expected
210                         pass
211                 finally:
212                         self.reset_request()
213         
214         def doJSON(self, data):
215                 # TODO: handle JSON errors
216                 data = data.decode('utf8')
217                 try:
218                         data = json.loads(data)
219                         method = 'doJSON_' + str(data['method']).lower()
220                 except ValueError:
221                         return self.doError(r'Parse error')
222                 except TypeError:
223                         return self.doError(r'Bad call')
224                 if not hasattr(self, method):
225                         return self.doError(r'Procedure not found')
226                 # TODO: handle errors as JSON-RPC
227                 self._JSONHeaders = {}
228                 params = data.setdefault('params', ())
229                 try:
230                         rv = getattr(self, method)(*tuple(data['params']))
231                 except Exception as e:
232                         self.logger.error(("Error during JSON-RPC call: %s%s\n" % (method, params)) + traceback.format_exc())
233                         return self.doError(r'Service error: %s' % (e,))
234                 rv = {'id': data['id'], 'error': None, 'result': rv}
235                 try:
236                         rv = json.dumps(rv)
237                 except:
238                         return self.doError(r'Error encoding reply in JSON')
239                 rv = rv.encode('utf8')
240                 return self.sendReply(200, rv, headers=self._JSONHeaders)
241         
242         getwork_rv_template = {
243                 'data': '000000800000000000000000000000000000000000000000000000000000000000000000000000000000000080020000',
244                 'target': 'ffffffffffffffffffffffffffffffffffffffffffffffffffffffff00000000',
245                 'hash1': '00000000000000000000000000000000000000000000000000000000000000000000008000000000000000000000000000000000000000000000000000010000',
246         }
247         def doJSON_getwork(self, data=None):
248                 if not data is None:
249                         return self.doJSON_submitwork(data)
250                 rv = dict(self.getwork_rv_template)
251                 hdr = self.server.getBlockHeader(self.Username)
252                 
253                 # FIXME: this assumption breaks with internal rollntime
254                 # NOTE: noncerange needs to set nonce to start value at least
255                 global _CheckForDupesHACK
256                 uhdr = hdr[:68] + hdr[72:]
257                 if uhdr in _CheckForDupesHACK:
258                         raise self.server.RaiseRedFlags(RuntimeError('issuing duplicate work'))
259                 _CheckForDupesHACK[uhdr] = None
260                 
261                 data = b2a_hex(swap32(hdr)).decode('utf8') + rv['data']
262                 # TODO: endian shuffle etc
263                 rv['data'] = data
264                 if midstate and 'midstate' not in self.extensions:
265                         h = midstate.SHA256(hdr)[:8]
266                         rv['midstate'] = b2a_hex(pack('<LLLLLLLL', *h)).decode('ascii')
267                 return rv
268         
269         def doJSON_submitwork(self, datax):
270                 data = swap32(a2b_hex(datax))[:80]
271                 share = {
272                         'data': data,
273                         '_origdata' : datax,
274                         'username': self.Username,
275                         'remoteHost': self.addr[0],
276                 }
277                 try:
278                         self.server.receiveShare(share)
279                 except RejectedShare as rej:
280                         self._JSONHeaders['X-Reject-Reason'] = str(rej)
281                         return False
282                 return True
283         
284         getmemorypool_rv_template = {
285                 'mutable': [],
286                 'noncerange': '00000000ffffffff',
287                 'target': '00000000ffffffffffffffffffffffffffffffffffffffffffffffffffffffff',
288                 'version': 1,
289         }
290         def doJSON_getmemorypool(self, data=None):
291                 if not data is None:
292                         return self.doJSON_submitblock(data)
293                 
294                 rv = dict(self.getmemorypool_rv_template)
295                 MC = self.server.getBlockTemplate(self.Username)
296                 (dummy, merkleTree, cb, prevBlock, bits) = MC
297                 rv['previousblockhash'] = b2a_hex(prevBlock[::-1]).decode('ascii')
298                 tl = []
299                 for txn in merkleTree.data[1:]:
300                         tl.append(b2a_hex(txn.data).decode('ascii'))
301                 rv['transactions'] = tl
302                 now = int(time())
303                 rv['time'] = now
304                 # FIXME: ensure mintime is always >= real mintime, both here and in share acceptance
305                 rv['mintime'] = now - 180
306                 rv['maxtime'] = now + 120
307                 rv['bits'] = b2a_hex(bits[::-1]).decode('ascii')
308                 t = deepcopy(merkleTree.data[0])
309                 t.setCoinbase(cb)
310                 t.assemble()
311                 rv['coinbasetxn'] = b2a_hex(t.data).decode('ascii')
312                 return rv
313         
314         def doJSON_submitblock(self, data):
315                 data = a2b_hex(data)
316                 share = {
317                         'data': data[:80],
318                         'blkdata': data[80:],
319                         'username': self.Username,
320                         'remoteHost': self.addr[0],
321                 }
322                 try:
323                         self.server.receiveShare(share)
324                 except RejectedShare as rej:
325                         self._JSONHeaders['X-Reject-Reason'] = str(rej)
326                         return False
327                 return True
328         
329         def doJSON_setworkaux(self, k, hexv = None):
330                 if self.Username != self.server.SecretUser:
331                         self.doAuthenticate()
332                         return None
333                 if hexv:
334                         self.server.aux[k] = a2b_hex(hexv)
335                 else:
336                         del self.server.aux[k]
337                 return True
338         
339         def handle_close(self):
340                 self.cleanupLP()
341                 super().handle_close()
342         
343         def handle_request(self):
344                 if not self.Username:
345                         return self.doAuthenticate()
346                 if not self.method in (b'GET', b'POST'):
347                         return self.sendReply(405)
348                 if not self.path in (b'/', b'/LP', b'/LP/'):
349                         return self.sendReply(404)
350                 try:
351                         if self.path[:3] == b'/LP':
352                                 return self.doLongpoll()
353                         data = b''.join(self.incoming)
354                         return self.doJSON(data)
355                 except socket.error:
356                         raise
357                 except WithinLongpoll:
358                         raise
359                 except RequestHandled:
360                         raise
361                 except:
362                         self.logger.error(traceback.format_exc())
363                         return self.doError('uncaught error')
364         
365         def parse_headers(self, hs):
366                 self.CL = None
367                 self.Username = None
368                 self.method = None
369                 self.path = None
370                 hs = re.split(br'\r?\n', hs)
371                 data = hs.pop(0).split(b' ')
372                 try:
373                         self.method = data[0]
374                         self.path = data[1]
375                 except IndexError:
376                         self.close()
377                         return
378                 self.extensions = []
379                 self.reqinfo = {}
380                 self.quirks = {}
381                 while True:
382                         try:
383                                 data = hs.pop(0)
384                         except IndexError:
385                                 break
386                         data = tuple(map(lambda a: a.strip(), data.split(b':', 1)))
387                         method = 'doHeader_' + data[0].decode('ascii').lower()
388                         if hasattr(self, method):
389                                 try:
390                                         getattr(self, method)(data[1])
391                                 except RequestAlreadyHandled:
392                                         # Ignore multiple errors and such
393                                         pass
394         
395         def found_terminator(self):
396                 if self.reading_headers:
397                         inbuf = b"".join(self.incoming)
398                         self.incoming = []
399                         m = re.match(br'^[\r\n]+', inbuf)
400                         if m:
401                                 inbuf = inbuf[len(m.group(0)):]
402                         if not inbuf:
403                                 return
404                         
405                         self.reading_headers = False
406                         self.parse_headers(inbuf)
407                         if self.CL:
408                                 self.set_terminator(self.CL)
409                                 return
410                 
411                 self.set_terminator(None)
412                 try:
413                         self.handle_request()
414                         raise RequestNotHandled
415                 except RequestHandled:
416                         self.reset_request()
417                 except WithinLongpoll:
418                         pass
419                 except:
420                         self.logger.error(traceback.format_exc())
421         
422         def handle_error(self):
423                 self.logger.debug(traceback.format_exc())
424                 self.handle_close()
425         
426         get_terminator = asynchat.async_chat.get_terminator
427         set_terminator = asynchat.async_chat.set_terminator
428         
429         def handle_readbuf(self):
430                 while self.ac_in_buffer:
431                         lb = len(self.ac_in_buffer)
432                         terminator = self.get_terminator()
433                         if not terminator:
434                                 # no terminator, collect it all
435                                 self.collect_incoming_data (self.ac_in_buffer)
436                                 self.ac_in_buffer = b''
437                         elif isinstance(terminator, int):
438                                 # numeric terminator
439                                 n = terminator
440                                 if lb < n:
441                                         self.collect_incoming_data (self.ac_in_buffer)
442                                         self.ac_in_buffer = b''
443                                         self.terminator = self.terminator - lb
444                                 else:
445                                         self.collect_incoming_data (self.ac_in_buffer[:n])
446                                         self.ac_in_buffer = self.ac_in_buffer[n:]
447                                         self.terminator = 0
448                                         self.found_terminator()
449                         else:
450                                 # 3 cases:
451                                 # 1) end of buffer matches terminator exactly:
452                                 #    collect data, transition
453                                 # 2) end of buffer matches some prefix:
454                                 #    collect data to the prefix
455                                 # 3) end of buffer does not match any prefix:
456                                 #    collect data
457                                 # NOTE: this supports multiple different terminators, but
458                                 #       NOT ones that are prefixes of others...
459                                 if isinstance(self.ac_in_buffer, type(terminator)):
460                                         terminator = (terminator,)
461                                 termidx = tuple(map(self.ac_in_buffer.find, terminator))
462                                 try:
463                                         index = min(x for x in termidx if x >= 0)
464                                 except ValueError:
465                                         index = -1
466                                 if index != -1:
467                                         # we found the terminator
468                                         if index > 0:
469                                                 # don't bother reporting the empty string (source of subtle bugs)
470                                                 self.collect_incoming_data (self.ac_in_buffer[:index])
471                                         specific_terminator = terminator[termidx.index(index)]
472                                         terminator_len = len(specific_terminator)
473                                         self.ac_in_buffer = self.ac_in_buffer[index+terminator_len:]
474                                         # This does the Right Thing if the terminator is changed here.
475                                         self.found_terminator()
476                                 else:
477                                         # check for a prefix of the terminator
478                                         termidx = tuple(map(lambda a: asynchat.find_prefix_at_end (self.ac_in_buffer, a), terminator))
479                                         index = max(termidx)
480                                         if index:
481                                                 if index != lb:
482                                                         # we found a prefix, collect up to the prefix
483                                                         self.collect_incoming_data (self.ac_in_buffer[:-index])
484                                                         self.ac_in_buffer = self.ac_in_buffer[-index:]
485                                                 break
486                                         else:
487                                                 # no prefix, collect it all
488                                                 self.collect_incoming_data (self.ac_in_buffer)
489                                                 self.ac_in_buffer = b''
490         
491         def reset_request(self):
492                 self.replySent = False
493                 self.incoming = []
494                 self.set_terminator( (b"\n\n", b"\r\n\r\n") )
495                 self.reading_headers = True
496                 self._LP = False
497                 self.changeTask(self.handle_timeout, time() + 15)
498         
499         def collect_incoming_data(self, data):
500                 asynchat.async_chat._collect_incoming_data(self, data)
501         
502         def __init__(self, *a, **ka):
503                 super().__init__(*a, **ka)
504                 self.reset_request()
505         
506 setattr(JSONRPCHandler, 'doHeader_content-length', JSONRPCHandler.doHeader_content_length);
507 setattr(JSONRPCHandler, 'doHeader_user-agent', JSONRPCHandler.doHeader_user_agent);
508 setattr(JSONRPCHandler, 'doHeader_x-minimum-wait', JSONRPCHandler.doHeader_x_minimum_wait);
509 setattr(JSONRPCHandler, 'doHeader_x-mining-extensions', JSONRPCHandler.doHeader_x_mining_extensions);
510
511 JSONRPCListener = networkserver.NetworkListener
512
513 class JSONRPCServer(networkserver.AsyncSocketServer):
514         logger = logging.getLogger('JSONRPCServer')
515         
516         waker = True
517         
518         def __init__(self, *a, **ka):
519                 ka.setdefault('RequestHandlerClass', JSONRPCHandler)
520                 super().__init__(*a, **ka)
521                 
522                 self.SecretUser = None
523                 
524                 self.LPRequest = False
525                 self._LPClients = {}
526                 self._LPWaitTime = time() + 15
527                 
528                 self.LPTracking = {}
529         
530         def pre_schedule(self):
531                 if self.LPRequest == 1:
532                         self._LPsch()
533         
534         def wakeLongpoll(self):
535                 if self.LPRequest:
536                         self.logger.info('Ignoring longpoll attempt while another is waiting')
537                         return
538                 self.LPRequest = 1
539                 self.wakeup()
540         
541         def _LPsch(self):
542                 now = time()
543                 if self._LPWaitTime > now:
544                         delay = self._LPWaitTime - now
545                         self.logger.info('Waiting %.3g seconds to longpoll' % (delay,))
546                         self.schedule(self._actualLP, self._LPWaitTime)
547                         self.LPRequest = 2
548                 else:
549                         self._actualLP()
550         
551         def _actualLP(self):
552                 self.LPRequest = False
553                 C = tuple(self._LPClients.values())
554                 self._LPClients = {}
555                 if not C:
556                         self.logger.info('Nobody to longpoll')
557                         return
558                 OC = len(C)
559                 self.logger.debug("%d clients to wake up..." % (OC,))
560                 
561                 now = time()
562                 
563                 for ic in C:
564                         ic.wakeLongpoll()
565                 
566                 self._LPWaitTime = time()
567                 self.logger.info('Longpoll woke up %d clients in %.3f seconds' % (OC, self._LPWaitTime - now))
568                 self._LPWaitTime += 5  # TODO: make configurable: minimum time between longpolls
569         
570         def TopLPers(self, n = 0x10):
571                 tmp = list(self.LPTracking.keys())
572                 tmp.sort(key=lambda k: self.LPTracking[k])
573                 for jerk in map(lambda k: (k, self.LPTracking[k]), tmp[-n:]):
574                         print(jerk)