Skip to content

Commit a18139d

Browse files
committed
updated single colour pipelines to add set_colour method
1 parent d8b84f3 commit a18139d

3 files changed

Lines changed: 85 additions & 11 deletions

File tree

src/ncca/ngl/webgpu/line_pipeline.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Handles line rendering with customizable color and projection.
44
"""
55

6-
from typing import Optional
6+
from typing import Optional, Tuple
77

88
import numpy as np
99
import wgpu
@@ -101,9 +101,11 @@ def __init__(
101101

102102
def get_dtype(self) -> np.dtype:
103103
"""Get the data type of the pipeline."""
104-
return np.dtype([
105-
("MVP", "float32", (4, 4)),
106-
])
104+
return np.dtype(
105+
[
106+
("MVP", "float32", (4, 4)),
107+
]
108+
)
107109

108110
def _get_shader_code(self) -> str:
109111
"""Get the WGSL shader code for this pipeline."""
@@ -193,7 +195,9 @@ def update_uniforms(self, **kwargs) -> None:
193195
if "mvp" in kwargs and kwargs["mvp"] is not None:
194196
self.uniform_data["MVP"] = kwargs["mvp"]
195197

196-
self.device.queue.write_buffer(self.uniform_buffer, 0, self.uniform_data.tobytes())
198+
self.device.queue.write_buffer(
199+
self.uniform_buffer, 0, self.uniform_data.tobytes()
200+
)
197201

198202
def render(self, render_pass: wgpu.GPURenderPassEncoder, **kwargs) -> None:
199203
"""
@@ -247,9 +251,10 @@ def __init__(
247251
msaa_sample_count: int = 4,
248252
stride: int = 0,
249253
topology: wgpu.PrimitiveTopology = wgpu.PrimitiveTopology.line_list,
254+
colour: Tuple[float, float, float] = (1.0, 1.0, 1.0),
250255
):
251256
"""
252-
Initialize the line rendering pipeline.
257+
Initialize line rendering pipeline.
253258
254259
Args:
255260
device: WebGPU device
@@ -258,10 +263,12 @@ def __init__(
258263
msaa_sample_count: Number of MSAA samples
259264
stride: The stride of the vertex buffer. If 0, it is inferred from data_type.
260265
topology: Primitive topology (line_list or line_strip)
266+
colour: RGB color tuple for lines (default white)
261267
"""
262268
# Pipeline-specific buffer tracking
263269
self.vertex_buffer: Optional[wgpu.GPUBuffer] = None
264270
self.num_vertices: int = 0
271+
self._colour = np.array(colour, dtype=np.float32)
265272

266273
super().__init__(
267274
device=device,
@@ -275,9 +282,13 @@ def __init__(
275282

276283
def get_dtype(self) -> np.dtype:
277284
"""Get the data type of the pipeline."""
278-
return np.dtype([
279-
("MVP", "float32", (4, 4)),
280-
])
285+
return np.dtype(
286+
[
287+
("MVP", "float32", (4, 4)),
288+
("Colour", "float32", 3),
289+
("padding", "float32", 1),
290+
]
291+
)
281292

282293
def _get_shader_code(self) -> str:
283294
"""Get the WGSL shader code for this pipeline."""
@@ -335,11 +346,35 @@ def update_uniforms(self, **kwargs) -> None:
335346
Args:
336347
**kwargs: Pipeline-specific uniform parameters
337348
- mvp: 4x4 projection matrix
349+
- colour: RGB color tuple
338350
"""
339351
if "mvp" in kwargs and kwargs["mvp"] is not None:
340352
self.uniform_data["MVP"] = kwargs["mvp"]
341353

342-
self.device.queue.write_buffer(self.uniform_buffer, 0, self.uniform_data.tobytes())
354+
if "colour" in kwargs and kwargs["colour"] is not None:
355+
colour = np.array(kwargs["colour"], dtype=np.float32)
356+
if colour.shape == (3,):
357+
self.uniform_data["Colour"] = colour
358+
self._colour = colour
359+
360+
self.device.queue.write_buffer(
361+
self.uniform_buffer, 0, self.uniform_data.tobytes()
362+
)
363+
364+
def set_color(self, colour: Tuple[float, float, float]) -> None:
365+
"""
366+
Set the color for the lines.
367+
368+
Args:
369+
colour: RGB color tuple
370+
"""
371+
colour_array = np.array(colour, dtype=np.float32)
372+
if colour_array.shape == (3,):
373+
self.uniform_data["Colour"] = colour_array
374+
self._colour = colour_array
375+
self.device.queue.write_buffer(
376+
self.uniform_buffer, 0, self.uniform_data.tobytes()
377+
)
343378

344379
def render(self, render_pass: wgpu.GPURenderPassEncoder, **kwargs) -> None:
345380
"""

src/ncca/ngl/webgpu/pipeline_shaders.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@
295295
// LineShader.wgsl
296296
struct Uniforms {
297297
MVP: mat4x4<f32>,
298+
Colour: vec3<f32>,
299+
padding: f32,
298300
};
299301
300302
@binding(0) @group(0) var<uniform> uniforms: Uniforms;
@@ -306,7 +308,7 @@
306308
307309
@fragment
308310
fn fragment_main() -> @location(0) vec4<f32> {
309-
return vec4<f32>(1.0, 1.0, 1.0, 1.0); // Grey color for grid lines
311+
return vec4<f32>(uniforms.Colour, 1.0);
310312
}
311313
"""
312314

tests/test_triangle_pipeline_coverage.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,43 @@ def test_triangle_pipeline_single_colour_color_functionality(webgpu_device):
8787
)
8888

8989

90+
def test_line_pipeline_single_colour_color_functionality(webgpu_device):
91+
"""Test LinePipelineSingleColour color setting functionality."""
92+
from ncca.ngl.webgpu.line_pipeline import LinePipelineSingleColour
93+
94+
# Test default color (white)
95+
pipeline = LinePipelineSingleColour(webgpu_device)
96+
assert np.array_equal(pipeline._colour, np.array([1.0, 1.0, 1.0], dtype=np.float32))
97+
98+
# Test setting color via constructor
99+
red_color = (1.0, 0.0, 0.0)
100+
pipeline_red = LinePipelineSingleColour(webgpu_device, colour=red_color)
101+
assert np.array_equal(pipeline_red._colour, np.array(red_color, dtype=np.float32))
102+
103+
# Test setting color via update_uniforms
104+
green_color = (0.0, 1.0, 0.0)
105+
pipeline.update_uniforms(colour=green_color)
106+
assert np.array_equal(pipeline._colour, np.array(green_color, dtype=np.float32))
107+
108+
# Test setting color via set_color method
109+
blue_color = (0.0, 0.0, 1.0)
110+
pipeline.set_color(blue_color)
111+
assert np.array_equal(pipeline._colour, np.array(blue_color, dtype=np.float32))
112+
113+
# Test uniform buffer structure includes color
114+
dtype = pipeline.get_dtype()
115+
assert "Colour" in dtype.names
116+
colour_field = dtype.fields["Colour"]
117+
assert colour_field[0].shape == (3,) # 3 components
118+
119+
# Test uniform data gets updated
120+
orange_color = (1.0, 0.5, 0.0)
121+
pipeline.update_uniforms(colour=orange_color)
122+
assert np.array_equal(
123+
pipeline.uniform_data["Colour"], np.array(orange_color, dtype=np.float32)
124+
)
125+
126+
90127
def test_triangle_pipeline_multi_colour_color_processing_tuple_result(webgpu_device):
91128
"""Test multi-colour pipeline when color processing returns tuple to hit lines 240-243."""
92129
pipeline = TrianglePipelineMultiColour(webgpu_device)

0 commit comments

Comments
 (0)