Merge branch 'bugfix_gmp_submitblock_rv' into bugfix_gmp_submitblock_rv_2
[bitcoin:eloipool.git] / jsonrpcserver.py
1 # Eloipool - Python Bitcoin pool server
2 # Copyright (C) 2011-2012  Luke Dashjr <luke-jr+eloipool@utopios.org>
3 #
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as
6 # published by the Free Software Foundation, either version 3 of the
7 # License, or (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16
17 import asynchat
18 from base64 import b64decode
19 from binascii import a2b_hex, b2a_hex
20 from copy import deepcopy
21 from datetime import datetime
22 from email.utils import formatdate
23 import json
24 import logging
25 try:
26         import midstate
27         assert midstate.SHA256(b'This is just a test, ignore it. I am making it over 64-bytes long.')[:8] == (0x755f1a94, 0x999b270c, 0xf358c014, 0xfd39caeb, 0x0dcc9ebc, 0x4694cd1a, 0x8e95678e, 0x75fac450)
28 except:
29         logging.getLogger('jsonrpcserver').warning('Error importing \'midstate\' module; work will not provide midstates')
30         midstate = None
31 import networkserver
32 import os
33 import re
34 import socket
35 from struct import pack
36 import threading
37 from time import mktime, time, sleep
38 import traceback
39 from util import RejectedShare, swap32
40
41 class WithinLongpoll(BaseException):
42         pass
43
44 class RequestAlreadyHandled(BaseException):
45         pass
46
47 class RequestHandled(RequestAlreadyHandled):
48         pass
49
50 class RequestNotHandled(BaseException):
51         pass
52
53 # TODO: keepalive/close
54 _CheckForDupesHACK = {}
55 class JSONRPCHandler(networkserver.SocketHandler):
56         HTTPStatus = {
57                 200: 'OK',
58                 401: 'Unauthorized',
59                 404: 'Not Found',
60                 405: 'Method Not Allowed',
61                 500: 'Internal Server Error',
62         }
63         
64         LPHeaders = {
65                 'X-Long-Polling': None,
66         }
67         
68         logger = logging.getLogger('JSONRPCHandler')
69         
70         def sendReply(self, status=200, body=b'', headers=None):
71                 if self.replySent:
72                         raise RequestAlreadyHandled
73                 buf = "HTTP/1.1 %d %s\r\n" % (status, self.HTTPStatus.get(status, 'Eligius'))
74                 headers = dict(headers) if headers else {}
75                 headers['Date'] = formatdate(timeval=mktime(datetime.now().timetuple()), localtime=False, usegmt=True)
76                 headers.setdefault('Server', 'Eloipool')
77                 if body is None:
78                         headers.setdefault('Transfer-Encoding', 'chunked')
79                 else:
80                         headers['Content-Length'] = len(body)
81                 if status == 200:
82                         headers.setdefault('Content-Type', 'application/json')
83                         headers.setdefault('X-Long-Polling', '/LP')
84                         headers.setdefault('X-Roll-NTime', 'expire=120')
85                 elif body and body[0] == 123:  # b'{'
86                         headers.setdefault('Content-Type', 'application/json')
87                 for k, v in headers.items():
88                         if v is None: continue
89                         buf += "%s: %s\r\n" % (k, v)
90                 buf += "\r\n"
91                 buf = buf.encode('utf8')
92                 self.replySent = True
93                 if body is None:
94                         self.push(buf)
95                         return
96                 buf += body
97                 self.push(buf)
98                 raise RequestHandled
99         
100         def doError(self, reason = '', code = 100):
101                 reason = json.dumps(reason)
102                 reason = r'{"result":null,"id":null,"error":{"name":"JSONRPCError","code":%d,"message":%s}}' % (code, reason)
103                 return self.sendReply(500, reason.encode('utf8'))
104         
105         def doHeader_authorization(self, value):
106                 value = value.split(b' ')
107                 if len(value) != 2 or value[0] != b'Basic':
108                         return self.doError('Bad Authorization header')
109                 value = b64decode(value[1])
110                 value = value.split(b':')[0]
111                 self.Username = value.decode('utf8')
112         
113         def doHeader_connection(self, value):
114                 if value == b'close':
115                         self.quirks['close'] = None
116         
117         def doHeader_content_length(self, value):
118                 self.CL = int(value)
119         
120         def doHeader_user_agent(self, value):
121                 self.reqinfo['UA'] = value
122                 quirks = self.quirks
123                 try:
124                         if value[:9] == b'phoenix/v':
125                                 v = tuple(map(int, value[9:].split(b'.')))
126                                 if v[0] < 2 and v[1] < 8 and v[2] < 1:
127                                         quirks['NELH'] = None
128                 except:
129                         pass
130                 self.quirks = quirks
131         
132         def doHeader_x_minimum_wait(self, value):
133                 self.reqinfo['MinWait'] = int(value)
134         
135         def doHeader_x_mining_extensions(self, value):
136                 self.extensions = value.decode('ascii').lower().split(' ')
137         
138         def doAuthenticate(self):
139                 self.sendReply(401, headers={'WWW-Authenticate': 'Basic realm="Eligius"'})
140         
141         def doLongpoll(self):
142                 timeNow = time()
143                 
144                 self._LP = True
145                 if 'NELH' not in self.quirks:
146                         # [NOT No] Early Longpoll Headers
147                         self.sendReply(200, body=None, headers=self.LPHeaders)
148                         self.push(b"1\r\n{\r\n")
149                         self.changeTask(self._chunkedKA, timeNow + 45)
150                 else:
151                         self.changeTask(None)
152                 
153                 waitTime = self.reqinfo.get('MinWait', 15)  # TODO: make default configurable
154                 self.waitTime = waitTime + timeNow
155                 
156                 totfromme = self.LPTrack()
157                 self.server._LPClients[id(self)] = self
158                 self.logger.debug("New LP client; %d total; %d from %s" % (len(self.server._LPClients), totfromme, self.addr[0]))
159                 
160                 raise WithinLongpoll
161         
162         def _chunkedKA(self):
163                 # Keepalive via chunked transfer encoding
164                 self.push(b"1\r\n \r\n")
165                 self.changeTask(self._chunkedKA, time() + 45)
166         
167         def LPTrack(self):
168                 myip = self.addr[0]
169                 if myip not in self.server.LPTracking:
170                         self.server.LPTracking[myip] = 0
171                 self.server.LPTracking[myip] += 1
172                 return self.server.LPTracking[myip]
173         
174         def LPUntrack(self):
175                 self.server.LPTracking[self.addr[0]] -= 1
176         
177         def cleanupLP(self):
178                 # Called when the connection is closed
179                 if not self._LP:
180                         return
181                 self.changeTask(None)
182                 try:
183                         del self.server._LPClients[id(self)]
184                 except KeyError:
185                         pass
186                 self.LPUntrack()
187         
188         def wakeLongpoll(self):
189                 now = time()
190                 if now < self.waitTime:
191                         self.changeTask(self.wakeLongpoll, self.waitTime)
192                         return
193                 else:
194                         self.changeTask(None)
195                 
196                 self.LPUntrack()
197                 
198                 rv = self.doJSON_getwork()
199                 rv['submitold'] = True
200                 rv = {'id': 1, 'error': None, 'result': rv}
201                 rv = json.dumps(rv)
202                 rv = rv.encode('utf8')
203                 if 'NELH' not in self.quirks:
204                         rv = rv[1:]  # strip the '{' we already sent
205                         self.push(('%x' % len(rv)).encode('utf8') + b"\r\n" + rv + b"\r\n0\r\n\r\n")
206                         self.reset_request()
207                         return
208                 
209                 try:
210                         self.sendReply(200, body=rv, headers=self.LPHeaders)
211                         raise RequestNotHandled
212                 except RequestHandled:
213                         # Expected
214                         pass
215                 finally:
216                         self.reset_request()
217         
218         def doJSON(self, data):
219                 # TODO: handle JSON errors
220                 data = data.decode('utf8')
221                 try:
222                         data = json.loads(data)
223                         method = 'doJSON_' + str(data['method']).lower()
224                 except ValueError:
225                         return self.doError(r'Parse error')
226                 except TypeError:
227                         return self.doError(r'Bad call')
228                 if not hasattr(self, method):
229                         return self.doError(r'Procedure not found')
230                 # TODO: handle errors as JSON-RPC
231                 self._JSONHeaders = {}
232                 params = data.setdefault('params', ())
233                 try:
234                         rv = getattr(self, method)(*tuple(data['params']))
235                 except Exception as e:
236                         self.logger.error(("Error during JSON-RPC call: %s%s\n" % (method, params)) + traceback.format_exc())
237                         return self.doError(r'Service error: %s' % (e,))
238                 rv = {'id': data['id'], 'error': None, 'result': rv}
239                 try:
240                         rv = json.dumps(rv)
241                 except:
242                         return self.doError(r'Error encoding reply in JSON')
243                 rv = rv.encode('utf8')
244                 return self.sendReply(200, rv, headers=self._JSONHeaders)
245         
246         getwork_rv_template = {
247                 'data': '000000800000000000000000000000000000000000000000000000000000000000000000000000000000000080020000',
248                 'target': 'ffffffffffffffffffffffffffffffffffffffffffffffffffffffff00000000',
249                 'hash1': '00000000000000000000000000000000000000000000000000000000000000000000008000000000000000000000000000000000000000000000000000010000',
250         }
251         def doJSON_getwork(self, data=None):
252                 if not data is None:
253                         return self.doJSON_submitwork(data)
254                 rv = dict(self.getwork_rv_template)
255                 hdr = self.server.getBlockHeader(self.Username)
256                 
257                 # FIXME: this assumption breaks with internal rollntime
258                 # NOTE: noncerange needs to set nonce to start value at least
259                 global _CheckForDupesHACK
260                 uhdr = hdr[:68] + hdr[72:]
261                 if uhdr in _CheckForDupesHACK:
262                         raise self.server.RaiseRedFlags(RuntimeError('issuing duplicate work'))
263                 _CheckForDupesHACK[uhdr] = None
264                 
265                 data = b2a_hex(swap32(hdr)).decode('utf8') + rv['data']
266                 # TODO: endian shuffle etc
267                 rv['data'] = data
268                 if midstate and 'midstate' not in self.extensions:
269                         h = midstate.SHA256(hdr)[:8]
270                         rv['midstate'] = b2a_hex(pack('<LLLLLLLL', *h)).decode('ascii')
271                 return rv
272         
273         def doJSON_submitwork(self, datax):
274                 data = swap32(a2b_hex(datax))[:80]
275                 share = {
276                         'data': data,
277                         '_origdata' : datax,
278                         'username': self.Username,
279                         'remoteHost': self.addr[0],
280                 }
281                 try:
282                         self.server.receiveShare(share)
283                 except RejectedShare as rej:
284                         self._JSONHeaders['X-Reject-Reason'] = str(rej)
285                         return False
286                 return True
287         
288         getmemorypool_rv_template = {
289                 'mutable': [
290                         'coinbase/append',
291                 ],
292                 'noncerange': '00000000ffffffff',
293                 'target': '00000000ffffffffffffffffffffffffffffffffffffffffffffffffffffffff',
294                 'version': 1,
295         }
296         def doJSON_getmemorypool(self, data=None):
297                 if not data is None:
298                         return self.doJSON_submitblock(data)
299                 
300                 rv = dict(self.getmemorypool_rv_template)
301                 MC = self.server.getBlockTemplate(self.Username)
302                 (dummy, merkleTree, cb, prevBlock, bits) = MC
303                 rv['previousblockhash'] = b2a_hex(prevBlock[::-1]).decode('ascii')
304                 tl = []
305                 for txn in merkleTree.data[1:]:
306                         tl.append(b2a_hex(txn.data).decode('ascii'))
307                 rv['transactions'] = tl
308                 now = int(time())
309                 rv['time'] = now
310                 # FIXME: ensure mintime is always >= real mintime, both here and in share acceptance
311                 rv['mintime'] = now - 180
312                 rv['maxtime'] = now + 120
313                 rv['bits'] = b2a_hex(bits[::-1]).decode('ascii')
314                 t = deepcopy(merkleTree.data[0])
315                 t.setCoinbase(cb)
316                 t.assemble()
317                 rv['coinbasetxn'] = b2a_hex(t.data).decode('ascii')
318                 return rv
319         
320         def doJSON_submitblock(self, data):
321                 data = a2b_hex(data)
322                 share = {
323                         'data': data[:80],
324                         'blkdata': data[80:],
325                         'username': self.Username,
326                         'remoteHost': self.addr[0],
327                 }
328                 try:
329                         self.server.receiveShare(share)
330                 except RejectedShare as rej:
331                         self._JSONHeaders['X-Reject-Reason'] = str(rej)
332                         return False
333                 return True
334         
335         def doJSON_setworkaux(self, k, hexv = None):
336                 if self.Username != self.server.SecretUser:
337                         self.doAuthenticate()
338                         return None
339                 if hexv:
340                         self.server.aux[k] = a2b_hex(hexv)
341                 else:
342                         del self.server.aux[k]
343                 return True
344         
345         def handle_close(self):
346                 self.cleanupLP()
347                 super().handle_close()
348         
349         def handle_request(self):
350                 if not self.Username:
351                         return self.doAuthenticate()
352                 if not self.method in (b'GET', b'POST'):
353                         return self.sendReply(405)
354                 if not self.path in (b'/', b'/LP', b'/LP/'):
355                         return self.sendReply(404)
356                 try:
357                         if self.path[:3] == b'/LP':
358                                 return self.doLongpoll()
359                         data = b''.join(self.incoming)
360                         return self.doJSON(data)
361                 except socket.error:
362                         raise
363                 except WithinLongpoll:
364                         raise
365                 except RequestHandled:
366                         raise
367                 except:
368                         self.logger.error(traceback.format_exc())
369                         return self.doError('uncaught error')
370         
371         def parse_headers(self, hs):
372                 self.CL = None
373                 self.Username = None
374                 self.method = None
375                 self.path = None
376                 hs = re.split(br'\r?\n', hs)
377                 data = hs.pop(0).split(b' ')
378                 try:
379                         self.method = data[0]
380                         self.path = data[1]
381                 except IndexError:
382                         self.close()
383                         return
384                 self.extensions = []
385                 self.reqinfo = {}
386                 self.quirks = {}
387                 if data[2:] != [b'HTTP/1.1']:
388                         self.quirks['close'] = None
389                 self.quirks['NELH'] = None  # FIXME: identify which clients have a problem with this
390                 while True:
391                         try:
392                                 data = hs.pop(0)
393                         except IndexError:
394                                 break
395                         data = tuple(map(lambda a: a.strip(), data.split(b':', 1)))
396                         method = 'doHeader_' + data[0].decode('ascii').lower()
397                         if hasattr(self, method):
398                                 try:
399                                         getattr(self, method)(data[1])
400                                 except RequestAlreadyHandled:
401                                         # Ignore multiple errors and such
402                                         pass
403         
404         def found_terminator(self):
405                 if self.reading_headers:
406                         inbuf = b"".join(self.incoming)
407                         self.incoming = []
408                         m = re.match(br'^[\r\n]+', inbuf)
409                         if m:
410                                 inbuf = inbuf[len(m.group(0)):]
411                         if not inbuf:
412                                 return
413                         
414                         self.reading_headers = False
415                         self.parse_headers(inbuf)
416                         if self.CL:
417                                 self.set_terminator(self.CL)
418                                 return
419                 
420                 self.set_terminator(None)
421                 try:
422                         self.handle_request()
423                         raise RequestNotHandled
424                 except RequestHandled:
425                         self.reset_request()
426                 except WithinLongpoll:
427                         pass
428                 except:
429                         self.logger.error(traceback.format_exc())
430         
431         def handle_error(self):
432                 self.logger.debug(traceback.format_exc())
433                 self.handle_close()
434         
435         get_terminator = asynchat.async_chat.get_terminator
436         set_terminator = asynchat.async_chat.set_terminator
437         
438         def handle_readbuf(self):
439                 while self.ac_in_buffer:
440                         lb = len(self.ac_in_buffer)
441                         terminator = self.get_terminator()
442                         if not terminator:
443                                 # no terminator, collect it all
444                                 self.collect_incoming_data (self.ac_in_buffer)
445                                 self.ac_in_buffer = b''
446                         elif isinstance(terminator, int):
447                                 # numeric terminator
448                                 n = terminator
449                                 if lb < n:
450                                         self.collect_incoming_data (self.ac_in_buffer)
451                                         self.ac_in_buffer = b''
452                                         self.terminator = self.terminator - lb
453                                 else:
454                                         self.collect_incoming_data (self.ac_in_buffer[:n])
455                                         self.ac_in_buffer = self.ac_in_buffer[n:]
456                                         self.terminator = 0
457                                         self.found_terminator()
458                         else:
459                                 # 3 cases:
460                                 # 1) end of buffer matches terminator exactly:
461                                 #    collect data, transition
462                                 # 2) end of buffer matches some prefix:
463                                 #    collect data to the prefix
464                                 # 3) end of buffer does not match any prefix:
465                                 #    collect data
466                                 # NOTE: this supports multiple different terminators, but
467                                 #       NOT ones that are prefixes of others...
468                                 if isinstance(self.ac_in_buffer, type(terminator)):
469                                         terminator = (terminator,)
470                                 termidx = tuple(map(self.ac_in_buffer.find, terminator))
471                                 try:
472                                         index = min(x for x in termidx if x >= 0)
473                                 except ValueError:
474                                         index = -1
475                                 if index != -1:
476                                         # we found the terminator
477                                         if index > 0:
478                                                 # don't bother reporting the empty string (source of subtle bugs)
479                                                 self.collect_incoming_data (self.ac_in_buffer[:index])
480                                         specific_terminator = terminator[termidx.index(index)]
481                                         terminator_len = len(specific_terminator)
482                                         self.ac_in_buffer = self.ac_in_buffer[index+terminator_len:]
483                                         # This does the Right Thing if the terminator is changed here.
484                                         self.found_terminator()
485                                 else:
486                                         # check for a prefix of the terminator
487                                         termidx = tuple(map(lambda a: asynchat.find_prefix_at_end (self.ac_in_buffer, a), terminator))
488                                         index = max(termidx)
489                                         if index:
490                                                 if index != lb:
491                                                         # we found a prefix, collect up to the prefix
492                                                         self.collect_incoming_data (self.ac_in_buffer[:-index])
493                                                         self.ac_in_buffer = self.ac_in_buffer[-index:]
494                                                 break
495                                         else:
496                                                 # no prefix, collect it all
497                                                 self.collect_incoming_data (self.ac_in_buffer)
498                                                 self.ac_in_buffer = b''
499         
500         def reset_request(self):
501                 self.replySent = False
502                 self.incoming = []
503                 self.set_terminator( (b"\n\n", b"\r\n\r\n") )
504                 self.reading_headers = True
505                 self._LP = False
506                 self.changeTask(self.handle_timeout, time() + 150)
507                 if 'close' in self.quirks:
508                         self.close()
509         
510         def collect_incoming_data(self, data):
511                 asynchat.async_chat._collect_incoming_data(self, data)
512         
513         def __init__(self, *a, **ka):
514                 super().__init__(*a, **ka)
515                 self.quirks = {}
516                 self.reset_request()
517         
518 setattr(JSONRPCHandler, 'doHeader_content-length', JSONRPCHandler.doHeader_content_length);
519 setattr(JSONRPCHandler, 'doHeader_user-agent', JSONRPCHandler.doHeader_user_agent);
520 setattr(JSONRPCHandler, 'doHeader_x-minimum-wait', JSONRPCHandler.doHeader_x_minimum_wait);
521 setattr(JSONRPCHandler, 'doHeader_x-mining-extensions', JSONRPCHandler.doHeader_x_mining_extensions);
522
523 JSONRPCListener = networkserver.NetworkListener
524
525 class JSONRPCServer(networkserver.AsyncSocketServer):
526         logger = logging.getLogger('JSONRPCServer')
527         
528         waker = True
529         
530         def __init__(self, *a, **ka):
531                 ka.setdefault('RequestHandlerClass', JSONRPCHandler)
532                 super().__init__(*a, **ka)
533                 
534                 self.SecretUser = None
535                 
536                 self.LPRequest = False
537                 self._LPClients = {}
538                 self._LPWaitTime = time() + 15
539                 
540                 self.LPTracking = {}
541         
542         def pre_schedule(self):
543                 if self.LPRequest == 1:
544                         self._LPsch()
545         
546         def wakeLongpoll(self):
547                 if self.LPRequest:
548                         self.logger.info('Ignoring longpoll attempt while another is waiting')
549                         return
550                 self.LPRequest = 1
551                 self.wakeup()
552         
553         def _LPsch(self):
554                 now = time()
555                 if self._LPWaitTime > now:
556                         delay = self._LPWaitTime - now
557                         self.logger.info('Waiting %.3g seconds to longpoll' % (delay,))
558                         self.schedule(self._actualLP, self._LPWaitTime)
559                         self.LPRequest = 2
560                 else:
561                         self._actualLP()
562         
563         def _actualLP(self):
564                 self.LPRequest = False
565                 C = tuple(self._LPClients.values())
566                 self._LPClients = {}
567                 if not C:
568                         self.logger.info('Nobody to longpoll')
569                         return
570                 OC = len(C)
571                 self.logger.debug("%d clients to wake up..." % (OC,))
572                 
573                 now = time()
574                 
575                 for ic in C:
576                         try:
577                                 ic.wakeLongpoll()
578                         except socket.error:
579                                 OC -= 1
580                                 # Ignore socket errors; let the main event loop take care of them later
581                         except:
582                                 OC -= 1
583                                 self.logger.debug('Error waking longpoll handler:\n' + traceback.format_exc())
584                 
585                 self._LPWaitTime = time()
586                 self.logger.info('Longpoll woke up %d clients in %.3f seconds' % (OC, self._LPWaitTime - now))
587                 self._LPWaitTime += 5  # TODO: make configurable: minimum time between longpolls
588         
589         def TopLPers(self, n = 0x10):
590                 tmp = list(self.LPTracking.keys())
591                 tmp.sort(key=lambda k: self.LPTracking[k])
592                 for jerk in map(lambda k: (k, self.LPTracking[k]), tmp[-n:]):
593                         print(jerk)