Trigger merkletree update when block inv received from bitcoin p2p
[bitcoin:eloipool.git] / bitcoin / node.py
1 # Eloipool - Python Bitcoin pool server
2 # Copyright (C) 2011-2012  Luke Dashjr <luke-jr+eloipool@utopios.org>
3 #
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as
6 # published by the Free Software Foundation, either version 3 of the
7 # License, or (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16
17 from .varlen import varlenDecode, varlenEncode
18 import asynchat
19 from binascii import b2a_hex
20 from collections import deque
21 import logging
22 import networkserver
23 import re
24 import socket
25 from struct import pack, unpack
26 from time import time
27 from util import dblsha, tryErr
28
29 MAX_PACKET_PAYLOAD = 0x200000
30
31 def makeNetAddr(addr):
32         timestamp = pack('<L', int(time()))
33         aIP = pack('>BBBB', *map(int, addr[0].split('.')))
34         aPort = pack('>H', addr[1])
35         return timestamp + b'\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\xff\xff' + aIP + aPort
36
37 class BitcoinLink(networkserver.SocketHandler):
38         logger = logging.getLogger('BitcoinLink')
39         
40         def __init__(self, *a, **ka):
41                 dest = ka.pop('dest', None)
42                 if dest:
43                         # Initiate outbound connection
44                         try:
45                                 if ':' not in dest[0]:
46                                         dest = ('::ffff:' + dest[0],) + tuple(x for x in dest[1:])
47                         except:
48                                 pass
49                         sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
50                         sock.connect(dest)
51                         ka['sock'] = sock
52                         ka['addr'] = dest
53                 super().__init__(*a, **ka)
54                 self.dest = dest
55                 self.sentVersion = False
56                 self.changeTask(None)  # FIXME: TEMPORARY
57                 if dest:
58                         self.pushVersion()
59         
60         def handle_readbuf(self):
61                 netid = self.server.netid
62                 while self.ac_in_buffer:
63                         if self.ac_in_buffer[:4] != netid:
64                                 p = self.ac_in_buffer.find(netid)
65                                 if p == -1:
66                                         p = asynchat.find_prefix_at_end(self.ac_in_buffer, netid)
67                                         if p:
68                                                 self.ac_in_buffer = self.ac_in_buffer[-p:]
69                                         else:
70                                                 self.ac_in_buffer = b''
71                                         break
72                                 self.ac_in_buffer = self.ac_in_buffer[p:]
73                         
74                         cmd = self.ac_in_buffer[4:0x10].rstrip(b'\0').decode('utf8')
75                         payloadLen = unpack('<L', self.ac_in_buffer[0x10:0x14])[0]
76                         if payloadLen > MAX_PACKET_PAYLOAD:
77                                 raise RuntimeError('Packet payload is too long (%d bytes)' % (payloadLen,))
78                         payloadEnd = payloadLen + 0x18
79                         if len(self.ac_in_buffer) < payloadEnd:
80                                 # Don't have the whole packet yet
81                                 break
82                         
83                         method = 'doCmd_' + cmd
84                         cksum = self.ac_in_buffer[0x14:0x18]
85                         payload = self.ac_in_buffer[0x18:payloadEnd]
86                         self.ac_in_buffer = self.ac_in_buffer[payloadEnd:]
87                         
88                         realcksum = dblsha(payload)[:4]
89                         if realcksum != cksum:
90                                 self.logger.debug('Wrong checksum on `%s\' message (%s vs actual:%s); ignoring' % (cmd, b2a_hex(cksum), b2a_hex(realcksum)))
91                                 return
92                         
93                         if hasattr(self, method):
94                                 getattr(self, method)(payload)
95         
96         def pushMessage(self, *a, **ka):
97                 self.push(self.server.makeMessage(*a, **ka))
98         
99         def makeVersion(self):
100                 r = pack('<lQq26s26sQ',
101                         60000,              # version
102                         0,                  # services bitfield
103                         int(time()),        # timestamp
104                         b'',                # FIXME: other-side address
105                         b'',                # FIXME: my-side address
106                         self.server.nonce,  # nonce
107                 )
108                 UA = self.server.userAgent
109                 r += varlenEncode(len(UA)) + UA
110                 r += b'\0\0\0\0'         # start_height
111                 return r
112         
113         def pushVersion(self):
114                 if self.sentVersion:
115                         return
116                 self.pushMessage('version', self.makeVersion())
117                 self.sentVersion = True
118         
119         def doCmd_inv(self, payload):
120                 (invCount, payload) = varlenDecode(payload)
121                 for i in range(invCount):
122                         invType = unpack('<I', payload[:4])[0]
123                         invHash = payload[4:36]
124                         payload = payload[36:]
125                         method = 'doInv_%s' % (invType,)
126                         if hasattr(self, method):
127                                 getattr(self, method)(invHash)
128         
129         def doInv_2(self, blkhash):  # MSG_BLOCK
130                 self.logger.debug('Received block inv over p2p for %s' % (b2a_hex(blkhash[::-1]).decode('ascii'),))
131                 self.server.newBlock(blkhash)
132         
133         def doCmd_version(self, payload):
134                 # FIXME: check for loopbacks
135                 self.pushVersion()
136                 # FIXME: don't send verack to ancient clients
137                 self.pushMessage('verack')
138
139 class BitcoinNode(networkserver.AsyncSocketServer):
140         logger = logging.getLogger('BitcoinNode')
141         
142         waker = True
143         
144         def __init__(self, netid, *a, **ka):
145                 ka.setdefault('RequestHandlerClass', BitcoinLink)
146                 super().__init__(*a, **ka)
147                 self.netid = netid
148                 self.userAgent = b'/BitcoinNode:0.1/'
149                 self.nonce = 0  # FIXME
150                 self._om = deque()
151         
152         def pre_schedule(self):
153                 OM = self._om
154                 while OM:
155                         m = OM.popleft()
156                         CB = 0
157                         for c in self._fd.values():
158                                 try:
159                                         c.push(m)
160                                 except:
161                                         pass
162                                 else:
163                                         CB += 1
164                         cmd = m[4:0x10].rstrip(b'\0').decode('utf8')
165                         self.logger.info('Sent `%s\' to %d nodes' % (cmd, CB))
166         
167         def makeMessage(self, cmd, payload = b''):
168                 cmd = cmd.encode('utf8')
169                 assert len(cmd) <= 12
170                 cmd += b'\0' * (12 - len(cmd))
171                 
172                 cksum = dblsha(payload)[:4]
173                 payloadLen = pack('<L', len(payload))
174                 return self.netid + cmd + payloadLen + cksum + payload
175         
176         def submitBlock(self, payload):
177                 self._om.append(self.makeMessage('block', payload))
178                 self.wakeup()
179         
180         def newBlock(self, blkhash):
181                 pass