|
1 | 1 | use std::hash::Hash; |
2 | 2 |
|
| 3 | +use crate::fragment::FragmentBuilder; |
| 4 | +use crate::types::{LocalValId, OpMode, ValRef}; |
| 5 | + |
3 | 6 | /// Operation node trait. `computegraph` is fully generic over this abstraction. |
4 | 7 | /// |
5 | 8 | /// `GraphOp` captures the metadata of an operation (input/output counts, |
@@ -37,6 +40,46 @@ pub trait GraphOp: Clone + std::fmt::Debug + Hash + Eq + Send + Sync + 'static { |
37 | 40 | fn n_outputs(&self) -> usize; |
38 | 41 | } |
39 | 42 |
|
| 43 | +/// Minimal trait for emitting operations into a computation context. |
| 44 | +/// |
| 45 | +/// AD transpose rules use only this interface, enabling both graph-building |
| 46 | +/// (`FragmentBuilder`) and eager execution through the same code. |
| 47 | +/// |
| 48 | +/// # Examples |
| 49 | +/// |
| 50 | +/// ```ignore |
| 51 | +/// use computegraph::{FragmentBuilder, GraphOp, OpEmitter, OpMode, ValRef}; |
| 52 | +/// |
| 53 | +/// #[derive(Clone, Debug, Hash, PartialEq, Eq)] |
| 54 | +/// enum UnaryOp { |
| 55 | +/// Identity, |
| 56 | +/// } |
| 57 | +/// |
| 58 | +/// impl GraphOp for UnaryOp { |
| 59 | +/// type Operand = f64; |
| 60 | +/// type Context = (); |
| 61 | +/// type InputKey = &'static str; |
| 62 | +/// |
| 63 | +/// fn n_inputs(&self) -> usize { 1 } |
| 64 | +/// fn n_outputs(&self) -> usize { 1 } |
| 65 | +/// } |
| 66 | +/// |
| 67 | +/// let mut builder = FragmentBuilder::<UnaryOp>::new(); |
| 68 | +/// let x = builder.add_input("x"); |
| 69 | +/// let ys = builder.add_op(UnaryOp::Identity, vec![ValRef::Local(x)], OpMode::Primal); |
| 70 | +/// assert_eq!(ys.len(), 1); |
| 71 | +/// ``` |
| 72 | +pub trait OpEmitter<Op: GraphOp> { |
| 73 | + /// Emits an operation with the given inputs and mode, returning output ids. |
| 74 | + fn add_op(&mut self, op: Op, inputs: Vec<ValRef<Op>>, mode: OpMode) -> Vec<LocalValId>; |
| 75 | +} |
| 76 | + |
| 77 | +impl<Op: GraphOp> OpEmitter<Op> for FragmentBuilder<Op> { |
| 78 | + fn add_op(&mut self, op: Op, inputs: Vec<ValRef<Op>>, mode: OpMode) -> Vec<LocalValId> { |
| 79 | + FragmentBuilder::add_op(self, op, inputs, mode) |
| 80 | + } |
| 81 | +} |
| 82 | + |
40 | 83 | /// Extension trait that adds evaluation capability to a [`GraphOp`]. |
41 | 84 | /// |
42 | 85 | /// # Examples |
|
0 commit comments