Bugfix: submitblock method returns a rejection reason or null, not a boolean
[bitcoin:eloipool.git] / jsonrpcserver.py
1 # Eloipool - Python Bitcoin pool server
2 # Copyright (C) 2011-2012  Luke Dashjr <luke-jr+eloipool@utopios.org>
3 #
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as
6 # published by the Free Software Foundation, either version 3 of the
7 # License, or (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16
17 import asynchat
18 from base64 import b64decode
19 from binascii import a2b_hex, b2a_hex
20 from copy import deepcopy
21 from datetime import datetime
22 from email.utils import formatdate
23 import json
24 import logging
25 try:
26         import midstate
27         assert midstate.SHA256(b'This is just a test, ignore it. I am making it over 64-bytes long.')[:8] == (0x755f1a94, 0x999b270c, 0xf358c014, 0xfd39caeb, 0x0dcc9ebc, 0x4694cd1a, 0x8e95678e, 0x75fac450)
28 except:
29         logging.getLogger('jsonrpcserver').warning('Error importing \'midstate\' module; work will not provide midstates')
30         midstate = None
31 import networkserver
32 import os
33 import re
34 import socket
35 from struct import pack
36 import threading
37 from time import mktime, time, sleep
38 import traceback
39 from util import RejectedShare, swap32
40
41 class WithinLongpoll(BaseException):
42         pass
43
44 # TODO: keepalive/close
45 _CheckForDupesHACK = {}
46 class JSONRPCHandler(networkserver.SocketHandler):
47         HTTPStatus = {
48                 200: 'OK',
49                 401: 'Unauthorized',
50                 404: 'Not Found',
51                 405: 'Method Not Allowed',
52                 500: 'Internal Server Error',
53         }
54         
55         LPHeaders = {
56                 'X-Long-Polling': None,
57         }
58         
59         logger = logging.getLogger('JSONRPCHandler')
60         
61         def sendReply(self, status=200, body=b'', headers=None):
62                 buf = "HTTP/1.1 %d %s\r\n" % (status, self.HTTPStatus.get(status, 'Eligius'))
63                 headers = dict(headers) if headers else {}
64                 headers['Date'] = formatdate(timeval=mktime(datetime.now().timetuple()), localtime=False, usegmt=True)
65                 headers.setdefault('Server', 'Eloipool')
66                 if body is None:
67                         headers.setdefault('Transfer-Encoding', 'chunked')
68                         body = b''
69                 else:
70                         headers['Content-Length'] = len(body)
71                 if status == 200:
72                         headers.setdefault('Content-Type', 'application/json')
73                         headers.setdefault('X-Long-Polling', '/LP')
74                         headers.setdefault('X-Roll-NTime', 'expire=120')
75                 elif body and body[0] == 123:  # b'{'
76                         headers.setdefault('Content-Type', 'application/json')
77                 for k, v in headers.items():
78                         if v is None: continue
79                         buf += "%s: %s\r\n" % (k, v)
80                 buf += "\r\n"
81                 buf = buf.encode('utf8')
82                 buf += body
83                 self.push(buf)
84         
85         def doError(self, reason = '', code = 100):
86                 reason = json.dumps(reason)
87                 reason = r'{"result":null,"id":null,"error":{"name":"JSONRPCError","code":%d,"message":%s}}' % (code, reason)
88                 return self.sendReply(500, reason.encode('utf8'))
89         
90         def doHeader_authorization(self, value):
91                 value = value.split(b' ')
92                 if len(value) != 2 or value[0] != b'Basic':
93                         return self.doError('Bad Authorization header')
94                 value = b64decode(value[1])
95                 value = value.split(b':')[0]
96                 self.Username = value.decode('utf8')
97         
98         def doHeader_content_length(self, value):
99                 self.CL = int(value)
100         
101         def doHeader_user_agent(self, value):
102                 self.reqinfo['UA'] = value
103                 quirks = self.quirks
104                 try:
105                         if value[:9] == b'phoenix/v':
106                                 v = tuple(map(int, value[9:].split(b'.')))
107                                 if v[0] < 2 and v[1] < 8 and v[2] < 1:
108                                         quirks['NELH'] = None
109                 except:
110                         pass
111                 self.quirks = quirks
112         
113         def doHeader_x_minimum_wait(self, value):
114                 self.reqinfo['MinWait'] = int(value)
115         
116         def doHeader_x_mining_extensions(self, value):
117                 self.extensions = value.decode('ascii').lower().split(' ')
118         
119         def doAuthenticate(self):
120                 self.sendReply(401, headers={'WWW-Authenticate': 'Basic realm="Eligius"'})
121         
122         def doLongpoll(self):
123                 timeNow = time()
124                 
125                 self._LP = True
126                 if 'NELH' not in self.quirks:
127                         # [NOT No] Early Longpoll Headers
128                         self.sendReply(200, body=None, headers=self.LPHeaders)
129                         self.push(b"1\r\n{\r\n")
130                         self.changeTask(self._chunkedKA, timeNow + 45)
131                 else:
132                         self.changeTask(None)
133                 
134                 waitTime = self.reqinfo.get('MinWait', 15)  # TODO: make default configurable
135                 self.waitTime = waitTime + timeNow
136                 
137                 totfromme = self.LPTrack()
138                 self.server._LPClients[id(self)] = self
139                 self.logger.debug("New LP client; %d total; %d from %s" % (len(self.server._LPClients), totfromme, self.addr[0]))
140                 
141                 raise WithinLongpoll
142         
143         def _chunkedKA(self):
144                 # Keepalive via chunked transfer encoding
145                 self.push(b"1\r\n \r\n")
146                 self.changeTask(self._chunkedKA, time() + 45)
147         
148         def LPTrack(self):
149                 myip = self.addr[0]
150                 if myip not in self.server.LPTracking:
151                         self.server.LPTracking[myip] = 0
152                 self.server.LPTracking[myip] += 1
153                 return self.server.LPTracking[myip]
154         
155         def LPUntrack(self):
156                 self.server.LPTracking[self.addr[0]] -= 1
157         
158         def cleanupLP(self):
159                 # Called when the connection is closed
160                 if not self._LP:
161                         return
162                 self.changeTask(None)
163                 try:
164                         del self.server._LPClients[id(self)]
165                 except KeyError:
166                         pass
167                 self.LPUntrack()
168         
169         def wakeLongpoll(self):
170                 now = time()
171                 if now < self.waitTime:
172                         self.changeTask(self.wakeLongpoll, self.waitTime)
173                         return
174                 else:
175                         self.changeTask(None)
176                 
177                 self.LPUntrack()
178                 
179                 rv = self.doJSON_getwork()
180                 rv['submitold'] = True
181                 rv = {'id': 1, 'error': None, 'result': rv}
182                 rv = json.dumps(rv)
183                 rv = rv.encode('utf8')
184                 if 'NELH' not in self.quirks:
185                         rv = rv[1:]  # strip the '{' we already sent
186                         self.push(('%x' % len(rv)).encode('utf8') + b"\r\n" + rv + b"\r\n0\r\n\r\n")
187                 else:
188                         self.sendReply(200, body=rv, headers=self.LPHeaders)
189                 
190                 self.reset_request()
191         
192         def doJSON(self, data):
193                 # TODO: handle JSON errors
194                 data = data.decode('utf8')
195                 try:
196                         data = json.loads(data)
197                         method = 'doJSON_' + str(data['method']).lower()
198                 except ValueError:
199                         return self.doError(r'Parse error')
200                 except TypeError:
201                         return self.doError(r'Bad call')
202                 if not hasattr(self, method):
203                         return self.doError(r'Procedure not found')
204                 # TODO: handle errors as JSON-RPC
205                 self._JSONHeaders = {}
206                 params = data.setdefault('params', ())
207                 try:
208                         rv = getattr(self, method)(*tuple(data['params']))
209                 except Exception as e:
210                         self.logger.error(("Error during JSON-RPC call: %s%s\n" % (method, params)) + traceback.format_exc())
211                         return self.doError(r'Service error: %s' % (e,))
212                 if rv is None:
213                         # response was already sent (eg, authentication request)
214                         return
215                 rv = {'id': data['id'], 'error': None, 'result': rv}
216                 try:
217                         rv = json.dumps(rv)
218                 except:
219                         return self.doError(r'Error encoding reply in JSON')
220                 rv = rv.encode('utf8')
221                 return self.sendReply(200, rv, headers=self._JSONHeaders)
222         
223         getwork_rv_template = {
224                 'data': '000000800000000000000000000000000000000000000000000000000000000000000000000000000000000080020000',
225                 'target': 'ffffffffffffffffffffffffffffffffffffffffffffffffffffffff00000000',
226                 'hash1': '00000000000000000000000000000000000000000000000000000000000000000000008000000000000000000000000000000000000000000000000000010000',
227         }
228         def doJSON_getwork(self, data=None):
229                 if not data is None:
230                         return self.doJSON_submitwork(data)
231                 rv = dict(self.getwork_rv_template)
232                 hdr = self.server.getBlockHeader(self.Username)
233                 
234                 # FIXME: this assumption breaks with internal rollntime
235                 # NOTE: noncerange needs to set nonce to start value at least
236                 global _CheckForDupesHACK
237                 uhdr = hdr[:68] + hdr[72:]
238                 if uhdr in _CheckForDupesHACK:
239                         raise self.server.RaiseRedFlags(RuntimeError('issuing duplicate work'))
240                 _CheckForDupesHACK[uhdr] = None
241                 
242                 data = b2a_hex(swap32(hdr)).decode('utf8') + rv['data']
243                 # TODO: endian shuffle etc
244                 rv['data'] = data
245                 if midstate and 'midstate' not in self.extensions:
246                         h = midstate.SHA256(hdr)[:8]
247                         rv['midstate'] = b2a_hex(pack('<LLLLLLLL', *h)).decode('ascii')
248                 return rv
249         
250         def doJSON_submitwork(self, datax):
251                 data = swap32(a2b_hex(datax))[:80]
252                 share = {
253                         'data': data,
254                         '_origdata' : datax,
255                         'username': self.Username,
256                         'remoteHost': self.addr[0],
257                 }
258                 try:
259                         self.server.receiveShare(share)
260                 except RejectedShare as rej:
261                         self._JSONHeaders['X-Reject-Reason'] = str(rej)
262                         return False
263                 return True
264         
265         getmemorypool_rv_template = {
266                 'mutable': [],
267                 'noncerange': '00000000ffffffff',
268                 'target': '00000000ffffffffffffffffffffffffffffffffffffffffffffffffffffffff',
269                 'version': 1,
270         }
271         def doJSON_getmemorypool(self, data=None):
272                 if not data is None:
273                         return self.doJSON_submitblock(data) is None
274                 
275                 rv = dict(self.getmemorypool_rv_template)
276                 MC = self.server.getBlockTemplate(self.Username)
277                 (dummy, merkleTree, cb, prevBlock, bits) = MC
278                 rv['previousblockhash'] = b2a_hex(prevBlock[::-1]).decode('ascii')
279                 tl = []
280                 for txn in merkleTree.data[1:]:
281                         tl.append(b2a_hex(txn.data).decode('ascii'))
282                 rv['transactions'] = tl
283                 now = int(time())
284                 rv['time'] = now
285                 # FIXME: ensure mintime is always >= real mintime, both here and in share acceptance
286                 rv['mintime'] = now - 180
287                 rv['maxtime'] = now + 120
288                 rv['bits'] = b2a_hex(bits[::-1]).decode('ascii')
289                 t = deepcopy(merkleTree.data[0])
290                 t.setCoinbase(cb)
291                 t.assemble()
292                 rv['coinbasetxn'] = b2a_hex(t.data).decode('ascii')
293                 return rv
294         
295         def doJSON_submitblock(self, data):
296                 data = a2b_hex(data)
297                 share = {
298                         'data': data[:80],
299                         'blkdata': data[80:],
300                         'username': self.Username,
301                         'remoteHost': self.addr[0],
302                 }
303                 try:
304                         self.server.receiveShare(share)
305                 except RejectedShare as rej:
306                         return str(rej)
307                 return None
308         
309         def doJSON_setworkaux(self, k, hexv = None):
310                 if self.Username != self.server.SecretUser:
311                         self.doAuthenticate()
312                         return None
313                 if hexv:
314                         self.server.aux[k] = a2b_hex(hexv)
315                 else:
316                         del self.server.aux[k]
317                 return True
318         
319         def handle_close(self):
320                 self.cleanupLP()
321                 super().handle_close()
322         
323         def handle_request(self):
324                 if not self.Username:
325                         return self.doAuthenticate()
326                 if not self.method in (b'GET', b'POST'):
327                         return self.sendReply(405)
328                 if not self.path in (b'/', b'/LP', b'/LP/'):
329                         return self.sendReply(404)
330                 try:
331                         if self.path[:3] == b'/LP':
332                                 return self.doLongpoll()
333                         data = b''.join(self.incoming)
334                         return self.doJSON(data)
335                 except socket.error:
336                         raise
337                 except WithinLongpoll:
338                         raise
339                 except:
340                         self.logger.error(traceback.format_exc())
341                         return self.doError('uncaught error')
342         
343         def parse_headers(self, hs):
344                 self.CL = None
345                 self.Username = None
346                 self.method = None
347                 self.path = None
348                 hs = re.split(br'\r?\n', hs)
349                 data = hs.pop(0).split(b' ')
350                 try:
351                         self.method = data[0]
352                         self.path = data[1]
353                 except IndexError:
354                         self.close()
355                         return
356                 self.extensions = []
357                 self.reqinfo = {}
358                 self.quirks = {}
359                 while True:
360                         try:
361                                 data = hs.pop(0)
362                         except IndexError:
363                                 break
364                         data = tuple(map(lambda a: a.strip(), data.split(b':', 1)))
365                         method = 'doHeader_' + data[0].decode('ascii').lower()
366                         if hasattr(self, method):
367                                 getattr(self, method)(data[1])
368         
369         def found_terminator(self):
370                 if self.reading_headers:
371                         inbuf = b"".join(self.incoming)
372                         self.incoming = []
373                         m = re.match(br'^[\r\n]+', inbuf)
374                         if m:
375                                 inbuf = inbuf[len(m.group(0)):]
376                         if not inbuf:
377                                 return
378                         
379                         self.reading_headers = False
380                         self.parse_headers(inbuf)
381                         if self.CL:
382                                 self.set_terminator(self.CL)
383                                 return
384                 
385                 self.set_terminator(None)
386                 try:
387                         self.handle_request()
388                         self.reset_request()
389                 except WithinLongpoll:
390                         pass
391         
392         def handle_error(self):
393                 self.logger.debug(traceback.format_exc())
394                 self.handle_close()
395         
396         get_terminator = asynchat.async_chat.get_terminator
397         set_terminator = asynchat.async_chat.set_terminator
398         
399         def handle_readbuf(self):
400                 while self.ac_in_buffer:
401                         lb = len(self.ac_in_buffer)
402                         terminator = self.get_terminator()
403                         if not terminator:
404                                 # no terminator, collect it all
405                                 self.collect_incoming_data (self.ac_in_buffer)
406                                 self.ac_in_buffer = b''
407                         elif isinstance(terminator, int):
408                                 # numeric terminator
409                                 n = terminator
410                                 if lb < n:
411                                         self.collect_incoming_data (self.ac_in_buffer)
412                                         self.ac_in_buffer = b''
413                                         self.terminator = self.terminator - lb
414                                 else:
415                                         self.collect_incoming_data (self.ac_in_buffer[:n])
416                                         self.ac_in_buffer = self.ac_in_buffer[n:]
417                                         self.terminator = 0
418                                         self.found_terminator()
419                         else:
420                                 # 3 cases:
421                                 # 1) end of buffer matches terminator exactly:
422                                 #    collect data, transition
423                                 # 2) end of buffer matches some prefix:
424                                 #    collect data to the prefix
425                                 # 3) end of buffer does not match any prefix:
426                                 #    collect data
427                                 # NOTE: this supports multiple different terminators, but
428                                 #       NOT ones that are prefixes of others...
429                                 if isinstance(self.ac_in_buffer, type(terminator)):
430                                         terminator = (terminator,)
431                                 termidx = tuple(map(self.ac_in_buffer.find, terminator))
432                                 try:
433                                         index = min(x for x in termidx if x >= 0)
434                                 except ValueError:
435                                         index = -1
436                                 if index != -1:
437                                         # we found the terminator
438                                         if index > 0:
439                                                 # don't bother reporting the empty string (source of subtle bugs)
440                                                 self.collect_incoming_data (self.ac_in_buffer[:index])
441                                         specific_terminator = terminator[termidx.index(index)]
442                                         terminator_len = len(specific_terminator)
443                                         self.ac_in_buffer = self.ac_in_buffer[index+terminator_len:]
444                                         # This does the Right Thing if the terminator is changed here.
445                                         self.found_terminator()
446                                 else:
447                                         # check for a prefix of the terminator
448                                         termidx = tuple(map(lambda a: asynchat.find_prefix_at_end (self.ac_in_buffer, a), terminator))
449                                         index = max(termidx)
450                                         if index:
451                                                 if index != lb:
452                                                         # we found a prefix, collect up to the prefix
453                                                         self.collect_incoming_data (self.ac_in_buffer[:-index])
454                                                         self.ac_in_buffer = self.ac_in_buffer[-index:]
455                                                 break
456                                         else:
457                                                 # no prefix, collect it all
458                                                 self.collect_incoming_data (self.ac_in_buffer)
459                                                 self.ac_in_buffer = b''
460         
461         def reset_request(self):
462                 self.incoming = []
463                 self.set_terminator( (b"\n\n", b"\r\n\r\n") )
464                 self.reading_headers = True
465                 self._LP = False
466                 self.changeTask(self.handle_timeout, time() + 15)
467         
468         def collect_incoming_data(self, data):
469                 asynchat.async_chat._collect_incoming_data(self, data)
470         
471         def __init__(self, *a, **ka):
472                 super().__init__(*a, **ka)
473                 self.reset_request()
474         
475 setattr(JSONRPCHandler, 'doHeader_content-length', JSONRPCHandler.doHeader_content_length);
476 setattr(JSONRPCHandler, 'doHeader_user-agent', JSONRPCHandler.doHeader_user_agent);
477 setattr(JSONRPCHandler, 'doHeader_x-minimum-wait', JSONRPCHandler.doHeader_x_minimum_wait);
478 setattr(JSONRPCHandler, 'doHeader_x-mining-extensions', JSONRPCHandler.doHeader_x_mining_extensions);
479
480 JSONRPCListener = networkserver.NetworkListener
481
482 class JSONRPCServer(networkserver.AsyncSocketServer):
483         logger = logging.getLogger('JSONRPCServer')
484         
485         waker = True
486         
487         def __init__(self, *a, **ka):
488                 ka.setdefault('RequestHandlerClass', JSONRPCHandler)
489                 super().__init__(*a, **ka)
490                 
491                 self.SecretUser = None
492                 
493                 self.LPRequest = False
494                 self._LPClients = {}
495                 self._LPWaitTime = time() + 15
496                 
497                 self.LPTracking = {}
498         
499         def pre_schedule(self):
500                 if self.LPRequest == 1:
501                         self._LPsch()
502         
503         def wakeLongpoll(self):
504                 if self.LPRequest:
505                         self.logger.info('Ignoring longpoll attempt while another is waiting')
506                         return
507                 self.LPRequest = 1
508                 self.wakeup()
509         
510         def _LPsch(self):
511                 now = time()
512                 if self._LPWaitTime > now:
513                         delay = self._LPWaitTime - now
514                         self.logger.info('Waiting %.3g seconds to longpoll' % (delay,))
515                         self.schedule(self._actualLP, self._LPWaitTime)
516                         self.LPRequest = 2
517                 else:
518                         self._actualLP()
519         
520         def _actualLP(self):
521                 self.LPRequest = False
522                 C = tuple(self._LPClients.values())
523                 self._LPClients = {}
524                 if not C:
525                         self.logger.info('Nobody to longpoll')
526                         return
527                 OC = len(C)
528                 self.logger.debug("%d clients to wake up..." % (OC,))
529                 
530                 now = time()
531                 
532                 for ic in C:
533                         ic.wakeLongpoll()
534                 
535                 self._LPWaitTime = time()
536                 self.logger.info('Longpoll woke up %d clients in %.3f seconds' % (OC, self._LPWaitTime - now))
537                 self._LPWaitTime += 5  # TODO: make configurable: minimum time between longpolls
538         
539         def TopLPers(self, n = 0x10):
540                 tmp = list(self.LPTracking.keys())
541                 tmp.sort(key=lambda k: self.LPTracking[k])
542                 for jerk in map(lambda k: (k, self.LPTracking[k]), tmp[-n:]):
543                         print(jerk)