Skip to content

Commit 18b6357

Browse files
committed
updated to new webgpu pipeline
2 parents 2cadada + 20da50e commit 18b6357

11 files changed

Lines changed: 2237 additions & 477 deletions

AGENTS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ from ncca.ngl.mat4 import Mat4
8383
- **Variables**: snake_case (`camera_position`, `shader_program`)
8484
- **Constants**: UPPER_SNAKE_CASE (`MAX_LIGHTS`, `DEFAULT_SHADER`)
8585
- **Private members**: Single underscore (`_data`, `_internal_method`)
86+
- **Colour** is the correct spelling for **color** when referring to variable names
8687

8788
### Type Hints
8889
```python

src/ncca/ngl/webgpu/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
__author__ = "Jon Macey jmacey@bournemouth.ac.uk"
99
__license__ = "MIT"
1010

11+
from .pipeline_factory import PipelineFactory, PipelineType
1112
from .webgpu_constants import NGLToWebGPU
1213
from .webgpu_widget import WebGPUWidget
1314

14-
__all__ = ["WebGPUWidget", "NGLToWebGPU"]
15+
__all__ = ["WebGPUWidget", "NGLToWebGPU", "PipelineFactory", "PipelineType"]

src/ncca/ngl/webgpu/__main__.py

Lines changed: 367 additions & 3 deletions
Large diffs are not rendered by default.
Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
"""
2+
Abstract base classes for WebGPU rendering pipelines.
3+
Provides common functionality for buffer management, pipeline creation, and rendering.
4+
"""
5+
6+
from abc import ABC, abstractmethod
7+
from typing import Any, Dict, List, Optional, Tuple, Union
8+
9+
import numpy as np
10+
import wgpu
11+
12+
from .webgpu_constants import NGLToWebGPU
13+
14+
15+
class BaseWebGPUPipeline(ABC):
16+
"""
17+
Abstract base class for all WebGPU rendering pipelines.
18+
19+
Provides common functionality for:
20+
- Buffer management and creation
21+
- Pipeline configuration
22+
- Uniform buffer handling
23+
- Resource cleanup
24+
"""
25+
26+
def __init__(
27+
self,
28+
device: wgpu.GPUDevice,
29+
texture_format: wgpu.TextureFormat = wgpu.TextureFormat.rgba8unorm,
30+
depth_format: wgpu.TextureFormat = wgpu.TextureFormat.depth24plus,
31+
msaa_sample_count: int = 4,
32+
data_type: str = "Vec3",
33+
stride: int = 0,
34+
):
35+
"""
36+
Initialize base pipeline.
37+
38+
Args:
39+
device: WebGPU device
40+
texture_format: Color attachment format
41+
depth_format: Depth attachment format
42+
msaa_sample_count: Number of MSAA samples
43+
data_type: Vertex data type (e.g., "Vec3", "Vec2")
44+
stride: Vertex buffer stride. If 0, inferred from data_type
45+
"""
46+
self.device = device
47+
self.texture_format = texture_format
48+
self.depth_format = depth_format
49+
self.msaa_sample_count = msaa_sample_count
50+
self._data_type = data_type
51+
52+
if stride != 0:
53+
self._stride = stride
54+
else:
55+
self._stride = NGLToWebGPU.stride_from_type(self._data_type)
56+
57+
# Core pipeline resources
58+
self.pipeline: Optional[wgpu.GPURenderPipeline] = None
59+
self.uniform_buffer: Optional[wgpu.GPUBuffer] = None
60+
self.bind_group: Optional[wgpu.GPUBindGroup] = None
61+
62+
# Initialize uniform data structure
63+
self.uniform_data = np.zeros((), dtype=self.get_dtype())
64+
self._set_default_uniforms()
65+
66+
# Create the pipeline
67+
self._create_pipeline()
68+
69+
@abstractmethod
70+
def get_dtype(self) -> np.dtype:
71+
"""Get the numpy dtype for the uniform buffer structure."""
72+
pass
73+
74+
@abstractmethod
75+
def _get_shader_code(self) -> str:
76+
"""Get the WGSL shader code for this pipeline."""
77+
pass
78+
79+
@abstractmethod
80+
def _get_vertex_buffer_layouts(self) -> List[Dict[str, Any]]:
81+
"""Get vertex buffer layout configurations for the pipeline."""
82+
pass
83+
84+
@abstractmethod
85+
def _get_primitive_topology(self) -> wgpu.PrimitiveTopology:
86+
"""Get the primitive topology for the pipeline."""
87+
pass
88+
89+
@abstractmethod
90+
def _set_default_uniforms(self) -> None:
91+
"""Set default values for uniform data."""
92+
pass
93+
94+
@abstractmethod
95+
def _get_pipeline_label(self) -> str:
96+
"""Get the label for the pipeline."""
97+
pass
98+
99+
def _create_pipeline(self) -> None:
100+
"""Create the render pipeline and associated resources."""
101+
# Load shader
102+
shader_module = self.device.create_shader_module(code=self._get_shader_code())
103+
104+
# Create render pipeline
105+
self.pipeline = self.device.create_render_pipeline(
106+
label=self._get_pipeline_label(),
107+
layout="auto",
108+
vertex={
109+
"module": shader_module,
110+
"entry_point": "vertex_main",
111+
"buffers": self._get_vertex_buffer_layouts(),
112+
},
113+
fragment={
114+
"module": shader_module,
115+
"entry_point": "fragment_main",
116+
"targets": [{"format": self.texture_format}],
117+
},
118+
primitive={"topology": self._get_primitive_topology()},
119+
depth_stencil={
120+
"format": self.depth_format,
121+
"depth_write_enabled": True,
122+
"depth_compare": wgpu.CompareFunction.less,
123+
},
124+
multisample={"count": self.msaa_sample_count},
125+
)
126+
127+
# Create uniform buffer
128+
self.uniform_buffer = self.device.create_buffer_with_data(
129+
data=self.uniform_data.tobytes(),
130+
usage=int(wgpu.BufferUsage.UNIFORM | wgpu.BufferUsage.COPY_DST),
131+
label=f"{self._get_pipeline_label()}_uniform_buffer",
132+
)
133+
134+
# Create bind group
135+
bind_group_layout = self.pipeline.get_bind_group_layout(0)
136+
self.bind_group = self.device.create_bind_group(
137+
layout=bind_group_layout,
138+
entries=[
139+
{
140+
"binding": 0,
141+
"resource": {"buffer": self.uniform_buffer},
142+
}
143+
],
144+
)
145+
146+
def _create_or_update_buffer(
147+
self,
148+
current_buffer: Optional[wgpu.GPUBuffer],
149+
data: Union[np.ndarray, wgpu.GPUBuffer],
150+
usage: wgpu.BufferUsage,
151+
buffer_label: str,
152+
) -> Tuple[Optional[wgpu.GPUBuffer], int]:
153+
"""
154+
Create or update a GPU buffer with new data.
155+
156+
Args:
157+
current_buffer: Existing buffer (may be None)
158+
data: New data (numpy array or GPU buffer)
159+
usage: Buffer usage flags
160+
buffer_label: Label for the buffer
161+
162+
Returns:
163+
Tuple of (buffer, data_size)
164+
"""
165+
if isinstance(data, wgpu.GPUBuffer):
166+
# Use provided buffer directly
167+
return data, data.size
168+
169+
# Handle numpy array
170+
data_bytes = data.astype(np.float32).tobytes()
171+
data_size = len(data_bytes)
172+
173+
# Create new buffer if needed or existing one is too small
174+
if current_buffer is None or current_buffer.size < data_size:
175+
if current_buffer:
176+
current_buffer.destroy()
177+
buffer = self.device.create_buffer_with_data(
178+
data=data_bytes,
179+
usage=usage,
180+
label=buffer_label,
181+
)
182+
return buffer, data_size
183+
else:
184+
# Update existing buffer
185+
self.device.queue.write_buffer(current_buffer, 0, data_bytes)
186+
return current_buffer, data_size
187+
188+
def _process_vertex_data(
189+
self,
190+
data: Optional[Union[np.ndarray, wgpu.GPUBuffer]],
191+
default_value: Optional[np.ndarray] = None,
192+
padding_size: Optional[int] = None,
193+
buffer_label: str = "vertex_buffer",
194+
) -> Optional[Union[wgpu.GPUBuffer, Tuple[wgpu.GPUBuffer, int]]]:
195+
"""
196+
Process vertex data, handling numpy arrays, GPU buffers, and defaults.
197+
198+
Args:
199+
data: Input data (numpy array, GPU buffer, or None)
200+
default_value: Default value if data is None
201+
padding_size: Size to pad arrays to (for alignment)
202+
buffer_label: Label for created buffers
203+
204+
Returns:
205+
Processed buffer(s) or None
206+
"""
207+
if data is None and default_value is not None:
208+
data = default_value
209+
210+
if data is None:
211+
return None
212+
213+
if isinstance(data, wgpu.GPUBuffer):
214+
return data
215+
216+
# Handle numpy array
217+
if padding_size:
218+
# Pad array to specified size
219+
if data.ndim == 1:
220+
padded_data = np.zeros(padding_size, dtype=np.float32)
221+
padded_data[: len(data)] = data.astype(np.float32)
222+
else:
223+
padded_data = np.zeros((data.shape[0], padding_size), dtype=np.float32)
224+
padded_data[:, : data.shape[1]] = data.astype(np.float32)
225+
data = padded_data
226+
227+
buffer, _ = self._create_or_update_buffer(
228+
None, # Always create new for processed data
229+
data,
230+
wgpu.BufferUsage.VERTEX | wgpu.BufferUsage.COPY_DST,
231+
buffer_label,
232+
)
233+
return buffer
234+
235+
@abstractmethod
236+
def set_data(self, **kwargs) -> None:
237+
"""
238+
Set rendering data (vertices, colors, etc.).
239+
240+
Args:
241+
**kwargs: Pipeline-specific data parameters
242+
"""
243+
pass
244+
245+
@abstractmethod
246+
def update_uniforms(self, **kwargs) -> None:
247+
"""
248+
Update uniform buffer values.
249+
250+
Args:
251+
**kwargs: Pipeline-specific uniform parameters
252+
"""
253+
pass
254+
255+
@abstractmethod
256+
def render(self, render_pass: wgpu.GPURenderPassEncoder, **kwargs) -> None:
257+
"""
258+
Render using this pipeline.
259+
260+
Args:
261+
render_pass: Active render pass encoder
262+
**kwargs: Pipeline-specific render parameters
263+
"""
264+
pass
265+
266+
def cleanup(self) -> None:
267+
"""Release pipeline resources. Can be overridden for additional cleanup."""
268+
if self.uniform_buffer:
269+
self.uniform_buffer.destroy()
270+
271+
272+
class BasePointPipeline(BaseWebGPUPipeline):
273+
"""
274+
Base class for point rendering pipelines.
275+
276+
Provides common functionality for:
277+
- Point billboarding
278+
- Quad generation
279+
- Circle clipping in fragment shader
280+
"""
281+
282+
def _get_primitive_topology(self) -> wgpu.PrimitiveTopology:
283+
"""Points are rendered as triangle strips for quad generation."""
284+
return wgpu.PrimitiveTopology.triangle_strip
285+
286+
def _get_default_vertex_layouts(
287+
self, has_colour_buffer: bool = False
288+
) -> List[Dict[str, Any]]:
289+
"""
290+
Get default vertex buffer layouts for point rendering.
291+
292+
Args:
293+
has_colour_buffer: Whether to include colour buffer layout
294+
295+
Returns:
296+
List of vertex buffer layout configurations
297+
"""
298+
layouts = [
299+
{
300+
"array_stride": self._stride,
301+
"step_mode": "instance",
302+
"attributes": [
303+
{
304+
"format": NGLToWebGPU.vertex_format(self._data_type),
305+
"offset": 0,
306+
"shader_location": 0,
307+
},
308+
],
309+
},
310+
]
311+
312+
if has_colour_buffer:
313+
layouts.append(
314+
{
315+
"array_stride": NGLToWebGPU.stride_from_type("Vec3"),
316+
"step_mode": "instance",
317+
"attributes": [
318+
{
319+
"format": NGLToWebGPU.vertex_format("Vec3"),
320+
"offset": 0,
321+
"shader_location": 1,
322+
},
323+
],
324+
}
325+
)
326+
327+
return layouts
328+
329+
def _render_points(
330+
self,
331+
render_pass: wgpu.GPURenderPassEncoder,
332+
position_buffer: wgpu.GPUBuffer,
333+
colour_buffer: Optional[wgpu.GPUBuffer] = None,
334+
num_points: Optional[int] = None,
335+
) -> None:
336+
"""
337+
Common point rendering implementation.
338+
339+
Args:
340+
render_pass: Active render pass encoder
341+
position_buffer: Buffer containing point positions
342+
colour_buffer: Optional buffer containing point colours
343+
num_points: Number of points to render
344+
"""
345+
if position_buffer is None:
346+
return
347+
348+
count = num_points if num_points is not None else getattr(self, "num_points", 0)
349+
350+
render_pass.set_pipeline(self.pipeline)
351+
render_pass.set_bind_group(0, self.bind_group, [], 0, 999999)
352+
render_pass.set_vertex_buffer(0, position_buffer)
353+
354+
if colour_buffer:
355+
render_pass.set_vertex_buffer(1, colour_buffer)
356+
357+
# 4 vertices per quad for point rendering
358+
render_pass.draw(4, count)

0 commit comments

Comments
 (0)