Merge branch 'bugfix_race_coinbase' into bugfix_race_coinbase_2
[bitcoin:eloipool.git] / jsonrpcserver.py
1 import asynchat
2 import asyncore
3 from base64 import b64decode
4 from binascii import a2b_hex, b2a_hex
5 from datetime import datetime
6 from email.utils import formatdate
7 import json
8 import logging
9 try:
10         import midstate
11         assert midstate.SHA256(b'This is just a test, ignore it. I am making it over 64-bytes long.')[:8] == (0x755f1a94, 0x999b270c, 0xf358c014, 0xfd39caeb, 0x0dcc9ebc, 0x4694cd1a, 0x8e95678e, 0x75fac450)
12 except:
13         midstate = None
14 import os
15 import re
16 import select
17 import socket
18 from struct import pack
19 import threading
20 from time import mktime, time, sleep
21 import traceback
22 from util import RejectedShare, swap32
23
24 class WithinLongpoll(BaseException):
25         pass
26
27 # TODO: keepalive/close
28 _CheckForDupesHACK = {}
29 class JSONRPCHandler(asynchat.async_chat):
30         HTTPStatus = {
31                 200: 'OK',
32                 401: 'Unauthorized',
33                 404: 'Not Found',
34                 405: 'Method Not Allowed',
35                 500: 'Internal Server Error',
36         }
37         
38         logger = logging.getLogger('JSONRPCHandler')
39         
40         def sendReply(self, status=200, body=b'', headers=None):
41                 buf = "HTTP/1.1 %d %s\r\n" % (status, self.HTTPStatus.get(status, 'Eligius'))
42                 headers = dict(headers) if headers else {}
43                 headers['Date'] = formatdate(timeval=mktime(datetime.now().timetuple()), localtime=False, usegmt=True)
44                 if body is None:
45                         headers.setdefault('Transfer-Encoding', 'chunked')
46                         body = b''
47                 else:
48                         headers['Content-Length'] = len(body)
49                 if status == 200:
50                         headers.setdefault('Content-Type', 'application/json')
51                         headers.setdefault('X-Long-Polling', '/LP')
52                         headers.setdefault('X-Roll-NTime', 'expire=120')
53                 for k, v in headers.items():
54                         if v is None: continue
55                         buf += "%s: %s\r\n" % (k, v)
56                 buf += "\r\n"
57                 buf = buf.encode('utf8')
58                 buf += body
59                 self.push(buf)
60         
61         def doError(self, reason = ''):
62                 return self.sendReply(500, reason.encode('utf8'))
63         
64         def doHeader_authorization(self, value):
65                 value = value.split(b' ')
66                 if len(value) != 2 or value[0] != b'Basic':
67                         return self.doError('Bad Authorization header')
68                 value = b64decode(value[1])
69                 value = value.split(b':')[0]
70                 self.Username = value.decode('utf8')
71         
72         def doHeader_content_length(self, value):
73                 self.CL = int(value)
74         
75         def doHeader_x_minimum_wait(self, value):
76                 self.reqinfo['MinWait'] = int(value)
77         
78         def doHeader_x_mining_extensions(self, value):
79                 self.extensions = value.decode('ascii').lower().split(' ')
80         
81         def doAuthenticate(self):
82                 self.sendReply(401, headers={'WWW-Authenticate': 'Basic realm="Eligius"'})
83         
84         def doLongpoll(self):
85                 self.sendReply(200, body=None)
86                 self.push(b"1\r\n{\r\n")
87                 waitTime = self.reqinfo.get('MinWait', 15)  # TODO: make default configurable
88                 timeNow = time()
89                 self.waitTime = waitTime + timeNow
90                 
91                 with self.server._LPLock:
92                         self.server._LPClients[id(self)] = self
93                         self.server.schedule(self._chunkedKA, timeNow + 45)
94                         self.logger.debug("New LP client; %d total" % (len(self.server._LPClients),))
95                 
96                 raise WithinLongpoll
97         
98         def _chunkedKA(self):
99                 # Keepalive via chunked transfer encoding
100                 self.push(b"1\r\n \r\n")
101                 self.server.schedule(self._chunkedKA, time() + 45)
102         
103         def cleanupLP(self):
104                 # Called when either the connection is closed, or LP is about to be sent
105                 with self.server._LPLock:
106                         try:
107                                 del self.server._LPClients[id(self)]
108                         except KeyError:
109                                 pass
110                 self.server.rmSchedule(self._chunkedKA, self.wakeLongpoll)
111         
112         def wakeLongpoll(self):
113                 self.cleanupLP()
114                 
115                 now = time()
116                 if now < self.waitTime:
117                         self.server.schedule(self.wakeLongpoll, self.waitTime)
118                         return
119                 
120                 rv = self.doJSON_getwork()
121                 rv = {'id': 1, 'error': None, 'result': rv}
122                 rv = json.dumps(rv)
123                 rv = rv.encode('utf8')
124                 rv = rv[1:]  # strip the '{' we already sent
125                 self.push(('%x' % len(rv)).encode('utf8') + b"\r\n" + rv + b"\r\n0\r\n\r\n")
126                 
127                 self.reset_request()
128         
129         def doJSON(self, data):
130                 # TODO: handle JSON errors
131                 data = data.decode('utf8')
132                 data = json.loads(data)
133                 method = 'doJSON_' + str(data['method']).lower()
134                 if not hasattr(self, method):
135                         return self.doError('No such method')
136                 # TODO: handle errors as JSON-RPC
137                 self._JSONHeaders = {}
138                 rv = getattr(self, method)(*tuple(data['params']))
139                 if rv is None:
140                         return
141                 rv = {'id': data['id'], 'error': None, 'result': rv}
142                 rv = json.dumps(rv)
143                 rv = rv.encode('utf8')
144                 return self.sendReply(200, rv, headers=self._JSONHeaders)
145         
146         getwork_rv_template = {
147                 'data': '000000800000000000000000000000000000000000000000000000000000000000000000000000000000000080020000',
148                 'target': 'ffffffffffffffffffffffffffffffffffffffffffffffffffffffff00000000',
149                 'hash1': '00000000000000000000000000000000000000000000000000000000000000000000008000000000000000000000000000000000000000000000000000010000',
150         }
151         def doJSON_getwork(self, data=None):
152                 if not data is None:
153                         return self.doJSON_submitwork(data)
154                 rv = dict(self.getwork_rv_template)
155                 hdr = self.server.getBlockHeader(self.Username)
156                 
157                 # FIXME: this assumption breaks with internal rollntime
158                 # NOTE: noncerange needs to set nonce to start value at least
159                 global _CheckForDupesHACK
160                 uhdr = hdr[:68] + hdr[72:]
161                 if uhdr in _CheckForDupesHACK:
162                         raise self.server.RaiseRedFlags(RuntimeError('issuing duplicate work'))
163                 _CheckForDupesHACK[uhdr] = None
164                 
165                 data = b2a_hex(swap32(hdr)).decode('utf8') + rv['data']
166                 # TODO: endian shuffle etc
167                 rv['data'] = data
168                 if midstate and 'midstate' not in self.extensions:
169                         h = midstate.SHA256(hdr)[:8]
170                         rv['midstate'] = b2a_hex(pack('<LLLLLLLL', *h)).decode('ascii')
171                 return rv
172         
173         def doJSON_submitwork(self, datax):
174                 data = swap32(a2b_hex(datax))[:80]
175                 share = {
176                         'data': data,
177                         '_origdata' : datax,
178                         'username': self.Username,
179                         'remoteHost': self.addr,
180                 }
181                 try:
182                         self.server.receiveShare(share)
183                 except RejectedShare as rej:
184                         self._JSONHeaders['X-Reject-Reason'] = str(rej)
185                         return False
186                 return True
187         
188         def doJSON_setworkaux(self, k, hexv = None):
189                 if self.Username != self.server.SecretUser:
190                         self.doAuthenticate()
191                         return None
192                 if hexv:
193                         self.server.aux[k] = a2b_hex(hexv)
194                 else:
195                         del self.server.aux[k]
196                 return True
197         
198         def handle_close(self):
199                 self.cleanupLP()
200                 self.close()
201         
202         def handle_request(self):
203                 if not self.Username:
204                         return self.doAuthenticate()
205                 if not self.method in (b'GET', b'POST'):
206                         return self.sendReply(405)
207                 if not self.path in (b'/', b'/LP', b'/LP/'):
208                         return self.sendReply(404)
209                 try:
210                         if self.path[:3] == b'/LP':
211                                 return self.doLongpoll()
212                         data = b''.join(self.incoming)
213                         return self.doJSON(data)
214                 except socket.error:
215                         raise
216                 except WithinLongpoll:
217                         raise
218                 except:
219                         self.logger.error(traceback.format_exc())
220                         return self.doError('uncaught error')
221         
222         def parse_headers(self, hs):
223                 hs = re.split(br'\r?\n', hs)
224                 data = hs.pop(0).split(b' ')
225                 self.method = data[0]
226                 self.path = data[1]
227                 self.CL = None
228                 self.extensions = []
229                 self.Username = None
230                 self.reqinfo = {}
231                 while True:
232                         try:
233                                 data = hs.pop(0)
234                         except IndexError:
235                                 break
236                         data = tuple(map(lambda a: a.strip(), data.split(b':', 1)))
237                         method = 'doHeader_' + data[0].decode('ascii').lower()
238                         if hasattr(self, method):
239                                 getattr(self, method)(data[1])
240         
241         def found_terminator(self):
242                 if self.reading_headers:
243                         self.reading_headers = False
244                         self.parse_headers(b"".join(self.incoming))
245                         self.incoming = []
246                         if self.CL:
247                                 self.set_terminator(self.CL)
248                                 return
249                 
250                 self.set_terminator(None)
251                 try:
252                         self.handle_request()
253                         self.reset_request()
254                 except WithinLongpoll:
255                         pass
256         
257         def handle_read (self):
258                 try:
259                         data = self.recv (self.ac_in_buffer_size)
260                 except socket.error as why:
261                         self.handle_error()
262                         return
263                 
264                 if isinstance(data, str) and self.use_encoding:
265                         data = bytes(str, self.encoding)
266                 self.ac_in_buffer = self.ac_in_buffer + data
267                 
268                 # Continue to search for self.terminator in self.ac_in_buffer,
269                 # while calling self.collect_incoming_data.  The while loop
270                 # is necessary because we might read several data+terminator
271                 # combos with a single recv(4096).
272                 
273                 while self.ac_in_buffer:
274                         lb = len(self.ac_in_buffer)
275                         terminator = self.get_terminator()
276                         if not terminator:
277                                 # no terminator, collect it all
278                                 self.collect_incoming_data (self.ac_in_buffer)
279                                 self.ac_in_buffer = b''
280                         elif isinstance(terminator, int):
281                                 # numeric terminator
282                                 n = terminator
283                                 if lb < n:
284                                         self.collect_incoming_data (self.ac_in_buffer)
285                                         self.ac_in_buffer = b''
286                                         self.terminator = self.terminator - lb
287                                 else:
288                                         self.collect_incoming_data (self.ac_in_buffer[:n])
289                                         self.ac_in_buffer = self.ac_in_buffer[n:]
290                                         self.terminator = 0
291                                         self.found_terminator()
292                         else:
293                                 # 3 cases:
294                                 # 1) end of buffer matches terminator exactly:
295                                 #    collect data, transition
296                                 # 2) end of buffer matches some prefix:
297                                 #    collect data to the prefix
298                                 # 3) end of buffer does not match any prefix:
299                                 #    collect data
300                                 # NOTE: this supports multiple different terminators, but
301                                 #       NOT ones that are prefixes of others...
302                                 if isinstance(self.ac_in_buffer, type(terminator)):
303                                         terminator = (terminator,)
304                                 termidx = tuple(map(self.ac_in_buffer.find, terminator))
305                                 try:
306                                         index = min(x for x in termidx if x >= 0)
307                                 except ValueError:
308                                         index = -1
309                                 if index != -1:
310                                         # we found the terminator
311                                         if index > 0:
312                                                 # don't bother reporting the empty string (source of subtle bugs)
313                                                 self.collect_incoming_data (self.ac_in_buffer[:index])
314                                         specific_terminator = terminator[termidx.index(index)]
315                                         terminator_len = len(specific_terminator)
316                                         self.ac_in_buffer = self.ac_in_buffer[index+terminator_len:]
317                                         # This does the Right Thing if the terminator is changed here.
318                                         self.found_terminator()
319                                 else:
320                                         # check for a prefix of the terminator
321                                         termidx = tuple(map(lambda a: asynchat.find_prefix_at_end (self.ac_in_buffer, a), terminator))
322                                         index = min(x for x in termidx if x >= 0)
323                                         if index:
324                                                 if index != lb:
325                                                         # we found a prefix, collect up to the prefix
326                                                         self.collect_incoming_data (self.ac_in_buffer[:-index])
327                                                         self.ac_in_buffer = self.ac_in_buffer[-index:]
328                                                 break
329                                         else:
330                                                 # no prefix, collect it all
331                                                 self.collect_incoming_data (self.ac_in_buffer)
332                                                 self.ac_in_buffer = b''
333         
334         def reset_request(self):
335                 self.incoming = []
336                 self.set_terminator( (b"\n\n", b"\r\n\r\n") )
337                 self.reading_headers = True
338         
339         def collect_incoming_data(self, data):
340                 asynchat.async_chat._collect_incoming_data(self, data)
341         
342         def __init__(self, sock, addr, map):
343                 asynchat.async_chat.__init__(self, sock=sock, map=map)
344                 self.addr = addr
345                 self.reset_request()
346         
347 setattr(JSONRPCHandler, 'doHeader_content-length', JSONRPCHandler.doHeader_content_length);
348 setattr(JSONRPCHandler, 'doHeader_x-minimum-wait', JSONRPCHandler.doHeader_x_minimum_wait);
349 setattr(JSONRPCHandler, 'doHeader_x-mining-extensions', JSONRPCHandler.doHeader_x_mining_extensions);
350
351 class JSONRPCServer(asyncore.dispatcher):
352         def __init__(self, server_address, RequestHandlerClass=JSONRPCHandler, map={}):
353                 self.logger = logging.getLogger('JSONRPCServer')
354                 
355                 asyncore.dispatcher.__init__(self, map=map)
356                 self.map = map
357                 self.create_socket(socket.AF_INET6, socket.SOCK_STREAM)
358                 self.set_reuse_addr()
359                 self.bind(server_address)
360                 self.listen(100)
361                 #self.server_address = self.socket.getsockname()
362                 
363                 self.SecretUser = None
364                 
365                 self._schLock = threading.RLock()
366                 self._sch = []
367                 
368                 self._LPClients = {}
369                 self._LPLock = threading.RLock()
370                 self._LPWaitTime = time() + 15
371                 self._LPILock = threading.Lock()
372                 self._LPI = False
373                 self._LPWLock = threading.Lock()
374         
375         def handle_accept(self):
376                 conn, addr = self.accept()
377                 h = JSONRPCHandler(conn, addr, map=self.map)
378                 h.server = self
379         
380         def schedule(self, task, startTime):
381                 with self._schLock:
382                         IL = len(self._sch)
383                         while IL and startTime < self._sch[IL-1][0]:
384                                 IL -= 1
385                         self._sch.insert(IL, (startTime, task))
386         
387         def rmSchedule(self, *tasks):
388                 tasks = list(tasks)
389                 with self._schLock:
390                         i = 0
391                         SL = len(self._sch)
392                         while i < SL:
393                                 if self._sch[i][1] in tasks:
394                                         tasks.remove(self._sch[i][1])
395                                         del self._sch[i]
396                                         if not tasks:
397                                                 break
398                                         SL -= 1
399                                 else:
400                                         i += 1
401         
402         def serve_forever(self):
403                 while True:
404                         with self._schLock:
405                                 if len(self._sch):
406                                         timeNow = time()
407                                         while True:
408                                                 timeNext = self._sch[0][0]
409                                                 if timeNow < timeNext:
410                                                         timeout = timeNext - timeNow
411                                                         break
412                                                 f = self._sch.pop(0)[1]
413                                                 f()
414                                                 if not len(self._sch):
415                                                         timeout = None
416                                                         break
417                                 else:
418                                         timeout = None
419                         try:
420                                 asyncore.loop(use_poll=True, map=self.map, timeout=timeout, count=1)
421                         except select.error:
422                                 pass
423         
424         def wakeLongpoll(self):
425                 self.logger.debug("(LPILock)")
426                 with self._LPILock:
427                         if self._LPI:
428                                 self.logger.info('Ignoring longpoll attempt while another is waiting')
429                                 return
430                         self._LPI = True
431                 
432                 th = threading.Thread(target=self._LPthread)
433                 th.daemon = True
434                 th.start()
435         
436         def _LPthread(self):
437                 self.logger.debug("(LPWLock)")
438                 with self._LPWLock:
439                         now = time()
440                         if self._LPWaitTime > now:
441                                 delay = self._LPWaitTime - now
442                                 self.logger.info('Waiting %.3g seconds to longpoll' % (delay,))
443                                 sleep(delay)
444                         
445                         self._LPI = False
446                         
447                         self._actualLP()
448         
449         def _actualLP(self):
450                 self.logger.debug("(LPLock)")
451                 with self._LPLock:
452                         C = tuple(self._LPClients.values())
453                         if not C:
454                                 self.logger.info('Nobody to longpoll')
455                                 return
456                         OC = len(C)
457                         self.logger.debug("%d clients to wake up..." % (OC,))
458                         
459                         now = time()
460                         
461                         for ic in C:
462                                 ic.wakeLongpoll()
463                         
464                         self._LPWaitTime = time()
465                         self.logger.info('Longpoll woke up %d clients in %.3f seconds' % (OC, self._LPWaitTime - now))
466                         self._LPWaitTime += 5  # TODO: make configurable: minimum time between longpolls