First draft of XML Shader reference implementation.
[bsnes:xml-shaders.git] / reference / shaderreader.py
1 #!/usr/bin/python
2 from xml.etree import ElementTree as ET
3 from OpenGL.GL import *
4 from OpenGL.GL import shaders
5
6
7 SCALE_METHOD_FIXED = "fixed"
8 SCALE_METHOD_INPUT_SCALE = "input"
9 SCALE_METHOD_OUTPUT_SCALE = "output"
10 SCALE_METHOD_NONE = None
11
12
13 ELEM_SHADER = "shader"
14 ELEM_VERTEX = "vertex"
15 ELEM_FRAGMENT = "fragment"
16
17 ATTR_LANGUAGE = "language"
18 ATTR_LANGUAGE_GLSL = "GLSL"
19 ATTR_FILTER = "filter"
20 ATTR_FILTER_NEAREST = "nearest"
21 ATTR_FILTER_LINEAR = "linear"
22 ATTR_SIZE = "size"
23 ATTR_SIZE_X = "size_x"
24 ATTR_SIZE_Y = "size_y"
25 ATTR_SCALE = "scale"
26 ATTR_SCALE_X = "scale_x"
27 ATTR_SCALE_Y = "scale_y"
28 ATTR_OUTSCALE = "outscale"
29 ATTR_OUTSCALE_X = "outscale_x"
30 ATTR_OUTSCALE_Y = "outscale_y"
31
32
33 class ShaderReaderException(Exception):
34         pass
35
36
37 class ShaderPass(object):
38
39         def __init__(self, elements):
40                 parts = []
41                 for elem in elements:
42                         if elem.tag == ELEM_VERTEX:
43                                 parts.append(
44                                                 shaders.compileShader(elem.text, GL_VERTEX_SHADER)
45                                         )
46                         else:
47                                 parts.append(
48                                                 shaders.compileShader(elem.text, GL_FRAGMENT_SHADER)
49                                         )
50
51                 self.programID = shaders.compileProgram(*parts)
52
53                 fragmentElem = elements[-1]
54
55                 self.filterMethod = fragmentElem.get(ATTR_FILTER, ATTR_FILTER_NEAREST)
56                 if self.filterMethod not in (ATTR_FILTER_NEAREST, ATTR_FILTER_LINEAR):
57                         raise ShaderReaderException("'filter' attribute should be "
58                                         "'nearest' or 'linear', not %r" % (self.filterMethod,))
59
60                 self.horizontalScaleMethod = SCALE_METHOD_NONE
61                 self.horizontalScaleValue = None
62                 self.verticalScaleMethod = SCALE_METHOD_NONE
63                 self.verticalScaleValue = None
64
65                 self._set_scale(fragmentElem, ATTR_SIZE, int, True,
66                                 SCALE_METHOD_FIXED)
67                 self._set_scale(fragmentElem, ATTR_SIZE, int, False,
68                                 SCALE_METHOD_FIXED)
69                 self._set_scale(fragmentElem, ATTR_SIZE_X, int, True,
70                                 SCALE_METHOD_FIXED)
71                 self._set_scale(fragmentElem, ATTR_SIZE_Y, int, False,
72                                 SCALE_METHOD_FIXED)
73                 self._set_scale(fragmentElem, ATTR_SCALE, float, True,
74                                 SCALE_METHOD_INPUT_SCALE)
75                 self._set_scale(fragmentElem, ATTR_SCALE, float, False,
76                                 SCALE_METHOD_INPUT_SCALE)
77                 self._set_scale(fragmentElem, ATTR_SCALE_X, float, True,
78                                 SCALE_METHOD_INPUT_SCALE)
79                 self._set_scale(fragmentElem, ATTR_SCALE_Y, float, False,
80                                 SCALE_METHOD_INPUT_SCALE)
81                 self._set_scale(fragmentElem, ATTR_OUTSCALE, float, True,
82                                 SCALE_METHOD_OUTPUT_SCALE)
83                 self._set_scale(fragmentElem, ATTR_OUTSCALE, float, False,
84                                 SCALE_METHOD_OUTPUT_SCALE)
85                 self._set_scale(fragmentElem, ATTR_OUTSCALE_X, float, True,
86                                 SCALE_METHOD_OUTPUT_SCALE)
87                 self._set_scale(fragmentElem, ATTR_OUTSCALE_Y, float, False,
88                                 SCALE_METHOD_OUTPUT_SCALE)
89
90         def _set_scale(self, elem, fieldName, parser, horiz, method):
91                 if fieldName not in elem.attrib:
92                         return
93
94                 try:
95                         value = parser(elem.attrib[fieldName])
96                 except ValueError:
97                         raise ShaderReaderException("Field %s should have a value "
98                                         "compatible with %r, not %r" % (
99                                                 fieldName, parser, elem.attrib[fieldName],
100                                         )
101                                 )
102
103                 if horiz:
104                         if self.horizontalScaleValue is not None:
105                                 raise ShaderReaderException("Can't apply attribute %s "
106                                                 "because this shader pass already sets a "
107                                                 "horizontal scaling method.")
108                         self.horizontalScaleMethod = method
109                         self.horizontalScaleValue = value
110
111                 else:
112                         if self.verticalScaleValue is not None:
113                                 raise ShaderReaderException("Can't apply attribute %s "
114                                                 "because this shader pass already sets a "
115                                                 "vertical scaling method.")
116                         self.verticalScaleMethod = method
117                         self.verticalScaleValue = value
118
119         def calculateFramebufferSize(self, inputW, inputH, finalW, finalH):
120                 if self.horizontalScaleMethod == SCALE_METHOD_INPUT_SCALE:
121                         outputW = inputW * self.horizontalScaleValue
122                 elif self.horizontalScaleMethod == SCALE_METHOD_OUTPUT_SCALE:
123                         outputW = finalW * self.horizontalScaleValue
124                 elif self.horizontalScaleMethod == SCALE_METHOD_FIXED:
125                         outputW = self.horizontalScaleValue
126                 else:
127                         outputW = None
128
129                 if self.verticalScaleMethod == SCALE_METHOD_INPUT_SCALE:
130                         outputH = inputH * self.verticalScaleValue
131                 elif self.verticalScaleMethod == SCALE_METHOD_OUTPUT_SCALE:
132                         outputH = finalH * self.verticalScaleValue
133                 elif self.verticalScaleMethod == SCALE_METHOD_FIXED:
134                         outputH = self.verticalScaleValue
135                 else:
136                         outputH = None
137
138                 return outputW, outputH
139
140
141 def parse_shader(data):
142         root = ET.fromstring(data)
143
144         if root.tag != ELEM_SHADER:
145                 raise ShaderReaderException("Root element of XML shader should be "
146                                 "<shader/>")
147
148         if root.get(ATTR_LANGUAGE) != ATTR_LANGUAGE_GLSL:
149                 raise ShaderReaderException("Currently, only GLSL shaders supported.")
150
151         shaderPasses = []
152         shaderParts = []
153
154         for elem in root.getchildren():
155                 if elem.tag not in (ELEM_VERTEX, ELEM_FRAGMENT):
156                         continue
157
158                 if elem.tag == ELEM_VERTEX:
159                         if shaderParts:
160                                 raise ShaderReaderException("Found a new <vertex/> element "
161                                                 "before the previous shader pass was complete. This "
162                                                 "shader pass is malformed.")
163
164                         shaderParts.append(elem)
165
166                         continue
167
168                 # We've got a fragment shader, ending this shader pass.
169                 shaderParts.append(elem)
170                 shaderPasses.append(ShaderPass(shaderParts))
171
172                 shaderParts = []
173
174         if not shaderPasses:
175                 raise ShaderReaderException("No shaders found in the shader file.")
176
177         return shaderPasses
178
179
180 if __name__ == "__main__":
181         import sys
182         from pprint import pprint
183         with open(sys.argv[1], "r") as handle:
184                 data = handle.read()
185
186         pprint(parse_shader(data))