fix -b, -w options
[shedskin:mainline.git] / ss.py
1 #!/usr/bin/env python
2
3 # *** SHED SKIN Python-to-C++ Compiler 0.0.28 ***
4 # Copyright 2005-2008 Mark Dufour; License GNU GPL version 3 (See LICENSE)
5
6 from compiler import *
7 from compiler.ast import *
8 from compiler.visitor import *
9
10 from shared import *
11 from graph import *
12 from cpp import *
13 from infer import *
14
15 import sys, string, copy, getopt, os.path, textwrap, traceback
16 from distutils import sysconfig
17
18 from backward import *
19
20 # --- XXX description, confusion_misc?
21 def confusion_misc(): 
22     confusion = set()
23
24     # --- tuple2
25
26     # use regular tuple if both elements have the same type representation
27     cl = defclass('tuple')
28     var1 = lookupvar('first', cl)
29     var2 = lookupvar('second', cl)
30     if not var1 or not var2: return # XXX ?
31
32     for dcpa in getgx().tuple2.copy():
33         getgx().tuple2.remove(dcpa)
34
35     # use regular tuple template for tuples used in addition
36     for node in getgx().merged_all:
37         if isinstance(node, CallFunc):
38             if isinstance(node.node, Getattr) and node.node.attrname in ['__add__','__iadd__'] and not isinstance(node.args[0], Const):
39
40                 tupletypes = set()
41                 for types in [getgx().merged_all[node.node.expr], getgx().merged_all[node.args[0]]]:
42                     for t in types: 
43                         if t[0].ident == 'tuple':  
44                             if t[1] in getgx().tuple2:
45                                 getgx().tuple2.remove(t[1])
46                                 getgx().types[getgx().cnode[var1, t[1], 0]].update(getgx().types[getgx().cnode[var2, t[1], 0]])
47
48                             tupletypes.update(getgx().types[getgx().cnode[var1, t[1], 0]])
49
50 def analysis(source, testing=False):
51     if testing: 
52         gx = newgx()
53         setgx(gx)
54         ast = parse(source+'\n')
55     else:
56         gx = getgx()
57         ast = parsefile(source)
58
59     mv = None
60     setmv(mv)
61
62     # --- build dataflow graph from source code
63     gx.main_module = parse_module(gx.main_mod, ast)
64     gx.main_module.filename = gx.main_mod+'.py'
65     gx.modules[gx.main_mod] = gx.main_module
66     mv = gx.main_module.mv
67     setmv(mv)
68
69     # --- seed class_.__name__ attributes..
70     for cl in getgx().allclasses:
71         if cl.ident == 'class_':
72             var = defaultvar('__name__', cl)
73             getgx().types[inode(var)] = set([(defclass('str_'), 0)])
74
75     # --- number classes (-> constant-time subclass check)
76     number_classes()
77
78     # --- non-ifa: copy classes for each allocation site
79     for cl in getgx().allclasses:
80         if cl.ident in ['int_','float_','none', 'class_','str_']: continue
81
82         if cl.ident == 'list':
83             cl.dcpa = len(getgx().list_types)+2
84         elif cl.ident == '__iter': # XXX huh
85             pass
86         else:
87             cl.dcpa = 2
88
89         for dcpa in range(1, cl.dcpa): 
90             class_copy(cl, dcpa)
91
92     var = defaultvar('unit', defclass('str_'))
93     getgx().types[inode(var)] = set([(defclass('str_'), 0)])
94
95     #printstate()
96     #printconstraints()
97
98     # --- filters
99     #merge = merged(getgx().types)
100     #apply_filters(getgx().types.copy(), merge)
101    
102     # --- cartesian product algorithm & iterative flow analysis
103     iterative_dataflow_analysis()
104     #propagate()
105
106     #merge = merged(getgx().types)
107     #apply_filters(getgx().types, merge)
108
109     for cl in getgx().allclasses:
110         for name in cl.vars:
111             if name in cl.parent.vars and not name.startswith('__'):
112                 error("instance variable '%s' of class '%s' shadows class variable" % (name, cl.ident))
113
114     getgx().merged_all = merged(getgx().types) #, inheritance=True)
115     getgx().merge_dcpa = merged(getgx().types, dcpa=True)
116
117     mv = getgx().main_module.mv
118     setmv(mv)
119     propagate() # XXX remove 
120
121     getgx().merged_all = merged(getgx().types) #, inheritance=True)
122     getgx().merged_inh = merged(getgx().types, inheritance=True)
123
124     # --- determine template parameters
125     template_parameters()
126
127     # --- detect inheritance stuff
128     upgrade_variables()
129     getgx().merged_all = merged(getgx().types)
130
131     getgx().merged_inh = merged(getgx().types, inheritance=True)
132
133     analyze_virtuals()
134
135     # --- determine integer/float types that cannot be unboxed
136     confused_vars()
137     # --- check other sources of confusion
138     confusion_misc() 
139
140     getgx().merge_dcpa = merged(getgx().types, dcpa=True)
141     getgx().merged_all = merged(getgx().types) #, inheritance=True) # XXX
142
143     # --- determine which classes need an __init__ method
144     for node, types in getgx().merged_all.items():
145         if isinstance(node, CallFunc):
146             objexpr, ident, _ , method_call, _, _, _ = analyze_callfunc(node)
147             if method_call and ident == '__init__':
148                 for t in getgx().merged_all[objexpr]:
149                     t[0].has_init = True
150
151     # --- determine which classes need copy, deepcopy methods
152     if 'copy' in getgx().modules:
153         func = getgx().modules['copy'].funcs['copy']
154         var = func.vars[func.formals[0]]
155         for cl in set([t[0] for t in getgx().merged_inh[var]]):
156             cl.has_copy = True # XXX transitive, modeling
157
158         func = getgx().modules['copy'].funcs['deepcopy']
159         var = func.vars[func.formals[0]]
160         for cl in set([t[0] for t in getgx().merged_inh[var]]):
161             cl.has_deepcopy = True # XXX transitive, modeling
162
163     # --- add inheritance relationships for non-original Nodes (and tempvars?); XXX register more, right solution?
164     for func in getgx().allfuncs:
165         #if not func.mv.module.builtin and func.ident == '__init__':
166         if func in getgx().inheritance_relations: 
167             #print 'inherited from', func, getgx().inheritance_relations[func]
168             for inhfunc in getgx().inheritance_relations[func]:
169                 for a, b in zip(func.registered, inhfunc.registered):
170                     #print a, '->', b 
171                     inherit_rec(a, b)
172
173                 for a, b in zip(func.registered_tempvars, inhfunc.registered_tempvars): # XXX more general
174                     getgx().inheritance_tempvars.setdefault(a, []).append(b)
175
176     getgx().merged_inh = merged(getgx().types, inheritance=True) # XXX why X times
177
178     # --- finally, generate C++ code and Makefiles.. :-)
179
180     #printstate()
181     #printconstraints()
182     generate_code()
183     #generate_bindings()
184
185     #print 'cnode!'
186     #for (a,b) in getgx().cnode.items():
187     #    print a, b
188    # for (a,b) in getgx().types.items():
189    #     print a, b
190
191     # error for dynamic expression (XXX before codegen)
192     for node in getgx().merged_all:
193         if isinstance(node, Node) and not isinstance(node, AssAttr) and not inode(node).mv.module.builtin:
194             typesetreprnew(node, inode(node).parent) 
195
196     return gx
197
198 # --- annotate original code; use above function to merge results to original code dimensions
199 def annotate():
200     def paste(expr, text):
201         if not expr.lineno: return
202         if (expr,0,0) in getgx().cnode and inode(expr).mv != mv: return # XXX
203         line = source[expr.lineno-1][:-1]
204         if '#' in line: line = line[:line.index('#')]
205         if text != '':
206             text = '# '+text
207         line = string.rstrip(line)
208         if text != '' and len(line) < 40: line += (40-len(line))*' '
209         source[expr.lineno-1] = line 
210         if text: source[expr.lineno-1] += ' ' + text
211         source[expr.lineno-1] += '\n'
212
213     for module in getgx().modules.values(): 
214         mv = module.mv
215         setmv(mv)
216
217         # merge type information for nodes in module XXX inheritance across modules?
218         merge = merged([n for n in getgx().types if n.mv == mv], inheritance=True)
219
220         source = open(module.filename).readlines()
221
222         # --- constants/names/attributes
223         for expr in merge:
224             if isinstance(expr, (Const, Name)):
225                 paste(expr, typesetreprnew(expr, inode(expr).parent, False))
226         for expr in merge:
227             if isinstance(expr, Getattr):
228                 paste(expr, typesetreprnew(expr, inode(expr).parent, False))
229         for expr in merge:
230             if isinstance(expr, (Tuple,List,Dict)):
231                 paste(expr, typesetreprnew(expr, inode(expr).parent, False))
232
233         # --- instance variables
234         funcs = getmv().funcs.values()
235         for cl in getmv().classes.values():
236             labels = [var.name+': '+typesetreprnew(var, cl, False) for var in cl.vars.values() if var in merge and merge[var] and not var.name.startswith('__')] 
237             if labels: paste(cl.node, ', '.join(labels))
238             funcs += cl.funcs.values()
239
240         # --- function variables
241         for func in funcs:
242             if not func.node or func.node in getgx().inherited: continue
243             vars = [func.vars[f] for f in func.formals]
244             labels = [var.name+': '+typesetreprnew(var, func, False) for var in vars if not var.name.startswith('__')]
245             paste(func.node, ', '.join(labels))
246
247         # --- callfuncs
248         for callfunc, _ in getmv().callfuncs:
249             if isinstance(callfunc.node, Getattr):
250                 if not isinstance(callfunc.node, (fakeGetattr, fakeGetattr2, fakeGetattr3)):
251                     paste(callfunc.node.expr, typesetreprnew(callfunc, inode(callfunc).parent, False))
252             else: 
253                 paste(callfunc.node, typesetreprnew(callfunc, inode(callfunc).parent, False))
254
255         # --- higher-level crap (listcomps, returns, assignments, prints)
256         for expr in merge: 
257             if isinstance(expr, ListComp):
258                 paste(expr, typesetreprnew(expr, inode(expr).parent, False))
259             elif isinstance(expr, Return):
260                 paste(expr, typesetreprnew(expr.value, inode(expr).parent, False))
261             elif isinstance(expr, (AssTuple, AssList)):
262                 paste(expr, typesetreprnew(expr, inode(expr).parent, False))
263             elif isinstance(expr, (Print,Printnl)):
264                 paste(expr, ', '.join([typesetreprnew(child, inode(child).parent, False) for child in expr.nodes]))
265
266         # --- assignments
267         for expr in merge: 
268             if isinstance(expr, Assign):
269                 pairs = assign_rec(expr.nodes[0], expr.expr)
270                 paste(expr, ', '.join([typesetreprnew(r, inode(r).parent, False) for (l,r) in pairs]))
271             elif isinstance(expr, AugAssign):
272                 paste(expr, typesetreprnew(expr.expr, inode(expr).parent, False))
273
274         # --- output annotated file (skip if no write permission)
275         if not module.builtin: 
276             try:
277                 out = open(os.path.join(getgx().output_dir, module.filename[:-3]+'.ss.py'),'w')
278                 out.write(''.join(source))
279                 out.close()
280             except IOError:
281                 pass
282
283 # --- generate C++ and Makefiles
284 def generate_code():
285     print '[generating c++ code..]'
286
287     ident = getgx().main_module.ident 
288
289     if sys.platform == 'win32':
290         pyver = '%d%d' % sys.version_info[:2]
291     else:
292         pyver = sysconfig.get_config_var('VERSION')
293
294         includes = '-I' + sysconfig.get_python_inc() + ' ' + \
295                    '-I' + sysconfig.get_python_inc(plat_specific=True)
296
297         ldflags = sysconfig.get_config_var('LIBS') + ' ' + \
298                   sysconfig.get_config_var('SYSLIBS') + ' ' + \
299                   '-lpython'+pyver 
300         if not sysconfig.get_config_var('Py_ENABLE_SHARED'):
301             ldflags += ' -L' + sysconfig.get_config_var('LIBPL')
302
303     if getgx().extension_module: 
304         if sys.platform == 'win32': ident += '.pyd'
305         else: ident += '.so'
306
307     # --- generate C++ files
308     mods = getgx().modules.values()
309     for module in mods:
310         if not module.builtin:
311             # create output directory if necessary
312             if getgx().output_dir:
313                 output_dir = os.path.join(getgx().output_dir, module.dir)
314                 if not os.path.exists(output_dir):
315                     os.makedirs(output_dir)
316
317             gv = generateVisitor(module)
318             mv = module.mv 
319             setmv(mv)
320             gv.func_pointers(False)
321             walk(module.ast, gv)
322             gv.out.close()
323             gv.header_file()
324             gv.out.close()
325             gv.insert_consts(declare=False)
326             gv.insert_consts(declare=True)
327             gv.insert_includes()
328
329     # --- generate Makefile
330     makefile = file(os.path.join(getgx().output_dir, 'Makefile'), 'w')
331
332     cppfiles = ' '.join([m.filename[:-3].replace(' ', '\ ')+'.cpp' for m in mods])
333     hppfiles = ' '.join([m.filename[:-3].replace(' ', '\ ')+'.hpp' for m in mods])
334
335     # import flags
336     if getgx().flags: flags = getgx().flags
337     elif os.path.isfile('FLAGS'): flags = 'FLAGS'
338     else: flags = connect_paths(getgx().sysdir, 'FLAGS')
339
340     for line in file(flags):
341         line = line[:-1]
342
343         if line[:line.find('=')].strip() == 'CCFLAGS': 
344             line += ' -I. -I'+getgx().libdir.replace(' ', '\ ')
345             if sys.platform == 'darwin' and os.path.isdir('/opt/local/include'): 
346                 line += ' -I/opt/local/include' # macports... and fink?
347             if not getgx().wrap_around_check: line += ' -DNOWRAP' 
348             if getgx().bounds_checking: line += ' -DBOUNDS' 
349             if getgx().extension_module: 
350                 if sys.platform == 'win32': line += ' -Ic:/Python%s/include -D__SS_BIND' % pyver
351                 else: line += ' -g -fPIC -D__SS_BIND ' + includes
352
353         elif line[:line.find('=')].strip() == 'LFLAGS': 
354             if sys.platform == 'darwin' and os.path.isdir('/opt/local/lib'):  
355                 line += ' -L/opt/local/lib'
356             if getgx().extension_module: 
357                 if sys.platform == 'win32': line += ' -shared -Lc:/Python%s/libs -lpython%s' % (pyver, pyver) 
358                 elif sys.platform == 'darwin': line += ' -bundle -Xlinker -dynamic ' + ldflags
359                 elif sys.platform == 'sunos5': line += ' -shared -Xlinker ' + ldflags
360                 else: line += ' -shared -Xlinker -export-dynamic ' + ldflags
361
362             if 're' in [m.ident for m in mods]:
363                 line += ' -lpcre'
364             if 'socket' in [m.ident for m in mods]:
365                 if sys.platform == 'win32':
366                     line += ' -lws2_32'
367                 elif sys.platform == 'sunos5':
368                     line += ' -lsocket -lnsl'
369
370         print >>makefile, line
371     print >>makefile
372
373     print >>makefile, 'all:\t'+ident+'\n'
374
375     if not getgx().extension_module:
376         print >>makefile, 'run:\tall'
377         print >>makefile, '\t./'+ident+'\n'
378
379         print >>makefile, 'full:'
380         print >>makefile, '\tshedskin '+ident+'; $(MAKE) run\n'
381
382     print >>makefile, 'CPPFILES='+cppfiles
383     print >>makefile, 'HPPFILES='+hppfiles+'\n'
384
385     print >>makefile, ident+':\t$(CPPFILES) $(HPPFILES)'
386     print >>makefile, '\t$(CC) $(CCFLAGS) $(CPPFILES) $(LFLAGS) -o '+ident+'\n'
387
388     if sys.platform == 'win32':
389         ident += '.exe'
390     print >>makefile, 'clean:'
391     print >>makefile, '\trm '+ident
392
393     makefile.close()
394
395 def usage():
396     print """Usage: shedskin [OPTION]... FILE
397
398  -a --noann             Don't output annotated source code
399  -b --bounds            Enable bounds checking
400  -d --dir               Specify alternate directory for output files
401  -e --extmod            Generate extension module
402  -f --flags             Provide alternate Makefile flags
403  -i --infinite          Try to avoid infinite analysis time 
404  -w --nowrap            Disable wrap-around checking 
405 """
406     sys.exit()
407
408 def main():
409     gx = newgx()
410     setgx(gx)
411
412     print '*** SHED SKIN Python-to-C++ Compiler 0.0.28 ***'
413     print 'Copyright 2005-2008 Mark Dufour; License GNU GPL version 3 (See LICENSE)'
414     print '(Please send bug reports here: mark.dufour@gmail.com)'
415     print
416     
417     # --- some checks
418     major, minor = sys.version_info[:2]
419     if major != 2 or minor < 3:
420         print '*ERROR* Shed Skin is not compatible with this version of Python'
421         sys.exit()
422
423 #    if sys.platform == 'win32' and os.path.isdir('c:/mingw'):
424 #        print '*ERROR* please rename or remove c:/mingw, as it conflicts with Shed Skin'
425 #        sys.exit()
426
427     # --- command-line options
428     try:
429         opts, args = getopt.getopt(sys.argv[1:], 'heibwf:ad:', ['infinite', 'extmod', 'bounds', 'nowrap', 'flags=', 'dir='])
430     except getopt.GetoptError:
431         usage()
432     
433     for o, a in opts:
434         if o in ['-h', '--help']: usage()
435         if o in ['-b', '--bounds']: getgx().bounds_checking = True
436         if o in ['-e', '--extmod']: getgx().extension_module = True
437         if o in ['-a', '--noann']: getgx().annotation = False
438         if o in ['-i', '--infinite']: getgx().avoid_loops = True
439         if o in ['-d', '--dir']: getgx().output_dir = a
440         if o in ['-w', '--nowrap']: getgx().wrap_around_check = False
441         if o in ['-f', '--flags']: 
442             if not os.path.isfile(a): 
443                 print "*ERROR* no such file: '%s'" % a
444                 sys.exit()
445             getgx().flags = a
446
447     # --- argument
448     if len(args) != 1:
449         usage()
450     name = args[0]
451     if not name.endswith('.py'):
452         name += '.py'
453     if not os.path.isfile(name): 
454         print "*ERROR* no such file: '%s'" % name
455         sys.exit()
456     gx.main_mod = name[:-3]
457         
458     # --- analyze & annotate
459     analysis(name)
460     if getgx().annotation:
461         annotate()
462
463 if __name__ == '__main__':
464     main()