Skip to content

Commit 111cd5d

Browse files
shinaokaclaude
andauthored
feat: add OpEmitter trait for eager AD execution (#1)
Minimal trait with single method add_op(). FragmentBuilder implements it. AD transpose rules can target this trait to enable both graph-building and eager execution through the same code path. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b361eac commit 111cd5d

3 files changed

Lines changed: 46 additions & 1 deletion

File tree

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@ version = "0.1.0"
44
edition = "2021"
55
license = "MIT OR Apache-2.0"
66
publish = false
7+
8+
[workspace]

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@ pub mod resolve;
99
pub mod traits;
1010
pub mod types;
1111

12-
pub use traits::{EvalGraphOp, GraphOp};
12+
pub use traits::{EvalGraphOp, GraphOp, OpEmitter};
1313
pub use types::{GlobalOpKey, GlobalValKey, LocalOpId, LocalValId, OpMode, ValRef};

src/traits.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
use std::hash::Hash;
22

3+
use crate::fragment::FragmentBuilder;
4+
use crate::types::{LocalValId, OpMode, ValRef};
5+
36
/// Operation node trait. `computegraph` is fully generic over this abstraction.
47
///
58
/// `GraphOp` captures the metadata of an operation (input/output counts,
@@ -37,6 +40,46 @@ pub trait GraphOp: Clone + std::fmt::Debug + Hash + Eq + Send + Sync + 'static {
3740
fn n_outputs(&self) -> usize;
3841
}
3942

43+
/// Minimal trait for emitting operations into a computation context.
44+
///
45+
/// AD transpose rules use only this interface, enabling both graph-building
46+
/// (`FragmentBuilder`) and eager execution through the same code.
47+
///
48+
/// # Examples
49+
///
50+
/// ```ignore
51+
/// use computegraph::{FragmentBuilder, GraphOp, OpEmitter, OpMode, ValRef};
52+
///
53+
/// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
54+
/// enum UnaryOp {
55+
/// Identity,
56+
/// }
57+
///
58+
/// impl GraphOp for UnaryOp {
59+
/// type Operand = f64;
60+
/// type Context = ();
61+
/// type InputKey = &'static str;
62+
///
63+
/// fn n_inputs(&self) -> usize { 1 }
64+
/// fn n_outputs(&self) -> usize { 1 }
65+
/// }
66+
///
67+
/// let mut builder = FragmentBuilder::<UnaryOp>::new();
68+
/// let x = builder.add_input("x");
69+
/// let ys = builder.add_op(UnaryOp::Identity, vec![ValRef::Local(x)], OpMode::Primal);
70+
/// assert_eq!(ys.len(), 1);
71+
/// ```
72+
pub trait OpEmitter<Op: GraphOp> {
73+
/// Emits an operation with the given inputs and mode, returning output ids.
74+
fn add_op(&mut self, op: Op, inputs: Vec<ValRef<Op>>, mode: OpMode) -> Vec<LocalValId>;
75+
}
76+
77+
impl<Op: GraphOp> OpEmitter<Op> for FragmentBuilder<Op> {
78+
fn add_op(&mut self, op: Op, inputs: Vec<ValRef<Op>>, mode: OpMode) -> Vec<LocalValId> {
79+
FragmentBuilder::add_op(self, op, inputs, mode)
80+
}
81+
}
82+
4083
/// Extension trait that adds evaluation capability to a [`GraphOp`].
4184
///
4285
/// # Examples

0 commit comments

Comments
 (0)