Add filter="nearest" to some old shaders that need it.
[bsnes:xml-shaders.git] / reference / shaderreader.py
1 #!/usr/bin/python
2 from xml.etree import ElementTree as ET
3 from OpenGL.GL import *
4 from OpenGL.GL import shaders
5
6
7 SCALE_METHOD_FIXED = "fixed"
8 SCALE_METHOD_INPUT_SCALE = "input"
9 SCALE_METHOD_OUTPUT_SCALE = "output"
10 SCALE_METHOD_NONE = None
11
12
13 ELEM_SHADER = "shader"
14 ELEM_VERTEX = "vertex"
15 ELEM_FRAGMENT = "fragment"
16
17 ATTR_LANGUAGE = "language"
18 ATTR_LANGUAGE_GLSL = "GLSL"
19 ATTR_FILTER = "filter"
20 ATTR_FILTER_NEAREST = "nearest"
21 ATTR_FILTER_LINEAR = "linear"
22 ATTR_FILTER_DEFAULT = ATTR_FILTER_LINEAR
23 ATTR_SIZE = "size"
24 ATTR_SIZE_X = "size_x"
25 ATTR_SIZE_Y = "size_y"
26 ATTR_SCALE = "scale"
27 ATTR_SCALE_X = "scale_x"
28 ATTR_SCALE_Y = "scale_y"
29 ATTR_OUTSCALE = "outscale"
30 ATTR_OUTSCALE_X = "outscale_x"
31 ATTR_OUTSCALE_Y = "outscale_y"
32
33
34 class ShaderReaderException(Exception):
35         pass
36
37
38 class ShaderPass(object):
39
40         def __init__(self, elements):
41                 parts = []
42                 for elem in elements:
43                         if elem.tag == ELEM_VERTEX:
44                                 parts.append(
45                                                 shaders.compileShader(elem.text, GL_VERTEX_SHADER)
46                                         )
47                         else:
48                                 parts.append(
49                                                 shaders.compileShader(elem.text, GL_FRAGMENT_SHADER)
50                                         )
51
52                 self.programID = shaders.compileProgram(*parts)
53
54                 fragmentElem = elements[-1]
55
56                 self.filterMethod = fragmentElem.get(ATTR_FILTER, ATTR_FILTER_DEFAULT)
57                 if self.filterMethod not in (ATTR_FILTER_NEAREST, ATTR_FILTER_LINEAR):
58                         raise ShaderReaderException("'filter' attribute should be "
59                                         "'nearest' or 'linear', not %r" % (self.filterMethod,))
60
61                 self.horizontalScaleMethod = SCALE_METHOD_NONE
62                 self.horizontalScaleValue = None
63                 self.verticalScaleMethod = SCALE_METHOD_NONE
64                 self.verticalScaleValue = None
65
66                 self._set_scale(fragmentElem, ATTR_SIZE, int, True,
67                                 SCALE_METHOD_FIXED)
68                 self._set_scale(fragmentElem, ATTR_SIZE, int, False,
69                                 SCALE_METHOD_FIXED)
70                 self._set_scale(fragmentElem, ATTR_SIZE_X, int, True,
71                                 SCALE_METHOD_FIXED)
72                 self._set_scale(fragmentElem, ATTR_SIZE_Y, int, False,
73                                 SCALE_METHOD_FIXED)
74                 self._set_scale(fragmentElem, ATTR_SCALE, float, True,
75                                 SCALE_METHOD_INPUT_SCALE)
76                 self._set_scale(fragmentElem, ATTR_SCALE, float, False,
77                                 SCALE_METHOD_INPUT_SCALE)
78                 self._set_scale(fragmentElem, ATTR_SCALE_X, float, True,
79                                 SCALE_METHOD_INPUT_SCALE)
80                 self._set_scale(fragmentElem, ATTR_SCALE_Y, float, False,
81                                 SCALE_METHOD_INPUT_SCALE)
82                 self._set_scale(fragmentElem, ATTR_OUTSCALE, float, True,
83                                 SCALE_METHOD_OUTPUT_SCALE)
84                 self._set_scale(fragmentElem, ATTR_OUTSCALE, float, False,
85                                 SCALE_METHOD_OUTPUT_SCALE)
86                 self._set_scale(fragmentElem, ATTR_OUTSCALE_X, float, True,
87                                 SCALE_METHOD_OUTPUT_SCALE)
88                 self._set_scale(fragmentElem, ATTR_OUTSCALE_Y, float, False,
89                                 SCALE_METHOD_OUTPUT_SCALE)
90
91         def _set_scale(self, elem, fieldName, parser, horiz, method):
92                 if fieldName not in elem.attrib:
93                         return
94
95                 try:
96                         value = parser(elem.attrib[fieldName])
97                 except ValueError:
98                         raise ShaderReaderException("Field %s should have a value "
99                                         "compatible with %r, not %r" % (
100                                                 fieldName, parser, elem.attrib[fieldName],
101                                         )
102                                 )
103
104                 if horiz:
105                         if self.horizontalScaleValue is not None:
106                                 raise ShaderReaderException("Can't apply attribute %s "
107                                                 "because this shader pass already sets a "
108                                                 "horizontal scaling method.")
109                         self.horizontalScaleMethod = method
110                         self.horizontalScaleValue = value
111
112                 else:
113                         if self.verticalScaleValue is not None:
114                                 raise ShaderReaderException("Can't apply attribute %s "
115                                                 "because this shader pass already sets a "
116                                                 "vertical scaling method.")
117                         self.verticalScaleMethod = method
118                         self.verticalScaleValue = value
119
120         def calculateFramebufferSize(self, inputW, inputH, finalW, finalH):
121                 if self.horizontalScaleMethod == SCALE_METHOD_INPUT_SCALE:
122                         outputW = inputW * self.horizontalScaleValue
123                 elif self.horizontalScaleMethod == SCALE_METHOD_OUTPUT_SCALE:
124                         outputW = finalW * self.horizontalScaleValue
125                 elif self.horizontalScaleMethod == SCALE_METHOD_FIXED:
126                         outputW = self.horizontalScaleValue
127                 else:
128                         outputW = None
129
130                 if self.verticalScaleMethod == SCALE_METHOD_INPUT_SCALE:
131                         outputH = inputH * self.verticalScaleValue
132                 elif self.verticalScaleMethod == SCALE_METHOD_OUTPUT_SCALE:
133                         outputH = finalH * self.verticalScaleValue
134                 elif self.verticalScaleMethod == SCALE_METHOD_FIXED:
135                         outputH = self.verticalScaleValue
136                 else:
137                         outputH = None
138
139                 return outputW, outputH
140
141         def requiresImplicitPass(self):
142                 if self.horizontalScaleMethod != SCALE_METHOD_NONE:
143                         return True
144                 if self.verticalScaleMethod != SCALE_METHOD_NONE:
145                         return True
146                 return False
147
148
149 def parse_shader(data):
150         root = ET.fromstring(data)
151
152         if root.tag != ELEM_SHADER:
153                 raise ShaderReaderException("Root element of XML shader should be "
154                                 "<shader/>")
155
156         if root.get(ATTR_LANGUAGE) != ATTR_LANGUAGE_GLSL:
157                 raise ShaderReaderException("Currently, only GLSL shaders supported.")
158
159         shaderPasses = []
160         shaderParts = []
161
162         for elem in root.getchildren():
163                 if elem.tag not in (ELEM_VERTEX, ELEM_FRAGMENT):
164                         continue
165
166                 if elem.tag == ELEM_VERTEX:
167                         if shaderParts:
168                                 raise ShaderReaderException("Found a new <vertex/> element "
169                                                 "before the previous shader pass was complete. This "
170                                                 "shader pass is malformed.")
171
172                         shaderParts.append(elem)
173
174                         continue
175
176                 # We've got a fragment shader, ending this shader pass.
177                 shaderParts.append(elem)
178                 shaderPasses.append(ShaderPass(shaderParts))
179
180                 shaderParts = []
181
182         if not shaderPasses:
183                 raise ShaderReaderException("No shaders found in the shader file.")
184
185         return shaderPasses
186
187
188 if __name__ == "__main__":
189         import sys
190         from pprint import pprint
191         with open(sys.argv[1], "r") as handle:
192                 data = handle.read()
193
194         pprint(parse_shader(data))