diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Celu.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Celu.cs
new file mode 100644
index 00000000..fcbc08ce
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Celu.cs
@@ -0,0 +1,50 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a continuously differentiable exponential linear unit (CELU) activation function.
+///
+///
+/// See for more information.
+///
+[Description("Creates a continuously differentiable exponential linear unit (CELU) activation function.")]
+[DisplayName("CELU")]
+public class Celu
+{
+ ///
+ /// The alpha value for the CELU activation function.
+ ///
+ [Description("The alpha value for the CELU activation function.")]
+ public double Alpha { get; set; } = 1D;
+
+ ///
+ /// If set to true, will do this operation in-place.
+ ///
+ [Description("If set to true, will do this operation in-place.")]
+ public bool Inplace { get; set; } = false;
+
+ ///
+ /// Creates a continuously differentiable exponential linear unit (CELU) module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(CELU(Alpha, Inplace));
+ }
+
+ ///
+ /// Creates a continuously differentiable exponential linear unit (CELU) module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => CELU(Alpha, Inplace));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Elu.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Elu.cs
new file mode 100644
index 00000000..039902cb
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Elu.cs
@@ -0,0 +1,50 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates an exponential linear unit (ELU) activation function.
+///
+///
+/// See for more information.
+///
+[Description("Creates an exponential linear unit (ELU) activation function.")]
+[DisplayName("ELU")]
+public class Elu
+{
+ ///
+ /// The alpha value for the ELU activation function.
+ ///
+ [Description("The alpha value for the ELU activation function")]
+ public double Alpha { get; set; } = 1D;
+
+ ///
+ /// If set to true, will do this operation in-place.
+ ///
+ [Description("If set to true, will do this operation in-place")]
+ public bool Inplace { get; set; } = false;
+
+ ///
+ /// Creates an exponential linear unit (ELU) module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(ELU(Alpha, Inplace));
+ }
+
+ ///
+ /// Creates an exponential linear unit (ELU) module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => ELU(Alpha, Inplace));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Gelu.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Gelu.cs
new file mode 100644
index 00000000..978c9bf4
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Gelu.cs
@@ -0,0 +1,44 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a gaussian error linear unit (GELU) activation function.
+///
+///
+/// See for more information.
+///
+[Description("Creates a gaussian error linear unit (GELU) activation function.")]
+[DisplayName("GELU")]
+public class Gelu
+{
+ ///
+ /// If set to true, will do this operation in-place.
+ ///
+ [Description("If set to true, will do this operation in-place.")]
+ public bool InPlace { get; set; } = false;
+
+ ///
+ /// Creates a gaussian error linear unit (GELU) module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(GELU(InPlace));
+ }
+
+ ///
+ /// Creates a gaussian error linear unit (GELU) module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => GELU(InPlace));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Glu.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Glu.cs
new file mode 100644
index 00000000..a022ec4d
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Glu.cs
@@ -0,0 +1,44 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a gated linear unit (GLU) module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a gated linear unit (GLU) module.")]
+[DisplayName("GLU")]
+public class Glu
+{
+ ///
+ /// The dimension on which to split the input tensor.
+ ///
+ [Description("The dimension on which to split the input tensor.")]
+ public long Dim { get; set; } = -1;
+
+ ///
+ /// Creates a gated linear unit (GLU) module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(GLU(Dim));
+ }
+
+ ///
+ /// Creates a gated linear unit (GLU) module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => GLU(Dim));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Hardshrink.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Hardshrink.cs
new file mode 100644
index 00000000..53cb6ff3
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Hardshrink.cs
@@ -0,0 +1,43 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a Hardshrink module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a Hardshrink module.")]
+public class Hardshrink
+{
+ ///
+ /// The lambda parameter for the Hardshrink function.
+ ///
+ [Description("The lambda parameter for the Hardshrink function")]
+ public double Lambda { get; set; } = 0.5D;
+
+ ///
+ /// Creates a Hardshrink module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Hardshrink(Lambda));
+ }
+
+ ///
+ /// Creates a Hardshrink module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Hardshrink(Lambda));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Hardsigmoid.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Hardsigmoid.cs
new file mode 100644
index 00000000..1a5f4156
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Hardsigmoid.cs
@@ -0,0 +1,43 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a Hardsigmoid module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a Hardsigmoid module.")]
+public class Hardsigmoid
+{
+ ///
+ /// If set to true, will do this operation in-place.
+ ///
+ [Description("If set to true, will do this operation in-place")]
+ public bool Inplace { get; set; } = false;
+
+ ///
+ /// Creates a Hardsigmoid module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Hardsigmoid(Inplace));
+ }
+
+ ///
+ /// Creates a Hardsigmoid module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Hardsigmoid(Inplace));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Hardswish.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Hardswish.cs
new file mode 100644
index 00000000..73d2a0c2
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Hardswish.cs
@@ -0,0 +1,43 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a Hardswish module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a Hardswish module.")]
+public class Hardswish
+{
+ ///
+ /// If set to true, will do this operation in-place.
+ ///
+ [Description("If set to true, will do this operation in-place")]
+ public bool Inplace { get; set; } = false;
+
+ ///
+ /// Creates a Hardswish module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Hardswish(Inplace));
+ }
+
+ ///
+ /// Creates a Hardswish module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Hardswish(Inplace));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Hardtanh.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Hardtanh.cs
new file mode 100644
index 00000000..3abbbc81
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Hardtanh.cs
@@ -0,0 +1,55 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a Hardtanh module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a Hardtanh module.")]
+public class Hardtanh
+{
+ ///
+ /// The minimum value of the linear region range.
+ ///
+ [Description("The minimum value of the linear region range.")]
+ public double MinVal { get; set; } = -1D;
+
+ ///
+ /// The maximum value of the linear region range.
+ ///
+ [Description("The maximum value of the linear region range.")]
+ public double MaxVal { get; set; } = 1D;
+
+ ///
+ /// If set to true, will do this operation in-place.
+ ///
+ [Description("If set to true, will do this operation in-place")]
+ public bool Inplace { get; set; } = false;
+
+ ///
+ /// Creates a Hardtanh module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Hardtanh(MinVal, MaxVal, Inplace));
+ }
+
+ ///
+ /// Creates a Hardtanh module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Hardtanh(MinVal, MaxVal, Inplace));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/LeakyRelu.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/LeakyRelu.cs
new file mode 100644
index 00000000..e8fa6cee
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/LeakyRelu.cs
@@ -0,0 +1,50 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a leaky rectified linear unit (LeakyReLU) activation function.
+///
+///
+/// See for more information.
+///
+[Description("Creates a leaky rectified linear unit (LeakyReLU) activation function.")]
+[DisplayName("LeakyReLU")]
+public class LeakyRelu
+{
+ ///
+ /// The angle of the negative slope.
+ ///
+ [Description("The angle of the negative slope.")]
+ public double NegativeSlope { get; set; } = 0.01D;
+
+ ///
+ /// If set to true, will do this operation in-place.
+ ///
+ [Description("If set to true, will do this operation in-place.")]
+ public bool Inplace { get; set; } = false;
+
+ ///
+ /// Creates a leaky rectified linear unit (LeakyReLU) module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(LeakyReLU(NegativeSlope, Inplace));
+ }
+
+ ///
+ /// Creates a leaky rectified linear unit (LeakyReLU) module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => LeakyReLU(NegativeSlope, Inplace));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/LogSigmoid.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/LogSigmoid.cs
new file mode 100644
index 00000000..99f03b29
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/LogSigmoid.cs
@@ -0,0 +1,37 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a log sigmoid module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a log sigmoid module.")]
+public class LogSigmoid
+{
+ ///
+ /// Creates a LogSigmoid module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(LogSigmoid());
+ }
+
+ ///
+ /// Creates a LogSigmoid module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => LogSigmoid());
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/LogSoftmax.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/LogSoftmax.cs
new file mode 100644
index 00000000..b9d14bb7
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/LogSoftmax.cs
@@ -0,0 +1,43 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a log softmax activation function.
+///
+///
+/// See for more information.
+///
+[Description("Creates a log softmax activation function.")]
+public class LogSoftmax
+{
+ ///
+ /// The dimension along which LogSoftmax will be computed.
+ ///
+ [Description("The dimension along which LogSoftmax will be computed")]
+ public long Dim { get; set; }
+
+ ///
+ /// Creates a LogSoftmax module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(LogSoftmax(Dim));
+ }
+
+ ///
+ /// Creates a LogSoftmax module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => LogSoftmax(Dim));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Mish.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Mish.cs
new file mode 100644
index 00000000..1877e79f
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Mish.cs
@@ -0,0 +1,37 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a mish activation function.
+///
+///
+/// See for more information.
+///
+[Description("Creates a mish activation function.")]
+public class Mish
+{
+ ///
+ /// Creates a Mish module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Mish());
+ }
+
+ ///
+ /// Creates a Mish module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Mish());
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/MultiheadAttention.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/MultiheadAttention.cs
new file mode 100644
index 00000000..76370225
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/MultiheadAttention.cs
@@ -0,0 +1,85 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a multi-head attention mechanism.
+///
+///
+/// See for more information.
+///
+[Description("Creates a multi-head attention mechanism.")]
+public class MultiheadAttention
+{
+ ///
+ /// The dimension of the model.
+ ///
+ [Description("The dimension of the model.")]
+ public long EmbeddedDim { get; set; }
+
+ ///
+ /// The number of parallel attention heads.
+ ///
+ [Description("The number of parallel attention heads.")]
+ public long NumHeads { get; set; }
+
+ ///
+ /// The dropout probability on attended weights.
+ ///
+ [Description("The dropout probability on attended weights.")]
+ public double Dropout { get; set; } = 0D;
+
+ ///
+ /// If true, adds a learnable bias to the input/output projection layers.
+ ///
+ [Description("If true, adds a learnable bias to the input/output projection layers.")]
+ public bool Bias { get; set; } = true;
+
+ ///
+ /// If true, adds bias to the key and value sequences at dimension 0.
+ ///
+ [Description("If true, adds bias to the key and value sequences at dimension 0.")]
+ public bool AddBiasKeyValue { get; set; } = false;
+
+ ///
+ /// If true, adds a new batch of zeros to the key and value sequences at dimension 1.
+ ///
+ [Description("If true, adds a new batch of zeros to the key and value sequences at dimension 1.")]
+ public bool AddZeroAttention { get; set; } = false;
+
+ ///
+ /// The total number of features for keys.
+ ///
+ [Description("The total number of features for keys.")]
+ public long? KeyDim { get; set; } = null;
+
+ ///
+ /// The total number of features for values.
+ ///
+ [Description("The total number of features for values.")]
+ public long? ValueDim { get; set; } = null;
+
+ ///
+ /// Creates a Multi-head attention module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(MultiheadAttention(EmbeddedDim, NumHeads, Dropout, Bias, AddBiasKeyValue, AddZeroAttention, KeyDim, ValueDim));
+ }
+
+ ///
+ /// Creates a Multi-head attention module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => MultiheadAttention(EmbeddedDim, NumHeads, Dropout, Bias, AddBiasKeyValue, AddZeroAttention, KeyDim, ValueDim));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Prelu.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Prelu.cs
new file mode 100644
index 00000000..ea8a0beb
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Prelu.cs
@@ -0,0 +1,65 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using System.Xml.Serialization;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a parametric rectified linear unit (PReLU) activation function.
+///
+///
+/// See for more information.
+///
+[Description("Creates a parametric rectified linear unit (PReLU) activation function.")]
+[DisplayName("PReLU")]
+public class Prelu
+{
+ ///
+ /// The number of parameters to learn.
+ ///
+ [Description("The number of parameters to learn.")]
+ public long NumParameters { get; set; }
+
+ ///
+ /// The initial value for the learnable parameters.
+ ///
+ [Description("The initial value for the learnable parameters.")]
+ public double InitialValue { get; set; } = 0.25D;
+
+ ///
+ /// The desired device of returned tensor.
+ ///
+ [XmlIgnore]
+ [Description("The desired device of returned tensor.")]
+ public Device Device { get; set; } = null;
+
+ ///
+ /// The desired data type of returned tensor.
+ ///
+ [Description("The desired data type of returned tensor.")]
+ [TypeConverter(typeof(ScalarTypeConverter))]
+ public ScalarType? Type { get; set; } = null;
+
+ ///
+ /// Creates a parametric rectified linear unit (PReLU) activation function.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(PReLU(NumParameters, InitialValue, Device, Type));
+ }
+
+ ///
+ /// Creates a parametric rectified linear unit (PReLU) activation function.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => PReLU(NumParameters, InitialValue, Device, Type));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Relu.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Relu.cs
new file mode 100644
index 00000000..53107768
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Relu.cs
@@ -0,0 +1,44 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a rectified linear unit (ReLU) activation function.
+///
+///
+/// See for more information.
+///
+[Description("Creates a rectified linear unit (ReLU) activation function.")]
+[DisplayName("ReLU")]
+public class Relu
+{
+ ///
+ /// If set to true, will do this operation in-place.
+ ///
+ [Description("If set to true, will do this operation in-place.")]
+ public bool Inplace { get; set; } = false;
+
+ ///
+ /// Creates a rectified linear unit (ReLU) module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(ReLU(Inplace));
+ }
+
+ ///
+ /// Creates a rectified linear unit (ReLU) module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => ReLU(Inplace));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/ReluBounded.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/ReluBounded.cs
new file mode 100644
index 00000000..80ca0788
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/ReluBounded.cs
@@ -0,0 +1,44 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a bounded rectified linear unit (ReLU6) activation function.
+///
+///
+/// See for more information.
+///
+[Description("Creates a bounded rectified linear unit (ReLU6) activation function.")]
+[DisplayName("ReLU6")]
+public class ReluBounded
+{
+ ///
+ /// If set to true, will do this operation in-place.
+ ///
+ [Description("If set to true, will do this operation in-place.")]
+ public bool Inplace { get; set; } = false;
+
+ ///
+ /// Creates a bounded rectified linear unit (ReLU6) module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(ReLU6(Inplace));
+ }
+
+ ///
+ /// Creates a bounded rectified linear unit (ReLU6) module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => ReLU6(Inplace));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Rrelu.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Rrelu.cs
new file mode 100644
index 00000000..632b8edc
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Rrelu.cs
@@ -0,0 +1,56 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a randomized leaky rectified linear unit (RReLU) module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a randomized leaky rectified linear unit (RReLU) module.")]
+[DisplayName("RReLU")]
+public class Rrelu
+{
+ ///
+ /// The lower bound of the uniform distribution.
+ ///
+ [Description("The lower bound of the uniform distribution.")]
+ public double Lower { get; set; } = 0.125D;
+
+ ///
+ /// The upper bound of the uniform distribution.
+ ///
+ [Description("The upper bound of the uniform distribution.")]
+ public double Upper { get; set; } = 0.3333333333333333D;
+
+ ///
+ /// If set to true, will do this operation in-place.
+ ///
+ [Description("If set to true, will do this operation in-place.")]
+ public bool Inplace { get; set; } = false;
+
+ ///
+ /// Creates a randomized leaky rectified linear unit (RReLU) module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(RReLU(Lower, Upper, Inplace));
+ }
+
+ ///
+ /// Creates a randomized leaky rectified linear unit (RReLU) module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => RReLU(Lower, Upper, Inplace));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Selu.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Selu.cs
new file mode 100644
index 00000000..6d460b9e
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Selu.cs
@@ -0,0 +1,44 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a scaled exponential linear unit (SELU) activation function.
+///
+///
+/// See for more information.
+///
+[Description("Creates a scaled exponential linear unit (SELU) activation function.")]
+[DisplayName("SeLU")]
+public class Selu
+{
+ ///
+ /// If set to true, will do this operation in-place.
+ ///
+ [Description("If set to true, will do this operation in-place.")]
+ public bool Inplace { get; set; } = false;
+
+ ///
+ /// Creates a scaled exponential linear unit (SELU) module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(SELU(Inplace));
+ }
+
+ ///
+ /// Creates a scaled exponential linear unit (SELU) module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => SELU(Inplace));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Sigmoid.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Sigmoid.cs
new file mode 100644
index 00000000..17460059
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Sigmoid.cs
@@ -0,0 +1,37 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a sigmoid activation function.
+///
+///
+/// See for more information.
+///
+[Description("Creates a sigmoid activation function.")]
+public class Sigmoid
+{
+ ///
+ /// Creates a sigmoid module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Sigmoid());
+ }
+
+ ///
+ /// Creates a sigmoid module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Sigmoid());
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Silu.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Silu.cs
new file mode 100644
index 00000000..a6f88818
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Silu.cs
@@ -0,0 +1,38 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a sigmoid weighted linear unit (SiLU) activation function.
+///
+///
+/// See for more information.
+///
+[Description("Creates a sigmoid weighted linear unit (SiLU) activation function.")]
+[DisplayName("SiLU")]
+public class Silu
+{
+ ///
+ /// Creates a sigmoid weighted linear unit (SiLU) module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(SiLU());
+ }
+
+ ///
+ /// Creates a sigmoid weighted linear unit (SiLU) module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => SiLU());
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Softmax.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Softmax.cs
new file mode 100644
index 00000000..a77ba716
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Softmax.cs
@@ -0,0 +1,43 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a softmax activation function.
+///
+///
+/// See for more information.
+///
+[Description("Creates a softmax activation function.")]
+public class Softmax
+{
+ ///
+ /// The dimension along which Softmax will be computed.
+ ///
+ [Description("The dimension along which Softmax will be computed.")]
+ public long Dim { get; set; }
+
+ ///
+ /// Creates a Softmax module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Softmax(Dim));
+ }
+
+ ///
+ /// Creates a Softmax module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Softmax(Dim));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Softmax2d.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Softmax2d.cs
new file mode 100644
index 00000000..55ad3c75
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Softmax2d.cs
@@ -0,0 +1,37 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a 2d softmax activation function.
+///
+///
+/// See for more information.
+///
+[Description("Creates a 2d softmax activation function.")]
+public class Softmax2d
+{
+ ///
+ /// Creates a Softmax2d module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Softmax2d());
+ }
+
+ ///
+ /// Creates a Softmax2d module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Softmax2d());
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Softmin.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Softmin.cs
new file mode 100644
index 00000000..2fe4119f
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Softmin.cs
@@ -0,0 +1,43 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a softmin activation function.
+///
+///
+/// See for more information.
+///
+[Description("Creates a softmin activation function.")]
+public class Softmin
+{
+ ///
+ /// The dimension along which softmin will be computed.
+ ///
+ [Description("The dimension along which softmin will be computed.")]
+ public long Dim { get; set; }
+
+ ///
+ /// Creates a Softmin module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Softmin(Dim));
+ }
+
+ ///
+ /// Creates a Softmin module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Softmin(Dim));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Softplus.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Softplus.cs
new file mode 100644
index 00000000..2f9a918b
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Softplus.cs
@@ -0,0 +1,49 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a softplus activation function.
+///
+///
+/// See for more information.
+///
+[Description("Creates a softplus module.")]
+public class Softplus
+{
+ ///
+ /// The beta value for the softplus formula.
+ ///
+ [Description("The beta value for the softplus formula.")]
+ public double Beta { get; set; } = 1D;
+
+ ///
+ /// The threshold value for which values above it use a linear function.
+ ///
+ [Description("The threshold value for which values above it use a linear function.")]
+ public double Threshold { get; set; } = 20D;
+
+ ///
+ /// Creates a Softplus module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Softplus(Beta, Threshold));
+ }
+
+ ///
+ /// Creates a Softplus module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Softplus(Beta, Threshold));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Softshrink.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Softshrink.cs
new file mode 100644
index 00000000..2f00f8fa
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Softshrink.cs
@@ -0,0 +1,43 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a softshrink activation function.
+///
+///
+/// See for more information.
+///
+[Description("Creates a softshrink activation function.")]
+public class Softshrink
+{
+ ///
+ /// The lambda value for the softshrink formula.
+ ///
+ [Description("The lambda value for the softshrink formula.")]
+ public double Lambda { get; set; } = 0.5D;
+
+ ///
+ /// Creates a Softshrink module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Softshrink(Lambda));
+ }
+
+ ///
+ /// Creates a Softshrink module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Softshrink(Lambda));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Softsign.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Softsign.cs
new file mode 100644
index 00000000..1c94bdb1
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Softsign.cs
@@ -0,0 +1,37 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a softsign activation function.
+///
+///
+/// See for more information.
+///
+[Description("Creates a softsign activation function.")]
+public class Softsign
+{
+ ///
+ /// Creates a Softsign module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Softsign());
+ }
+
+ ///
+ /// Creates a Softsign module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Softsign());
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Tanh.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Tanh.cs
new file mode 100644
index 00000000..851ae14a
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Tanh.cs
@@ -0,0 +1,37 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a hyperbolic tangent (tanh) activation function.
+///
+///
+/// See for more information.
+///
+[Description("Creates a hyperbolic tangent (tanh) activation function.")]
+public class Tanh
+{
+ ///
+ /// Creates a Tanh module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Tanh());
+ }
+
+ ///
+ /// Creates a Tanh module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Tanh());
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Tanhshrink.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Tanhshrink.cs
new file mode 100644
index 00000000..904f3c41
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Tanhshrink.cs
@@ -0,0 +1,37 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a tanhshrink activation function.
+///
+///
+/// See for more information.
+///
+[Description("Creates a tanhshrink activation function.")]
+public class Tanhshrink
+{
+ ///
+ /// Creates a Tanhshrink module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Tanhshrink());
+ }
+
+ ///
+ /// Creates a Tanhshrink module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Tanhshrink());
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Threshold.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Threshold.cs
new file mode 100644
index 00000000..1b31c2ca
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunction/Threshold.cs
@@ -0,0 +1,55 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.ActivationFunction;
+
+///
+/// Represents an operator that creates a threshold activation function.
+///
+///
+/// See for more information.
+///
+[Description("Represents an operator that creates a threshold activation function.")]
+public class Threshold
+{
+ ///
+ /// The threshold value.
+ ///
+ [Description("The threshold value.")]
+ public double ThresholdValue { get; set; }
+
+ ///
+ /// The value used to replace values below the threshold.
+ ///
+ [Description("The value used to replace values below the threshold.")]
+ public double FillValue { get; set; }
+
+ ///
+ /// If set to true, will do this operation in-place.
+ ///
+ [Description("If set to true, will do this operation in-place.")]
+ public bool Inplace { get; set; } = false;
+
+ ///
+ /// Creates a Threshold module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Threshold(ThresholdValue, FillValue, Inplace));
+ }
+
+ ///
+ /// Creates a Threshold module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Threshold(ThresholdValue, FillValue, Inplace));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ActivationFunctionBuilder.cs b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunctionBuilder.cs
new file mode 100644
index 00000000..6c2bdf5a
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ActivationFunctionBuilder.cs
@@ -0,0 +1,69 @@
+using System.ComponentModel;
+using System.Xml.Serialization;
+
+namespace Bonsai.ML.Torch.NeuralNets;
+
+///
+/// Represents an operator that creates an activation function.
+///
+[XmlInclude(typeof(ActivationFunction.Celu))]
+[XmlInclude(typeof(ActivationFunction.Elu))]
+[XmlInclude(typeof(ActivationFunction.Glu))]
+[XmlInclude(typeof(ActivationFunction.Gelu))]
+[XmlInclude(typeof(ActivationFunction.Hardshrink))]
+[XmlInclude(typeof(ActivationFunction.Hardsigmoid))]
+[XmlInclude(typeof(ActivationFunction.Hardswish))]
+[XmlInclude(typeof(ActivationFunction.Hardtanh))]
+[XmlInclude(typeof(ActivationFunction.LeakyRelu))]
+[XmlInclude(typeof(ActivationFunction.LogSigmoid))]
+[XmlInclude(typeof(ActivationFunction.LogSoftmax))]
+[XmlInclude(typeof(ActivationFunction.Mish))]
+[XmlInclude(typeof(ActivationFunction.MultiheadAttention))]
+[XmlInclude(typeof(ActivationFunction.Prelu))]
+[XmlInclude(typeof(ActivationFunction.Rrelu))]
+[XmlInclude(typeof(ActivationFunction.Relu))]
+[XmlInclude(typeof(ActivationFunction.ReluBounded))]
+[XmlInclude(typeof(ActivationFunction.Selu))]
+[XmlInclude(typeof(ActivationFunction.Sigmoid))]
+[XmlInclude(typeof(ActivationFunction.Silu))]
+[XmlInclude(typeof(ActivationFunction.Softmax))]
+[XmlInclude(typeof(ActivationFunction.Softmax2d))]
+[XmlInclude(typeof(ActivationFunction.Softmin))]
+[XmlInclude(typeof(ActivationFunction.Softplus))]
+[XmlInclude(typeof(ActivationFunction.Softshrink))]
+[XmlInclude(typeof(ActivationFunction.Softsign))]
+[XmlInclude(typeof(ActivationFunction.Tanh))]
+[XmlInclude(typeof(ActivationFunction.Tanhshrink))]
+[XmlInclude(typeof(ActivationFunction.Threshold))]
+[DefaultProperty(nameof(ActivationFunction))]
+[Combinator]
+[Description("Creates an activation function.")]
+[WorkflowElementCategory(ElementCategory.Source)]
+public class ActivationFunctionBuilder : ModuleCombinatorBuilder, INamedElement
+{
+ internal override string BuilderName => "ActivationFunction";
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public ActivationFunctionBuilder()
+ {
+ Module = new ActivationFunction.Relu();
+ }
+
+ ///
+ /// Gets or sets the specific activation function to create.
+ ///
+ [DesignOnly(true)]
+ [DisplayName("Module")]
+ [Externalizable(false)]
+ [RefreshProperties(RefreshProperties.All)]
+ [Category(nameof(CategoryAttribute.Design))]
+ [Description("The specific activation function to create.")]
+ [TypeConverter(typeof(ModuleTypeConverter))]
+ public object ActivationFunction
+ {
+ get => Module;
+ set => Module = value;
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Architecture/AlexNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Architecture/AlexNet.cs
new file mode 100644
index 00000000..07047937
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Architecture/AlexNet.cs
@@ -0,0 +1,72 @@
+using TorchSharp;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Architecture;
+
+///
+/// Modified version of original AlexNet to fix CIFAR10 32x32 images.
+///
+internal class AlexNet : Module
+{
+ private readonly Module features;
+ private readonly Module avgPool;
+ private readonly Module classifier;
+
+ ///
+ /// Constructs a new AlexNet model.
+ ///
+ ///
+ ///
+ ///
+ public AlexNet(string name, int numClasses, Device device = null) : base(name)
+ {
+ features = Sequential(
+ ("c1", Conv2d(3, 64, kernel_size: 3, stride: 2, padding: 1)),
+ ("r1", ReLU(inplace: true)),
+ ("mp1", MaxPool2d(kernel_size: [ 2, 2 ])),
+ ("c2", Conv2d(64, 192, kernel_size: 3, padding: 1)),
+ ("r2", ReLU(inplace: true)),
+ ("mp2", MaxPool2d(kernel_size: [ 2, 2 ])),
+ ("c3", Conv2d(192, 384, kernel_size: 3, padding: 1)),
+ ("r3", ReLU(inplace: true)),
+ ("c4", Conv2d(384, 256, kernel_size: 3, padding: 1)),
+ ("r4", ReLU(inplace: true)),
+ ("c5", Conv2d(256, 256, kernel_size: 3, padding: 1)),
+ ("r5", ReLU(inplace: true)),
+ ("mp3", MaxPool2d(kernel_size: [ 2, 2 ])));
+
+ avgPool = AdaptiveAvgPool2d([ 2, 2 ]);
+
+ classifier = Sequential(
+ ("d1", nn.Dropout()),
+ ("l1", nn.Linear(256 * 2 * 2, 4096)),
+ ("r1", ReLU(inplace: true)),
+ ("d2", nn.Dropout()),
+ ("l2", nn.Linear(4096, 4096)),
+ ("r3", ReLU(inplace: true)),
+ ("d3", nn.Dropout()),
+ ("l3", nn.Linear(4096, numClasses))
+ );
+
+ RegisterComponents();
+
+ if (device != null && device.type != DeviceType.CPU)
+ this.to(device);
+ }
+
+ ///
+ /// Forward pass of the AlexNet model.
+ ///
+ ///
+ ///
+ public override Tensor forward(Tensor input)
+ {
+ var f = features.forward(input);
+ var avg = avgPool.forward(f);
+
+ var x = avg.view([ avg.shape[0], 256 * 2 * 2 ]);
+
+ return classifier.forward(x);
+ }
+}
\ No newline at end of file
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Architecture/Mnist.cs b/src/Bonsai.ML.Torch/NeuralNets/Architecture/Mnist.cs
new file mode 100644
index 00000000..aca74443
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Architecture/Mnist.cs
@@ -0,0 +1,84 @@
+using TorchSharp;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Architecture;
+ ///
+/// Represents a simple convolutional neural network for the MNIST dataset.
+///
+internal class Mnist : Module
+{
+ private readonly Module conv1;
+ private readonly Module conv2;
+ private readonly Module fc1;
+ private readonly Module fc2;
+
+ private readonly Module pool1;
+
+ private readonly Module relu1;
+ private readonly Module relu2;
+ private readonly Module relu3;
+
+ private readonly Module dropout1;
+ private readonly Module dropout2;
+
+ private readonly Module flatten;
+ private readonly Module logsm;
+
+ ///
+ /// Constructs a new Mnist model.
+ ///
+ ///
+ ///
+ ///
+ public Mnist(string name, int numClasses, Device device = null) : base(name)
+ {
+ conv1 = Conv2d(1, 32, 3);
+ conv2 = Conv2d(32, 64, 3);
+ fc1 = nn.Linear(9216, 128);
+ fc2 = nn.Linear(128, numClasses);
+
+ pool1 = MaxPool2d(kernel_size: [2, 2]);
+
+ relu1 = ReLU();
+ relu2 = ReLU();
+ relu3 = ReLU();
+
+ dropout1 = nn.Dropout(0.25);
+ dropout2 = nn.Dropout(0.5);
+
+ flatten = nn.Flatten();
+ logsm = LogSoftmax(1);
+
+ RegisterComponents();
+
+ if (device != null && device.type != DeviceType.CPU)
+ this.to(device);
+ }
+
+ ///
+ /// Forward pass of the Mnist model.
+ ///
+ ///
+ ///
+ public override Tensor forward(Tensor input)
+ {
+ var l11 = conv1.forward(input);
+ var l12 = relu1.forward(l11);
+
+ var l21 = conv2.forward(l12);
+ var l22 = relu2.forward(l21);
+ var l23 = pool1.forward(l22);
+ var l24 = dropout1.forward(l23);
+
+ var x = flatten.forward(l24);
+
+ var l31 = fc1.forward(x);
+ var l32 = relu3.forward(l31);
+ var l33 = dropout2.forward(l32);
+
+ var l41 = fc2.forward(l33);
+
+ return logsm.forward(l41);
+ }
+}
\ No newline at end of file
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Architecture/MobileNet.cs b/src/Bonsai.ML.Torch/NeuralNets/Architecture/MobileNet.cs
new file mode 100644
index 00000000..60a67d49
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Architecture/MobileNet.cs
@@ -0,0 +1,76 @@
+using System;
+using System.Collections.Generic;
+using TorchSharp;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Architecture;
+
+///
+/// MobileNet model.
+///
+internal class MobileNet : Module
+{
+ private readonly long[] planes = [ 64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024 ];
+ private readonly long[] strides = [ 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1 ];
+
+ private readonly Module layers;
+
+ ///
+ /// Constructs a new MobileNet model.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public MobileNet(string name, int numClasses, Device device = null) : base(name)
+ {
+ if (planes.Length != strides.Length) throw new ArgumentException("'planes' and 'strides' must have the same length.");
+
+ var modules = new List<(string, Module)>
+ {
+ ($"conv2d-first", Conv2d(3, 32, kernel_size: 3, stride: 1, padding: 1, bias: false)),
+ ($"bnrm2d-first", BatchNorm2d(32)),
+ ($"relu-first", ReLU())
+ };
+ MakeLayers(modules, 32);
+ modules.Add(("avgpool", AvgPool2d([2, 2])));
+ modules.Add(("flatten", nn.Flatten()));
+ modules.Add(($"linear", nn.Linear(planes[planes.Length-1], numClasses)));
+
+ layers = Sequential(modules);
+
+ RegisterComponents();
+
+ if (device != null && device.type != DeviceType.CPU)
+ this.to(device);
+ }
+
+ private void MakeLayers(List<(string, Module)> modules, long in_planes)
+ {
+
+ for (var i = 0; i < strides.Length; i++) {
+ var out_planes = planes[i];
+ var stride = strides[i];
+
+ modules.Add(($"conv2d-{i}a", Conv2d(in_planes, in_planes, kernel_size: 3, stride: stride, padding: 1, groups: in_planes, bias: false)));
+ modules.Add(($"bnrm2d-{i}a", BatchNorm2d(in_planes)));
+ modules.Add(($"relu-{i}a", ReLU()));
+ modules.Add(($"conv2d-{i}b", Conv2d(in_planes, out_planes, kernel_size: 1L, stride: 1L, padding: 0L, bias: false)));
+ modules.Add(($"bnrm2d-{i}b", BatchNorm2d(out_planes)));
+ modules.Add(($"relu-{i}b", ReLU()));
+
+ in_planes = out_planes;
+ }
+ }
+
+ ///
+ /// Forward pass of the MobileNet model.
+ ///
+ ///
+ ///
+ public override Tensor forward(Tensor input)
+ {
+ return layers.forward(input);
+ }
+}
\ No newline at end of file
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Architecture/ModelArchitecture.cs b/src/Bonsai.ML.Torch/NeuralNets/Architecture/ModelArchitecture.cs
new file mode 100644
index 00000000..1a1b2552
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Architecture/ModelArchitecture.cs
@@ -0,0 +1,22 @@
+namespace Bonsai.ML.Torch.NeuralNets.Architecture;
+
+///
+/// Represents canonical model architectures.
+///
+public enum ModelArchitecture
+{
+ ///
+ /// The AlexNet model architecture.
+ ///
+ AlexNet,
+
+ ///
+ /// The MobileNet model architecture.
+ ///
+ MobileNet,
+
+ ///
+ /// The Mnist model architecture.
+ ///
+ Mnist
+}
\ No newline at end of file
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Backward.cs b/src/Bonsai.ML.Torch/NeuralNets/Backward.cs
index 52257d68..7468a210 100644
--- a/src/Bonsai.ML.Torch/NeuralNets/Backward.cs
+++ b/src/Bonsai.ML.Torch/NeuralNets/Backward.cs
@@ -1,75 +1,25 @@
using System;
using System.ComponentModel;
using System.Reactive.Linq;
-using System.Xml.Serialization;
using static TorchSharp.torch;
-using static TorchSharp.torch.nn;
-using static TorchSharp.torch.optim;
-namespace Bonsai.ML.Torch.NeuralNets
+namespace Bonsai.ML.Torch.NeuralNets;
+
+///
+/// Represents an operator that computes backward on the input tensor.
+///
+[Combinator]
+[Description("Computes backward on the input tensor.")]
+[WorkflowElementCategory(ElementCategory.Sink)]
+public class Backward
{
///
- /// Trains a model using backpropagation.
+ /// Computes backward on the input tensor.
///
- [Combinator]
- [ResetCombinator]
- [Description("Trains a model using backpropagation.")]
- [WorkflowElementCategory(ElementCategory.Transform)]
- public class Backward
+ ///
+ ///
+ public IObservable Process(IObservable source)
{
- ///
- /// The optimizer to use for training.
- ///
- public Optimizer Optimizer { get; set; }
-
- ///
- /// The model to train.
- ///
- [XmlIgnore]
- public ITorchModule Model { get; set; }
-
- ///
- /// The loss function to use for training.
- ///
- public Loss Loss { get; set; }
-
- ///
- /// Trains the model using backpropagation.
- ///
- ///
- ///
- public IObservable Process(IObservable> source)
- {
- optim.Optimizer optimizer = Optimizer switch
- {
- Optimizer.Adam => Adam(Model.Module.parameters()),
- _ => throw new ArgumentException($"Selected optimizer, {Optimizer} is currently not supported.")
- };
-
- Module loss = Loss switch
- {
- Loss.NegativeLogLikelihood => NLLLoss(),
- _ => throw new ArgumentException($"Selected loss, {Loss} is currently not supported.")
- };
-
- var scheduler = lr_scheduler.StepLR(optimizer, 1, 0.7);
- Model.Module.train();
-
- return source.Select((input) => {
- var (data, target) = input;
- using (_ = NewDisposeScope())
- {
- optimizer.zero_grad();
-
- var prediction = Model.Forward(data);
- var output = loss.forward(prediction, target);
-
- output.backward();
-
- optimizer.step();
- return output.MoveToOuterDisposeScope();
- }
- });
- }
+ return source.Do(input => input.backward());
}
}
\ No newline at end of file
diff --git a/src/Bonsai.ML.Torch/NeuralNets/CollectParameters.cs b/src/Bonsai.ML.Torch/NeuralNets/CollectParameters.cs
new file mode 100644
index 00000000..3b33577f
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/CollectParameters.cs
@@ -0,0 +1,34 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using System.Linq;
+using System.Collections.Generic;
+using TorchSharp.Modules;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets;
+
+///
+/// Represents an operator that collects the parameters of torch modules into a collection.
+///
+[Combinator]
+[Description("Collects the parameters from torch modules into a collection.")]
+[WorkflowElementCategory(ElementCategory.Combinator)]
+public class CollectParameters
+{
+ ///
+ /// Collects the parameters from torch modules into a collection.
+ ///
+ ///
+ ///
+ public IObservable> Process(params IObservable[] sources)
+ {
+ return Observable
+ .Concat(sources.Select(source =>
+ source.Take(1)))
+ .SelectMany(module =>
+ {
+ return module.parameters(recurse: true);
+ }).ToList();
+ }
+}
\ No newline at end of file
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Container/Sequential.cs b/src/Bonsai.ML.Torch/NeuralNets/Container/Sequential.cs
new file mode 100644
index 00000000..5bf4676c
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Container/Sequential.cs
@@ -0,0 +1,28 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+using System.Linq;
+using System.Collections.Generic;
+
+namespace Bonsai.ML.Torch.NeuralNets.Container;
+
+///
+/// Represents an operator that creates a sequential container.
+///
+///
+/// See for more information.
+///
+[Description("Creates a sequential container.")]
+public class Sequential
+{
+ ///
+ /// Creates a sequential container from the input modules.
+ ///
+ ///
+ public IObservable Process(IObservable source) where T : IEnumerable>
+ {
+ return source.Select(modules => Sequential(modules));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ContainerBuilder.cs b/src/Bonsai.ML.Torch/NeuralNets/ContainerBuilder.cs
new file mode 100644
index 00000000..cbd23287
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ContainerBuilder.cs
@@ -0,0 +1,44 @@
+using System.ComponentModel;
+using System.Xml.Serialization;
+namespace Bonsai.ML.Torch.NeuralNets;
+
+///
+/// Represents an operator that creates a torch module for convolution operations.
+///
+[XmlInclude(typeof(Container.Sequential))]
+[DefaultProperty(nameof(ContainerModule))]
+[Combinator]
+[Description("Creates a sequential container for torch modules.")]
+[WorkflowElementCategory(ElementCategory.Transform)]
+public class ContainerBuilder : ModuleCombinatorBuilder, INamedElement
+{
+ ///
+ public override Range ArgumentRange => Range.Create(1, 1);
+
+ ///
+ internal override string BuilderName => "Container";
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public ContainerBuilder()
+ {
+ Module = new Container.Sequential();
+ }
+
+ ///
+ /// Gets or sets the specific container module to create.
+ ///
+ [DesignOnly(true)]
+ [DisplayName("Module")]
+ [Externalizable(false)]
+ [RefreshProperties(RefreshProperties.All)]
+ [Category(nameof(CategoryAttribute.Design))]
+ [Description("The specific container module to create.")]
+ [TypeConverter(typeof(ModuleTypeConverter))]
+ public object ContainerModule
+ {
+ get => Module;
+ set => Module = value;
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Convolution/Conv1d.cs b/src/Bonsai.ML.Torch/NeuralNets/Convolution/Conv1d.cs
new file mode 100644
index 00000000..fdf76173
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Convolution/Conv1d.cs
@@ -0,0 +1,106 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using System.Xml.Serialization;
+using TorchSharp;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Convolution;
+
+///
+/// Represents an operator that creates a 1D convolution module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a 1D convolution module.")]
+public class Conv1d
+{
+ ///
+ /// The number of input channels in the input tensor.
+ ///
+ [Description("The number of input channels in the input tensor.")]
+ public long InChannels { get; set; }
+
+ ///
+ /// The number of output channels produced by the convolution.
+ ///
+ [Description("The number of output channels produced by the convolution.")]
+ public long OutChannels { get; set; }
+
+ ///
+ /// The size of the convolution kernel.
+ ///
+ [Description("The size of the convolution kernel.")]
+ public long KernelSize { get; set; }
+
+ ///
+ /// The stride of the convolution.
+ ///
+ [Description("The stride of the convolution.")]
+ public long Stride { get; set; } = 1;
+
+ ///
+ /// The padding added to both sides of the input.
+ ///
+ [Description("The padding added to both sides of the input.")]
+ public long Padding { get; set; } = 0;
+
+ ///
+ /// The spacing between kernel elements.
+ ///
+ [Description("The spacing between kernel elements.")]
+ public long Dilation { get; set; } = 1;
+
+ ///
+ /// The mode of padding.
+ ///
+ [Description("The mode of padding.")]
+ public PaddingModes PaddingMode { get; set; } = PaddingModes.Zeros;
+
+ ///
+ /// The number of blocked connections from input channels to output channels.
+ ///
+ [Description("The number of blocked connections from input channels to output channels.")]
+ public long Groups { get; set; } = 1;
+
+ ///
+ /// If true, adds a learnable bias to the output.
+ ///
+ [Description("If true, adds a learnable bias to the output")]
+ public bool Bias { get; set; } = true;
+
+ ///
+ /// The desired device of the returned tensor.
+ ///
+ [XmlIgnore]
+ [Description("The desired device of the returned tensor")]
+ public Device Device { get; set; } = null;
+
+ ///
+ /// The desired data type of the returned tensor.
+ ///
+ [Description("The desired data type of the returned tensor")]
+ public ScalarType? Type { get; set; } = null;
+
+ ///
+ /// Creates a Conv1d module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Conv1d(InChannels, OutChannels, KernelSize, Stride, Padding, Dilation, PaddingMode, Groups, Bias, Device, Type));
+ }
+
+ ///
+ /// Creates a Conv1d module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Conv1d(InChannels, OutChannels, KernelSize, Stride, Padding, Dilation, PaddingMode, Groups, Bias, Device, Type));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Convolution/Conv2d.cs b/src/Bonsai.ML.Torch/NeuralNets/Convolution/Conv2d.cs
new file mode 100644
index 00000000..c426b5e7
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Convolution/Conv2d.cs
@@ -0,0 +1,110 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using System.Xml.Serialization;
+using TorchSharp;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Convolution;
+
+///
+/// Represents an operator that creates a 2D convolution module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a 2D convolution module.")]
+public class Conv2d
+{
+ ///
+ /// The number of input channels in the input tensor.
+ ///
+ [Description("The number of input channels in the input tensor.")]
+ public long InChannels { get; set; }
+
+ ///
+ /// The number of output channels produced by the convolution.
+ ///
+ [Description("The number of output channels produced by the convolution.")]
+ public long OutChannels { get; set; }
+
+ ///
+ /// The size of the convolution kernel.
+ ///
+ [Description("The size of the convolution kernel.")]
+ [TypeConverter(typeof(ValueTupleConverter))]
+ public (long, long) KernelSize { get; set; }
+
+ ///
+ /// The stride of the convolution.
+ ///
+ [Description("The stride of the convolution.")]
+ [TypeConverter(typeof(NullableValueTupleConverter))]
+ public (long, long)? Stride { get; set; } = null;
+
+ ///
+ /// The padding added to all four sides of the input.
+ ///
+ [Description("The padding added to all four sides of the input.")]
+ [TypeConverter(typeof(NullableValueTupleConverter))]
+ public (long, long)? Padding { get; set; } = null;
+
+ ///
+ /// The spacing between kernel elements.
+ ///
+ [Description("The spacing between kernel elements.")]
+ [TypeConverter(typeof(NullableValueTupleConverter))]
+ public (long, long)? Dilation { get; set; } = null;
+
+ ///
+ /// The mode of padding.
+ ///
+ [Description("The mode of padding.")]
+ public PaddingModes PaddingMode { get; set; } = PaddingModes.Zeros;
+
+ ///
+ /// The number of blocked connections from input channels to output channels.
+ ///
+ [Description("The number of blocked connections from input channels to output channels.")]
+ public long Groups { get; set; } = 1;
+
+ ///
+ /// If true, adds a learnable bias to the output.
+ ///
+ [Description("If true, adds a learnable bias to the output")]
+ public bool Bias { get; set; } = true;
+
+ ///
+ /// The desired device of the returned tensor.
+ ///
+ [XmlIgnore]
+ [Description("The desired device of the returned tensor")]
+ public Device Device { get; set; } = null;
+
+ ///
+ /// The desired data type of the returned tensor.
+ ///
+ [Description("The desired data type of the returned tensor")]
+ public ScalarType? Type { get; set; } = null;
+
+ ///
+ /// Creates a Conv2d module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Conv2d(InChannels, OutChannels, KernelSize, Stride, Padding, Dilation, PaddingMode, Groups, Bias, Device, Type));
+ }
+
+ ///
+ /// Creates a Conv2d module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Conv2d(InChannels, OutChannels, KernelSize, Stride, Padding, Dilation, PaddingMode, Groups, Bias, Device, Type));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Convolution/Conv3d.cs b/src/Bonsai.ML.Torch/NeuralNets/Convolution/Conv3d.cs
new file mode 100644
index 00000000..c02e30d6
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Convolution/Conv3d.cs
@@ -0,0 +1,110 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using System.Xml.Serialization;
+using TorchSharp;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Convolution;
+
+///
+/// Represents an operator that creates a 3D convolution module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a 3D convolution module.")]
+public class Conv3d
+{
+ ///
+ /// The number of input channels in the input tensor.
+ ///
+ [Description("The number of input channels in the input tensor.")]
+ public long InChannels { get; set; }
+
+ ///
+ /// The number of output channels produced by the convolution.
+ ///
+ [Description("The number of output channels produced by the convolution.")]
+ public long OutChannels { get; set; }
+
+ ///
+ /// The size of the convolution kernel.
+ ///
+ [Description("The size of the convolution kernel.")]
+ [TypeConverter(typeof(ValueTupleConverter))]
+ public (long, long, long) KernelSize { get; set; }
+
+ ///
+ /// The stride of the convolution.
+ ///
+ [Description("The stride of the convolution.")]
+ [TypeConverter(typeof(NullableValueTupleConverter))]
+ public (long, long, long)? Stride { get; set; } = null;
+
+ ///
+ /// The padding to add to all six sides of the input.
+ ///
+ [Description("The padding to add to all six sides of the input.")]
+ [TypeConverter(typeof(NullableValueTupleConverter))]
+ public (long, long, long)? Padding { get; set; } = null;
+
+ ///
+ /// The spacing between kernel elements.
+ ///
+ [Description("The spacing between kernel elements.")]
+ [TypeConverter(typeof(NullableValueTupleConverter))]
+ public (long, long, long)? Dilation { get; set; } = null;
+
+ ///
+ /// The mode of padding.
+ ///
+ [Description("The mode of padding.")]
+ public PaddingModes PaddingMode { get; set; } = PaddingModes.Zeros;
+
+ ///
+ /// The number of blocked connections from input channels to output channels.
+ ///
+ [Description("The number of blocked connections from input channels to output channels.")]
+ public long Groups { get; set; } = 1;
+
+ ///
+ /// If true, adds a learnable bias to the output.
+ ///
+ [Description("If true, adds a learnable bias to the output")]
+ public bool Bias { get; set; } = true;
+
+ ///
+ /// The desired device of the returned tensor.
+ ///
+ [XmlIgnore]
+ [Description("The desired device of the returned tensor")]
+ public Device Device { get; set; } = null;
+
+ ///
+ /// The desired data type of the returned tensor.
+ ///
+ [Description("The desired data type of the returned tensor")]
+ public ScalarType? Type { get; set; } = null;
+
+ ///
+ /// Creates a Conv3d module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Conv3d(InChannels, OutChannels, KernelSize, Stride, Padding, Dilation, PaddingMode, Groups, Bias, Device, Type));
+ }
+
+ ///
+ /// Creates a Conv3d module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Conv3d(InChannels, OutChannels, KernelSize, Stride, Padding, Dilation, PaddingMode, Groups, Bias, Device, Type));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Convolution/ConvTranspose1d.cs b/src/Bonsai.ML.Torch/NeuralNets/Convolution/ConvTranspose1d.cs
new file mode 100644
index 00000000..0aee1947
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Convolution/ConvTranspose1d.cs
@@ -0,0 +1,112 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using System.Xml.Serialization;
+using TorchSharp;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Convolution;
+
+///
+/// Represents an operator that creates a 1D transposed convolution module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a 1D transposed convolution module.")]
+public class ConvTranspose1d
+{
+ ///
+ /// The number of input channels in the input tensor.
+ ///
+ [Description("The number of input channels in the input tensor.")]
+ public long InChannels { get; set; }
+
+ ///
+ /// The number of output channels produced by the convolution.
+ ///
+ [Description("The number of output channels produced by the convolution.")]
+ public long OutChannels { get; set; }
+
+ ///
+ /// The size of the convolution kernel.
+ ///
+ [Description("The size of the convolution kernel.")]
+ public long KernelSize { get; set; }
+
+ ///
+ /// The stride of the convolution.
+ ///
+ [Description("The stride of the convolution.")]
+ public long Stride { get; set; } = 1;
+
+ ///
+ /// The padding to add to both sides of the input.
+ ///
+ [Description("The padding to add to both sides of the input.")]
+ public long Padding { get; set; } = 0;
+
+ ///
+ /// The additional size added to one side of the output shape.
+ ///
+ [Description("The additional size added to one side of the output shape.")]
+ public long OutputPadding { get; set; } = 0;
+
+ ///
+ /// The spacing between kernel elements.
+ ///
+ [Description("The spacing between kernel elements.")]
+ public long Dilation { get; set; } = 1;
+
+ ///
+ /// The mode of padding.
+ ///
+ [Description("The mode of padding.")]
+ public PaddingModes PaddingMode { get; set; } = PaddingModes.Zeros;
+
+ ///
+ /// The number of blocked connections from input channels to output channels.
+ ///
+ [Description("The number of blocked connections from input channels to output channels.")]
+ public long Groups { get; set; } = 1;
+
+ ///
+ /// If true, adds a learnable bias to the output.
+ ///
+ [Description("If true, adds a learnable bias to the output")]
+ public bool Bias { get; set; } = true;
+
+ ///
+ /// The desired device of the returned tensor.
+ ///
+ [XmlIgnore]
+ [Description("The desired device of the returned tensor")]
+ public Device Device { get; set; } = null;
+
+ ///
+ /// The desired data type of the returned tensor.
+ ///
+ [Description("The desired data type of the returned tensor")]
+ public ScalarType? Type { get; set; } = null;
+
+ ///
+ /// Creates a ConvTranspose1d module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(ConvTranspose1d(InChannels, OutChannels, KernelSize, Stride, Padding, OutputPadding, Dilation, PaddingMode, Groups, Bias, Device, Type));
+ }
+
+ ///
+ /// Creates a ConvTranspose1d module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => ConvTranspose1d(InChannels, OutChannels, KernelSize, Stride, Padding, OutputPadding, Dilation, PaddingMode, Groups, Bias, Device, Type));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Convolution/ConvTranspose2d.cs b/src/Bonsai.ML.Torch/NeuralNets/Convolution/ConvTranspose2d.cs
new file mode 100644
index 00000000..6307420a
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Convolution/ConvTranspose2d.cs
@@ -0,0 +1,117 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using System.Xml.Serialization;
+using TorchSharp;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Convolution;
+
+///
+/// Represents an operator that creates a 2D transposed convolution module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a 2D transposed convolution module.")]
+public class ConvTranspose2d
+{
+ ///
+ /// The number of input channels in the input tensor.
+ ///
+ [Description("The number of input channels in the input tensor.")]
+ public long InChannels { get; set; }
+
+ ///
+ /// The number of output channels produced by the convolution.
+ ///
+ [Description("The number of output channels produced by the convolution.")]
+ public long OutChannels { get; set; }
+
+ ///
+ /// The size of the convolution kernel.
+ ///
+ [Description("The size of the convolution kernel.")]
+ [TypeConverter(typeof(ValueTupleConverter))]
+ public (long, long) KernelSize { get; set; }
+
+ ///
+ /// The stride of the convolution.
+ ///
+ [Description("The stride of the convolution.")]
+ [TypeConverter(typeof(NullableValueTupleConverter))]
+ public (long, long)? Stride { get; set; } = null;
+
+ ///
+ /// The zero-padding added to both sides of each dimension in the input.
+ ///
+ [Description("Zero-padding added to both sides of each dimension in the input.")]
+ [TypeConverter(typeof(NullableValueTupleConverter))]
+ public (long, long)? Padding { get; set; } = null;
+
+ ///
+ /// The additional size added to one side of each dimension in the output shape.
+ ///
+ [Description("The additional size added to one side of each dimension in the output shape.")]
+ [TypeConverter(typeof(NullableValueTupleConverter))]
+ public (long, long)? OutputPadding { get; set; } = null;
+
+ ///
+ /// The spacing between kernel elements.
+ ///
+ [Description("The dilation parameter for the Conv1d module")]
+ [TypeConverter(typeof(NullableValueTupleConverter))]
+ public (long, long)? Dilation { get; set; } = null;
+
+ ///
+ /// The mode of padding.
+ ///
+ [Description("The mode of padding.")]
+ public PaddingModes PaddingMode { get; set; } = PaddingModes.Zeros;
+
+ ///
+ /// The number of blocked connections from input channels to output channels.
+ ///
+ [Description("The number of blocked connections from input channels to output channels.")]
+ public long Groups { get; set; } = 1;
+
+ ///
+ /// If true, adds a learnable bias to the output.
+ ///
+ [Description("If true, adds a learnable bias to the output")]
+ public bool Bias { get; set; } = true;
+
+ ///
+ /// The desired device of the returned tensor.
+ ///
+ [XmlIgnore]
+ [Description("The desired device of the returned tensor")]
+ public Device Device { get; set; } = null;
+
+ ///
+ /// The desired data type of the returned tensor.
+ ///
+ [Description("The desired data type of the returned tensor")]
+ public ScalarType? Type { get; set; } = null;
+
+ ///
+ /// Creates a ConvTranspose2d module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(ConvTranspose2d(InChannels, OutChannels, KernelSize, Stride, Padding, OutputPadding, Dilation, PaddingMode, Groups, Bias, Device, Type));
+ }
+
+ ///
+ /// Creates a ConvTranspose2d module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => ConvTranspose2d(InChannels, OutChannels, KernelSize, Stride, Padding, OutputPadding, Dilation, PaddingMode, Groups, Bias, Device, Type));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Convolution/ConvTranspose3d.cs b/src/Bonsai.ML.Torch/NeuralNets/Convolution/ConvTranspose3d.cs
new file mode 100644
index 00000000..c8388a76
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Convolution/ConvTranspose3d.cs
@@ -0,0 +1,117 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using System.Xml.Serialization;
+using TorchSharp;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Convolution;
+
+///
+/// Represents an operator that creates a 3D transposed convolution module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a 3D transposed convolution module.")]
+public class ConvTranspose3d
+{
+ ///
+ /// The number of input channels in the input tensor.
+ ///
+ [Description("The number of input channels in the input tensor.")]
+ public long InChannels { get; set; }
+
+ ///
+ /// The number of output channels produced by the convolution.
+ ///
+ [Description("The number of output channels produced by the convolution.")]
+ public long OutChannels { get; set; }
+
+ ///
+ /// The size of the convolution kernel.
+ ///
+ [Description("The size of the convolution kernel.")]
+ [TypeConverter(typeof(ValueTupleConverter))]
+ public (long, long, long) KernelSize { get; set; }
+
+ ///
+ /// The stride of the convolution.
+ ///
+ [Description("The stride of the convolution.")]
+ [TypeConverter(typeof(NullableValueTupleConverter))]
+ public (long, long, long)? Stride { get; set; } = null;
+
+ ///
+ /// The zero-padding added to both sides of each dimension in the input.
+ ///
+ [Description("The zero-padding added to both sides of each dimension in the input.")]
+ [TypeConverter(typeof(NullableValueTupleConverter))]
+ public (long, long, long)? Padding { get; set; } = null;
+
+ ///
+ /// The additional size added to one side of each dimension in the output shape.
+ ///
+ [Description("The additional size added to one side of each dimension in the output shape.")]
+ [TypeConverter(typeof(NullableValueTupleConverter))]
+ public (long, long, long)? OutputPadding { get; set; } = null;
+
+ ///
+ /// The spacing between kernel elements.
+ ///
+ [Description("The spacing between kernel elements.")]
+ [TypeConverter(typeof(NullableValueTupleConverter))]
+ public (long, long, long)? Dilation { get; set; } = null;
+
+ ///
+ /// The mode of padding.
+ ///
+ [Description("The mode of padding.")]
+ public PaddingModes PaddingMode { get; set; } = PaddingModes.Zeros;
+
+ ///
+ /// The number of blocked connections from input channels to output channels.
+ ///
+ [Description("The number of blocked connections from input channels to output channels.")]
+ public long Groups { get; set; } = 1;
+
+ ///
+ /// If true, adds a learnable bias to the output.
+ ///
+ [Description("If true, adds a learnable bias to the output")]
+ public bool Bias { get; set; } = true;
+
+ ///
+ /// The desired device of the returned tensor.
+ ///
+ [XmlIgnore]
+ [Description("The desired device of the returned tensor")]
+ public Device Device { get; set; } = null;
+
+ ///
+ /// The desired data type of the returned tensor.
+ ///
+ [Description("The desired data type of the returned tensor")]
+ public ScalarType? Type { get; set; } = null;
+
+ ///
+ /// Creates a ConvTranspose3d module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(ConvTranspose3d(InChannels, OutChannels, KernelSize, Stride, Padding, OutputPadding, Dilation, PaddingMode, Groups, Bias, Device, Type));
+ }
+
+ ///
+ /// Creates a ConvTranspose3d module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => ConvTranspose3d(InChannels, OutChannels, KernelSize, Stride, Padding, OutputPadding, Dilation, PaddingMode, Groups, Bias, Device, Type));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Convolution/Fold.cs b/src/Bonsai.ML.Torch/NeuralNets/Convolution/Fold.cs
new file mode 100644
index 00000000..113c8ccd
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Convolution/Fold.cs
@@ -0,0 +1,72 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Convolution;
+
+///
+/// Represents an operator that creates a fold module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a fold module.")]
+public class Fold
+{
+ ///
+ /// The shape of the spatial dimensions of the output tensor.
+ ///
+ [Description("The shape of the spatial dimensions of the output tensor.")]
+ [TypeConverter(typeof(ValueTupleConverter))]
+ public (long, long) OutputSize { get; set; }
+
+ ///
+ /// The size of the sliding blocks.
+ ///
+ [Description("The size of the sliding blocks.")]
+ [TypeConverter(typeof(ValueTupleConverter))]
+ public (long, long) KernelSize { get; set; }
+
+ ///
+ /// The stride of elements within the neighborhood.
+ ///
+ [Description("The stride of elements within the neighborhood.")]
+ [TypeConverter(typeof(NullableValueTupleConverter))]
+ public (long, long)? Dilation { get; set; } = null;
+
+ ///
+ /// The implicit zero-padding to be added on both sides of input.
+ ///
+ [Description("The implicit zero-padding to be added on both sides of input.")]
+ [TypeConverter(typeof(NullableValueTupleConverter))]
+ public (long, long)? Padding { get; set; } = null;
+
+ ///
+ /// The stride of the sliding blocks in the input tensor.
+ ///
+ [Description("The stride of the sliding blocks in the input tensor.")]
+ [TypeConverter(typeof(NullableValueTupleConverter))]
+ public (long, long)? Stride { get; set; } = null;
+
+ ///
+ /// Creates a Fold module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Fold(OutputSize, KernelSize, Dilation, Padding, Stride));
+ }
+
+ ///
+ /// Creates a Fold module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Fold(OutputSize, KernelSize, Dilation, Padding, Stride));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Convolution/Unfold.cs b/src/Bonsai.ML.Torch/NeuralNets/Convolution/Unfold.cs
new file mode 100644
index 00000000..d32aa810
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Convolution/Unfold.cs
@@ -0,0 +1,65 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Convolution;
+
+///
+/// Represents an operator that creates an unfold module.
+///
+///
+/// See for more information.
+///
+[Description("Creates an unfold module.")]
+public class Unfold
+{
+ ///
+ /// The size of the sliding blocks.
+ ///
+ [Description("The size of the sliding blocks.")]
+ [TypeConverter(typeof(ValueTupleConverter))]
+ public (long, long) KernelSize { get; set; }
+
+ ///
+ /// The stride of elements within the neighborhood.
+ ///
+ [Description("The stride of elements within the neighborhood.")]
+ [TypeConverter(typeof(NullableValueTupleConverter))]
+ public (long, long)? Dilation { get; set; } = null;
+
+ ///
+ /// The implicit zero-padding to be added on both sides of input.
+ ///
+ [Description("The implicit zero-padding to be added on both sides of input.")]
+ [TypeConverter(typeof(NullableValueTupleConverter))]
+ public (long, long)? Padding { get; set; } = null;
+
+ ///
+ /// The stride of the sliding blocks in the input tensor.
+ ///
+ [Description("The stride of the sliding blocks in the input tensor.")]
+ [TypeConverter(typeof(NullableValueTupleConverter))]
+ public (long, long)? Stride { get; set; } = null;
+
+ ///
+ /// Creates an Unfold module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Unfold(KernelSize, Dilation, Padding, Stride));
+ }
+
+ ///
+ /// Creates an Unfold module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Unfold(KernelSize, Dilation, Padding, Stride));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ConvolutionModuleBuilder.cs b/src/Bonsai.ML.Torch/NeuralNets/ConvolutionModuleBuilder.cs
new file mode 100644
index 00000000..ec9e4ae5
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ConvolutionModuleBuilder.cs
@@ -0,0 +1,48 @@
+using System.ComponentModel;
+using System.Xml.Serialization;
+
+namespace Bonsai.ML.Torch.NeuralNets;
+
+///
+/// Represents an operator that creates a torch module for convolution operations.
+///
+[XmlInclude(typeof(Convolution.Conv1d))]
+[XmlInclude(typeof(Convolution.Conv2d))]
+[XmlInclude(typeof(Convolution.Conv3d))]
+[XmlInclude(typeof(Convolution.ConvTranspose1d))]
+[XmlInclude(typeof(Convolution.ConvTranspose2d))]
+[XmlInclude(typeof(Convolution.ConvTranspose3d))]
+[XmlInclude(typeof(Convolution.Fold))]
+[XmlInclude(typeof(Convolution.Unfold))]
+[DefaultProperty(nameof(ConvolutionModule))]
+[Combinator]
+[Description("Creates a torch module for convolution operations.")]
+[WorkflowElementCategory(ElementCategory.Source)]
+public class ConvolutionModuleBuilder : ModuleCombinatorBuilder, INamedElement
+{
+ internal override string BuilderName => "ConvolutionModule";
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public ConvolutionModuleBuilder()
+ {
+ Module = new Convolution.Conv1d();
+ }
+
+ ///
+ /// Gets or sets the specific convolution module to create.
+ ///
+ [DesignOnly(true)]
+ [DisplayName("Module")]
+ [Externalizable(false)]
+ [RefreshProperties(RefreshProperties.All)]
+ [Category(nameof(CategoryAttribute.Design))]
+ [Description("The specific convolution module to create.")]
+ [TypeConverter(typeof(ModuleTypeConverter))]
+ public object ConvolutionModule
+ {
+ get => Module;
+ set => Module = value;
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Distance/CosineSimilarity.cs b/src/Bonsai.ML.Torch/NeuralNets/Distance/CosineSimilarity.cs
new file mode 100644
index 00000000..4e1de349
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Distance/CosineSimilarity.cs
@@ -0,0 +1,49 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Distance;
+
+///
+/// Represents an operator that creates a cosine similarity module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a cosine similarity module.")]
+public class CosineSimilarity
+{
+ ///
+ /// The dimension where cosine similarity is computed.
+ ///
+ [Description("The dimension where cosine similarity is computed.")]
+ public long Dim { get; set; } = 1;
+
+ ///
+ /// The value added to the denominator to avoid division by zero.
+ ///
+ [Description("The value added to the denominator to avoid division by zero.")]
+ public double Eps { get; set; } = 1E-08D;
+
+ ///
+ /// Creates a CosineSimilarity module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(CosineSimilarity(Dim, Eps));
+ }
+
+ ///
+ /// Creates a CosineSimilarity module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => CosineSimilarity(Dim, Eps));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Distance/PairwiseDistance.cs b/src/Bonsai.ML.Torch/NeuralNets/Distance/PairwiseDistance.cs
new file mode 100644
index 00000000..b269e2e2
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Distance/PairwiseDistance.cs
@@ -0,0 +1,55 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Distance;
+
+///
+/// Represents an operator that creates a pairwise distance module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a pairwise distance module.")]
+public class PairwiseDistance
+{
+ ///
+ /// The norm degree which can be positive or negative.
+ ///
+ [Description("The norm degree which can be positive or negative.")]
+ public double P { get; set; } = 2D;
+
+ ///
+ /// The value added to the denominator to avoid division by zero.
+ ///
+ [Description("The value added to the denominator to avoid division by zero.")]
+ public double Eps { get; set; } = 1E-06D;
+
+ ///
+ /// Determines whether or not to keep the vector dimension.
+ ///
+ [Description("Determines whether or not to keep the vector dimension.")]
+ public bool KeepDim { get; set; } = false;
+
+ ///
+ /// Creates a PairwiseDistance module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(PairwiseDistance(P, Eps, KeepDim));
+ }
+
+ ///
+ /// Creates a PairwiseDistance module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => PairwiseDistance(P, Eps, KeepDim));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/DistanceModuleBuilder.cs b/src/Bonsai.ML.Torch/NeuralNets/DistanceModuleBuilder.cs
new file mode 100644
index 00000000..aae8f651
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/DistanceModuleBuilder.cs
@@ -0,0 +1,42 @@
+using System.ComponentModel;
+using System.Xml.Serialization;
+
+namespace Bonsai.ML.Torch.NeuralNets;
+
+///
+/// Represents an operator that creates a module for distance computations.
+///
+[XmlInclude(typeof(Distance.CosineSimilarity))]
+[XmlInclude(typeof(Distance.PairwiseDistance))]
+[DefaultProperty(nameof(DistanceModule))]
+[Combinator]
+[Description("Creates a module for distance computations.")]
+[WorkflowElementCategory(ElementCategory.Source)]
+public class DistanceModuleBuilder : ModuleCombinatorBuilder, INamedElement
+{
+ internal override string BuilderName => "DistanceModule";
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public DistanceModuleBuilder()
+ {
+ Module = new Distance.CosineSimilarity();
+ }
+
+ ///
+ /// Gets or sets the specific distance module to create.
+ ///
+ [DesignOnly(true)]
+ [DisplayName("Module")]
+ [Externalizable(false)]
+ [RefreshProperties(RefreshProperties.All)]
+ [Category(nameof(CategoryAttribute.Design))]
+ [Description("The specific distance module to create.")]
+ [TypeConverter(typeof(ModuleTypeConverter))]
+ public object DistanceModule
+ {
+ get => Module;
+ set => Module = value;
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Dropout/AlphaDropout.cs b/src/Bonsai.ML.Torch/NeuralNets/Dropout/AlphaDropout.cs
new file mode 100644
index 00000000..85b52e95
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Dropout/AlphaDropout.cs
@@ -0,0 +1,49 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Dropout;
+
+///
+/// Represents an operator that creates an alpha dropout module.
+///
+///
+/// See for more information.
+///
+[Description("Creates an alpha dropout module.")]
+public class AlphaDropout
+{
+ ///
+ /// The probability of an element to be dropped.
+ ///
+ [Description("The probability of an element to be dropped.")]
+ public double Probability { get; set; } = 0.5D;
+
+ ///
+ /// If set to true, will do this operation in-place.
+ ///
+ [Description("If set to true, will do this operation in-place")]
+ public bool Inplace { get; set; } = false;
+
+ ///
+ /// Creates an AlphaDropout module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(AlphaDropout(Probability, Inplace));
+ }
+
+ ///
+ /// Creates an AlphaDropout module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => AlphaDropout(Probability, Inplace));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Dropout/Dropout.cs b/src/Bonsai.ML.Torch/NeuralNets/Dropout/Dropout.cs
new file mode 100644
index 00000000..cdda8cda
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Dropout/Dropout.cs
@@ -0,0 +1,49 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Dropout;
+
+///
+/// Represents an operator that creates a dropout module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a dropout module.")]
+public class Dropout
+{
+ ///
+ /// The probability of an element to be zeroed.
+ ///
+ [Description("The probability of an element to be zeroed.")]
+ public double Probability { get; set; } = 0.5D;
+
+ ///
+ /// If set to true, will do this operation in-place.
+ ///
+ [Description("If set to true, will do this operation in-place.")]
+ public bool Inplace { get; set; } = false;
+
+ ///
+ /// Creates a Dropout module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(nn.Dropout(Probability, Inplace));
+ }
+
+ ///
+ /// Creates a Dropout module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => nn.Dropout(Probability, Inplace));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Dropout/Dropout1d.cs b/src/Bonsai.ML.Torch/NeuralNets/Dropout/Dropout1d.cs
new file mode 100644
index 00000000..899affb6
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Dropout/Dropout1d.cs
@@ -0,0 +1,49 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Dropout;
+
+///
+/// Represents an operator that creates a 1D dropout module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a 1D dropout module.")]
+public class Dropout1d
+{
+ ///
+ /// The probability of an element to be zeroed.
+ ///
+ [Description("The probability of an element to be zeroed.")]
+ public double Probability { get; set; } = 0.5D;
+
+ ///
+ /// If set to true, will do this operation in-place.
+ ///
+ [Description("If set to true, will do this operation in-place.")]
+ public bool Inplace { get; set; } = false;
+
+ ///
+ /// Creates a Dropout1d module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(nn.Dropout1d(Probability, Inplace));
+ }
+
+ ///
+ /// Creates a Dropout1d module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => nn.Dropout1d(Probability, Inplace));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Dropout/Dropout2d.cs b/src/Bonsai.ML.Torch/NeuralNets/Dropout/Dropout2d.cs
new file mode 100644
index 00000000..4618e315
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Dropout/Dropout2d.cs
@@ -0,0 +1,49 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Dropout;
+
+///
+/// Represents an operator that creates a 2D dropout module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a 2D dropout module.")]
+public class Dropout2d
+{
+ ///
+ /// The probability of an element to be zeroed.
+ ///
+ [Description("The probability of an element to be zeroed.")]
+ public double Probability { get; set; } = 0.5D;
+
+ ///
+ /// If set to true, will do this operation in-place.
+ ///
+ [Description("If set to true, will do this operation in-place.")]
+ public bool Inplace { get; set; } = false;
+
+ ///
+ /// Creates a Dropout2d module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(nn.Dropout2d(Probability, Inplace));
+ }
+
+ ///
+ /// Creates a Dropout2d module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => nn.Dropout2d(Probability, Inplace));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Dropout/Dropout3d.cs b/src/Bonsai.ML.Torch/NeuralNets/Dropout/Dropout3d.cs
new file mode 100644
index 00000000..627a240c
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Dropout/Dropout3d.cs
@@ -0,0 +1,49 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Dropout;
+
+///
+/// Represents an operator that creates a 3D dropout module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a 3D dropout module.")]
+public class Dropout3d
+{
+ ///
+ /// The probability of an element to be zeroed.
+ ///
+ [Description("The probability of an element to be zeroed.")]
+ public double Probability { get; set; } = 0.5D;
+
+ ///
+ /// If set to true, will do this operation in-place.
+ ///
+ [Description("If set to true, will do this operation in-place.")]
+ public bool Inplace { get; set; } = false;
+
+ ///
+ /// Creates a Dropout3d module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(nn.Dropout3d(Probability, Inplace));
+ }
+
+ ///
+ /// Creates a Dropout3d module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => nn.Dropout3d(Probability, Inplace));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Dropout/FeatureAlphaDropout.cs b/src/Bonsai.ML.Torch/NeuralNets/Dropout/FeatureAlphaDropout.cs
new file mode 100644
index 00000000..dd827a83
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Dropout/FeatureAlphaDropout.cs
@@ -0,0 +1,43 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Dropout;
+
+///
+/// Represents an operator that creates a feature alpha dropout module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a feature alpha dropout module.")]
+public class FeatureAlphaDropout
+{
+ ///
+ /// The probability of an element to be zeroed.
+ ///
+ [Description("The probability of an element to be zeroed.")]
+ public double Probability { get; set; } = 0.5D;
+
+ ///
+ /// Creates a FeatureAlphaDropout module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(FeatureAlphaDropout(Probability));
+ }
+
+ ///
+ /// Creates a FeatureAlphaDropout module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => FeatureAlphaDropout(Probability));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/DropoutModuleBuilder.cs b/src/Bonsai.ML.Torch/NeuralNets/DropoutModuleBuilder.cs
new file mode 100644
index 00000000..b111af08
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/DropoutModuleBuilder.cs
@@ -0,0 +1,46 @@
+using System.ComponentModel;
+using System.Xml.Serialization;
+
+namespace Bonsai.ML.Torch.NeuralNets;
+
+///
+/// Represents an operator that creates a module for dropout computations.
+///
+[XmlInclude(typeof(Dropout.AlphaDropout))]
+[XmlInclude(typeof(Dropout.Dropout))]
+[XmlInclude(typeof(Dropout.Dropout1d))]
+[XmlInclude(typeof(Dropout.Dropout2d))]
+[XmlInclude(typeof(Dropout.Dropout3d))]
+[XmlInclude(typeof(Dropout.FeatureAlphaDropout))]
+[DefaultProperty(nameof(DropoutModule))]
+[Combinator]
+[Description("Creates a module for dropout computations.")]
+[WorkflowElementCategory(ElementCategory.Source)]
+public class DropoutModuleBuilder : ModuleCombinatorBuilder, INamedElement
+{
+ internal override string BuilderName => "DropoutModule";
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public DropoutModuleBuilder()
+ {
+ Module = new Dropout.AlphaDropout();
+ }
+
+ ///
+ /// Gets or sets the specific dropout module to create.
+ ///
+ [DesignOnly(true)]
+ [DisplayName("Module")]
+ [Externalizable(false)]
+ [RefreshProperties(RefreshProperties.All)]
+ [Category(nameof(CategoryAttribute.Design))]
+ [Description("The specific dropout module to create.")]
+ [TypeConverter(typeof(ModuleTypeConverter))]
+ public object DropoutModule
+ {
+ get => Module;
+ set => Module = value;
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Flatten/Flatten.cs b/src/Bonsai.ML.Torch/NeuralNets/Flatten/Flatten.cs
new file mode 100644
index 00000000..6f343414
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Flatten/Flatten.cs
@@ -0,0 +1,49 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Flatten;
+
+///
+/// Represents an operator that creates a Flatten module.
+///
+///
+/// See for more information.
+///
+[Description("Creates a Flatten module.")]
+public class Flatten
+{
+ ///
+ /// The first dimension to flatten.
+ ///
+ [Description("The first dimension to flatten.")]
+ public long StartDim { get; set; } = 1;
+
+ ///
+ /// The last dimension to flatten.
+ ///
+ [Description("The last dimension to flatten.")]
+ public long EndDim { get; set; } = -1;
+
+ ///
+ /// Creates a Flatten module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(nn.Flatten(StartDim, EndDim));
+ }
+
+ ///
+ /// Creates a Flatten module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => nn.Flatten(StartDim, EndDim));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Flatten/Unflatten.cs b/src/Bonsai.ML.Torch/NeuralNets/Flatten/Unflatten.cs
new file mode 100644
index 00000000..b00683de
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/Flatten/Unflatten.cs
@@ -0,0 +1,50 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.nn;
+
+namespace Bonsai.ML.Torch.NeuralNets.Flatten;
+
+///
+/// Creates an Unflatten module.
+///
+///
+/// See for more information.
+///
+[Description("Creates an Unflatten module.")]
+public class Unflatten
+{
+ ///
+ /// The dimension to unflatten.
+ ///
+ [Description("The dimension to unflatten.")]
+ public long Dim { get; set; }
+
+ ///
+ /// The new shape of the unflattened dimension.
+ ///
+ [Description("The new shape of the unflattened dimension.")]
+ [TypeConverter(typeof(UnidimensionalArrayConverter))]
+ public long[] UnflattenedSize { get; set; }
+
+ ///
+ /// Creates an Unflatten module.
+ ///
+ ///
+ public IObservable Process()
+ {
+ return Observable.Return(Unflatten(Dim, UnflattenedSize));
+ }
+
+ ///
+ /// Creates an Unflatten module.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(_ => Unflatten(Dim, UnflattenedSize));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/FlattenModuleBuilder.cs b/src/Bonsai.ML.Torch/NeuralNets/FlattenModuleBuilder.cs
new file mode 100644
index 00000000..d5ef301c
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/FlattenModuleBuilder.cs
@@ -0,0 +1,42 @@
+using System.ComponentModel;
+using System.Xml.Serialization;
+
+namespace Bonsai.ML.Torch.NeuralNets;
+
+///
+/// Represents an operator that creates a module for flattening tensors.
+///
+[XmlInclude(typeof(Flatten.Flatten))]
+[XmlInclude(typeof(Flatten.Unflatten))]
+[DefaultProperty(nameof(FlattenModule))]
+[Combinator]
+[Description("Creates a module for flattening tensors.")]
+[WorkflowElementCategory(ElementCategory.Source)]
+public class FlattenModuleBuilder : ModuleCombinatorBuilder, INamedElement
+{
+ internal override string BuilderName => "FlattenModule";
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public FlattenModuleBuilder()
+ {
+ Module = new Flatten.Flatten();
+ }
+
+ ///
+ /// Gets or sets the specific flatten module to create.
+ ///
+ [DesignOnly(true)]
+ [DisplayName("Module")]
+ [Externalizable(false)]
+ [RefreshProperties(RefreshProperties.All)]
+ [Category(nameof(CategoryAttribute.Design))]
+ [Description("The specific flatten module to create.")]
+ [TypeConverter(typeof(ModuleTypeConverter))]
+ public object FlattenModule
+ {
+ get => Module;
+ set => Module = value;
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/Forward.cs b/src/Bonsai.ML.Torch/NeuralNets/Forward.cs
deleted file mode 100644
index fdab2387..00000000
--- a/src/Bonsai.ML.Torch/NeuralNets/Forward.cs
+++ /dev/null
@@ -1,35 +0,0 @@
-using System;
-using System.ComponentModel;
-using System.Reactive.Linq;
-using static TorchSharp.torch;
-using System.Xml.Serialization;
-
-namespace Bonsai.ML.Torch.NeuralNets
-{
- ///
- /// Runs forward inference on the input tensor using the specified model.
- ///
- [Combinator]
- [ResetCombinator]
- [Description("Runs forward inference on the input tensor using the specified model.")]
- [WorkflowElementCategory(ElementCategory.Transform)]
- public class Forward
- {
- ///
- /// The model to use for inference.
- ///
- [XmlIgnore]
- public ITorchModule Model { get; set; }
-
- ///
- /// Runs forward inference on the input tensor using the specified model.
- ///
- ///
- ///
- public IObservable Process(IObservable source)
- {
- Model.Module.eval();
- return source.Select(Model.Forward);
- }
- }
-}
\ No newline at end of file
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ForwardBuilder.cs b/src/Bonsai.ML.Torch/NeuralNets/ForwardBuilder.cs
new file mode 100644
index 00000000..79074611
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/ForwardBuilder.cs
@@ -0,0 +1,122 @@
+using System;
+using System.Linq;
+using System.Linq.Expressions;
+using System.Collections.Generic;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using System.Reflection;
+using static TorchSharp.torch;
+using static TorchSharp.torch.jit;
+using Bonsai.Expressions;
+using TorchSharp;
+
+namespace Bonsai.ML.Torch.NeuralNets;
+
+///
+/// Represents an operator that runs forward inference on the input using the specified module.
+///
+[Combinator]
+[Description("Runs forward inference on the input using the specified module.")]
+[WorkflowElementCategory(ElementCategory.Transform)]
+public class ForwardBuilder : SingleArgumentExpressionBuilder
+{
+ ///
+ public override Expression Build(IEnumerable arguments)
+ {
+ var argumentList = arguments.First().Type.GetGenericArguments()[0];
+ var inputArgs = argumentList.GetGenericArguments()[0];
+ var moduleArg = argumentList.GetGenericArguments()[1];
+
+ var sourceType = typeof(Tuple<,>).MakeGenericType(inputArgs, moduleArg);
+
+ var selectMethod = typeof(Observable)
+ .GetMethods()
+ .First(m =>
+ m.Name == nameof(Observable.Select) &&
+ m.IsGenericMethodDefinition &&
+ m.GetParameters().Length == 2);
+
+ System.Type? moduleType = null;
+ Expression? selectExpression = null;
+ System.Type? resultType = null;
+
+ for (var t = moduleArg; t != null && t != typeof(object); t = t.BaseType)
+ {
+ if (t.IsGenericType && t.GetGenericTypeDefinition().Name.StartsWith("Module`"))
+ {
+ moduleType = t;
+
+ resultType = typeof(Tuple<,>).MakeGenericType(inputArgs, moduleType);
+
+ var tuple1 = Expression.Parameter(sourceType, "t");
+ var tuple1Item1 = Expression.Property(tuple1, "Item1");
+ var tuple1Item2 = Expression.Property(tuple1, "Item2");
+
+ var moduleConversionExpression = Expression.Convert(tuple1Item2, moduleType);
+ var tupleCreateMethod = typeof(Tuple)
+ .GetMethods()
+ .Single(m =>
+ m.Name == nameof(Tuple.Create) &&
+ m.IsGenericMethodDefinition &&
+ m.GetGenericArguments().Length == 2 &&
+ m.GetParameters().Length == 2)
+ .MakeGenericMethod(inputArgs, moduleType);
+
+ var newTuple = Expression.Call(tupleCreateMethod, tuple1Item1, moduleConversionExpression);
+
+ var selector = Expression.Lambda(newTuple, tuple1);
+
+ selectExpression = Expression.Call(selectMethod.MakeGenericMethod(sourceType, resultType), arguments.First(), selector);
+ break;
+ }
+
+ else if (t.IsGenericType && t.GetGenericTypeDefinition().Name.StartsWith("ScriptModule`"))
+ {
+ moduleType = t;
+ resultType = sourceType;
+ selectExpression = arguments.First();
+ break;
+ }
+ }
+
+ if (moduleType is null)
+ throw new InvalidOperationException("The specified module type is not a valid TorchSharp module.");
+
+ var tuple = Expression.Parameter(resultType, "t");
+ var item1 = Expression.Property(tuple, "Item1");
+ var item2 = Expression.Property(tuple, "Item2");
+
+ List forwardCallArgs = [];
+
+ if (inputArgs.IsGenericType && inputArgs.GetGenericTypeDefinition().Name.StartsWith("Tuple`"))
+ {
+ var inputArgsTypes = inputArgs.GetGenericArguments();
+ for (int i = 0; i < inputArgsTypes.Length; i++)
+ {
+ var itemN = Expression.Property(item1, $"Item{i + 1}");
+ forwardCallArgs.Add(itemN);
+ }
+ }
+ else
+ {
+ forwardCallArgs.Add(item1);
+ }
+
+ var moduleMethods = moduleType
+ .GetMethods(BindingFlags.Instance | BindingFlags.Public)
+ .Where(m => m.Name == "forward" &&
+ m.IsPublic &&
+ m.GetParameters().Length == forwardCallArgs.Count &&
+ m.GetParameters().Select(p => p.ParameterType).SequenceEqual(forwardCallArgs.Select(a => a.Type)));
+
+ if (!moduleMethods.Any())
+ throw new InvalidOperationException("The module does not contain a matching forward method.");
+
+ var forwardMethod = moduleMethods.First();
+
+ var forwardCall = Expression.Call(item2, forwardMethod, forwardCallArgs);
+ var forwardLambda = Expression.Lambda(forwardCall, tuple);
+
+ return Expression.Call(selectMethod.MakeGenericMethod(resultType, forwardMethod.ReturnType), selectExpression, forwardLambda);
+ }
+}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs b/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs
deleted file mode 100644
index 5cde6f73..00000000
--- a/src/Bonsai.ML.Torch/NeuralNets/ITorchModule.cs
+++ /dev/null
@@ -1,22 +0,0 @@
-using static TorchSharp.torch;
-
-namespace Bonsai.ML.Torch.NeuralNets
-{
- ///
- /// Represents an interface for a Torch module.
- ///
- public interface ITorchModule
- {
- ///
- /// The module.
- ///
- public nn.Module Module { get; }
-
- ///
- /// Runs forward inference on the input tensor using the specified model.
- ///
- ///
- ///
- public Tensor Forward(Tensor tensor);
- }
-}
diff --git a/src/Bonsai.ML.Torch/NeuralNets/LearningRateScheduler/Constant.cs b/src/Bonsai.ML.Torch/NeuralNets/LearningRateScheduler/Constant.cs
new file mode 100644
index 00000000..84fae13b
--- /dev/null
+++ b/src/Bonsai.ML.Torch/NeuralNets/LearningRateScheduler/Constant.cs
@@ -0,0 +1,52 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.optim.lr_scheduler;
+
+namespace Bonsai.ML.Torch.NeuralNets.LearningRateScheduler;
+
+///
+/// Represents an operator that creates a constant learning rate scheduler.
+///
+///
+/// See for more information.
+///
+[Description("Creates a constant learning rate scheduler.")]
+public class Constant
+{
+ ///
+ /// The number that the learning rate will be multiplied by until the milestone.
+ ///
+ [Description("The number that the learning rate will be multiplied by until the milestone.")]
+ public double Factor { get; set; } = 0.3333333333333333D;
+
+ ///
+ /// The number of steps that the scheduler multiplies the learning rate by the factor.
+ ///
+ [Description("The number of steps that the scheduler multiplies the learning rate by the factor.")]
+ public int TotalIters { get; set; } = 5;
+
+ ///
+ /// The index of the last epoch.
+ ///
+ [Description("The index of the last epoch.")]
+ public int LastEpoch { get; set; } = -1;
+
+ ///
+ /// Determines whether to write a message to stdout for each update.
+ ///
+ [Description("Determines whether to write a message to stdout for each update.")]
+ public bool Verbose { get; set; } = false;
+
+ ///
+ /// Creates a ConstantLR scheduler for the input optimizer.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable