-
Notifications
You must be signed in to change notification settings - Fork 106
Expand file tree
/
Copy pathrust_gpu_shader.rs
More file actions
97 lines (84 loc) · 3.17 KB
/
rust_gpu_shader.rs
File metadata and controls
97 lines (84 loc) · 3.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
use crate::scaffold::shader::{SpirvShader, WgpuShader};
use anyhow::Context;
use spirv_builder::{ModuleResult, SpirvBuilder};
use std::borrow::Cow;
use std::path::PathBuf;
use std::{env, fs};
/// A compute shader written in Rust compiled with spirv-builder.
pub struct RustComputeShader {
pub path: PathBuf,
pub target: String,
pub capabilities: Vec<spirv_builder::Capability>,
}
impl RustComputeShader {
pub fn new<P: Into<PathBuf>>(path: P) -> Self {
Self {
path: path.into(),
target: "spirv-unknown-vulkan1.1".to_string(),
capabilities: Vec::new(),
}
}
pub fn with_target<P: Into<PathBuf>>(path: P, target: impl Into<String>) -> Self {
Self {
path: path.into(),
target: target.into(),
capabilities: Vec::new(),
}
}
pub fn with_capability(mut self, capability: spirv_builder::Capability) -> Self {
self.capabilities.push(capability);
self
}
}
impl SpirvShader for RustComputeShader {
fn spirv_bytes(&self) -> anyhow::Result<(Vec<u8>, String)> {
let mut builder = SpirvBuilder::new(&self.path, &self.target)
.release(true)
.multimodule(false)
.shader_panic_strategy(spirv_builder::ShaderPanicStrategy::SilentExit)
.preserve_bindings(true);
for capability in &self.capabilities {
builder = builder.capability(*capability);
}
let artifact = builder.build().context("SpirvBuilder::build() failed")?;
if artifact.entry_points.len() != 1 {
anyhow::bail!(
"Expected exactly one entry point, found {}",
artifact.entry_points.len()
);
}
let entry_point = artifact.entry_points.into_iter().next().unwrap();
let shader_bytes = match artifact.module {
ModuleResult::SingleModule(path) => fs::read(&path)
.with_context(|| format!("reading spv file '{}' failed", path.display()))?,
ModuleResult::MultiModule(_modules) => {
anyhow::bail!("MultiModule modules produced");
}
};
Ok((shader_bytes, entry_point))
}
}
impl WgpuShader for RustComputeShader {
fn create_module(
&self,
device: &wgpu::Device,
) -> anyhow::Result<(wgpu::ShaderModule, Option<String>)> {
let (shader_bytes, entry_point) = self.spirv_bytes()?;
if !shader_bytes.len().is_multiple_of(4) {
anyhow::bail!("SPIR-V binary length is not a multiple of 4");
}
let shader_words: Vec<u32> = bytemuck::cast_slice(&shader_bytes).to_vec();
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Compute Shader"),
source: wgpu::ShaderSource::SpirV(Cow::Owned(shader_words)),
});
Ok((module, Some(entry_point)))
}
}
/// For the SPIR-V shader, the manifest directory is used as the build path.
impl Default for RustComputeShader {
fn default() -> Self {
let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");
Self::new(PathBuf::from(manifest_dir))
}
}