From f9ea88e5e8510959189d7348226018626fd10ff8 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Sat, 1 Nov 2025 21:42:19 +0000 Subject: [PATCH 01/16] Added Cdf and InverseCdf classes for probability distributions --- src/Bonsai.ML.Torch/Distributions/Cdf.cs | 101 ++++++++++++++++++ .../Distributions/InverseCdf.cs | 100 +++++++++++++++++ 2 files changed, 201 insertions(+) create mode 100644 src/Bonsai.ML.Torch/Distributions/Cdf.cs create mode 100644 src/Bonsai.ML.Torch/Distributions/InverseCdf.cs diff --git a/src/Bonsai.ML.Torch/Distributions/Cdf.cs b/src/Bonsai.ML.Torch/Distributions/Cdf.cs new file mode 100644 index 00000000..32678399 --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Cdf.cs @@ -0,0 +1,101 @@ +using Bonsai; +using System; +using System.Reactive.Linq; +using System.ComponentModel; +using static TorchSharp.torch; +using static TorchSharp.torch.distributions; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.Distributions; + +/// +/// Creates a cumulative distribution function (CDF) from the input distribution. +/// +[Combinator] +[ResetCombinator] +[Description("Creates a cumulative distribution function (CDF) from the input distribution.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Cdf : TensorContainerBase +{ + /// + /// Initializes a new instance of the class. + /// + public Cdf() + { + RegisterTensor( + () => _values, + v => _values = v); + } + + private Tensor _values; + /// + /// The values at which to evaluate the CDF. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + [Description("The values at which to evaluate the CDF.")] + public Tensor Values + { + get => _values; + set => _values = value; + } + + /// + /// The values in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Values))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ValuesXml + { + get => TensorConverter.ConvertToString(_values, Type); + set => _values = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// The input distribution. + /// + [XmlIgnore] + [Description("The input distribution.")] + public Distribution Distribution { get; set; } + + /// + /// Processes the input distribution to compute the CDF at the specified values. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(distribution => distribution.cdf(Values)); + } + + /// + /// Processes the input values to compute the CDF using the specified distribution. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Distribution.cdf); + } + + /// + /// Processes the input tuples of distribution and values to compute the CDF. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select((input) => input.Item1.cdf(input.Item2)); + } + + /// + /// Processes the input tuples of values and distribution to compute the CDF. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select((input) => input.Item2.cdf(input.Item1)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/InverseCdf.cs b/src/Bonsai.ML.Torch/Distributions/InverseCdf.cs new file mode 100644 index 00000000..69ca0116 --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/InverseCdf.cs @@ -0,0 +1,100 @@ +using Bonsai; +using System; +using System.Reactive.Linq; +using System.ComponentModel; +using static TorchSharp.torch; +using static TorchSharp.torch.distributions; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.Distributions; + +/// +/// Creates an inverse cumulative distribution function (inverse CDF) from the input distribution. +/// +[Combinator] +[ResetCombinator] +[Description("Creates an inverse cumulative distribution function (inverse CDF) from the input distribution.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class InverseCdf : TensorContainerBase +{ + /// + /// Initializes a new instance of the class. + /// + public InverseCdf() + { + RegisterTensor( + () => _values, + v => _values = v); + } + + private Tensor _values; + /// + /// The values at which to evaluate the inverse CDF. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + [Description("The values at which to evaluate the inverse CDF.")] + public Tensor Values + { + get => _values; + set => _values = value; + } + + /// + /// The values in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Values))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ValuesXml + { + get => TensorConverter.ConvertToString(_values, Type); + set => _values = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// The input distribution. + /// + [XmlIgnore] + public Distribution Distribution { get; set; } + + /// + /// Processes the input distribution to compute the inverse CDF at the specified values. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(distribution => distribution.icdf(Values)); + } + + /// + /// Processes the input values to compute the inverse CDF using the specified distribution. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Distribution.icdf); + } + + /// + /// Processes the input tuples of distribution and values to compute the inverse CDF. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select((input) => input.Item1.icdf(input.Item2)); + } + + /// + /// Processes the input tuples of values and distribution to compute the inverse CDF. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select((input) => input.Item2.icdf(input.Item1)); + } +} \ No newline at end of file From 08f73b77ab1194e412c7b5806753986bb2e37527 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 31 Jul 2025 18:36:01 +0100 Subject: [PATCH 02/16] Added a `Generator` class to create a reproducible RNG with specified seed value --- src/Bonsai.ML.Torch/Generator.cs | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 src/Bonsai.ML.Torch/Generator.cs diff --git a/src/Bonsai.ML.Torch/Generator.cs b/src/Bonsai.ML.Torch/Generator.cs new file mode 100644 index 00000000..fd6b9727 --- /dev/null +++ b/src/Bonsai.ML.Torch/Generator.cs @@ -0,0 +1,20 @@ +using Bonsai; +using static TorchSharp.torch; +using TorchSharp; +using System; +using System.Reactive.Linq; + +namespace Bonsai.ML.Torch; + +[Combinator] +public class Generator +{ + public Device Device { get; set; } + + public ulong Seed { get; set; } = 0; + + public IObservable Process() + { + return Observable.Return(new torch.Generator(Seed, Device)); + } +} From 09b3488ebc2e4f4fd11ca6de5355e43be752f4ec Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 31 Jul 2025 18:36:44 +0100 Subject: [PATCH 03/16] Added a collection of probability distributions --- .../Distributions/Bernoulli.cs | 44 +++++++++++++++++ src/Bonsai.ML.Torch/Distributions/Beta.cs | 47 +++++++++++++++++++ src/Bonsai.ML.Torch/Distributions/Binomial.cs | 47 +++++++++++++++++++ .../Distributions/MultivariateNormal.cs | 47 +++++++++++++++++++ src/Bonsai.ML.Torch/Distributions/Poisson.cs | 44 +++++++++++++++++ 5 files changed, 229 insertions(+) create mode 100644 src/Bonsai.ML.Torch/Distributions/Bernoulli.cs create mode 100644 src/Bonsai.ML.Torch/Distributions/Beta.cs create mode 100644 src/Bonsai.ML.Torch/Distributions/Binomial.cs create mode 100644 src/Bonsai.ML.Torch/Distributions/MultivariateNormal.cs create mode 100644 src/Bonsai.ML.Torch/Distributions/Poisson.cs diff --git a/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs b/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs new file mode 100644 index 00000000..4602b8d0 --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs @@ -0,0 +1,44 @@ +using static TorchSharp.torch; +using TorchSharp; +using Bonsai; +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.Distributions; + +[Combinator] +[Description("")] +[WorkflowElementCategory(ElementCategory.Source)] +public class Bernoulli : IScalarTypeProvider +{ + [XmlIgnore] + [Browsable(false)] + public ScalarType Type => ScalarType.Float32; + + [TypeConverter(typeof(TensorConverter))] + public Tensor Probability { get; set; } + + [XmlIgnore] + public torch.Generator Generator { get; set; } = null; + + public IObservable Process() + { + return Observable.Return(distributions.Bernoulli(Probability, generator: Generator)); + } + + public IObservable Process(IObservable source) + { + return source.Select((generator) => + { + Generator = generator; + return distributions.Bernoulli(Probability, generator: Generator); + }); + } + + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.Bernoulli(Probability, generator: Generator)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Beta.cs b/src/Bonsai.ML.Torch/Distributions/Beta.cs new file mode 100644 index 00000000..a1868a89 --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Beta.cs @@ -0,0 +1,47 @@ +using static TorchSharp.torch; +using TorchSharp; +using Bonsai; +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.Distributions; + +[Combinator] +[Description("")] +[WorkflowElementCategory(ElementCategory.Source)] +public class Beta : IScalarTypeProvider +{ + [XmlIgnore] + [Browsable(false)] + public ScalarType Type => ScalarType.Float32; + + [TypeConverter(typeof(TensorConverter))] + public Tensor Concentration1 { get; set; } + + [TypeConverter(typeof(TensorConverter))] + public Tensor Concentration0 { get; set; } + + [XmlIgnore] + public torch.Generator Generator { get; set; } = null; + + public IObservable Process() + { + return Observable.Return(distributions.Beta(Concentration1, Concentration0, generator: Generator)); + } + + public IObservable Process(IObservable source) + { + return source.Select((generator) => + { + Generator = generator; + return distributions.Beta(Concentration1, Concentration0, generator: Generator); + }); + } + + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.Beta(Concentration1, Concentration0, generator: Generator)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Binomial.cs b/src/Bonsai.ML.Torch/Distributions/Binomial.cs new file mode 100644 index 00000000..92f5544e --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Binomial.cs @@ -0,0 +1,47 @@ +using static TorchSharp.torch; +using TorchSharp; +using Bonsai; +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.Distributions; + +[Combinator] +[Description("")] +[WorkflowElementCategory(ElementCategory.Source)] +public class Binomial : IScalarTypeProvider +{ + [XmlIgnore] + [Browsable(false)] + public ScalarType Type => ScalarType.Float32; + + [TypeConverter(typeof(TensorConverter))] + public Tensor Count { get; set; } + + [TypeConverter(typeof(TensorConverter))] + public Tensor Probabilities { get; set; } + + [XmlIgnore] + public torch.Generator Generator { get; set; } = null; + + public IObservable Process() + { + return Observable.Return(distributions.Binomial(Count, Probabilities, generator: Generator)); + } + + public IObservable Process(IObservable source) + { + return source.Select((generator) => + { + Generator = generator; + return distributions.Binomial(Count, Probabilities, generator: Generator); + }); + } + + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.Binomial(Count, Probabilities, generator: Generator)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/MultivariateNormal.cs b/src/Bonsai.ML.Torch/Distributions/MultivariateNormal.cs new file mode 100644 index 00000000..a828356b --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/MultivariateNormal.cs @@ -0,0 +1,47 @@ +using static TorchSharp.torch; +using TorchSharp; +using Bonsai; +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.Distributions; + +[Combinator] +[Description("")] +[WorkflowElementCategory(ElementCategory.Source)] +public class MultivariateNormal : IScalarTypeProvider +{ + [XmlIgnore] + [Browsable(false)] + public ScalarType Type => ScalarType.Float32; + + [TypeConverter(typeof(TensorConverter))] + public Tensor Mean { get; set; } + + [TypeConverter(typeof(TensorConverter))] + public Tensor Covariance { get; set; } + + [XmlIgnore] + public torch.Generator Generator { get; set; } = null; + + public IObservable Process() + { + return Observable.Return(distributions.MultivariateNormal(Mean, Covariance, generator: Generator)); + } + + public IObservable Process(IObservable source) + { + return source.Select((generator) => + { + Generator = generator; + return distributions.MultivariateNormal(Mean, Covariance, generator: Generator); + }); + } + + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.MultivariateNormal(Mean, Covariance, generator: Generator)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Poisson.cs b/src/Bonsai.ML.Torch/Distributions/Poisson.cs new file mode 100644 index 00000000..39dcec81 --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Poisson.cs @@ -0,0 +1,44 @@ +using static TorchSharp.torch; +using TorchSharp; +using Bonsai; +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.Distributions; + +[Combinator] +[Description("")] +[WorkflowElementCategory(ElementCategory.Source)] +public class Poisson : IScalarTypeProvider +{ + [XmlIgnore] + [Browsable(false)] + public ScalarType Type => ScalarType.Float32; + + [TypeConverter(typeof(TensorConverter))] + public Tensor Rate { get; set; } + + [XmlIgnore] + public torch.Generator Generator { get; set; } = null; + + public IObservable Process() + { + return Observable.Return(distributions.Poisson(Rate, generator: Generator)); + } + + public IObservable Process(IObservable source) + { + return source.Select((generator) => + { + Generator = generator; + return distributions.Poisson(Rate, generator: Generator); + }); + } + + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.Poisson(Rate, generator: Generator)); + } +} \ No newline at end of file From b9661655b23fae223c01db06872ae6085eef55c3 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 1 Aug 2025 18:37:51 +0100 Subject: [PATCH 04/16] Added a class for seeding a torch compatible RNG --- src/Bonsai.ML.Torch/Generator.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Bonsai.ML.Torch/Generator.cs b/src/Bonsai.ML.Torch/Generator.cs index fd6b9727..01a051bd 100644 --- a/src/Bonsai.ML.Torch/Generator.cs +++ b/src/Bonsai.ML.Torch/Generator.cs @@ -3,12 +3,14 @@ using TorchSharp; using System; using System.Reactive.Linq; +using System.Xml.Serialization; namespace Bonsai.ML.Torch; [Combinator] public class Generator { + [XmlIgnore] public Device Device { get; set; } public ulong Seed { get; set; } = 0; From f174c2cf246f5a548d8d4f664488219501179f61 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 6 Aug 2025 17:13:42 +0100 Subject: [PATCH 05/16] Began adding a number of useful distributions and functionality to the distributions namespace --- .../Distributions/Bernoulli.cs | 23 ++++-- src/Bonsai.ML.Torch/Distributions/Beta.cs | 28 ++++++- src/Bonsai.ML.Torch/Distributions/Binomial.cs | 28 ++++++- .../Distributions/Categorical.cs | 57 ++++++++++++++ src/Bonsai.ML.Torch/Distributions/Cauchy.cs | 72 ++++++++++++++++++ .../Distributions/Dirichlet.cs | 57 ++++++++++++++ .../Distributions/Exponential.cs | 58 +++++++++++++++ src/Bonsai.ML.Torch/Distributions/Gamma.cs | 74 +++++++++++++++++++ .../Distributions/Geometric.cs | 58 +++++++++++++++ .../Distributions/LogProbability.cs | 33 +++++++++ .../Distributions/MultivariateNormal.cs | 27 +++++++ src/Bonsai.ML.Torch/Distributions/Poisson.cs | 14 ++++ .../Distributions/ReparametrizedSample.cs | 21 ++++++ src/Bonsai.ML.Torch/Distributions/Sample.cs | 21 ++++++ 14 files changed, 564 insertions(+), 7 deletions(-) create mode 100644 src/Bonsai.ML.Torch/Distributions/Categorical.cs create mode 100644 src/Bonsai.ML.Torch/Distributions/Cauchy.cs create mode 100644 src/Bonsai.ML.Torch/Distributions/Dirichlet.cs create mode 100644 src/Bonsai.ML.Torch/Distributions/Exponential.cs create mode 100644 src/Bonsai.ML.Torch/Distributions/Gamma.cs create mode 100644 src/Bonsai.ML.Torch/Distributions/Geometric.cs create mode 100644 src/Bonsai.ML.Torch/Distributions/LogProbability.cs create mode 100644 src/Bonsai.ML.Torch/Distributions/ReparametrizedSample.cs create mode 100644 src/Bonsai.ML.Torch/Distributions/Sample.cs diff --git a/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs b/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs index 4602b8d0..b3186318 100644 --- a/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs +++ b/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs @@ -9,23 +9,36 @@ namespace Bonsai.ML.Torch.Distributions; [Combinator] +[ResetCombinator] [Description("")] [WorkflowElementCategory(ElementCategory.Source)] public class Bernoulli : IScalarTypeProvider { - [XmlIgnore] [Browsable(false)] public ScalarType Type => ScalarType.Float32; + [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - public Tensor Probability { get; set; } + public Tensor Probabilities { get; set; } + + /// + /// The values of the probabilities in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Probabilities))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ProbabilitiesXml + { + get => TensorConverter.ConvertToString(Probabilities, Type); + set => Probabilities = TensorConverter.ConvertFromString(value, Type); + } [XmlIgnore] public torch.Generator Generator { get; set; } = null; public IObservable Process() { - return Observable.Return(distributions.Bernoulli(Probability, generator: Generator)); + return Observable.Return(distributions.Bernoulli(Probabilities, generator: Generator)); } public IObservable Process(IObservable source) @@ -33,12 +46,12 @@ public class Bernoulli : IScalarTypeProvider return source.Select((generator) => { Generator = generator; - return distributions.Bernoulli(Probability, generator: Generator); + return distributions.Bernoulli(Probabilities, generator: Generator); }); } public IObservable Process(IObservable source) { - return source.Select(_ => distributions.Bernoulli(Probability, generator: Generator)); + return source.Select(_ => distributions.Bernoulli(Probabilities, generator: Generator)); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Beta.cs b/src/Bonsai.ML.Torch/Distributions/Beta.cs index a1868a89..ba3acaf7 100644 --- a/src/Bonsai.ML.Torch/Distributions/Beta.cs +++ b/src/Bonsai.ML.Torch/Distributions/Beta.cs @@ -9,20 +9,46 @@ namespace Bonsai.ML.Torch.Distributions; [Combinator] +[ResetCombinator] [Description("")] [WorkflowElementCategory(ElementCategory.Source)] public class Beta : IScalarTypeProvider { - [XmlIgnore] [Browsable(false)] public ScalarType Type => ScalarType.Float32; + [XmlIgnore] [TypeConverter(typeof(TensorConverter))] public Tensor Concentration1 { get; set; } + /// + /// The values of concentration 1 in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Concentration1))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string Concentration1Xml + { + get => TensorConverter.ConvertToString(Concentration1, Type); + set => Concentration1 = TensorConverter.ConvertFromString(value, Type); + } + + [XmlIgnore] [TypeConverter(typeof(TensorConverter))] public Tensor Concentration0 { get; set; } + /// + /// The values of concentration 0 in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Concentration0))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string Concentration0Xml + { + get => TensorConverter.ConvertToString(Concentration0, Type); + set => Concentration0 = TensorConverter.ConvertFromString(value, Type); + } + [XmlIgnore] public torch.Generator Generator { get; set; } = null; diff --git a/src/Bonsai.ML.Torch/Distributions/Binomial.cs b/src/Bonsai.ML.Torch/Distributions/Binomial.cs index 92f5544e..2e7e46d4 100644 --- a/src/Bonsai.ML.Torch/Distributions/Binomial.cs +++ b/src/Bonsai.ML.Torch/Distributions/Binomial.cs @@ -9,20 +9,46 @@ namespace Bonsai.ML.Torch.Distributions; [Combinator] +[ResetCombinator] [Description("")] [WorkflowElementCategory(ElementCategory.Source)] public class Binomial : IScalarTypeProvider { - [XmlIgnore] [Browsable(false)] public ScalarType Type => ScalarType.Float32; + [XmlIgnore] [TypeConverter(typeof(TensorConverter))] public Tensor Count { get; set; } + /// + /// The values of count in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Count))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string CountXml + { + get => TensorConverter.ConvertToString(Count, Type); + set => Count = TensorConverter.ConvertFromString(value, Type); + } + + [XmlIgnore] [TypeConverter(typeof(TensorConverter))] public Tensor Probabilities { get; set; } + /// + /// The values of probabilities in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Probabilities))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ProbabilitiesXml + { + get => TensorConverter.ConvertToString(Probabilities, Type); + set => Probabilities = TensorConverter.ConvertFromString(value, Type); + } + [XmlIgnore] public torch.Generator Generator { get; set; } = null; diff --git a/src/Bonsai.ML.Torch/Distributions/Categorical.cs b/src/Bonsai.ML.Torch/Distributions/Categorical.cs new file mode 100644 index 00000000..19a4d438 --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Categorical.cs @@ -0,0 +1,57 @@ +using static TorchSharp.torch; +using TorchSharp; +using Bonsai; +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.Distributions; + +[Combinator] +[ResetCombinator] +[Description("")] +[WorkflowElementCategory(ElementCategory.Source)] +public class Categorical : IScalarTypeProvider +{ + [Browsable(false)] + public ScalarType Type => ScalarType.Float32; + + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor Probabilities { get; set; } + + /// + /// The values of probabilities in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Probabilities))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ProbabilitiesXml + { + get => TensorConverter.ConvertToString(Probabilities, Type); + set => Probabilities = TensorConverter.ConvertFromString(value, Type); + } + + [XmlIgnore] + public torch.Generator Generator { get; set; } = null; + + public IObservable Process() + { + return Observable.Return(distributions.Categorical(Probabilities, generator: Generator)); + } + + public IObservable Process(IObservable source) + { + return source.Select((generator) => + { + Generator = generator; + return distributions.Categorical(Probabilities, generator: Generator); + }); + } + + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.Categorical(Probabilities, generator: Generator)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Cauchy.cs b/src/Bonsai.ML.Torch/Distributions/Cauchy.cs new file mode 100644 index 00000000..336a0586 --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Cauchy.cs @@ -0,0 +1,72 @@ +using static TorchSharp.torch; +using TorchSharp; +using Bonsai; +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.Distributions; + +[Combinator] +[ResetCombinator] +[Description("")] +[WorkflowElementCategory(ElementCategory.Source)] +public class Cauchy : IScalarTypeProvider +{ + [Browsable(false)] + public ScalarType Type => ScalarType.Float32; + + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor Locations { get; set; } + + /// + /// The values of the locations in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Locations))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string LocationsXml + { + get => TensorConverter.ConvertToString(Locations, Type); + set => Locations = TensorConverter.ConvertFromString(value, Type); + } + + [TypeConverter(typeof(TensorConverter))] + public Tensor Scales { get; set; } + + /// + /// The values of the scales in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Scales))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ScalesXml + { + get => TensorConverter.ConvertToString(Scales, Type); + set => Scales = TensorConverter.ConvertFromString(value, Type); + } + + [XmlIgnore] + public torch.Generator Generator { get; set; } = null; + + public IObservable Process() + { + return Observable.Return(distributions.Cauchy(Locations, Scales, generator: Generator)); + } + + public IObservable Process(IObservable source) + { + return source.Select((generator) => + { + Generator = generator; + return distributions.Cauchy(Locations, Scales, generator: Generator); + }); + } + + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.Cauchy(Locations, Scales, generator: Generator)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Dirichlet.cs b/src/Bonsai.ML.Torch/Distributions/Dirichlet.cs new file mode 100644 index 00000000..0a30fcda --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Dirichlet.cs @@ -0,0 +1,57 @@ +using static TorchSharp.torch; +using TorchSharp; +using Bonsai; +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.Distributions; + +[Combinator] +[ResetCombinator] +[Description("")] +[WorkflowElementCategory(ElementCategory.Source)] +public class Dirichlet : IScalarTypeProvider +{ + [XmlIgnore] + [Browsable(false)] + public ScalarType Type => ScalarType.Float32; + + [TypeConverter(typeof(TensorConverter))] + public Tensor Concentration { get; set; } + + /// + /// The values of the concentration in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Concentration))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ConcentrationXml + { + get => TensorConverter.ConvertToString(Concentration, Type); + set => Concentration = TensorConverter.ConvertFromString(value, Type); + } + + [XmlIgnore] + public torch.Generator Generator { get; set; } = null; + + public IObservable Process() + { + return Observable.Return(distributions.Dirichlet(Concentration, generator: Generator)); + } + + public IObservable Process(IObservable source) + { + return source.Select((generator) => + { + Generator = generator; + return distributions.Dirichlet(Concentration, generator: Generator); + }); + } + + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.Dirichlet(Concentration, generator: Generator)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Exponential.cs b/src/Bonsai.ML.Torch/Distributions/Exponential.cs new file mode 100644 index 00000000..4ecea185 --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Exponential.cs @@ -0,0 +1,58 @@ +using static TorchSharp.torch; +using TorchSharp; +using Bonsai; +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.Distributions; + +[Combinator] +[ResetCombinator] +[Description("")] +[WorkflowElementCategory(ElementCategory.Source)] +public class Exponential : IScalarTypeProvider +{ + [XmlIgnore] + [Browsable(false)] + public ScalarType Type => ScalarType.Float32; + + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor Rate { get; set; } + + /// + /// The values of the rates in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Rate))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string RateXml + { + get => TensorConverter.ConvertToString(Rate, Type); + set => Rate = TensorConverter.ConvertFromString(value, Type); + } + + [XmlIgnore] + public torch.Generator Generator { get; set; } = null; + + public IObservable Process() + { + return Observable.Return(distributions.Exponential(Rate, generator: Generator)); + } + + public IObservable Process(IObservable source) + { + return source.Select((generator) => + { + Generator = generator; + return distributions.Exponential(Rate, generator: Generator); + }); + } + + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.Exponential(Rate, generator: Generator)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Gamma.cs b/src/Bonsai.ML.Torch/Distributions/Gamma.cs new file mode 100644 index 00000000..adcb8baf --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Gamma.cs @@ -0,0 +1,74 @@ +using static TorchSharp.torch; +using TorchSharp; +using Bonsai; +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.Distributions; + +[Combinator] +[ResetCombinator] +[Description("")] +[WorkflowElementCategory(ElementCategory.Source)] +public class Gamma : IScalarTypeProvider +{ + [XmlIgnore] + [Browsable(false)] + public ScalarType Type => ScalarType.Float32; + + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor Concentration { get; set; } + + /// + /// The values of the concentration in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Concentration))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ConcentrationXml + { + get => TensorConverter.ConvertToString(Concentration, Type); + set => Concentration = TensorConverter.ConvertFromString(value, Type); + } + + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor Rate { get; set; } + + /// + /// The values of the rate in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Rate))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string RateXml + { + get => TensorConverter.ConvertToString(Rate, Type); + set => Rate = TensorConverter.ConvertFromString(value, Type); + } + + [XmlIgnore] + public torch.Generator Generator { get; set; } = null; + + public IObservable Process() + { + return Observable.Return(distributions.Gamma(Concentration, Rate, generator: Generator)); + } + + public IObservable Process(IObservable source) + { + return source.Select((generator) => + { + Generator = generator; + return distributions.Gamma(Concentration, Rate, generator: Generator); + }); + } + + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.Gamma(Concentration, Rate, generator: Generator)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Geometric.cs b/src/Bonsai.ML.Torch/Distributions/Geometric.cs new file mode 100644 index 00000000..876211c1 --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Geometric.cs @@ -0,0 +1,58 @@ +using static TorchSharp.torch; +using TorchSharp; +using Bonsai; +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.Distributions; + +[Combinator] +[ResetCombinator] +[Description("")] +[WorkflowElementCategory(ElementCategory.Source)] +public class Geometric : IScalarTypeProvider +{ + [XmlIgnore] + [Browsable(false)] + public ScalarType Type => ScalarType.Float32; + + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor Probabilities { get; set; } + + /// + /// The values of the probabilities in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Probabilities))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ProbabilitiesXml + { + get => TensorConverter.ConvertToString(Probabilities, Type); + set => Probabilities = TensorConverter.ConvertFromString(value, Type); + } + + [XmlIgnore] + public torch.Generator Generator { get; set; } = null; + + public IObservable Process() + { + return Observable.Return(distributions.Geometric(Probabilities, generator: Generator)); + } + + public IObservable Process(IObservable source) + { + return source.Select((generator) => + { + Generator = generator; + return distributions.Geometric(Probabilities, generator: Generator); + }); + } + + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.Geometric(Probabilities, generator: Generator)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/LogProbability.cs b/src/Bonsai.ML.Torch/Distributions/LogProbability.cs new file mode 100644 index 00000000..0b08fc72 --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/LogProbability.cs @@ -0,0 +1,33 @@ +using Bonsai; +using System; +using System.Reactive.Linq; +using System.ComponentModel; +using static TorchSharp.torch; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.Distributions; + +[Combinator] +[ResetCombinator] +[Description("")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class LogProbability +{ + [XmlIgnore] + public Tensor Values { get; set; } + + public IObservable Process(IObservable source) + { + return source.Select(distribution => distribution.log_prob(Values)); + } + + public IObservable Process(IObservable> source) + { + return source.Select((input) => input.Item1.log_prob(input.Item2)); + } + + public IObservable Process(IObservable> source) + { + return source.Select((input) => input.Item2.log_prob(input.Item1)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/MultivariateNormal.cs b/src/Bonsai.ML.Torch/Distributions/MultivariateNormal.cs index a828356b..83dac869 100644 --- a/src/Bonsai.ML.Torch/Distributions/MultivariateNormal.cs +++ b/src/Bonsai.ML.Torch/Distributions/MultivariateNormal.cs @@ -9,6 +9,7 @@ namespace Bonsai.ML.Torch.Distributions; [Combinator] +[ResetCombinator] [Description("")] [WorkflowElementCategory(ElementCategory.Source)] public class MultivariateNormal : IScalarTypeProvider @@ -17,12 +18,38 @@ public class MultivariateNormal : IScalarTypeProvider [Browsable(false)] public ScalarType Type => ScalarType.Float32; + [XmlIgnore] [TypeConverter(typeof(TensorConverter))] public Tensor Mean { get; set; } + /// + /// The values of the means in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Mean))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string MeanXml + { + get => TensorConverter.ConvertToString(Mean, Type); + set => Mean = TensorConverter.ConvertFromString(value, Type); + } + + [XmlIgnore] [TypeConverter(typeof(TensorConverter))] public Tensor Covariance { get; set; } + /// + /// The values of the covariance matrix in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Covariance))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string CovarianceXml + { + get => TensorConverter.ConvertToString(Covariance, Type); + set => Covariance = TensorConverter.ConvertFromString(value, Type); + } + [XmlIgnore] public torch.Generator Generator { get; set; } = null; diff --git a/src/Bonsai.ML.Torch/Distributions/Poisson.cs b/src/Bonsai.ML.Torch/Distributions/Poisson.cs index 39dcec81..57302905 100644 --- a/src/Bonsai.ML.Torch/Distributions/Poisson.cs +++ b/src/Bonsai.ML.Torch/Distributions/Poisson.cs @@ -9,6 +9,7 @@ namespace Bonsai.ML.Torch.Distributions; [Combinator] +[ResetCombinator] [Description("")] [WorkflowElementCategory(ElementCategory.Source)] public class Poisson : IScalarTypeProvider @@ -17,9 +18,22 @@ public class Poisson : IScalarTypeProvider [Browsable(false)] public ScalarType Type => ScalarType.Float32; + [XmlIgnore] [TypeConverter(typeof(TensorConverter))] public Tensor Rate { get; set; } + /// + /// The values of the rates in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Rate))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string RateXml + { + get => TensorConverter.ConvertToString(Rate, Type); + set => Rate = TensorConverter.ConvertFromString(value, Type); + } + [XmlIgnore] public torch.Generator Generator { get; set; } = null; diff --git a/src/Bonsai.ML.Torch/Distributions/ReparametrizedSample.cs b/src/Bonsai.ML.Torch/Distributions/ReparametrizedSample.cs new file mode 100644 index 00000000..b3e14972 --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/ReparametrizedSample.cs @@ -0,0 +1,21 @@ +using Bonsai; +using System; +using System.Reactive.Linq; +using System.ComponentModel; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Distributions; + +[Combinator] +[Description("")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class ReparametrizedSample +{ + [TypeConverter(typeof(UnidimensionalArrayConverter))] + public long[] SampleShape { get; set; } + + public IObservable Process(IObservable source) + { + return source.Select(distribution => distribution.rsample(SampleShape)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Sample.cs b/src/Bonsai.ML.Torch/Distributions/Sample.cs new file mode 100644 index 00000000..dee288f1 --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Sample.cs @@ -0,0 +1,21 @@ +using Bonsai; +using System; +using System.Reactive.Linq; +using System.ComponentModel; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Distributions; + +[Combinator] +[Description("")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Sample +{ + [TypeConverter(typeof(UnidimensionalArrayConverter))] + public long[] SampleShape { get; set; } + + public IObservable Process(IObservable source) + { + return source.Select(distribution => distribution.sample(SampleShape)); + } +} \ No newline at end of file From a16e6eda68febbec16d1d67da6d58613be26d3f2 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 8 Oct 2025 18:28:18 +0100 Subject: [PATCH 06/16] Created a dedicated namespace for working with random numbers and added seed generator, as well as other random number factory methods --- src/Bonsai.ML.Torch/Generator.cs | 22 ----- src/Bonsai.ML.Torch/Random/Generator.cs | 38 ++++++++ src/Bonsai.ML.Torch/Random/RandomIntegers.cs | 86 +++++++++++++++++++ src/Bonsai.ML.Torch/Random/RandomNormal.cs | 71 +++++++++++++++ .../Random/RandomPermutation.cs | 60 +++++++++++++ src/Bonsai.ML.Torch/Random/RandomUniform.cs | 72 ++++++++++++++++ 6 files changed, 327 insertions(+), 22 deletions(-) delete mode 100644 src/Bonsai.ML.Torch/Generator.cs create mode 100644 src/Bonsai.ML.Torch/Random/Generator.cs create mode 100644 src/Bonsai.ML.Torch/Random/RandomIntegers.cs create mode 100644 src/Bonsai.ML.Torch/Random/RandomNormal.cs create mode 100644 src/Bonsai.ML.Torch/Random/RandomPermutation.cs create mode 100644 src/Bonsai.ML.Torch/Random/RandomUniform.cs diff --git a/src/Bonsai.ML.Torch/Generator.cs b/src/Bonsai.ML.Torch/Generator.cs deleted file mode 100644 index 01a051bd..00000000 --- a/src/Bonsai.ML.Torch/Generator.cs +++ /dev/null @@ -1,22 +0,0 @@ -using Bonsai; -using static TorchSharp.torch; -using TorchSharp; -using System; -using System.Reactive.Linq; -using System.Xml.Serialization; - -namespace Bonsai.ML.Torch; - -[Combinator] -public class Generator -{ - [XmlIgnore] - public Device Device { get; set; } - - public ulong Seed { get; set; } = 0; - - public IObservable Process() - { - return Observable.Return(new torch.Generator(Seed, Device)); - } -} diff --git a/src/Bonsai.ML.Torch/Random/Generator.cs b/src/Bonsai.ML.Torch/Random/Generator.cs new file mode 100644 index 00000000..444c36cb --- /dev/null +++ b/src/Bonsai.ML.Torch/Random/Generator.cs @@ -0,0 +1,38 @@ +using Bonsai; +using static TorchSharp.torch; +using TorchSharp; +using System; +using System.Reactive.Linq; +using System.Xml.Serialization; +using System.ComponentModel; + +namespace Bonsai.ML.Torch.Random; + +/// +/// Creates a random number generator with the specified seed and device. +/// +[Combinator] +[Description("Creates a random number generator with the specified seed and device.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class Generator +{ + /// + /// The device on which to create the generator. + /// + [XmlIgnore] + public Device Device { get; set; } + + /// + /// The seed for the random number generator. + /// + public ulong Seed { get; set; } = 0; + + /// + /// Creates a random number generator with the specified seed and device. + /// + /// + public IObservable Process() + { + return Observable.Return(new torch.Generator(Seed, Device)); + } +} diff --git a/src/Bonsai.ML.Torch/Random/RandomIntegers.cs b/src/Bonsai.ML.Torch/Random/RandomIntegers.cs new file mode 100644 index 00000000..d4ee04c8 --- /dev/null +++ b/src/Bonsai.ML.Torch/Random/RandomIntegers.cs @@ -0,0 +1,86 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Random; + +/// +/// Creates a tensor filled with random integers sampled from a uniform distribution over the +/// interval [MinValue, MaxValue). +/// +[Combinator] +[ResetCombinator] +[Description("Creates a tensor filled with random integers.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class RandomIntegers +{ + /// + /// The size of the tensor. + /// + [Description("The size of the tensor.")] + [TypeConverter(typeof(UnidimensionalArrayConverter))] + public long[] Size { get; set; } = new long[0]; + + /// + /// The minimum value of the random integers inclusive. + /// + [Description("The minimum value of the random integers.")] + public int MinValue { get; set; } = 0; + + /// + /// The maximum value of the random integers exclusive. + /// + [Description("The maximum value of the random integers.")] + public int MaxValue { get; set; } = 100; + + /// + /// The data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + public ScalarType? Type { get; set; } = null; + + /// + /// The device on which to create the tensor. + /// + [Description("The device on which to create the tensor.")] + [XmlIgnore] + public Device Device { get; set; } = null; + + /// + /// The random number generator to use. + /// + public torch.Generator Generator { get; set; } = null; + + /// + /// Creates a tensor filled with random integers sampled from a uniform distribution over the + /// interval [MinValue, MaxValue). + /// + public IObservable Process() + { + return Observable.Return(randint(MinValue, MaxValue, Size, dtype: Type, device: Device, generator: Generator)); + } + + + /// + /// Generates an observable sequence of tensors filled with random integers for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => randint_like(value, MinValue, MaxValue, dtype: Type, device: Device)); + } + + /// + /// Generates an observable sequence of tensors filled with random integers for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => randint(MinValue, MaxValue, Size, dtype: Type, device: Device, generator: Generator)); + } +} diff --git a/src/Bonsai.ML.Torch/Random/RandomNormal.cs b/src/Bonsai.ML.Torch/Random/RandomNormal.cs new file mode 100644 index 00000000..a7c45691 --- /dev/null +++ b/src/Bonsai.ML.Torch/Random/RandomNormal.cs @@ -0,0 +1,71 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Random; + +/// +/// Creates a tensor filled with random floats sampled from a normal distribution with mean 0 and variance 1. +/// +[Combinator] +[ResetCombinator] +[Description("Creates a tensor filled with random floats.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class RandomFloats +{ + /// + /// The size of the tensor. + /// + [Description("The size of the tensor.")] + [TypeConverter(typeof(UnidimensionalArrayConverter))] + public long[] Size { get; set; } = new long[0]; + + /// + /// The data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + public ScalarType? Type { get; set; } = null; + + /// + /// The device on which to create the tensor. + /// + [Description("The device on which to create the tensor.")] + [XmlIgnore] + public Device Device { get; set; } = null; + + /// + /// The random number generator to use. + /// + public torch.Generator Generator { get; set; } = null; + + /// + /// Creates a tensor filled with random values sampled from a normal distribution with mean 0 and variance 1. + /// + public IObservable Process() + { + return Observable.Return(randn(Size, dtype: Type, device: Device, generator: Generator)); + } + + /// + /// Generates an observable sequence of tensors filled with random values for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => randn_like(value, dtype: Type, device: Device)); + } + + /// + /// Generates an observable sequence of tensors filled with random values for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => randn(Size, dtype: Type, device: Device, generator: Generator)); + } +} diff --git a/src/Bonsai.ML.Torch/Random/RandomPermutation.cs b/src/Bonsai.ML.Torch/Random/RandomPermutation.cs new file mode 100644 index 00000000..cef31b8d --- /dev/null +++ b/src/Bonsai.ML.Torch/Random/RandomPermutation.cs @@ -0,0 +1,60 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Random; + +/// +/// Creates a 1D tensor of a given size with a random permutation of integers from 0 to size - 1. +/// +[Combinator] +[ResetCombinator] +[Description("Creates a 1D tensor of a given size with a random permutation of integers from 0 to size - 1.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class RandomPermutation +{ + /// + /// The size of the tensor. + /// + [Description("The size of the tensor.")] + public long Size { get; set; } = 0; + + /// + /// The data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + public ScalarType? Type { get; set; } = null; + + /// + /// The device on which to create the tensor. + /// + [Description("The device on which to create the tensor.")] + [XmlIgnore] + public Device Device { get; set; } = null; + + /// + /// The random number generator to use. + /// + public torch.Generator Generator { get; set; } = null; + + /// + /// Creates a tensor of a given size with a random permutation of integers from 0 to size - 1. + /// + public IObservable Process() + { + return Observable.Return(randperm(Size, dtype: Type, device: Device, generator: Generator)); + } + + /// + /// Generates an observable sequence of tensors filled with random values for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => randperm(Size, dtype: Type, device: Device, generator: Generator)); + } +} diff --git a/src/Bonsai.ML.Torch/Random/RandomUniform.cs b/src/Bonsai.ML.Torch/Random/RandomUniform.cs new file mode 100644 index 00000000..865fc091 --- /dev/null +++ b/src/Bonsai.ML.Torch/Random/RandomUniform.cs @@ -0,0 +1,72 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Random; + +/// +/// Creates a tensor filled with random values sampled from a uniform distribution over the interval [0, 1). +/// +[Combinator] +[ResetCombinator] +[Description("Creates a tensor filled with random floats.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class RandomUniform +{ + /// + /// The size of the tensor. + /// + [Description("The size of the tensor.")] + [TypeConverter(typeof(UnidimensionalArrayConverter))] + public long[] Size { get; set; } = new long[0]; + + /// + /// The data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + public ScalarType? Type { get; set; } = null; + + /// + /// The device on which to create the tensor. + /// + [Description("The device on which to create the tensor.")] + [XmlIgnore] + public Device Device { get; set; } = null; + + /// + /// The random number generator to use. + /// + public torch.Generator Generator { get; set; } = null; + + /// + /// Creates a tensor filled with random values sampled from a uniform distribution over the interval [0, 1). + /// + public IObservable Process() + { + return Observable.Return(rand(Size, dtype: Type, device: Device, generator: Generator)); + } + + /// + /// Generates an observable sequence of tensors filled with random values for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => rand_like(value, dtype: Type, device: Device)); + } + + /// + /// Generates an observable sequence of tensors filled with random values for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => rand(Size, dtype: Type, device: Device, generator: Generator)); + } + +} From d106ff67e3047185ff9f8c8fc90b1396515adb21 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 8 Oct 2025 18:36:07 +0100 Subject: [PATCH 07/16] Added overload to generator to support emitting generators based on elements arriving from input sequence --- src/Bonsai.ML.Torch/Random/Generator.cs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/Bonsai.ML.Torch/Random/Generator.cs b/src/Bonsai.ML.Torch/Random/Generator.cs index 444c36cb..1a86f047 100644 --- a/src/Bonsai.ML.Torch/Random/Generator.cs +++ b/src/Bonsai.ML.Torch/Random/Generator.cs @@ -35,4 +35,12 @@ public class Generator { return Observable.Return(new torch.Generator(Seed, Device)); } + + /// + /// Generates an observable sequence of random number generators for each element of the input sequence. + /// + public IObservable Process(IObservable source) + { + return source.Select(value => new torch.Generator(Seed, Device)); + } } From 4227dfc413e44046cc299419a9095306280fcfd1 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 8 Oct 2025 18:37:06 +0100 Subject: [PATCH 08/16] Add support for using generator from input sequence and ignore in XML serialization --- src/Bonsai.ML.Torch/Random/RandomIntegers.cs | 14 ++++++++++++++ src/Bonsai.ML.Torch/Random/RandomNormal.cs | 15 +++++++++++++++ src/Bonsai.ML.Torch/Random/RandomPermutation.cs | 15 +++++++++++++++ src/Bonsai.ML.Torch/Random/RandomUniform.cs | 15 +++++++++++++++ 4 files changed, 59 insertions(+) diff --git a/src/Bonsai.ML.Torch/Random/RandomIntegers.cs b/src/Bonsai.ML.Torch/Random/RandomIntegers.cs index d4ee04c8..6bf36da9 100644 --- a/src/Bonsai.ML.Torch/Random/RandomIntegers.cs +++ b/src/Bonsai.ML.Torch/Random/RandomIntegers.cs @@ -52,6 +52,7 @@ public class RandomIntegers /// /// The random number generator to use. /// + [XmlIgnore] public torch.Generator Generator { get; set; } = null; /// @@ -63,6 +64,19 @@ public IObservable Process() return Observable.Return(randint(MinValue, MaxValue, Size, dtype: Type, device: Device, generator: Generator)); } + /// + /// Generates an observable sequence of tensors filled with random integers and uses the input generator. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => + { + Generator = value; + return randint(MinValue, MaxValue, Size, dtype: Type, device: Device, generator: Generator); + }); + } /// /// Generates an observable sequence of tensors filled with random integers for each element of the input sequence. diff --git a/src/Bonsai.ML.Torch/Random/RandomNormal.cs b/src/Bonsai.ML.Torch/Random/RandomNormal.cs index a7c45691..d26f5ee3 100644 --- a/src/Bonsai.ML.Torch/Random/RandomNormal.cs +++ b/src/Bonsai.ML.Torch/Random/RandomNormal.cs @@ -39,6 +39,7 @@ public class RandomFloats /// /// The random number generator to use. /// + [XmlIgnore] public torch.Generator Generator { get; set; } = null; /// @@ -49,6 +50,20 @@ public IObservable Process() return Observable.Return(randn(Size, dtype: Type, device: Device, generator: Generator)); } + /// + /// Generates an observable sequence of tensors filled with random values and uses the input generator. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => + { + Generator = value; + return randn(Size, dtype: Type, device: Device, generator: Generator); + }); + } + /// /// Generates an observable sequence of tensors filled with random values for each element of the input sequence. /// diff --git a/src/Bonsai.ML.Torch/Random/RandomPermutation.cs b/src/Bonsai.ML.Torch/Random/RandomPermutation.cs index cef31b8d..3b097722 100644 --- a/src/Bonsai.ML.Torch/Random/RandomPermutation.cs +++ b/src/Bonsai.ML.Torch/Random/RandomPermutation.cs @@ -38,6 +38,7 @@ public class RandomPermutation /// /// The random number generator to use. /// + [XmlIgnore] public torch.Generator Generator { get; set; } = null; /// @@ -48,6 +49,20 @@ public IObservable Process() return Observable.Return(randperm(Size, dtype: Type, device: Device, generator: Generator)); } + /// + /// Generates an observable sequence of tensors filled with random values and uses the input generator. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => + { + Generator = value; + return randperm(Size, dtype: Type, device: Device, generator: Generator); + }); + } + /// /// Generates an observable sequence of tensors filled with random values for each element of the input sequence. /// diff --git a/src/Bonsai.ML.Torch/Random/RandomUniform.cs b/src/Bonsai.ML.Torch/Random/RandomUniform.cs index 865fc091..9eb1e83e 100644 --- a/src/Bonsai.ML.Torch/Random/RandomUniform.cs +++ b/src/Bonsai.ML.Torch/Random/RandomUniform.cs @@ -39,6 +39,7 @@ public class RandomUniform /// /// The random number generator to use. /// + [XmlIgnore] public torch.Generator Generator { get; set; } = null; /// @@ -49,6 +50,20 @@ public IObservable Process() return Observable.Return(rand(Size, dtype: Type, device: Device, generator: Generator)); } + /// + /// Generates an observable sequence of tensors filled with random values and uses the input generator. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => + { + Generator = value; + return rand(Size, dtype: Type, device: Device, generator: Generator); + }); + } + /// /// Generates an observable sequence of tensors filled with random values for each element of the input sequence. /// From 832fe72b093254fedc3e90ef0f64ed48bd8a334f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Sat, 1 Nov 2025 21:40:44 +0000 Subject: [PATCH 09/16] Updated `Distributions` with doc strings and container base --- .../Distributions/Bernoulli.cs | 60 ++++++++++++--- src/Bonsai.ML.Torch/Distributions/Beta.cs | 76 +++++++++++++++--- src/Bonsai.ML.Torch/Distributions/Binomial.cs | 71 ++++++++++++++--- .../Distributions/Categorical.cs | 56 ++++++++++++-- src/Bonsai.ML.Torch/Distributions/Cauchy.cs | 76 +++++++++++++++--- .../Distributions/Dirichlet.cs | 58 +++++++++++--- .../Distributions/Exponential.cs | 58 +++++++++++--- src/Bonsai.ML.Torch/Distributions/Gamma.cs | 77 +++++++++++++++---- .../Distributions/Geometric.cs | 58 +++++++++++--- .../Distributions/MultivariateNormal.cs | 77 +++++++++++++++---- src/Bonsai.ML.Torch/Distributions/Poisson.cs | 58 +++++++++++--- 11 files changed, 607 insertions(+), 118 deletions(-) diff --git a/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs b/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs index b3186318..dff6f7ee 100644 --- a/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs +++ b/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs @@ -1,6 +1,5 @@ using static TorchSharp.torch; using TorchSharp; -using Bonsai; using System; using System.ComponentModel; using System.Reactive.Linq; @@ -8,18 +7,38 @@ namespace Bonsai.ML.Torch.Distributions; +/// +/// Creates a Bernoulli probability distribution parameterized by event probabilities. +/// Emits a TorchSharp distribution module that can be sampled or queried for log-probabilities. +/// [Combinator] [ResetCombinator] -[Description("")] +[Description("Creates a Bernoulli distribution with event probabilities and emits a TorchSharp distribution module.")] [WorkflowElementCategory(ElementCategory.Source)] -public class Bernoulli : IScalarTypeProvider +public class Bernoulli : TensorContainerBase { - [Browsable(false)] - public ScalarType Type => ScalarType.Float32; + /// + /// Initializes a new instance of the class. + /// + public Bernoulli() + { + RegisterTensor( + () => _probabilities, + value => _probabilities = value); + } + private Tensor _probabilities; + /// + /// Event probabilities p in [0, 1]. Can be a scalar or a tensor; the shape determines the batch/event shape. + /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - public Tensor Probabilities { get; set; } + [Description("Event probabilities p in [0, 1]. Can be a scalar or a tensor; shape sets the batch/event shape of the distribution.")] + public Tensor Probabilities + { + get => _probabilities; + set => _probabilities = value; + } /// /// The values of the probabilities in XML string format. @@ -29,19 +48,32 @@ public class Bernoulli : IScalarTypeProvider [EditorBrowsable(EditorBrowsableState.Never)] public string ProbabilitiesXml { - get => TensorConverter.ConvertToString(Probabilities, Type); - set => Probabilities = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(_probabilities, Type); + set => _probabilities = TensorConverter.ConvertFromString(value, Type); } + /// + /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// [XmlIgnore] - public torch.Generator Generator { get; set; } = null; + public Generator Generator { get; set; } = null; + /// + /// Creates a distribution using the configured and optional . + /// + /// An observable that emits the constructed Bernoulli distribution. public IObservable Process() { return Observable.Return(distributions.Bernoulli(Probabilities, generator: Generator)); } - public IObservable Process(IObservable source) + /// + /// Creates a distribution for each incoming RNG , + /// updating and passing it to TorchSharp. + /// + /// Observable sequence of random generators to use. + /// An observable sequence of Bernoulli distributions. + public IObservable Process(IObservable source) { return source.Select((generator) => { @@ -50,6 +82,14 @@ public string ProbabilitiesXml }); } + /// + /// For each element of the source stream, emits a distribution + /// constructed from the configured and current . + /// The source values are ignored and used only for timing. + /// + /// The type of the triggering source sequence. + /// Trigger sequence; each element causes a new distribution to be emitted. + /// An observable sequence of Bernoulli distributions. public IObservable Process(IObservable source) { return source.Select(_ => distributions.Bernoulli(Probabilities, generator: Generator)); diff --git a/src/Bonsai.ML.Torch/Distributions/Beta.cs b/src/Bonsai.ML.Torch/Distributions/Beta.cs index ba3acaf7..a9aeda50 100644 --- a/src/Bonsai.ML.Torch/Distributions/Beta.cs +++ b/src/Bonsai.ML.Torch/Distributions/Beta.cs @@ -8,18 +8,42 @@ namespace Bonsai.ML.Torch.Distributions; +/// +/// Creates a Beta probability distribution parameterized by two concentration parameters (alpha, beta). +/// Emits a TorchSharp distribution module that can be sampled or queried for probabilities. +/// [Combinator] [ResetCombinator] -[Description("")] +[Description("Creates a Beta distribution with concentration parameters (alpha, beta).")] [WorkflowElementCategory(ElementCategory.Source)] -public class Beta : IScalarTypeProvider +public class Beta : TensorContainerBase { - [Browsable(false)] - public ScalarType Type => ScalarType.Float32; + /// + /// Initializes a new instance of the class. + /// + public Beta() + { + RegisterTensor( + () => _concentration1, + value => _concentration1 = value); + + RegisterTensor( + () => _concentration0, + value => _concentration0 = value); + } + private Tensor _concentration1; + /// + /// Concentration parameter alpha (> 0). Can be a scalar or tensor; the shape determines the batch/event shape. + /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - public Tensor Concentration1 { get; set; } + [Description("Concentration alpha (> 0). Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] + public Tensor Concentration1 + { + get => _concentration1; + set => _concentration1 = value; + } /// /// The values of concentration 1 in XML string format. @@ -29,13 +53,22 @@ public class Beta : IScalarTypeProvider [EditorBrowsable(EditorBrowsableState.Never)] public string Concentration1Xml { - get => TensorConverter.ConvertToString(Concentration1, Type); - set => Concentration1 = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(_concentration1, Type); + set => _concentration1 = TensorConverter.ConvertFromString(value, Type); } + private Tensor _concentration0; + /// + /// Concentration parameter beta (> 0). Can be a scalar or tensor; the shape determines the batch/event shape. + /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - public Tensor Concentration0 { get; set; } + [Description("Concentration beta (> 0). Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] + public Tensor Concentration0 + { + get => _concentration0; + set => _concentration0 = value; + } /// /// The values of concentration 0 in XML string format. @@ -45,19 +78,31 @@ public string Concentration1Xml [EditorBrowsable(EditorBrowsableState.Never)] public string Concentration0Xml { - get => TensorConverter.ConvertToString(Concentration0, Type); - set => Concentration0 = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(_concentration0, Type); + set => _concentration0 = TensorConverter.ConvertFromString(value, Type); } + /// + /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// [XmlIgnore] - public torch.Generator Generator { get; set; } = null; + public Generator Generator { get; set; } = null; + /// + /// Creates a distribution using the configured parameters and optional . + /// + /// An observable that emits the constructed Beta distribution. public IObservable Process() { return Observable.Return(distributions.Beta(Concentration1, Concentration0, generator: Generator)); } - public IObservable Process(IObservable source) + /// + /// Creates a distribution for each incoming RNG . + /// + /// Observable sequence of random generators to use. + /// An observable sequence of Beta distributions. + public IObservable Process(IObservable source) { return source.Select((generator) => { @@ -66,6 +111,13 @@ public string Concentration0Xml }); } + /// + /// For each element of the source stream, emits a distribution + /// constructed from the configured parameters and current . + /// + /// The type of the triggering source sequence. + /// Trigger sequence; each element causes a new distribution to be emitted. + /// An observable sequence of Beta distributions. public IObservable Process(IObservable source) { return source.Select(_ => distributions.Beta(Concentration1, Concentration0, generator: Generator)); diff --git a/src/Bonsai.ML.Torch/Distributions/Binomial.cs b/src/Bonsai.ML.Torch/Distributions/Binomial.cs index 2e7e46d4..13c281bc 100644 --- a/src/Bonsai.ML.Torch/Distributions/Binomial.cs +++ b/src/Bonsai.ML.Torch/Distributions/Binomial.cs @@ -8,18 +8,41 @@ namespace Bonsai.ML.Torch.Distributions; +/// +/// Creates a Binomial probability distribution with a given number of trials and success probability. +/// Emits a TorchSharp distribution module that can be sampled or queried for log-probabilities. +/// [Combinator] [ResetCombinator] -[Description("")] +[Description("Creates a Binomial distribution with count (number of trials) and probability of success p.")] [WorkflowElementCategory(ElementCategory.Source)] -public class Binomial : IScalarTypeProvider +public class Binomial : TensorContainerBase { - [Browsable(false)] - public ScalarType Type => ScalarType.Float32; + /// + /// Initializes a new instance of the class. + /// + public Binomial() + { + RegisterTensor( + () => _count, + value => _count = value); + RegisterTensor( + () => _probabilities, + value => _probabilities = value); + } + private Tensor _count; + /// + /// Number of trials (non-negative). Can be a scalar or tensor. If tensor, values should be non-negative integers. + /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - public Tensor Count { get; set; } + [Description("Number of trials (non-negative). Can be a scalar or tensor.")] + public Tensor Count + { + get => _count; + set => _count = value; + } /// /// The values of count in XML string format. @@ -33,9 +56,18 @@ public string CountXml set => Count = TensorConverter.ConvertFromString(value, Type); } + private Tensor _probabilities; + /// + /// Probability of success p in [0, 1]. Can be a scalar or tensor; the shape should be broadcastable to . + /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - public Tensor Probabilities { get; set; } + [Description("Probability of success p in [0, 1]. Can be a scalar or tensor; shape should broadcast with Count.")] + public Tensor Probabilities + { + get => _probabilities; + set => _probabilities = value; + } /// /// The values of probabilities in XML string format. @@ -45,19 +77,31 @@ public string CountXml [EditorBrowsable(EditorBrowsableState.Never)] public string ProbabilitiesXml { - get => TensorConverter.ConvertToString(Probabilities, Type); - set => Probabilities = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(_probabilities, Type); + set => _probabilities = TensorConverter.ConvertFromString(value, Type); } + /// + /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// [XmlIgnore] - public torch.Generator Generator { get; set; } = null; + public Generator Generator { get; set; } = null; + /// + /// Creates a distribution using the configured parameters and optional . + /// + /// An observable that emits the constructed Binomial distribution. public IObservable Process() { return Observable.Return(distributions.Binomial(Count, Probabilities, generator: Generator)); } - public IObservable Process(IObservable source) + /// + /// Creates a distribution for each incoming RNG . + /// + /// Observable sequence of random generators to use. + /// An observable sequence of Binomial distributions. + public IObservable Process(IObservable source) { return source.Select((generator) => { @@ -66,6 +110,13 @@ public string ProbabilitiesXml }); } + /// + /// For each element of the source stream, emits a distribution + /// constructed from the configured parameters and current . + /// + /// The type of the triggering source sequence. + /// Trigger sequence; each element causes a new distribution to be emitted. + /// An observable sequence of Binomial distributions. public IObservable Process(IObservable source) { return source.Select(_ => distributions.Binomial(Count, Probabilities, generator: Generator)); diff --git a/src/Bonsai.ML.Torch/Distributions/Categorical.cs b/src/Bonsai.ML.Torch/Distributions/Categorical.cs index 19a4d438..382fa765 100644 --- a/src/Bonsai.ML.Torch/Distributions/Categorical.cs +++ b/src/Bonsai.ML.Torch/Distributions/Categorical.cs @@ -8,18 +8,39 @@ namespace Bonsai.ML.Torch.Distributions; +/// +/// Creates a Categorical (discrete) distribution over classes given event probabilities. +/// Emits a TorchSharp distribution module that can be sampled or queried for probabilities. +/// [Combinator] [ResetCombinator] -[Description("")] +[Description("Creates a Categorical distribution with class probabilities and emits a TorchSharp distribution module.")] [WorkflowElementCategory(ElementCategory.Source)] -public class Categorical : IScalarTypeProvider +public class Categorical : TensorContainerBase { - [Browsable(false)] - public ScalarType Type => ScalarType.Float32; + /// + /// Initializes a new instance of the class. + /// + public Categorical() + { + RegisterTensor( + () => _probabilities, + value => _probabilities = value); + } + private Tensor _probabilities; + /// + /// Class probabilities along the last dimension. Values must be non-negative and typically sum to 1 per row. + /// Can be a 1D vector or higher-rank tensor for batched distributions. + /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - public Tensor Probabilities { get; set; } + [Description("Class probabilities along the last dimension (non-negative, typically sum to 1 per row). Supports batching.")] + public Tensor Probabilities + { + get => _probabilities; + set => _probabilities = value; + } /// /// The values of probabilities in XML string format. @@ -29,18 +50,30 @@ public class Categorical : IScalarTypeProvider [EditorBrowsable(EditorBrowsableState.Never)] public string ProbabilitiesXml { - get => TensorConverter.ConvertToString(Probabilities, Type); - set => Probabilities = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(_probabilities, Type); + set => _probabilities = TensorConverter.ConvertFromString(value, Type); } + /// + /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// [XmlIgnore] - public torch.Generator Generator { get; set; } = null; + public Generator Generator { get; set; } = null; + /// + /// Creates a distribution using the configured . + /// + /// An observable that emits the constructed Categorical distribution. public IObservable Process() { return Observable.Return(distributions.Categorical(Probabilities, generator: Generator)); } + /// + /// Creates a distribution for each incoming RNG . + /// + /// Observable sequence of random generators to use. + /// An observable sequence of Categorical distributions. public IObservable Process(IObservable source) { return source.Select((generator) => @@ -50,6 +83,13 @@ public string ProbabilitiesXml }); } + /// + /// For each element of the source stream, emits a distribution + /// constructed from the configured and current . + /// + /// The type of the triggering source sequence. + /// Trigger sequence; each element causes a new distribution to be emitted. + /// An observable sequence of Categorical distributions. public IObservable Process(IObservable source) { return source.Select(_ => distributions.Categorical(Probabilities, generator: Generator)); diff --git a/src/Bonsai.ML.Torch/Distributions/Cauchy.cs b/src/Bonsai.ML.Torch/Distributions/Cauchy.cs index 336a0586..0eeb8c95 100644 --- a/src/Bonsai.ML.Torch/Distributions/Cauchy.cs +++ b/src/Bonsai.ML.Torch/Distributions/Cauchy.cs @@ -8,18 +8,42 @@ namespace Bonsai.ML.Torch.Distributions; +/// +/// Creates a Cauchy (Lorentz) distribution parameterized by location and scale. +/// Emits a TorchSharp distribution module that can be sampled or queried for probabilities. +/// [Combinator] [ResetCombinator] -[Description("")] +[Description("Creates a Cauchy distribution with the specified location and scale parameters.")] [WorkflowElementCategory(ElementCategory.Source)] -public class Cauchy : IScalarTypeProvider +public class Cauchy : TensorContainerBase { - [Browsable(false)] - public ScalarType Type => ScalarType.Float32; + /// + /// Initializes a new instance of the class. + /// + public Cauchy() + { + RegisterTensor( + () => _locations, + value => _locations = value); + + RegisterTensor( + () => _scales, + value => _scales = value); + } + private Tensor _locations; + /// + /// Location parameter. Can be a scalar or tensor; shape determines the batch/event shape. + /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - public Tensor Locations { get; set; } + [Description("Location parameter. Can be a scalar or tensor; supports batching.")] + public Tensor Locations + { + get => _locations; + set => _locations = value; + } /// /// The values of the locations in XML string format. @@ -29,12 +53,21 @@ public class Cauchy : IScalarTypeProvider [EditorBrowsable(EditorBrowsableState.Never)] public string LocationsXml { - get => TensorConverter.ConvertToString(Locations, Type); - set => Locations = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(_locations, Type); + set => _locations = TensorConverter.ConvertFromString(value, Type); } + private Tensor _scales; + /// + /// Scale parameter (> 0). Can be a scalar or tensor; must be broadcastable with . + /// [TypeConverter(typeof(TensorConverter))] - public Tensor Scales { get; set; } + [Description("Scale parameter (> 0). Can be a scalar or tensor; must broadcast with Locations.")] + public Tensor Scales + { + get => _scales; + set => _scales = value; + } /// /// The values of the scales in XML string format. @@ -44,19 +77,31 @@ public string LocationsXml [EditorBrowsable(EditorBrowsableState.Never)] public string ScalesXml { - get => TensorConverter.ConvertToString(Scales, Type); - set => Scales = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(_scales, Type); + set => _scales = TensorConverter.ConvertFromString(value, Type); } + /// + /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// [XmlIgnore] - public torch.Generator Generator { get; set; } = null; + public Generator Generator { get; set; } = null; + /// + /// Creates a distribution using the configured parameters. + /// + /// An observable that emits the constructed Cauchy distribution. public IObservable Process() { return Observable.Return(distributions.Cauchy(Locations, Scales, generator: Generator)); } - public IObservable Process(IObservable source) + /// + /// Creates a distribution for each incoming RNG . + /// + /// Observable sequence of random generators to use. + /// An observable sequence of Cauchy distributions. + public IObservable Process(IObservable source) { return source.Select((generator) => { @@ -65,6 +110,13 @@ public string ScalesXml }); } + /// + /// For each element of the source stream, emits a distribution + /// constructed from the configured parameters and current . + /// + /// The type of the triggering source sequence. + /// Trigger sequence; each element causes a new distribution to be emitted. + /// An observable sequence of Cauchy distributions. public IObservable Process(IObservable source) { return source.Select(_ => distributions.Cauchy(Locations, Scales, generator: Generator)); diff --git a/src/Bonsai.ML.Torch/Distributions/Dirichlet.cs b/src/Bonsai.ML.Torch/Distributions/Dirichlet.cs index 0a30fcda..e095869f 100644 --- a/src/Bonsai.ML.Torch/Distributions/Dirichlet.cs +++ b/src/Bonsai.ML.Torch/Distributions/Dirichlet.cs @@ -8,18 +8,37 @@ namespace Bonsai.ML.Torch.Distributions; +/// +/// Creates a Dirichlet probability distribution parameterized by concentration parameters. +/// Emits a TorchSharp distribution module that can be sampled or queried for probabilities. +/// [Combinator] [ResetCombinator] -[Description("")] +[Description("Creates a Dirichlet distribution with concentration parameters.")] [WorkflowElementCategory(ElementCategory.Source)] -public class Dirichlet : IScalarTypeProvider +public class Dirichlet : TensorContainerBase { - [XmlIgnore] - [Browsable(false)] - public ScalarType Type => ScalarType.Float32; + /// + /// Initializes a new instance of the class. + /// + public Dirichlet() + { + RegisterTensor( + () => _concentration, + value => _concentration = value); + } + private Tensor _concentration; + /// + /// Concentration parameters (> 0). Can be a scalar or tensor; the shape determines the batch/event shape. + /// [TypeConverter(typeof(TensorConverter))] - public Tensor Concentration { get; set; } + [Description("Concentration parameters (> 0). Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] + public Tensor Concentration + { + get => _concentration; + set => _concentration = value; + } /// /// The values of the concentration in XML string format. @@ -29,19 +48,31 @@ public class Dirichlet : IScalarTypeProvider [EditorBrowsable(EditorBrowsableState.Never)] public string ConcentrationXml { - get => TensorConverter.ConvertToString(Concentration, Type); - set => Concentration = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(_concentration, Type); + set => _concentration = TensorConverter.ConvertFromString(value, Type); } + /// + /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// [XmlIgnore] - public torch.Generator Generator { get; set; } = null; + public Generator Generator { get; set; } = null; + /// + /// Creates a distribution using the configured parameters and optional . + /// + /// An observable that emits the constructed Dirichlet distribution. public IObservable Process() { return Observable.Return(distributions.Dirichlet(Concentration, generator: Generator)); } - public IObservable Process(IObservable source) + /// + /// Creates a distribution for each incoming RNG . + /// + /// Observable sequence of random generators to use. + /// An observable sequence of Dirichlet distributions. + public IObservable Process(IObservable source) { return source.Select((generator) => { @@ -50,6 +81,13 @@ public string ConcentrationXml }); } + /// + /// For each element of the source stream, emits a distribution + /// constructed from the configured parameters and current . + /// + /// The type of the triggering source sequence. + /// Trigger sequence; each element causes a new distribution to be emitted. + /// An observable sequence of Dirichlet distributions. public IObservable Process(IObservable source) { return source.Select(_ => distributions.Dirichlet(Concentration, generator: Generator)); diff --git a/src/Bonsai.ML.Torch/Distributions/Exponential.cs b/src/Bonsai.ML.Torch/Distributions/Exponential.cs index 4ecea185..4dc4cc16 100644 --- a/src/Bonsai.ML.Torch/Distributions/Exponential.cs +++ b/src/Bonsai.ML.Torch/Distributions/Exponential.cs @@ -8,19 +8,38 @@ namespace Bonsai.ML.Torch.Distributions; +/// +/// Creates an Exponential probability distribution parameterized by rate. +/// Emits a TorchSharp distribution module that can be sampled or queried for probabilities. +/// [Combinator] [ResetCombinator] -[Description("")] +[Description("Creates an Exponential distribution with the specified rate parameter.")] [WorkflowElementCategory(ElementCategory.Source)] -public class Exponential : IScalarTypeProvider +public class Exponential : TensorContainerBase { - [XmlIgnore] - [Browsable(false)] - public ScalarType Type => ScalarType.Float32; + /// + /// Initializes a new instance of the class. + /// + public Exponential() + { + RegisterTensor( + () => _rate, + value => _rate = value); + } + private Tensor _rate; + /// + /// Rate parameter (> 0). Can be a scalar or tensor; the shape determines the batch/event shape. + /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - public Tensor Rate { get; set; } + [Description("Rate parameter (> 0). Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] + public Tensor Rate + { + get => _rate; + set => _rate = value; + } /// /// The values of the rates in XML string format. @@ -30,19 +49,31 @@ public class Exponential : IScalarTypeProvider [EditorBrowsable(EditorBrowsableState.Never)] public string RateXml { - get => TensorConverter.ConvertToString(Rate, Type); - set => Rate = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(_rate, Type); + set => _rate = TensorConverter.ConvertFromString(value, Type); } + /// + /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// [XmlIgnore] - public torch.Generator Generator { get; set; } = null; + public Generator Generator { get; set; } = null; + /// + /// Creates a distribution using the configured parameters and optional . + /// + /// An observable that emits the constructed Exponential distribution. public IObservable Process() { return Observable.Return(distributions.Exponential(Rate, generator: Generator)); } - public IObservable Process(IObservable source) + /// + /// Creates a distribution for each incoming RNG . + /// + /// Observable sequence of random generators to use. + /// An observable sequence of Exponential distributions. + public IObservable Process(IObservable source) { return source.Select((generator) => { @@ -51,6 +82,13 @@ public string RateXml }); } + /// + /// For each element of the source stream, emits a distribution + /// constructed from the configured parameters and current . + /// + /// The type of the triggering source sequence. + /// Trigger sequence; each element causes a new distribution to be emitted. + /// An observable sequence of Exponential distributions. public IObservable Process(IObservable source) { return source.Select(_ => distributions.Exponential(Rate, generator: Generator)); diff --git a/src/Bonsai.ML.Torch/Distributions/Gamma.cs b/src/Bonsai.ML.Torch/Distributions/Gamma.cs index adcb8baf..f89d3f27 100644 --- a/src/Bonsai.ML.Torch/Distributions/Gamma.cs +++ b/src/Bonsai.ML.Torch/Distributions/Gamma.cs @@ -8,19 +8,42 @@ namespace Bonsai.ML.Torch.Distributions; +/// +/// Creates a Gamma probability distribution parameterized by concentration and rate. +/// Emits a TorchSharp distribution module that can be sampled or queried for probabilities. +/// [Combinator] [ResetCombinator] -[Description("")] +[Description("Creates a Gamma distribution with concentration and rate parameters.")] [WorkflowElementCategory(ElementCategory.Source)] -public class Gamma : IScalarTypeProvider +public class Gamma : TensorContainerBase { - [XmlIgnore] - [Browsable(false)] - public ScalarType Type => ScalarType.Float32; + /// + /// Initializes a new instance of the class. + /// + public Gamma() + { + RegisterTensor( + () => _concentration, + value => _concentration = value); + + RegisterTensor( + () => _rate, + value => _rate = value); + } + private Tensor _concentration; + /// + /// Concentration parameter (> 0). Can be a scalar or tensor; the shape determines the batch/event shape. + /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - public Tensor Concentration { get; set; } + [Description("Concentration parameter (> 0). Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] + public Tensor Concentration + { + get => _concentration; + set => _concentration = value; + } /// /// The values of the concentration in XML string format. @@ -30,13 +53,22 @@ public class Gamma : IScalarTypeProvider [EditorBrowsable(EditorBrowsableState.Never)] public string ConcentrationXml { - get => TensorConverter.ConvertToString(Concentration, Type); - set => Concentration = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(_concentration, Type); + set => _concentration = TensorConverter.ConvertFromString(value, Type); } + private Tensor _rate; + /// + /// Rate parameter (> 0). Can be a scalar or tensor; must be broadcastable with . + /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - public Tensor Rate { get; set; } + [Description("Rate parameter (> 0). Can be a scalar or tensor; must broadcast with Concentration.")] + public Tensor Rate + { + get => _rate; + set => _rate = value; + } /// /// The values of the rate in XML string format. @@ -46,19 +78,31 @@ public string ConcentrationXml [EditorBrowsable(EditorBrowsableState.Never)] public string RateXml { - get => TensorConverter.ConvertToString(Rate, Type); - set => Rate = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(_rate, Type); + set => _rate = TensorConverter.ConvertFromString(value, Type); } + /// + /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// [XmlIgnore] - public torch.Generator Generator { get; set; } = null; + public Generator Generator { get; set; } = null; + /// + /// Creates a distribution using the configured parameters and optional . + /// + /// An observable that emits the constructed Gamma distribution. public IObservable Process() { return Observable.Return(distributions.Gamma(Concentration, Rate, generator: Generator)); } - public IObservable Process(IObservable source) + /// + /// Creates a distribution for each incoming RNG . + /// + /// Observable sequence of random generators to use. + /// An observable sequence of Gamma distributions. + public IObservable Process(IObservable source) { return source.Select((generator) => { @@ -67,6 +111,13 @@ public string RateXml }); } + /// + /// For each element of the source stream, emits a distribution + /// constructed from the configured parameters and current . + /// + /// The type of the triggering source sequence. + /// Trigger sequence; each element causes a new distribution to be emitted. + /// An observable sequence of Gamma distributions. public IObservable Process(IObservable source) { return source.Select(_ => distributions.Gamma(Concentration, Rate, generator: Generator)); diff --git a/src/Bonsai.ML.Torch/Distributions/Geometric.cs b/src/Bonsai.ML.Torch/Distributions/Geometric.cs index 876211c1..42cbbf87 100644 --- a/src/Bonsai.ML.Torch/Distributions/Geometric.cs +++ b/src/Bonsai.ML.Torch/Distributions/Geometric.cs @@ -8,19 +8,38 @@ namespace Bonsai.ML.Torch.Distributions; +/// +/// Creates a Geometric probability distribution parameterized by success probability. +/// Emits a TorchSharp distribution module that can be sampled or queried for probabilities. +/// [Combinator] [ResetCombinator] -[Description("")] +[Description("Creates a Geometric distribution with the specified success probability.")] [WorkflowElementCategory(ElementCategory.Source)] -public class Geometric : IScalarTypeProvider +public class Geometric : TensorContainerBase { - [XmlIgnore] - [Browsable(false)] - public ScalarType Type => ScalarType.Float32; + /// + /// Initializes a new instance of the class. + /// + public Geometric() + { + RegisterTensor( + () => _probabilities, + value => _probabilities = value); + } + private Tensor _probabilities; + /// + /// Success probability p in [0, 1]. Can be a scalar or tensor; the shape determines the batch/event shape. + /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - public Tensor Probabilities { get; set; } + [Description("Success probability p in [0, 1]. Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] + public Tensor Probabilities + { + get => _probabilities; + set => _probabilities = value; + } /// /// The values of the probabilities in XML string format. @@ -30,19 +49,31 @@ public class Geometric : IScalarTypeProvider [EditorBrowsable(EditorBrowsableState.Never)] public string ProbabilitiesXml { - get => TensorConverter.ConvertToString(Probabilities, Type); - set => Probabilities = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(_probabilities, Type); + set => _probabilities = TensorConverter.ConvertFromString(value, Type); } + /// + /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// [XmlIgnore] - public torch.Generator Generator { get; set; } = null; + public Generator Generator { get; set; } = null; + /// + /// Creates a distribution using the configured parameters and optional . + /// + /// An observable that emits the constructed Geometric distribution. public IObservable Process() { return Observable.Return(distributions.Geometric(Probabilities, generator: Generator)); } - public IObservable Process(IObservable source) + /// + /// Creates a distribution for each incoming RNG . + /// + /// Observable sequence of random generators to use. + /// An observable sequence of Geometric distributions. + public IObservable Process(IObservable source) { return source.Select((generator) => { @@ -51,6 +82,13 @@ public string ProbabilitiesXml }); } + /// + /// For each element of the source stream, emits a distribution + /// constructed from the configured parameters and current . + /// + /// The type of the triggering source sequence. + /// Trigger sequence; each element causes a new distribution to be emitted. + /// An observable sequence of Geometric distributions. public IObservable Process(IObservable source) { return source.Select(_ => distributions.Geometric(Probabilities, generator: Generator)); diff --git a/src/Bonsai.ML.Torch/Distributions/MultivariateNormal.cs b/src/Bonsai.ML.Torch/Distributions/MultivariateNormal.cs index 83dac869..105033c9 100644 --- a/src/Bonsai.ML.Torch/Distributions/MultivariateNormal.cs +++ b/src/Bonsai.ML.Torch/Distributions/MultivariateNormal.cs @@ -8,19 +8,42 @@ namespace Bonsai.ML.Torch.Distributions; +/// +/// Creates a Multivariate Normal (Gaussian) distribution parameterized by mean and covariance matrix. +/// Emits a TorchSharp distribution module that can be sampled or queried for probabilities. +/// [Combinator] [ResetCombinator] -[Description("")] +[Description("Creates a Multivariate Normal distribution with mean vector and covariance matrix.")] [WorkflowElementCategory(ElementCategory.Source)] -public class MultivariateNormal : IScalarTypeProvider +public class MultivariateNormal : TensorContainerBase { - [XmlIgnore] - [Browsable(false)] - public ScalarType Type => ScalarType.Float32; + /// + /// Initializes a new instance of the class. + /// + public MultivariateNormal() + { + RegisterTensor( + () => _mean, + value => _mean = value); + + RegisterTensor( + () => _covariance, + value => _covariance = value); + } + private Tensor _mean; + /// + /// Mean vector of the distribution. Can be a 1D vector or higher-rank tensor for batched distributions. + /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - public Tensor Mean { get; set; } + [Description("Mean vector of the distribution. Can be a 1D vector or higher-rank tensor for batched distributions.")] + public Tensor Mean + { + get => _mean; + set => _mean = value; + } /// /// The values of the means in XML string format. @@ -30,13 +53,22 @@ public class MultivariateNormal : IScalarTypeProvider [EditorBrowsable(EditorBrowsableState.Never)] public string MeanXml { - get => TensorConverter.ConvertToString(Mean, Type); - set => Mean = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(_mean, Type); + set => _mean = TensorConverter.ConvertFromString(value, Type); } + private Tensor _covariance; + /// + /// Covariance matrix of the distribution. Must be positive-definite and square with dimension matching . + /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - public Tensor Covariance { get; set; } + [Description("Covariance matrix of the distribution. Must be positive-definite and square with dimension matching Mean.")] + public Tensor Covariance + { + get => _covariance; + set => _covariance = value; + } /// /// The values of the covariance matrix in XML string format. @@ -46,19 +78,31 @@ public string MeanXml [EditorBrowsable(EditorBrowsableState.Never)] public string CovarianceXml { - get => TensorConverter.ConvertToString(Covariance, Type); - set => Covariance = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(_covariance, Type); + set => _covariance = TensorConverter.ConvertFromString(value, Type); } + /// + /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// [XmlIgnore] - public torch.Generator Generator { get; set; } = null; + public Generator Generator { get; set; } = null; + /// + /// Creates a distribution using the configured parameters and optional . + /// + /// An observable that emits the constructed Multivariate Normal distribution. public IObservable Process() { return Observable.Return(distributions.MultivariateNormal(Mean, Covariance, generator: Generator)); } - public IObservable Process(IObservable source) + /// + /// Creates a distribution for each incoming RNG . + /// + /// Observable sequence of random generators to use. + /// An observable sequence of Multivariate Normal distributions. + public IObservable Process(IObservable source) { return source.Select((generator) => { @@ -67,6 +111,13 @@ public string CovarianceXml }); } + /// + /// For each element of the source stream, emits a distribution + /// constructed from the configured parameters and current . + /// + /// The type of the triggering source sequence. + /// Trigger sequence; each element causes a new distribution to be emitted. + /// An observable sequence of Multivariate Normal distributions. public IObservable Process(IObservable source) { return source.Select(_ => distributions.MultivariateNormal(Mean, Covariance, generator: Generator)); diff --git a/src/Bonsai.ML.Torch/Distributions/Poisson.cs b/src/Bonsai.ML.Torch/Distributions/Poisson.cs index 57302905..6c2e52b1 100644 --- a/src/Bonsai.ML.Torch/Distributions/Poisson.cs +++ b/src/Bonsai.ML.Torch/Distributions/Poisson.cs @@ -8,19 +8,38 @@ namespace Bonsai.ML.Torch.Distributions; +/// +/// Creates a Poisson probability distribution parameterized by rate (expected number of events). +/// Emits a TorchSharp distribution module that can be sampled or queried for probabilities. +/// [Combinator] [ResetCombinator] -[Description("")] +[Description("Creates a Poisson distribution with the specified rate parameter.")] [WorkflowElementCategory(ElementCategory.Source)] -public class Poisson : IScalarTypeProvider +public class Poisson : TensorContainerBase { - [XmlIgnore] - [Browsable(false)] - public ScalarType Type => ScalarType.Float32; + /// + /// Initializes a new instance of the class. + /// + public Poisson() + { + RegisterTensor( + () => _rate, + value => _rate = value); + } + private Tensor _rate; + /// + /// Rate parameter (> 0), representing the expected number of events. Can be a scalar or tensor; the shape determines the batch/event shape. + /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - public Tensor Rate { get; set; } + [Description("Rate parameter (> 0), expected number of events. Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] + public Tensor Rate + { + get => _rate; + set => _rate = value; + } /// /// The values of the rates in XML string format. @@ -30,19 +49,31 @@ public class Poisson : IScalarTypeProvider [EditorBrowsable(EditorBrowsableState.Never)] public string RateXml { - get => TensorConverter.ConvertToString(Rate, Type); - set => Rate = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(_rate, Type); + set => _rate = TensorConverter.ConvertFromString(value, Type); } + /// + /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// [XmlIgnore] - public torch.Generator Generator { get; set; } = null; + public Generator Generator { get; set; } = null; + /// + /// Creates a distribution using the configured parameters and optional . + /// + /// An observable that emits the constructed Poisson distribution. public IObservable Process() { return Observable.Return(distributions.Poisson(Rate, generator: Generator)); } - public IObservable Process(IObservable source) + /// + /// Creates a distribution for each incoming RNG . + /// + /// Observable sequence of random generators to use. + /// An observable sequence of Poisson distributions. + public IObservable Process(IObservable source) { return source.Select((generator) => { @@ -51,6 +82,13 @@ public string RateXml }); } + /// + /// For each element of the source stream, emits a distribution + /// constructed from the configured parameters and current . + /// + /// The type of the triggering source sequence. + /// Trigger sequence; each element causes a new distribution to be emitted. + /// An observable sequence of Poisson distributions. public IObservable Process(IObservable source) { return source.Select(_ => distributions.Poisson(Rate, generator: Generator)); From 50137cc52d82507ec45a85330f28c47abe9c5825 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Sat, 1 Nov 2025 21:42:55 +0000 Subject: [PATCH 10/16] Added doc strings and made to inherit from tensor container base --- .../Distributions/LogProbability.cs | 79 +++++++++++++++++-- .../Distributions/ReparametrizedSample.cs | 17 +++- src/Bonsai.ML.Torch/Distributions/Sample.cs | 17 +++- 3 files changed, 103 insertions(+), 10 deletions(-) diff --git a/src/Bonsai.ML.Torch/Distributions/LogProbability.cs b/src/Bonsai.ML.Torch/Distributions/LogProbability.cs index 0b08fc72..993c83c0 100644 --- a/src/Bonsai.ML.Torch/Distributions/LogProbability.cs +++ b/src/Bonsai.ML.Torch/Distributions/LogProbability.cs @@ -3,30 +3,97 @@ using System.Reactive.Linq; using System.ComponentModel; using static TorchSharp.torch; +using static TorchSharp.torch.distributions; using System.Xml.Serialization; namespace Bonsai.ML.Torch.Distributions; +/// +/// Computes the log probability of the given values under the specified distribution. +/// [Combinator] [ResetCombinator] -[Description("")] +[Description("Computes the log probability of the given values under the specified distribution.")] [WorkflowElementCategory(ElementCategory.Transform)] -public class LogProbability +public class LogProbability : TensorContainerBase { + /// + /// Initializes a new instance of the class. + /// + public LogProbability() + { + RegisterTensor( + () => _values, + v => _values = v); + } + + private Tensor _values; + /// + /// The values at which to evaluate the inverse CDF. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + [Description("The values at which to evaluate the inverse CDF.")] + public Tensor Values + { + get => _values; + set => _values = value; + } + + /// + /// The input distribution. + /// [XmlIgnore] - public Tensor Values { get; set; } + public Distribution Distribution { get; set; } - public IObservable Process(IObservable source) + /// + /// The values in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Values))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ValuesXml + { + get => TensorConverter.ConvertToString(_values, Type); + set => _values = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// Processes the input distribution to compute the log probability at the specified values. + /// + /// + /// + public IObservable Process(IObservable source) { return source.Select(distribution => distribution.log_prob(Values)); } - public IObservable Process(IObservable> source) + /// + /// Processes the input values to compute the log probability using the specified distribution. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Distribution.log_prob); + } + + /// + /// Processes the input tuples of distribution and values to compute the log probability. + /// + /// + /// + public IObservable Process(IObservable> source) { return source.Select((input) => input.Item1.log_prob(input.Item2)); } - public IObservable Process(IObservable> source) + /// + /// Processes the input tuples of values and distribution to compute the log probability. + /// + /// + /// + public IObservable Process(IObservable> source) { return source.Select((input) => input.Item2.log_prob(input.Item1)); } diff --git a/src/Bonsai.ML.Torch/Distributions/ReparametrizedSample.cs b/src/Bonsai.ML.Torch/Distributions/ReparametrizedSample.cs index b3e14972..c98a4142 100644 --- a/src/Bonsai.ML.Torch/Distributions/ReparametrizedSample.cs +++ b/src/Bonsai.ML.Torch/Distributions/ReparametrizedSample.cs @@ -3,18 +3,31 @@ using System.Reactive.Linq; using System.ComponentModel; using static TorchSharp.torch; +using static TorchSharp.torch.distributions; namespace Bonsai.ML.Torch.Distributions; +/// +/// Generates reparameterized samples from the input distribution. +/// Reparameterized samples allow gradients to flow through the sampling process. +/// [Combinator] -[Description("")] +[Description("Generates reparameterized samples from the input distribution.")] [WorkflowElementCategory(ElementCategory.Transform)] public class ReparametrizedSample { + /// + /// The shape of the samples to generate. + /// [TypeConverter(typeof(UnidimensionalArrayConverter))] public long[] SampleShape { get; set; } - public IObservable Process(IObservable source) + /// + /// Processes the input distribution to generate reparameterized samples. + /// + /// + /// + public IObservable Process(IObservable source) { return source.Select(distribution => distribution.rsample(SampleShape)); } diff --git a/src/Bonsai.ML.Torch/Distributions/Sample.cs b/src/Bonsai.ML.Torch/Distributions/Sample.cs index dee288f1..fa5f44d7 100644 --- a/src/Bonsai.ML.Torch/Distributions/Sample.cs +++ b/src/Bonsai.ML.Torch/Distributions/Sample.cs @@ -3,18 +3,31 @@ using System.Reactive.Linq; using System.ComponentModel; using static TorchSharp.torch; +using static TorchSharp.torch.distributions; namespace Bonsai.ML.Torch.Distributions; +/// +/// Generates samples from the input distribution. +/// Gradients do not flow through the sampling process. +/// [Combinator] -[Description("")] +[Description("Generates samples from the input distribution.")] [WorkflowElementCategory(ElementCategory.Transform)] public class Sample { + /// + /// The shape of the samples to generate. + /// [TypeConverter(typeof(UnidimensionalArrayConverter))] public long[] SampleShape { get; set; } - public IObservable Process(IObservable source) + /// + /// Processes the input distribution to generate samples. + /// + /// + /// + public IObservable Process(IObservable source) { return source.Select(distribution => distribution.sample(SampleShape)); } From 5d885b6ca55318a8b3af571c43d541d0b2aeddc9 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Sat, 1 Nov 2025 21:44:32 +0000 Subject: [PATCH 11/16] Refactored `Generator` documentation to suggest creating an individual RNG and added `ManualSeed` to set global state and return --- .../{Generator.cs => CreateGenerator.cs} | 14 ++++---- src/Bonsai.ML.Torch/Random/ManualSeed.cs | 32 +++++++++++++++++++ 2 files changed, 39 insertions(+), 7 deletions(-) rename src/Bonsai.ML.Torch/Random/{Generator.cs => CreateGenerator.cs} (64%) create mode 100644 src/Bonsai.ML.Torch/Random/ManualSeed.cs diff --git a/src/Bonsai.ML.Torch/Random/Generator.cs b/src/Bonsai.ML.Torch/Random/CreateGenerator.cs similarity index 64% rename from src/Bonsai.ML.Torch/Random/Generator.cs rename to src/Bonsai.ML.Torch/Random/CreateGenerator.cs index 1a86f047..4af2318a 100644 --- a/src/Bonsai.ML.Torch/Random/Generator.cs +++ b/src/Bonsai.ML.Torch/Random/CreateGenerator.cs @@ -9,12 +9,12 @@ namespace Bonsai.ML.Torch.Random; /// -/// Creates a random number generator with the specified seed and device. +/// Creates a specific instance of a random number generator with the specified seed and device. /// [Combinator] -[Description("Creates a random number generator with the specified seed and device.")] +[Description("Creates a specific instance of a random number generator with the specified seed and device.")] [WorkflowElementCategory(ElementCategory.Source)] -public class Generator +public class CreateGenerator { /// /// The device on which to create the generator. @@ -31,16 +31,16 @@ public class Generator /// Creates a random number generator with the specified seed and device. /// /// - public IObservable Process() + public IObservable Process() { - return Observable.Return(new torch.Generator(Seed, Device)); + return Observable.Return(new Generator(Seed, Device)); } /// /// Generates an observable sequence of random number generators for each element of the input sequence. /// - public IObservable Process(IObservable source) + public IObservable Process(IObservable source) { - return source.Select(value => new torch.Generator(Seed, Device)); + return source.Select(value => new Generator(Seed, Device)); } } diff --git a/src/Bonsai.ML.Torch/Random/ManualSeed.cs b/src/Bonsai.ML.Torch/Random/ManualSeed.cs new file mode 100644 index 00000000..3745f019 --- /dev/null +++ b/src/Bonsai.ML.Torch/Random/ManualSeed.cs @@ -0,0 +1,32 @@ +using Bonsai; +using static TorchSharp.torch; +using TorchSharp; +using System; +using System.Reactive.Linq; +using System.Xml.Serialization; +using System.ComponentModel; + +namespace Bonsai.ML.Torch.Random; + +/// +/// Sets the global random seed for TorchSharp and creates a random number generator with the specified seed. +/// +[Combinator] +[Description("Sets the global random seed for TorchSharp and creates a random number generator with the specified seed.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class ManualSeed +{ + /// + /// The seed for the random number generator. + /// + public long Seed { get; set; } = 0; + + /// + /// Creates a random number generator with the specified seed and device. + /// + /// + public IObservable Process() + { + return Observable.Return(manual_seed(Seed)); + } +} From 62dd824591d81be7d340c27957fe5935e082edb1 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 3 Nov 2025 13:26:52 +0000 Subject: [PATCH 12/16] Renamed distributions for better consistency --- .../Random/{RandomNormal.cs => Normal.cs} | 31 ++++++++++++----- .../{RandomPermutation.cs => Permutation.cs} | 27 ++++++++++++--- .../Random/{RandomUniform.cs => Uniform.cs} | 34 +++++++++++++------ .../{RandomIntegers.cs => UniformIntegers.cs} | 8 ++--- 4 files changed, 72 insertions(+), 28 deletions(-) rename src/Bonsai.ML.Torch/Random/{RandomNormal.cs => Normal.cs} (73%) rename src/Bonsai.ML.Torch/Random/{RandomPermutation.cs => Permutation.cs} (71%) rename src/Bonsai.ML.Torch/Random/{RandomUniform.cs => Uniform.cs} (67%) rename src/Bonsai.ML.Torch/Random/{RandomIntegers.cs => UniformIntegers.cs} (93%) diff --git a/src/Bonsai.ML.Torch/Random/RandomNormal.cs b/src/Bonsai.ML.Torch/Random/Normal.cs similarity index 73% rename from src/Bonsai.ML.Torch/Random/RandomNormal.cs rename to src/Bonsai.ML.Torch/Random/Normal.cs index d26f5ee3..f4329e29 100644 --- a/src/Bonsai.ML.Torch/Random/RandomNormal.cs +++ b/src/Bonsai.ML.Torch/Random/Normal.cs @@ -12,16 +12,16 @@ namespace Bonsai.ML.Torch.Random; /// [Combinator] [ResetCombinator] -[Description("Creates a tensor filled with random floats.")] +[Description("Creates a tensor filled with random floats sampled from a normal distribution.")] [WorkflowElementCategory(ElementCategory.Source)] -public class RandomFloats +public class Normal { /// /// The size of the tensor. /// [Description("The size of the tensor.")] [TypeConverter(typeof(UnidimensionalArrayConverter))] - public long[] Size { get; set; } = new long[0]; + public long[] Size { get; set; } /// /// The data type of the tensor elements. @@ -40,14 +40,27 @@ public class RandomFloats /// The random number generator to use. /// [XmlIgnore] - public torch.Generator Generator { get; set; } = null; + [Description("The random number generator to use.")] + public Generator Generator { get; set; } = null; + + /// + /// The mean of the normal distribution. + /// + [Description("The mean of the normal distribution.")] + public double Mean { get; set; } = 0; + + /// + /// The variance of the normal distribution. + /// + [Description("The variance of the normal distribution.")] + public double Variance { get; set; } = 1; /// /// Creates a tensor filled with random values sampled from a normal distribution with mean 0 and variance 1. /// public IObservable Process() { - return Observable.Return(randn(Size, dtype: Type, device: Device, generator: Generator)); + return Observable.Return(randn(Size, dtype: Type, device: Device, generator: Generator) * Variance + Mean); } /// @@ -55,12 +68,12 @@ public IObservable Process() /// /// /// - public IObservable Process(IObservable source) + public IObservable Process(IObservable source) { return source.Select(value => { Generator = value; - return randn(Size, dtype: Type, device: Device, generator: Generator); + return randn(Size, dtype: Type, device: Device, generator: Generator) * Variance + Mean; }); } @@ -71,7 +84,7 @@ public IObservable Process(IObservable source) /// public IObservable Process(IObservable source) { - return source.Select(value => randn_like(value, dtype: Type, device: Device)); + return source.Select(value => randn_like(value, dtype: Type, device: Device) * Variance + Mean); } /// @@ -81,6 +94,6 @@ public IObservable Process(IObservable source) /// public IObservable Process(IObservable source) { - return source.Select(value => randn(Size, dtype: Type, device: Device, generator: Generator)); + return source.Select(value => randn(Size, dtype: Type, device: Device, generator: Generator) * Variance + Mean); } } diff --git a/src/Bonsai.ML.Torch/Random/RandomPermutation.cs b/src/Bonsai.ML.Torch/Random/Permutation.cs similarity index 71% rename from src/Bonsai.ML.Torch/Random/RandomPermutation.cs rename to src/Bonsai.ML.Torch/Random/Permutation.cs index 3b097722..315b69fe 100644 --- a/src/Bonsai.ML.Torch/Random/RandomPermutation.cs +++ b/src/Bonsai.ML.Torch/Random/Permutation.cs @@ -14,7 +14,7 @@ namespace Bonsai.ML.Torch.Random; [ResetCombinator] [Description("Creates a 1D tensor of a given size with a random permutation of integers from 0 to size - 1.")] [WorkflowElementCategory(ElementCategory.Source)] -public class RandomPermutation +public class Permutation { /// /// The size of the tensor. @@ -39,10 +39,10 @@ public class RandomPermutation /// The random number generator to use. /// [XmlIgnore] - public torch.Generator Generator { get; set; } = null; + public Generator Generator { get; set; } = null; /// - /// Creates a tensor of a given size with a random permutation of integers from 0 to size - 1. + /// Creates a tensor of a given size with a random permutation of integers from [0, size). /// public IObservable Process() { @@ -50,11 +50,11 @@ public IObservable Process() } /// - /// Generates an observable sequence of tensors filled with random values and uses the input generator. + /// Generates an observable sequence of tensors with a random permutation of integers from [0, size) and uses the input generator. /// /// /// - public IObservable Process(IObservable source) + public IObservable Process(IObservable source) { return source.Select(value => { @@ -63,6 +63,23 @@ public IObservable Process(IObservable source) }); } + /// + /// Randomly permutates tensors from the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => + { + var size = value.numel(); + var shape = value.shape; + var idxs = randperm(size, dtype: Type, device: Device, generator: Generator); + return value.flatten().index_select(0, idxs).reshape(shape); + }); + } + + /// /// Generates an observable sequence of tensors filled with random values for each element of the input sequence. /// diff --git a/src/Bonsai.ML.Torch/Random/RandomUniform.cs b/src/Bonsai.ML.Torch/Random/Uniform.cs similarity index 67% rename from src/Bonsai.ML.Torch/Random/RandomUniform.cs rename to src/Bonsai.ML.Torch/Random/Uniform.cs index 9eb1e83e..8d3f0047 100644 --- a/src/Bonsai.ML.Torch/Random/RandomUniform.cs +++ b/src/Bonsai.ML.Torch/Random/Uniform.cs @@ -8,20 +8,20 @@ namespace Bonsai.ML.Torch.Random; /// -/// Creates a tensor filled with random values sampled from a uniform distribution over the interval [0, 1). +/// Creates a tensor filled with random numbers sampled from a uniform distribution over the interval [MinSize, MaxSize). /// [Combinator] [ResetCombinator] -[Description("Creates a tensor filled with random floats.")] +[Description("Creates a tensor filled with random numbers from a uniform distribution over the interval [MinSize, MaxSize).")] [WorkflowElementCategory(ElementCategory.Source)] -public class RandomUniform +public class Uniform { /// /// The size of the tensor. /// [Description("The size of the tensor.")] [TypeConverter(typeof(UnidimensionalArrayConverter))] - public long[] Size { get; set; } = new long[0]; + public long[] Size { get; set; } = []; /// /// The data type of the tensor elements. @@ -40,14 +40,28 @@ public class RandomUniform /// The random number generator to use. /// [XmlIgnore] - public torch.Generator Generator { get; set; } = null; + [Description("The random number generator to use.")] + public Generator Generator { get; set; } = null; /// - /// Creates a tensor filled with random values sampled from a uniform distribution over the interval [0, 1). + /// The minimum value of the random numbers inclusive. + /// + [Description("The minimum value of the random numbers.")] + public double MinValue { get; set; } = 0; + + /// + /// The maximum value of the random numbers exclusive. + /// + [Description("The maximum value of the random numbers.")] + public double MaxValue { get; set; } = 1; + + + /// + /// Creates a tensor filled with random values sampled from a uniform distribution over the interval [MinValue, MaxValue). /// public IObservable Process() { - return Observable.Return(rand(Size, dtype: Type, device: Device, generator: Generator)); + return Observable.Return(rand(Size, dtype: Type, device: Device, generator: Generator) * (MaxValue - MinValue) + MinValue); } /// @@ -60,7 +74,7 @@ public IObservable Process(IObservable source) return source.Select(value => { Generator = value; - return rand(Size, dtype: Type, device: Device, generator: Generator); + return rand(Size, dtype: Type, device: Device, generator: Generator) * (MaxValue - MinValue) + MinValue; }); } @@ -71,7 +85,7 @@ public IObservable Process(IObservable source) /// public IObservable Process(IObservable source) { - return source.Select(value => rand_like(value, dtype: Type, device: Device)); + return source.Select(value => rand_like(value, dtype: Type, device: Device) * (MaxValue - MinValue) + MinValue); } /// @@ -81,7 +95,7 @@ public IObservable Process(IObservable source) /// public IObservable Process(IObservable source) { - return source.Select(value => rand(Size, dtype: Type, device: Device, generator: Generator)); + return source.Select(value => rand(Size, dtype: Type, device: Device, generator: Generator) * (MaxValue - MinValue) + MinValue); } } diff --git a/src/Bonsai.ML.Torch/Random/RandomIntegers.cs b/src/Bonsai.ML.Torch/Random/UniformIntegers.cs similarity index 93% rename from src/Bonsai.ML.Torch/Random/RandomIntegers.cs rename to src/Bonsai.ML.Torch/Random/UniformIntegers.cs index 6bf36da9..89d6e4b4 100644 --- a/src/Bonsai.ML.Torch/Random/RandomIntegers.cs +++ b/src/Bonsai.ML.Torch/Random/UniformIntegers.cs @@ -15,14 +15,14 @@ namespace Bonsai.ML.Torch.Random; [ResetCombinator] [Description("Creates a tensor filled with random integers.")] [WorkflowElementCategory(ElementCategory.Source)] -public class RandomIntegers +public class UniformIntegers { /// /// The size of the tensor. /// [Description("The size of the tensor.")] [TypeConverter(typeof(UnidimensionalArrayConverter))] - public long[] Size { get; set; } = new long[0]; + public long[] Size { get; set; } = []; /// /// The minimum value of the random integers inclusive. @@ -53,7 +53,7 @@ public class RandomIntegers /// The random number generator to use. /// [XmlIgnore] - public torch.Generator Generator { get; set; } = null; + public Generator Generator { get; set; } = null; /// /// Creates a tensor filled with random integers sampled from a uniform distribution over the @@ -69,7 +69,7 @@ public IObservable Process() /// /// /// - public IObservable Process(IObservable source) + public IObservable Process(IObservable source) { return source.Select(value => { From 0f223afde86a545081f4c8ac0b07e7199a0827b0 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 10 Dec 2025 16:28:37 +0000 Subject: [PATCH 13/16] Refactored distribution classes with TensorOperatorConverter --- .../Distributions/Bernoulli.cs | 69 ++++++---------- src/Bonsai.ML.Torch/Distributions/Beta.cs | 82 ++++++------------- src/Bonsai.ML.Torch/Distributions/Binomial.cs | 80 ++++++------------ .../Distributions/Categorical.cs | 72 ++++++---------- src/Bonsai.ML.Torch/Distributions/Cauchy.cs | 70 +++++----------- ...f.cs => CumulativeDistributionFunction.cs} | 57 ++----------- .../Distributions/Dirichlet.cs | 57 ++++--------- .../Distributions/Exponential.cs | 59 +++++-------- src/Bonsai.ML.Torch/Distributions/Gamma.cs | 73 +++++------------ .../Distributions/Geometric.cs | 62 +++++--------- ... InverseCumulativeDistributionFunction.cs} | 53 +----------- .../Distributions/LogProbability.cs | 57 ++----------- .../Distributions/MultivariateNormal.cs | 73 +++++------------ src/Bonsai.ML.Torch/Distributions/Poisson.cs | 58 ++++--------- .../Distributions/ReparametrizedSample.cs | 4 +- src/Bonsai.ML.Torch/Distributions/Sample.cs | 4 +- 16 files changed, 254 insertions(+), 676 deletions(-) rename src/Bonsai.ML.Torch/Distributions/{Cdf.cs => CumulativeDistributionFunction.cs} (52%) rename src/Bonsai.ML.Torch/Distributions/{InverseCdf.cs => InverseCumulativeDistributionFunction.cs} (55%) diff --git a/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs b/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs index dff6f7ee..b81dec04 100644 --- a/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs +++ b/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs @@ -1,44 +1,27 @@ -using static TorchSharp.torch; -using TorchSharp; using System; using System.ComponentModel; using System.Reactive.Linq; using System.Xml.Serialization; +using static TorchSharp.torch; namespace Bonsai.ML.Torch.Distributions; /// -/// Creates a Bernoulli probability distribution parameterized by event probabilities. -/// Emits a TorchSharp distribution module that can be sampled or queried for log-probabilities. +/// Represents an operator that creates a Bernoulli probability distribution parameterized by event probabilities. /// [Combinator] -[ResetCombinator] [Description("Creates a Bernoulli distribution with event probabilities and emits a TorchSharp distribution module.")] [WorkflowElementCategory(ElementCategory.Source)] -public class Bernoulli : TensorContainerBase +[TypeConverter(typeof(TensorOperatorConverter))] +public class Bernoulli : IScalarTypeProvider { - /// - /// Initializes a new instance of the class. - /// - public Bernoulli() - { - RegisterTensor( - () => _probabilities, - value => _probabilities = value); - } - - private Tensor _probabilities; /// /// Event probabilities p in [0, 1]. Can be a scalar or a tensor; the shape determines the batch/event shape. /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] [Description("Event probabilities p in [0, 1]. Can be a scalar or a tensor; shape sets the batch/event shape of the distribution.")] - public Tensor Probabilities - { - get => _probabilities; - set => _probabilities = value; - } + public Tensor Probabilities { get; set; } = null; /// /// The values of the probabilities in XML string format. @@ -48,50 +31,44 @@ public Tensor Probabilities [EditorBrowsable(EditorBrowsableState.Never)] public string ProbabilitiesXml { - get => TensorConverter.ConvertToString(_probabilities, Type); - set => _probabilities = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(Probabilities, Type); + set => Probabilities = TensorConverter.ConvertFromString(value, Type); } /// - /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// Gets or sets the data type of the tensor elements. /// - [XmlIgnore] - public Generator Generator { get; set; } = null; + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; /// - /// Creates a distribution using the configured and optional . + /// Creates a Bernoulli distribution. /// - /// An observable that emits the constructed Bernoulli distribution. + /// public IObservable Process() { - return Observable.Return(distributions.Bernoulli(Probabilities, generator: Generator)); + return Observable.Return(distributions.Bernoulli(Probabilities)); } /// - /// Creates a distribution for each incoming RNG , - /// updating and passing it to TorchSharp. + /// Creates a Bernoulli distribution using the incoming RNG Generator. /// - /// Observable sequence of random generators to use. - /// An observable sequence of Bernoulli distributions. + /// + /// public IObservable Process(IObservable source) { - return source.Select((generator) => - { - Generator = generator; - return distributions.Bernoulli(Probabilities, generator: Generator); - }); + return source.Select(generator => distributions.Bernoulli(Probabilities, generator: generator)); } /// - /// For each element of the source stream, emits a distribution - /// constructed from the configured and current . - /// The source values are ignored and used only for timing. + /// For each element of the source stream, emits a Bernoulli distribution. /// - /// The type of the triggering source sequence. - /// Trigger sequence; each element causes a new distribution to be emitted. - /// An observable sequence of Bernoulli distributions. + /// + /// + /// public IObservable Process(IObservable source) { - return source.Select(_ => distributions.Bernoulli(Probabilities, generator: Generator)); + return source.Select(_ => distributions.Bernoulli(Probabilities)); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Beta.cs b/src/Bonsai.ML.Torch/Distributions/Beta.cs index a9aeda50..9abfcdb8 100644 --- a/src/Bonsai.ML.Torch/Distributions/Beta.cs +++ b/src/Bonsai.ML.Torch/Distributions/Beta.cs @@ -1,49 +1,28 @@ -using static TorchSharp.torch; -using TorchSharp; -using Bonsai; using System; using System.ComponentModel; using System.Reactive.Linq; using System.Xml.Serialization; +using TorchSharp; +using static TorchSharp.torch; namespace Bonsai.ML.Torch.Distributions; /// /// Creates a Beta probability distribution parameterized by two concentration parameters (alpha, beta). -/// Emits a TorchSharp distribution module that can be sampled or queried for probabilities. /// [Combinator] -[ResetCombinator] [Description("Creates a Beta distribution with concentration parameters (alpha, beta).")] [WorkflowElementCategory(ElementCategory.Source)] -public class Beta : TensorContainerBase +[TypeConverter(typeof(TensorOperatorConverter))] +public class Beta : IScalarTypeProvider { - /// - /// Initializes a new instance of the class. - /// - public Beta() - { - RegisterTensor( - () => _concentration1, - value => _concentration1 = value); - - RegisterTensor( - () => _concentration0, - value => _concentration0 = value); - } - - private Tensor _concentration1; /// /// Concentration parameter alpha (> 0). Can be a scalar or tensor; the shape determines the batch/event shape. /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] [Description("Concentration alpha (> 0). Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] - public Tensor Concentration1 - { - get => _concentration1; - set => _concentration1 = value; - } + public Tensor Concentration1 { get; set; } = null; /// /// The values of concentration 1 in XML string format. @@ -53,22 +32,17 @@ public Tensor Concentration1 [EditorBrowsable(EditorBrowsableState.Never)] public string Concentration1Xml { - get => TensorConverter.ConvertToString(_concentration1, Type); - set => _concentration1 = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(Concentration1, Type); + set => Concentration1 = TensorConverter.ConvertFromString(value, Type); } - private Tensor _concentration0; /// /// Concentration parameter beta (> 0). Can be a scalar or tensor; the shape determines the batch/event shape. /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] [Description("Concentration beta (> 0). Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] - public Tensor Concentration0 - { - get => _concentration0; - set => _concentration0 = value; - } + public Tensor Concentration0 { get; set; } = null; /// /// The values of concentration 0 in XML string format. @@ -78,48 +52,44 @@ public Tensor Concentration0 [EditorBrowsable(EditorBrowsableState.Never)] public string Concentration0Xml { - get => TensorConverter.ConvertToString(_concentration0, Type); - set => _concentration0 = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(Concentration0, Type); + set => Concentration0 = TensorConverter.ConvertFromString(value, Type); } /// - /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// Gets or sets the data type of the tensor elements. /// - [XmlIgnore] - public Generator Generator { get; set; } = null; + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; /// - /// Creates a distribution using the configured parameters and optional . + /// Creates a Beta distribution. /// - /// An observable that emits the constructed Beta distribution. + /// public IObservable Process() { - return Observable.Return(distributions.Beta(Concentration1, Concentration0, generator: Generator)); + return Observable.Return(distributions.Beta(Concentration1, Concentration0)); } /// - /// Creates a distribution for each incoming RNG . + /// Creates a Beta distribution using the incoming RNG generator. /// - /// Observable sequence of random generators to use. - /// An observable sequence of Beta distributions. + /// + /// public IObservable Process(IObservable source) { - return source.Select((generator) => - { - Generator = generator; - return distributions.Beta(Concentration1, Concentration0, generator: Generator); - }); + return source.Select(generator => distributions.Beta(Concentration1, Concentration0, generator: generator)); } /// - /// For each element of the source stream, emits a distribution - /// constructed from the configured parameters and current . + /// For each element of the source stream, emits a Beta distribution. /// - /// The type of the triggering source sequence. - /// Trigger sequence; each element causes a new distribution to be emitted. - /// An observable sequence of Beta distributions. + /// + /// + /// public IObservable Process(IObservable source) { - return source.Select(_ => distributions.Beta(Concentration1, Concentration0, generator: Generator)); + return source.Select(_ => distributions.Beta(Concentration1, Concentration0)); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Binomial.cs b/src/Bonsai.ML.Torch/Distributions/Binomial.cs index 13c281bc..3a23ad5b 100644 --- a/src/Bonsai.ML.Torch/Distributions/Binomial.cs +++ b/src/Bonsai.ML.Torch/Distributions/Binomial.cs @@ -1,48 +1,27 @@ -using static TorchSharp.torch; -using TorchSharp; -using Bonsai; using System; using System.ComponentModel; using System.Reactive.Linq; using System.Xml.Serialization; +using static TorchSharp.torch; namespace Bonsai.ML.Torch.Distributions; /// /// Creates a Binomial probability distribution with a given number of trials and success probability. -/// Emits a TorchSharp distribution module that can be sampled or queried for log-probabilities. /// [Combinator] -[ResetCombinator] -[Description("Creates a Binomial distribution with count (number of trials) and probability of success p.")] +[Description("Creates a Binomial distribution with count (number of trials) and probability of success.")] [WorkflowElementCategory(ElementCategory.Source)] -public class Binomial : TensorContainerBase +[TypeConverter(typeof(TensorOperatorConverter))] +public class Binomial : IScalarTypeProvider { - /// - /// Initializes a new instance of the class. - /// - public Binomial() - { - RegisterTensor( - () => _count, - value => _count = value); - RegisterTensor( - () => _probabilities, - value => _probabilities = value); - } - - private Tensor _count; /// /// Number of trials (non-negative). Can be a scalar or tensor. If tensor, values should be non-negative integers. /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] [Description("Number of trials (non-negative). Can be a scalar or tensor.")] - public Tensor Count - { - get => _count; - set => _count = value; - } + public Tensor Count { get; set; } = null; /// /// The values of count in XML string format. @@ -56,18 +35,13 @@ public string CountXml set => Count = TensorConverter.ConvertFromString(value, Type); } - private Tensor _probabilities; /// /// Probability of success p in [0, 1]. Can be a scalar or tensor; the shape should be broadcastable to . /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - [Description("Probability of success p in [0, 1]. Can be a scalar or tensor; shape should broadcast with Count.")] - public Tensor Probabilities - { - get => _probabilities; - set => _probabilities = value; - } + [Description("Probability of success in [0, 1]. Can be a scalar or tensor; shape should broadcastable with Count.")] + public Tensor Probabilities { get; set; } = null; /// /// The values of probabilities in XML string format. @@ -77,48 +51,44 @@ public Tensor Probabilities [EditorBrowsable(EditorBrowsableState.Never)] public string ProbabilitiesXml { - get => TensorConverter.ConvertToString(_probabilities, Type); - set => _probabilities = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(Probabilities, Type); + set => Probabilities = TensorConverter.ConvertFromString(value, Type); } /// - /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// Gets or sets the data type of the tensor elements. /// - [XmlIgnore] - public Generator Generator { get; set; } = null; + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; /// - /// Creates a distribution using the configured parameters and optional . + /// Creates a Binomial distribution. /// - /// An observable that emits the constructed Binomial distribution. + /// public IObservable Process() { - return Observable.Return(distributions.Binomial(Count, Probabilities, generator: Generator)); + return Observable.Return(distributions.Binomial(Count, Probabilities)); } /// - /// Creates a distribution for each incoming RNG . + /// Creates a Binomial distribution for each incoming RNG generator. /// - /// Observable sequence of random generators to use. - /// An observable sequence of Binomial distributions. + /// + /// public IObservable Process(IObservable source) { - return source.Select((generator) => - { - Generator = generator; - return distributions.Binomial(Count, Probabilities, generator: Generator); - }); + return source.Select(generator => distributions.Binomial(Count, Probabilities, generator: generator)); } /// - /// For each element of the source stream, emits a distribution - /// constructed from the configured parameters and current . + /// For each element of the source stream, emits a Binomial distribution. /// - /// The type of the triggering source sequence. - /// Trigger sequence; each element causes a new distribution to be emitted. - /// An observable sequence of Binomial distributions. + /// + /// + /// public IObservable Process(IObservable source) { - return source.Select(_ => distributions.Binomial(Count, Probabilities, generator: Generator)); + return source.Select(_ => distributions.Binomial(Count, Probabilities)); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Categorical.cs b/src/Bonsai.ML.Torch/Distributions/Categorical.cs index 382fa765..f25875d8 100644 --- a/src/Bonsai.ML.Torch/Distributions/Categorical.cs +++ b/src/Bonsai.ML.Torch/Distributions/Categorical.cs @@ -1,46 +1,28 @@ -using static TorchSharp.torch; -using TorchSharp; -using Bonsai; using System; using System.ComponentModel; using System.Reactive.Linq; using System.Xml.Serialization; +using TorchSharp; +using static TorchSharp.torch; namespace Bonsai.ML.Torch.Distributions; /// /// Creates a Categorical (discrete) distribution over classes given event probabilities. -/// Emits a TorchSharp distribution module that can be sampled or queried for probabilities. /// [Combinator] -[ResetCombinator] [Description("Creates a Categorical distribution with class probabilities and emits a TorchSharp distribution module.")] [WorkflowElementCategory(ElementCategory.Source)] -public class Categorical : TensorContainerBase +[TypeConverter(typeof(TensorOperatorConverter))] +public class Categorical : IScalarTypeProvider { /// - /// Initializes a new instance of the class. - /// - public Categorical() - { - RegisterTensor( - () => _probabilities, - value => _probabilities = value); - } - - private Tensor _probabilities; - /// - /// Class probabilities along the last dimension. Values must be non-negative and typically sum to 1 per row. - /// Can be a 1D vector or higher-rank tensor for batched distributions. + /// Class probabilities along the last dimension. Values must be non-negative and typically sum to 1 per row. Can be a 1D vector or higher-rank tensor for batched distributions. /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] [Description("Class probabilities along the last dimension (non-negative, typically sum to 1 per row). Supports batching.")] - public Tensor Probabilities - { - get => _probabilities; - set => _probabilities = value; - } + public Tensor Probabilities { get; set; } = null; /// /// The values of probabilities in XML string format. @@ -50,48 +32,44 @@ public Tensor Probabilities [EditorBrowsable(EditorBrowsableState.Never)] public string ProbabilitiesXml { - get => TensorConverter.ConvertToString(_probabilities, Type); - set => _probabilities = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(Probabilities, Type); + set => Probabilities = TensorConverter.ConvertFromString(value, Type); } /// - /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// Gets or sets the data type of the tensor elements. /// - [XmlIgnore] - public Generator Generator { get; set; } = null; + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; /// - /// Creates a distribution using the configured . + /// Creates a categorical distribution. /// - /// An observable that emits the constructed Categorical distribution. + /// public IObservable Process() { - return Observable.Return(distributions.Categorical(Probabilities, generator: Generator)); + return Observable.Return(distributions.Categorical(Probabilities)); } /// - /// Creates a distribution for each incoming RNG . + /// Creates a categorical distribution for each incoming RNG generator. /// - /// Observable sequence of random generators to use. - /// An observable sequence of Categorical distributions. - public IObservable Process(IObservable source) + /// + /// + public IObservable Process(IObservable source) { - return source.Select((generator) => - { - Generator = generator; - return distributions.Categorical(Probabilities, generator: Generator); - }); + return source.Select(generator => distributions.Categorical(Probabilities, generator: generator)); } /// - /// For each element of the source stream, emits a distribution - /// constructed from the configured and current . + /// For each element of the source stream, emits a categorical distribution. /// - /// The type of the triggering source sequence. - /// Trigger sequence; each element causes a new distribution to be emitted. - /// An observable sequence of Categorical distributions. + /// + /// + /// public IObservable Process(IObservable source) { - return source.Select(_ => distributions.Categorical(Probabilities, generator: Generator)); + return source.Select(_ => distributions.Categorical(Probabilities)); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Cauchy.cs b/src/Bonsai.ML.Torch/Distributions/Cauchy.cs index 0eeb8c95..47426854 100644 --- a/src/Bonsai.ML.Torch/Distributions/Cauchy.cs +++ b/src/Bonsai.ML.Torch/Distributions/Cauchy.cs @@ -1,49 +1,28 @@ -using static TorchSharp.torch; -using TorchSharp; -using Bonsai; using System; using System.ComponentModel; using System.Reactive.Linq; using System.Xml.Serialization; +using TorchSharp; +using static TorchSharp.torch; namespace Bonsai.ML.Torch.Distributions; /// -/// Creates a Cauchy (Lorentz) distribution parameterized by location and scale. -/// Emits a TorchSharp distribution module that can be sampled or queried for probabilities. +/// Represents an operator that creates a Cauchy (Lorentz) distribution parameterized by location and scale. /// [Combinator] -[ResetCombinator] [Description("Creates a Cauchy distribution with the specified location and scale parameters.")] [WorkflowElementCategory(ElementCategory.Source)] -public class Cauchy : TensorContainerBase +[TypeConverter(typeof(TensorOperatorConverter))] +public class Cauchy : IScalarTypeProvider { - /// - /// Initializes a new instance of the class. - /// - public Cauchy() - { - RegisterTensor( - () => _locations, - value => _locations = value); - - RegisterTensor( - () => _scales, - value => _scales = value); - } - - private Tensor _locations; /// /// Location parameter. Can be a scalar or tensor; shape determines the batch/event shape. /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] [Description("Location parameter. Can be a scalar or tensor; supports batching.")] - public Tensor Locations - { - get => _locations; - set => _locations = value; - } + public Tensor Locations { get; set; } = null; /// /// The values of the locations in XML string format. @@ -53,21 +32,17 @@ public Tensor Locations [EditorBrowsable(EditorBrowsableState.Never)] public string LocationsXml { - get => TensorConverter.ConvertToString(_locations, Type); - set => _locations = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(Locations, Type); + set => Locations = TensorConverter.ConvertFromString(value, Type); } - private Tensor _scales; /// /// Scale parameter (> 0). Can be a scalar or tensor; must be broadcastable with . /// + [XmlIgnore] [TypeConverter(typeof(TensorConverter))] [Description("Scale parameter (> 0). Can be a scalar or tensor; must broadcast with Locations.")] - public Tensor Scales - { - get => _scales; - set => _scales = value; - } + public Tensor Scales { get; set; } = null; /// /// The values of the scales in XML string format. @@ -77,15 +52,16 @@ public Tensor Scales [EditorBrowsable(EditorBrowsableState.Never)] public string ScalesXml { - get => TensorConverter.ConvertToString(_scales, Type); - set => _scales = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(Scales, Type); + set => Scales = TensorConverter.ConvertFromString(value, Type); } /// - /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// Gets or sets the data type of the tensor elements. /// - [XmlIgnore] - public Generator Generator { get; set; } = null; + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; /// /// Creates a distribution using the configured parameters. @@ -93,32 +69,28 @@ public string ScalesXml /// An observable that emits the constructed Cauchy distribution. public IObservable Process() { - return Observable.Return(distributions.Cauchy(Locations, Scales, generator: Generator)); + return Observable.Return(distributions.Cauchy(Locations, Scales)); } /// - /// Creates a distribution for each incoming RNG . + /// Creates a distribution for each incoming RNG . /// /// Observable sequence of random generators to use. /// An observable sequence of Cauchy distributions. public IObservable Process(IObservable source) { - return source.Select((generator) => - { - Generator = generator; - return distributions.Cauchy(Locations, Scales, generator: Generator); - }); + return source.Select(generator => distributions.Cauchy(Locations, Scales, generator: generator)); } /// /// For each element of the source stream, emits a distribution - /// constructed from the configured parameters and current . + /// constructed from the configured parameters. /// /// The type of the triggering source sequence. /// Trigger sequence; each element causes a new distribution to be emitted. /// An observable sequence of Cauchy distributions. public IObservable Process(IObservable source) { - return source.Select(_ => distributions.Cauchy(Locations, Scales, generator: Generator)); + return source.Select(_ => distributions.Cauchy(Locations, Scales)); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Cdf.cs b/src/Bonsai.ML.Torch/Distributions/CumulativeDistributionFunction.cs similarity index 52% rename from src/Bonsai.ML.Torch/Distributions/Cdf.cs rename to src/Bonsai.ML.Torch/Distributions/CumulativeDistributionFunction.cs index 32678399..2831dfab 100644 --- a/src/Bonsai.ML.Torch/Distributions/Cdf.cs +++ b/src/Bonsai.ML.Torch/Distributions/CumulativeDistributionFunction.cs @@ -1,57 +1,20 @@ -using Bonsai; using System; -using System.Reactive.Linq; using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; using static TorchSharp.torch; using static TorchSharp.torch.distributions; -using System.Xml.Serialization; namespace Bonsai.ML.Torch.Distributions; /// -/// Creates a cumulative distribution function (CDF) from the input distribution. +/// Represents an operator that creates a cumulative distribution function (CDF) from the given distribution. /// [Combinator] -[ResetCombinator] -[Description("Creates a cumulative distribution function (CDF) from the input distribution.")] +[Description("Creates a cumulative distribution function (CDF) from the given distribution.")] [WorkflowElementCategory(ElementCategory.Transform)] -public class Cdf : TensorContainerBase +public class CumulativeDistributionFunction { - /// - /// Initializes a new instance of the class. - /// - public Cdf() - { - RegisterTensor( - () => _values, - v => _values = v); - } - - private Tensor _values; - /// - /// The values at which to evaluate the CDF. - /// - [XmlIgnore] - [TypeConverter(typeof(TensorConverter))] - [Description("The values at which to evaluate the CDF.")] - public Tensor Values - { - get => _values; - set => _values = value; - } - - /// - /// The values in XML string format. - /// - [Browsable(false)] - [XmlElement(nameof(Values))] - [EditorBrowsable(EditorBrowsableState.Never)] - public string ValuesXml - { - get => TensorConverter.ConvertToString(_values, Type); - set => _values = TensorConverter.ConvertFromString(value, Type); - } - /// /// The input distribution. /// @@ -59,16 +22,6 @@ public string ValuesXml [Description("The input distribution.")] public Distribution Distribution { get; set; } - /// - /// Processes the input distribution to compute the CDF at the specified values. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(distribution => distribution.cdf(Values)); - } - /// /// Processes the input values to compute the CDF using the specified distribution. /// diff --git a/src/Bonsai.ML.Torch/Distributions/Dirichlet.cs b/src/Bonsai.ML.Torch/Distributions/Dirichlet.cs index e095869f..7ad3e131 100644 --- a/src/Bonsai.ML.Torch/Distributions/Dirichlet.cs +++ b/src/Bonsai.ML.Torch/Distributions/Dirichlet.cs @@ -1,44 +1,27 @@ -using static TorchSharp.torch; -using TorchSharp; -using Bonsai; using System; using System.ComponentModel; using System.Reactive.Linq; using System.Xml.Serialization; +using static TorchSharp.torch; namespace Bonsai.ML.Torch.Distributions; /// -/// Creates a Dirichlet probability distribution parameterized by concentration parameters. -/// Emits a TorchSharp distribution module that can be sampled or queried for probabilities. +/// Represents an operator that creates a Dirichlet probability distribution parameterized by concentration parameters. /// [Combinator] -[ResetCombinator] [Description("Creates a Dirichlet distribution with concentration parameters.")] [WorkflowElementCategory(ElementCategory.Source)] -public class Dirichlet : TensorContainerBase +[TypeConverter(typeof(TensorOperatorConverter))] +public class Dirichlet : IScalarTypeProvider { - /// - /// Initializes a new instance of the class. - /// - public Dirichlet() - { - RegisterTensor( - () => _concentration, - value => _concentration = value); - } - - private Tensor _concentration; /// /// Concentration parameters (> 0). Can be a scalar or tensor; the shape determines the batch/event shape. /// + [XmlIgnore] [TypeConverter(typeof(TensorConverter))] [Description("Concentration parameters (> 0). Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] - public Tensor Concentration - { - get => _concentration; - set => _concentration = value; - } + public Tensor Concentration { get; set; } = null; /// /// The values of the concentration in XML string format. @@ -48,48 +31,44 @@ public Tensor Concentration [EditorBrowsable(EditorBrowsableState.Never)] public string ConcentrationXml { - get => TensorConverter.ConvertToString(_concentration, Type); - set => _concentration = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(Concentration, Type); + set => Concentration = TensorConverter.ConvertFromString(value, Type); } /// - /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// Gets or sets the data type of the tensor elements. /// - [XmlIgnore] - public Generator Generator { get; set; } = null; + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; /// - /// Creates a distribution using the configured parameters and optional . + /// Creates a distribution using the configured parameters. /// /// An observable that emits the constructed Dirichlet distribution. public IObservable Process() { - return Observable.Return(distributions.Dirichlet(Concentration, generator: Generator)); + return Observable.Return(distributions.Dirichlet(Concentration)); } /// - /// Creates a distribution for each incoming RNG . + /// Creates a distribution for each incoming RNG . /// /// Observable sequence of random generators to use. /// An observable sequence of Dirichlet distributions. public IObservable Process(IObservable source) { - return source.Select((generator) => - { - Generator = generator; - return distributions.Dirichlet(Concentration, generator: Generator); - }); + return source.Select(generator => distributions.Dirichlet(Concentration, generator: generator)); } /// - /// For each element of the source stream, emits a distribution - /// constructed from the configured parameters and current . + /// For each element of the source stream, emits a distribution constructed from the configured parameters. /// /// The type of the triggering source sequence. /// Trigger sequence; each element causes a new distribution to be emitted. /// An observable sequence of Dirichlet distributions. public IObservable Process(IObservable source) { - return source.Select(_ => distributions.Dirichlet(Concentration, generator: Generator)); + return source.Select(_ => distributions.Dirichlet(Concentration)); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Exponential.cs b/src/Bonsai.ML.Torch/Distributions/Exponential.cs index 4dc4cc16..09f6f3fd 100644 --- a/src/Bonsai.ML.Torch/Distributions/Exponential.cs +++ b/src/Bonsai.ML.Torch/Distributions/Exponential.cs @@ -1,45 +1,28 @@ -using static TorchSharp.torch; -using TorchSharp; -using Bonsai; using System; using System.ComponentModel; using System.Reactive.Linq; using System.Xml.Serialization; +using TorchSharp; +using static TorchSharp.torch; namespace Bonsai.ML.Torch.Distributions; /// -/// Creates an Exponential probability distribution parameterized by rate. -/// Emits a TorchSharp distribution module that can be sampled or queried for probabilities. +/// Represents an operator that creates an exponential probability distribution parameterized by rate. /// [Combinator] -[ResetCombinator] -[Description("Creates an Exponential distribution with the specified rate parameter.")] +[Description("Creates an exponential distribution with the specified rate parameter.")] [WorkflowElementCategory(ElementCategory.Source)] -public class Exponential : TensorContainerBase +[TypeConverter(typeof(TensorOperatorConverter))] +public class Exponential : IScalarTypeProvider { - /// - /// Initializes a new instance of the class. - /// - public Exponential() - { - RegisterTensor( - () => _rate, - value => _rate = value); - } - - private Tensor _rate; /// /// Rate parameter (> 0). Can be a scalar or tensor; the shape determines the batch/event shape. /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] [Description("Rate parameter (> 0). Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] - public Tensor Rate - { - get => _rate; - set => _rate = value; - } + public Tensor Rate { get; set; } = null; /// /// The values of the rates in XML string format. @@ -49,48 +32,44 @@ public Tensor Rate [EditorBrowsable(EditorBrowsableState.Never)] public string RateXml { - get => TensorConverter.ConvertToString(_rate, Type); - set => _rate = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(Rate, Type); + set => Rate = TensorConverter.ConvertFromString(value, Type); } /// - /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// Gets or sets the data type of the tensor elements. /// - [XmlIgnore] - public Generator Generator { get; set; } = null; + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; /// - /// Creates a distribution using the configured parameters and optional . + /// Creates a distribution using the configured parameters. /// /// An observable that emits the constructed Exponential distribution. public IObservable Process() { - return Observable.Return(distributions.Exponential(Rate, generator: Generator)); + return Observable.Return(distributions.Exponential(Rate)); } /// - /// Creates a distribution for each incoming RNG . + /// Creates a distribution for each incoming RNG . /// /// Observable sequence of random generators to use. /// An observable sequence of Exponential distributions. public IObservable Process(IObservable source) { - return source.Select((generator) => - { - Generator = generator; - return distributions.Exponential(Rate, generator: Generator); - }); + return source.Select(generator => distributions.Exponential(Rate, generator: generator)); } /// - /// For each element of the source stream, emits a distribution - /// constructed from the configured parameters and current . + /// For each element of the source stream, emits a distribution constructed from the configured parameters. /// /// The type of the triggering source sequence. /// Trigger sequence; each element causes a new distribution to be emitted. /// An observable sequence of Exponential distributions. public IObservable Process(IObservable source) { - return source.Select(_ => distributions.Exponential(Rate, generator: Generator)); + return source.Select(_ => distributions.Exponential(Rate)); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Gamma.cs b/src/Bonsai.ML.Torch/Distributions/Gamma.cs index f89d3f27..f3142eeb 100644 --- a/src/Bonsai.ML.Torch/Distributions/Gamma.cs +++ b/src/Bonsai.ML.Torch/Distributions/Gamma.cs @@ -1,49 +1,28 @@ -using static TorchSharp.torch; -using TorchSharp; -using Bonsai; using System; using System.ComponentModel; using System.Reactive.Linq; using System.Xml.Serialization; +using TorchSharp; +using static TorchSharp.torch; namespace Bonsai.ML.Torch.Distributions; /// -/// Creates a Gamma probability distribution parameterized by concentration and rate. -/// Emits a TorchSharp distribution module that can be sampled or queried for probabilities. +/// Represents an operator that creates a gamma probability distribution parameterized by concentration and rate. /// [Combinator] -[ResetCombinator] -[Description("Creates a Gamma distribution with concentration and rate parameters.")] +[Description("Creates a gamma distribution with concentration and rate parameters.")] [WorkflowElementCategory(ElementCategory.Source)] -public class Gamma : TensorContainerBase +[TypeConverter(typeof(TensorOperatorConverter))] +public class Gamma : IScalarTypeProvider { - /// - /// Initializes a new instance of the class. - /// - public Gamma() - { - RegisterTensor( - () => _concentration, - value => _concentration = value); - - RegisterTensor( - () => _rate, - value => _rate = value); - } - - private Tensor _concentration; /// /// Concentration parameter (> 0). Can be a scalar or tensor; the shape determines the batch/event shape. /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] [Description("Concentration parameter (> 0). Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] - public Tensor Concentration - { - get => _concentration; - set => _concentration = value; - } + public Tensor Concentration { get; set; } = null; /// /// The values of the concentration in XML string format. @@ -53,22 +32,17 @@ public Tensor Concentration [EditorBrowsable(EditorBrowsableState.Never)] public string ConcentrationXml { - get => TensorConverter.ConvertToString(_concentration, Type); - set => _concentration = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(Concentration, Type); + set => Concentration = TensorConverter.ConvertFromString(value, Type); } - private Tensor _rate; /// /// Rate parameter (> 0). Can be a scalar or tensor; must be broadcastable with . /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] [Description("Rate parameter (> 0). Can be a scalar or tensor; must broadcast with Concentration.")] - public Tensor Rate - { - get => _rate; - set => _rate = value; - } + public Tensor Rate { get; set; } = null; /// /// The values of the rate in XML string format. @@ -78,48 +52,45 @@ public Tensor Rate [EditorBrowsable(EditorBrowsableState.Never)] public string RateXml { - get => TensorConverter.ConvertToString(_rate, Type); - set => _rate = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(Rate, Type); + set => Rate = TensorConverter.ConvertFromString(value, Type); } /// - /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// Gets or sets the data type of the tensor elements. /// - [XmlIgnore] - public Generator Generator { get; set; } = null; + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; /// - /// Creates a distribution using the configured parameters and optional . + /// Creates a distribution using the configured parameters. /// /// An observable that emits the constructed Gamma distribution. public IObservable Process() { - return Observable.Return(distributions.Gamma(Concentration, Rate, generator: Generator)); + return Observable.Return(distributions.Gamma(Concentration, Rate)); } /// - /// Creates a distribution for each incoming RNG . + /// Creates a distribution for each incoming RNG . /// /// Observable sequence of random generators to use. /// An observable sequence of Gamma distributions. public IObservable Process(IObservable source) { - return source.Select((generator) => - { - Generator = generator; - return distributions.Gamma(Concentration, Rate, generator: Generator); - }); + return source.Select(generator => distributions.Gamma(Concentration, Rate, generator: generator)); } /// /// For each element of the source stream, emits a distribution - /// constructed from the configured parameters and current . + /// constructed from the configured parameters. /// /// The type of the triggering source sequence. /// Trigger sequence; each element causes a new distribution to be emitted. /// An observable sequence of Gamma distributions. public IObservable Process(IObservable source) { - return source.Select(_ => distributions.Gamma(Concentration, Rate, generator: Generator)); + return source.Select(_ => distributions.Gamma(Concentration, Rate)); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Geometric.cs b/src/Bonsai.ML.Torch/Distributions/Geometric.cs index 42cbbf87..9a910e90 100644 --- a/src/Bonsai.ML.Torch/Distributions/Geometric.cs +++ b/src/Bonsai.ML.Torch/Distributions/Geometric.cs @@ -1,45 +1,28 @@ -using static TorchSharp.torch; -using TorchSharp; -using Bonsai; using System; using System.ComponentModel; using System.Reactive.Linq; using System.Xml.Serialization; +using TorchSharp; +using static TorchSharp.torch; namespace Bonsai.ML.Torch.Distributions; /// -/// Creates a Geometric probability distribution parameterized by success probability. -/// Emits a TorchSharp distribution module that can be sampled or queried for probabilities. +/// Represents an operator that creates a geometric probability distribution parameterized by success probability. /// [Combinator] -[ResetCombinator] -[Description("Creates a Geometric distribution with the specified success probability.")] +[Description("Creates a geometric distribution with the specified success probability.")] [WorkflowElementCategory(ElementCategory.Source)] -public class Geometric : TensorContainerBase +[TypeConverter(typeof(TensorOperatorConverter))] +public class Geometric : IScalarTypeProvider { /// - /// Initializes a new instance of the class. - /// - public Geometric() - { - RegisterTensor( - () => _probabilities, - value => _probabilities = value); - } - - private Tensor _probabilities; - /// - /// Success probability p in [0, 1]. Can be a scalar or tensor; the shape determines the batch/event shape. + /// Success probability in [0, 1]. Can be a scalar or tensor; the shape determines the batch/event shape. /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - [Description("Success probability p in [0, 1]. Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] - public Tensor Probabilities - { - get => _probabilities; - set => _probabilities = value; - } + [Description("Success probability in [0, 1]. Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] + public Tensor Probabilities { get; set; } = null; /// /// The values of the probabilities in XML string format. @@ -49,48 +32,45 @@ public Tensor Probabilities [EditorBrowsable(EditorBrowsableState.Never)] public string ProbabilitiesXml { - get => TensorConverter.ConvertToString(_probabilities, Type); - set => _probabilities = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(Probabilities, Type); + set => Probabilities = TensorConverter.ConvertFromString(value, Type); } /// - /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// Gets or sets the data type of the tensor elements. /// - [XmlIgnore] - public Generator Generator { get; set; } = null; + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; /// - /// Creates a distribution using the configured parameters and optional . + /// Creates a distribution using the configured parameters. /// /// An observable that emits the constructed Geometric distribution. public IObservable Process() { - return Observable.Return(distributions.Geometric(Probabilities, generator: Generator)); + return Observable.Return(distributions.Geometric(Probabilities)); } /// - /// Creates a distribution for each incoming RNG . + /// Creates a distribution for each incoming RNG . /// /// Observable sequence of random generators to use. /// An observable sequence of Geometric distributions. public IObservable Process(IObservable source) { - return source.Select((generator) => - { - Generator = generator; - return distributions.Geometric(Probabilities, generator: Generator); - }); + return source.Select(generator => distributions.Geometric(Probabilities, generator: generator)); } /// /// For each element of the source stream, emits a distribution - /// constructed from the configured parameters and current . + /// constructed from the configured parameters. /// /// The type of the triggering source sequence. /// Trigger sequence; each element causes a new distribution to be emitted. /// An observable sequence of Geometric distributions. public IObservable Process(IObservable source) { - return source.Select(_ => distributions.Geometric(Probabilities, generator: Generator)); + return source.Select(_ => distributions.Geometric(Probabilities)); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/InverseCdf.cs b/src/Bonsai.ML.Torch/Distributions/InverseCumulativeDistributionFunction.cs similarity index 55% rename from src/Bonsai.ML.Torch/Distributions/InverseCdf.cs rename to src/Bonsai.ML.Torch/Distributions/InverseCumulativeDistributionFunction.cs index 69ca0116..259336ed 100644 --- a/src/Bonsai.ML.Torch/Distributions/InverseCdf.cs +++ b/src/Bonsai.ML.Torch/Distributions/InverseCumulativeDistributionFunction.cs @@ -1,10 +1,9 @@ -using Bonsai; using System; -using System.Reactive.Linq; using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; using static TorchSharp.torch; using static TorchSharp.torch.distributions; -using System.Xml.Serialization; namespace Bonsai.ML.Torch.Distributions; @@ -12,62 +11,16 @@ namespace Bonsai.ML.Torch.Distributions; /// Creates an inverse cumulative distribution function (inverse CDF) from the input distribution. /// [Combinator] -[ResetCombinator] [Description("Creates an inverse cumulative distribution function (inverse CDF) from the input distribution.")] [WorkflowElementCategory(ElementCategory.Transform)] -public class InverseCdf : TensorContainerBase +public class InverseCumulativeDistributionFunction { - /// - /// Initializes a new instance of the class. - /// - public InverseCdf() - { - RegisterTensor( - () => _values, - v => _values = v); - } - - private Tensor _values; - /// - /// The values at which to evaluate the inverse CDF. - /// - [XmlIgnore] - [TypeConverter(typeof(TensorConverter))] - [Description("The values at which to evaluate the inverse CDF.")] - public Tensor Values - { - get => _values; - set => _values = value; - } - - /// - /// The values in XML string format. - /// - [Browsable(false)] - [XmlElement(nameof(Values))] - [EditorBrowsable(EditorBrowsableState.Never)] - public string ValuesXml - { - get => TensorConverter.ConvertToString(_values, Type); - set => _values = TensorConverter.ConvertFromString(value, Type); - } - /// /// The input distribution. /// [XmlIgnore] public Distribution Distribution { get; set; } - /// - /// Processes the input distribution to compute the inverse CDF at the specified values. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(distribution => distribution.icdf(Values)); - } - /// /// Processes the input values to compute the inverse CDF using the specified distribution. /// diff --git a/src/Bonsai.ML.Torch/Distributions/LogProbability.cs b/src/Bonsai.ML.Torch/Distributions/LogProbability.cs index 993c83c0..272df0d0 100644 --- a/src/Bonsai.ML.Torch/Distributions/LogProbability.cs +++ b/src/Bonsai.ML.Torch/Distributions/LogProbability.cs @@ -1,73 +1,26 @@ -using Bonsai; using System; -using System.Reactive.Linq; using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; using static TorchSharp.torch; using static TorchSharp.torch.distributions; -using System.Xml.Serialization; namespace Bonsai.ML.Torch.Distributions; /// -/// Computes the log probability of the given values under the specified distribution. +/// Represents an operator that computes the log probability of the input values under the specified distribution. /// [Combinator] -[ResetCombinator] -[Description("Computes the log probability of the given values under the specified distribution.")] +[Description("Computes the log probability of the input values under the specified distribution.")] [WorkflowElementCategory(ElementCategory.Transform)] -public class LogProbability : TensorContainerBase +public class LogProbability { - /// - /// Initializes a new instance of the class. - /// - public LogProbability() - { - RegisterTensor( - () => _values, - v => _values = v); - } - - private Tensor _values; - /// - /// The values at which to evaluate the inverse CDF. - /// - [XmlIgnore] - [TypeConverter(typeof(TensorConverter))] - [Description("The values at which to evaluate the inverse CDF.")] - public Tensor Values - { - get => _values; - set => _values = value; - } - /// /// The input distribution. /// [XmlIgnore] public Distribution Distribution { get; set; } - /// - /// The values in XML string format. - /// - [Browsable(false)] - [XmlElement(nameof(Values))] - [EditorBrowsable(EditorBrowsableState.Never)] - public string ValuesXml - { - get => TensorConverter.ConvertToString(_values, Type); - set => _values = TensorConverter.ConvertFromString(value, Type); - } - - /// - /// Processes the input distribution to compute the log probability at the specified values. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(distribution => distribution.log_prob(Values)); - } - /// /// Processes the input values to compute the log probability using the specified distribution. /// diff --git a/src/Bonsai.ML.Torch/Distributions/MultivariateNormal.cs b/src/Bonsai.ML.Torch/Distributions/MultivariateNormal.cs index 105033c9..8bc97f45 100644 --- a/src/Bonsai.ML.Torch/Distributions/MultivariateNormal.cs +++ b/src/Bonsai.ML.Torch/Distributions/MultivariateNormal.cs @@ -1,49 +1,27 @@ -using static TorchSharp.torch; -using TorchSharp; -using Bonsai; using System; using System.ComponentModel; using System.Reactive.Linq; using System.Xml.Serialization; +using static TorchSharp.torch; namespace Bonsai.ML.Torch.Distributions; /// -/// Creates a Multivariate Normal (Gaussian) distribution parameterized by mean and covariance matrix. -/// Emits a TorchSharp distribution module that can be sampled or queried for probabilities. +/// Represents an operator that creates a multivariate normal (Gaussian) distribution parameterized by mean vector and covariance matrix. /// [Combinator] -[ResetCombinator] -[Description("Creates a Multivariate Normal distribution with mean vector and covariance matrix.")] +[Description("Creates a multivariate normal (Gaussian) distribution with mean vector and covariance matrix.")] [WorkflowElementCategory(ElementCategory.Source)] -public class MultivariateNormal : TensorContainerBase +[TypeConverter(typeof(TensorOperatorConverter))] +public class MultivariateNormal : IScalarTypeProvider { - /// - /// Initializes a new instance of the class. - /// - public MultivariateNormal() - { - RegisterTensor( - () => _mean, - value => _mean = value); - - RegisterTensor( - () => _covariance, - value => _covariance = value); - } - - private Tensor _mean; /// /// Mean vector of the distribution. Can be a 1D vector or higher-rank tensor for batched distributions. /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] [Description("Mean vector of the distribution. Can be a 1D vector or higher-rank tensor for batched distributions.")] - public Tensor Mean - { - get => _mean; - set => _mean = value; - } + public Tensor Mean { get; set; } = null; /// /// The values of the means in XML string format. @@ -53,22 +31,17 @@ public Tensor Mean [EditorBrowsable(EditorBrowsableState.Never)] public string MeanXml { - get => TensorConverter.ConvertToString(_mean, Type); - set => _mean = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(Mean, Type); + set => Mean = TensorConverter.ConvertFromString(value, Type); } - private Tensor _covariance; /// /// Covariance matrix of the distribution. Must be positive-definite and square with dimension matching . /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] [Description("Covariance matrix of the distribution. Must be positive-definite and square with dimension matching Mean.")] - public Tensor Covariance - { - get => _covariance; - set => _covariance = value; - } + public Tensor Covariance { get; set; } = null; /// /// The values of the covariance matrix in XML string format. @@ -78,48 +51,44 @@ public Tensor Covariance [EditorBrowsable(EditorBrowsableState.Never)] public string CovarianceXml { - get => TensorConverter.ConvertToString(_covariance, Type); - set => _covariance = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(Covariance, Type); + set => Covariance = TensorConverter.ConvertFromString(value, Type); } /// - /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// Gets or sets the data type of the tensor elements. /// - [XmlIgnore] - public Generator Generator { get; set; } = null; + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; /// - /// Creates a distribution using the configured parameters and optional . + /// Creates a distribution using the configured parameters. /// /// An observable that emits the constructed Multivariate Normal distribution. public IObservable Process() { - return Observable.Return(distributions.MultivariateNormal(Mean, Covariance, generator: Generator)); + return Observable.Return(distributions.MultivariateNormal(Mean, Covariance)); } /// - /// Creates a distribution for each incoming RNG . + /// Creates a distribution for each incoming RNG . /// /// Observable sequence of random generators to use. /// An observable sequence of Multivariate Normal distributions. public IObservable Process(IObservable source) { - return source.Select((generator) => - { - Generator = generator; - return distributions.MultivariateNormal(Mean, Covariance, generator: Generator); - }); + return source.Select(generator => distributions.MultivariateNormal(Mean, Covariance, generator: generator)); } /// - /// For each element of the source stream, emits a distribution - /// constructed from the configured parameters and current . + /// For each element of the source stream, emits a distribution constructed from the configured parameters. /// /// The type of the triggering source sequence. /// Trigger sequence; each element causes a new distribution to be emitted. /// An observable sequence of Multivariate Normal distributions. public IObservable Process(IObservable source) { - return source.Select(_ => distributions.MultivariateNormal(Mean, Covariance, generator: Generator)); + return source.Select(_ => distributions.MultivariateNormal(Mean, Covariance)); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Poisson.cs b/src/Bonsai.ML.Torch/Distributions/Poisson.cs index 6c2e52b1..467e4195 100644 --- a/src/Bonsai.ML.Torch/Distributions/Poisson.cs +++ b/src/Bonsai.ML.Torch/Distributions/Poisson.cs @@ -1,45 +1,27 @@ -using static TorchSharp.torch; -using TorchSharp; -using Bonsai; using System; using System.ComponentModel; using System.Reactive.Linq; using System.Xml.Serialization; +using static TorchSharp.torch; namespace Bonsai.ML.Torch.Distributions; /// -/// Creates a Poisson probability distribution parameterized by rate (expected number of events). -/// Emits a TorchSharp distribution module that can be sampled or queried for probabilities. +/// Represents an operator that creates a poisson probability distribution parameterized by rate (expected number of events). /// [Combinator] -[ResetCombinator] -[Description("Creates a Poisson distribution with the specified rate parameter.")] +[Description("Creates a poisson distribution with the specified rate parameter.")] [WorkflowElementCategory(ElementCategory.Source)] -public class Poisson : TensorContainerBase +[TypeConverter(typeof(TensorOperatorConverter))] +public class Poisson : IScalarTypeProvider { - /// - /// Initializes a new instance of the class. - /// - public Poisson() - { - RegisterTensor( - () => _rate, - value => _rate = value); - } - - private Tensor _rate; /// /// Rate parameter (> 0), representing the expected number of events. Can be a scalar or tensor; the shape determines the batch/event shape. /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] [Description("Rate parameter (> 0), expected number of events. Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] - public Tensor Rate - { - get => _rate; - set => _rate = value; - } + public Tensor Rate { get; set; } = null; /// /// The values of the rates in XML string format. @@ -49,48 +31,44 @@ public Tensor Rate [EditorBrowsable(EditorBrowsableState.Never)] public string RateXml { - get => TensorConverter.ConvertToString(_rate, Type); - set => _rate = TensorConverter.ConvertFromString(value, Type); + get => TensorConverter.ConvertToString(Rate, Type); + set => Rate = TensorConverter.ConvertFromString(value, Type); } /// - /// Optional random number generator to use when sampling. If null, TorchSharp's global RNG is used. + /// Gets or sets the data type of the tensor elements. /// - [XmlIgnore] - public Generator Generator { get; set; } = null; + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; /// - /// Creates a distribution using the configured parameters and optional . + /// Creates a distribution using the configured parameters. /// /// An observable that emits the constructed Poisson distribution. public IObservable Process() { - return Observable.Return(distributions.Poisson(Rate, generator: Generator)); + return Observable.Return(distributions.Poisson(Rate)); } /// - /// Creates a distribution for each incoming RNG . + /// Creates a distribution for each incoming RNG . /// /// Observable sequence of random generators to use. /// An observable sequence of Poisson distributions. public IObservable Process(IObservable source) { - return source.Select((generator) => - { - Generator = generator; - return distributions.Poisson(Rate, generator: Generator); - }); + return source.Select(generator => distributions.Poisson(Rate, generator: generator)); } /// - /// For each element of the source stream, emits a distribution - /// constructed from the configured parameters and current . + /// For each element of the source stream, emits a distribution constructed from the configured parameters. /// /// The type of the triggering source sequence. /// Trigger sequence; each element causes a new distribution to be emitted. /// An observable sequence of Poisson distributions. public IObservable Process(IObservable source) { - return source.Select(_ => distributions.Poisson(Rate, generator: Generator)); + return source.Select(_ => distributions.Poisson(Rate)); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/ReparametrizedSample.cs b/src/Bonsai.ML.Torch/Distributions/ReparametrizedSample.cs index c98a4142..faff9cbb 100644 --- a/src/Bonsai.ML.Torch/Distributions/ReparametrizedSample.cs +++ b/src/Bonsai.ML.Torch/Distributions/ReparametrizedSample.cs @@ -1,4 +1,3 @@ -using Bonsai; using System; using System.Reactive.Linq; using System.ComponentModel; @@ -8,8 +7,7 @@ namespace Bonsai.ML.Torch.Distributions; /// -/// Generates reparameterized samples from the input distribution. -/// Reparameterized samples allow gradients to flow through the sampling process. +/// Represents an operator that generates reparameterized samples from the input distribution. Reparameterized samples allow gradients to flow through the sampling process. /// [Combinator] [Description("Generates reparameterized samples from the input distribution.")] diff --git a/src/Bonsai.ML.Torch/Distributions/Sample.cs b/src/Bonsai.ML.Torch/Distributions/Sample.cs index fa5f44d7..32ed0f8f 100644 --- a/src/Bonsai.ML.Torch/Distributions/Sample.cs +++ b/src/Bonsai.ML.Torch/Distributions/Sample.cs @@ -1,4 +1,3 @@ -using Bonsai; using System; using System.Reactive.Linq; using System.ComponentModel; @@ -8,8 +7,7 @@ namespace Bonsai.ML.Torch.Distributions; /// -/// Generates samples from the input distribution. -/// Gradients do not flow through the sampling process. +/// Represents an operator that generates samples from the input distribution. Gradients do not flow through the sampling process. /// [Combinator] [Description("Generates samples from the input distribution.")] From 8c09274cfb0f25eff5a14b1b495bbf8715372274 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 10 Dec 2025 17:33:19 +0000 Subject: [PATCH 14/16] Refactored classes in the random namespace --- src/Bonsai.ML.Torch/Random/CreateGenerator.cs | 11 +++++--- src/Bonsai.ML.Torch/Random/ManualSeed.cs | 25 +++++++++++++------ src/Bonsai.ML.Torch/Random/Normal.cs | 5 ++-- src/Bonsai.ML.Torch/Random/Permutation.cs | 12 ++++----- src/Bonsai.ML.Torch/Random/Uniform.cs | 12 ++++----- src/Bonsai.ML.Torch/Random/UniformIntegers.cs | 18 ++++++------- 6 files changed, 47 insertions(+), 36 deletions(-) diff --git a/src/Bonsai.ML.Torch/Random/CreateGenerator.cs b/src/Bonsai.ML.Torch/Random/CreateGenerator.cs index 4af2318a..a6aace8d 100644 --- a/src/Bonsai.ML.Torch/Random/CreateGenerator.cs +++ b/src/Bonsai.ML.Torch/Random/CreateGenerator.cs @@ -9,10 +9,10 @@ namespace Bonsai.ML.Torch.Random; /// -/// Creates a specific instance of a random number generator with the specified seed and device. +/// Represents an operator that creates a specific instance of a random number generator (RNG) with the specified seed and device. /// [Combinator] -[Description("Creates a specific instance of a random number generator with the specified seed and device.")] +[Description("Creates a specific instance of a random number generator (RNG) with the specified seed and device.")] [WorkflowElementCategory(ElementCategory.Source)] public class CreateGenerator { @@ -20,11 +20,13 @@ public class CreateGenerator /// The device on which to create the generator. /// [XmlIgnore] + [Description("The device on which to create the generator.")] public Device Device { get; set; } /// /// The seed for the random number generator. /// + [Description("The seed for the random number generator.")] public ulong Seed { get; set; } = 0; /// @@ -39,8 +41,11 @@ public IObservable Process() /// /// Generates an observable sequence of random number generators for each element of the input sequence. /// + /// The type of the elements in the input sequence. + /// The input observable sequence. + /// An observable sequence of random number generators. public IObservable Process(IObservable source) { - return source.Select(value => new Generator(Seed, Device)); + return source.Select(_ => new Generator(Seed, Device)); } } diff --git a/src/Bonsai.ML.Torch/Random/ManualSeed.cs b/src/Bonsai.ML.Torch/Random/ManualSeed.cs index 3745f019..24eaa930 100644 --- a/src/Bonsai.ML.Torch/Random/ManualSeed.cs +++ b/src/Bonsai.ML.Torch/Random/ManualSeed.cs @@ -1,18 +1,16 @@ -using Bonsai; -using static TorchSharp.torch; -using TorchSharp; using System; using System.Reactive.Linq; -using System.Xml.Serialization; using System.ComponentModel; +using TorchSharp; +using static TorchSharp.torch; namespace Bonsai.ML.Torch.Random; /// -/// Sets the global random seed for TorchSharp and creates a random number generator with the specified seed. +/// Represents an operator that sets the global random seed for TorchSharp and creates a random number generator (RNG) with the specified seed. /// [Combinator] -[Description("Sets the global random seed for TorchSharp and creates a random number generator with the specified seed.")] +[Description("Sets the global random seed for TorchSharp and creates a random number generator (RNG) with the specified seed.")] [WorkflowElementCategory(ElementCategory.Source)] public class ManualSeed { @@ -22,11 +20,22 @@ public class ManualSeed public long Seed { get; set; } = 0; /// - /// Creates a random number generator with the specified seed and device. + /// Sets the global random seed and creates a random number generator (RNG). /// /// - public IObservable Process() + public IObservable Process() { return Observable.Return(manual_seed(Seed)); } + + /// + /// Generates an observable sequence where each element sets the global random seed and creates a random number generator (RNG). + /// + /// The type of the elements in the input sequence. + /// The input observable sequence. + /// An observable sequence of random number generators. + public IObservable Process(IObservable source) + { + return source.Select(_ => manual_seed(Seed)); + } } diff --git a/src/Bonsai.ML.Torch/Random/Normal.cs b/src/Bonsai.ML.Torch/Random/Normal.cs index f4329e29..389eebee 100644 --- a/src/Bonsai.ML.Torch/Random/Normal.cs +++ b/src/Bonsai.ML.Torch/Random/Normal.cs @@ -2,13 +2,12 @@ using System.ComponentModel; using System.Reactive.Linq; using System.Xml.Serialization; -using TorchSharp; using static TorchSharp.torch; namespace Bonsai.ML.Torch.Random; /// -/// Creates a tensor filled with random floats sampled from a normal distribution with mean 0 and variance 1. +/// Represents an operator that creates a tensor filled with random floats sampled from a normal distribution with the specified mean and variance. /// [Combinator] [ResetCombinator] @@ -94,6 +93,6 @@ public IObservable Process(IObservable source) /// public IObservable Process(IObservable source) { - return source.Select(value => randn(Size, dtype: Type, device: Device, generator: Generator) * Variance + Mean); + return source.Select(_ => randn(Size, dtype: Type, device: Device, generator: Generator) * Variance + Mean); } } diff --git a/src/Bonsai.ML.Torch/Random/Permutation.cs b/src/Bonsai.ML.Torch/Random/Permutation.cs index 315b69fe..afc4c6e1 100644 --- a/src/Bonsai.ML.Torch/Random/Permutation.cs +++ b/src/Bonsai.ML.Torch/Random/Permutation.cs @@ -2,17 +2,16 @@ using System.ComponentModel; using System.Reactive.Linq; using System.Xml.Serialization; -using TorchSharp; using static TorchSharp.torch; namespace Bonsai.ML.Torch.Random; /// -/// Creates a 1D tensor of a given size with a random permutation of integers from 0 to size - 1. +/// Represents an operator that creates a 1D tensor of a given size with a random permutation of integers in [0, size). /// [Combinator] [ResetCombinator] -[Description("Creates a 1D tensor of a given size with a random permutation of integers from 0 to size - 1.")] +[Description("Creates a 1D tensor of a given size with a random permutation of integers in [0, size).")] [WorkflowElementCategory(ElementCategory.Source)] public class Permutation { @@ -39,6 +38,7 @@ public class Permutation /// The random number generator to use. /// [XmlIgnore] + [Description("The random number generator to use.")] public Generator Generator { get; set; } = null; /// @@ -50,7 +50,7 @@ public IObservable Process() } /// - /// Generates an observable sequence of tensors with a random permutation of integers from [0, size) and uses the input generator. + /// Creates a tensor with a random permutation of integers in [0, size) and uses the input generator. /// /// /// @@ -81,12 +81,12 @@ public IObservable Process(IObservable source) /// - /// Generates an observable sequence of tensors filled with random values for each element of the input sequence. + /// Generates random permutations for each element of the input sequence. /// /// /// public IObservable Process(IObservable source) { - return source.Select(value => randperm(Size, dtype: Type, device: Device, generator: Generator)); + return source.Select(_ => randperm(Size, dtype: Type, device: Device, generator: Generator)); } } diff --git a/src/Bonsai.ML.Torch/Random/Uniform.cs b/src/Bonsai.ML.Torch/Random/Uniform.cs index 8d3f0047..e837bd7d 100644 --- a/src/Bonsai.ML.Torch/Random/Uniform.cs +++ b/src/Bonsai.ML.Torch/Random/Uniform.cs @@ -8,11 +8,11 @@ namespace Bonsai.ML.Torch.Random; /// -/// Creates a tensor filled with random numbers sampled from a uniform distribution over the interval [MinSize, MaxSize). +/// Represents an operator that creates a tensor filled with random numbers sampled from a uniform distribution over the interval [MinSize, MaxSize). /// [Combinator] [ResetCombinator] -[Description("Creates a tensor filled with random numbers from a uniform distribution over the interval [MinSize, MaxSize).")] +[Description("Creates a tensor filled with random numbers sampled from a uniform distribution over the interval [MinValue, MaxValue).")] [WorkflowElementCategory(ElementCategory.Source)] public class Uniform { @@ -65,11 +65,11 @@ public IObservable Process() } /// - /// Generates an observable sequence of tensors filled with random values and uses the input generator. + /// Creates tensors filled with random values and uses the input generator. /// /// /// - public IObservable Process(IObservable source) + public IObservable Process(IObservable source) { return source.Select(value => { @@ -89,13 +89,13 @@ public IObservable Process(IObservable source) } /// - /// Generates an observable sequence of tensors filled with random values for each element of the input sequence. + /// Creates tensors filled with random values for each element of the input sequence. /// /// /// public IObservable Process(IObservable source) { - return source.Select(value => rand(Size, dtype: Type, device: Device, generator: Generator) * (MaxValue - MinValue) + MinValue); + return source.Select(_ => rand(Size, dtype: Type, device: Device, generator: Generator) * (MaxValue - MinValue) + MinValue); } } diff --git a/src/Bonsai.ML.Torch/Random/UniformIntegers.cs b/src/Bonsai.ML.Torch/Random/UniformIntegers.cs index 89d6e4b4..450ce499 100644 --- a/src/Bonsai.ML.Torch/Random/UniformIntegers.cs +++ b/src/Bonsai.ML.Torch/Random/UniformIntegers.cs @@ -2,18 +2,16 @@ using System.ComponentModel; using System.Reactive.Linq; using System.Xml.Serialization; -using TorchSharp; using static TorchSharp.torch; namespace Bonsai.ML.Torch.Random; /// -/// Creates a tensor filled with random integers sampled from a uniform distribution over the -/// interval [MinValue, MaxValue). +/// Creates a tensor filled with random integers sampled from a uniform distribution over the interval [MinValue, MaxValue). /// [Combinator] [ResetCombinator] -[Description("Creates a tensor filled with random integers.")] +[Description("Creates a tensor filled with random integers sampled from a uniform distribution over the interval [MinValue, MaxValue).")] [WorkflowElementCategory(ElementCategory.Source)] public class UniformIntegers { @@ -53,11 +51,11 @@ public class UniformIntegers /// The random number generator to use. /// [XmlIgnore] + [Description("The random number generator to use.")] public Generator Generator { get; set; } = null; /// - /// Creates a tensor filled with random integers sampled from a uniform distribution over the - /// interval [MinValue, MaxValue). + /// Creates a tensor filled with random integers sampled from a uniform distribution over the interval [MinValue, MaxValue). /// public IObservable Process() { @@ -65,7 +63,7 @@ public IObservable Process() } /// - /// Generates an observable sequence of tensors filled with random integers and uses the input generator. + /// Creates tensors filled with random integers and uses the input generator. /// /// /// @@ -79,7 +77,7 @@ public IObservable Process(IObservable source) } /// - /// Generates an observable sequence of tensors filled with random integers for each element of the input sequence. + /// Creates tensors filled with random integers for each element of the input sequence. /// /// /// @@ -89,12 +87,12 @@ public IObservable Process(IObservable source) } /// - /// Generates an observable sequence of tensors filled with random integers for each element of the input sequence. + /// Creates tensors filled with random integers for each element of the input sequence. /// /// /// public IObservable Process(IObservable source) { - return source.Select(value => randint(MinValue, MaxValue, Size, dtype: Type, device: Device, generator: Generator)); + return source.Select(_ => randint(MinValue, MaxValue, Size, dtype: Type, device: Device, generator: Generator)); } } From 53edb46841ea96e70fe372b7f19bfddaed2cf993 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 12 Dec 2025 14:27:10 +0000 Subject: [PATCH 15/16] Updated XML documentation --- src/Bonsai.ML.Torch/Distributions/Bernoulli.cs | 6 +++--- src/Bonsai.ML.Torch/Distributions/Beta.cs | 6 +++--- src/Bonsai.ML.Torch/Distributions/Binomial.cs | 8 ++++---- src/Bonsai.ML.Torch/Distributions/Categorical.cs | 10 +++++----- src/Bonsai.ML.Torch/Distributions/Cauchy.cs | 6 +++--- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs b/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs index b81dec04..e048aeea 100644 --- a/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs +++ b/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs @@ -16,11 +16,11 @@ namespace Bonsai.ML.Torch.Distributions; public class Bernoulli : IScalarTypeProvider { /// - /// Event probabilities p in [0, 1]. Can be a scalar or a tensor; the shape determines the batch/event shape. + /// The event probabilities in [0, 1]. Can be a scalar or a tensor; the shape determines the batch/event shape. /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - [Description("Event probabilities p in [0, 1]. Can be a scalar or a tensor; shape sets the batch/event shape of the distribution.")] + [Description("The event probabilities in [0, 1]. Can be a scalar or a tensor; shape sets the batch/event shape of the distribution.")] public Tensor Probabilities { get; set; } = null; /// @@ -71,4 +71,4 @@ public string ProbabilitiesXml { return source.Select(_ => distributions.Bernoulli(Probabilities)); } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/Distributions/Beta.cs b/src/Bonsai.ML.Torch/Distributions/Beta.cs index 9abfcdb8..00e2e99f 100644 --- a/src/Bonsai.ML.Torch/Distributions/Beta.cs +++ b/src/Bonsai.ML.Torch/Distributions/Beta.cs @@ -8,7 +8,7 @@ namespace Bonsai.ML.Torch.Distributions; /// -/// Creates a Beta probability distribution parameterized by two concentration parameters (alpha, beta). +/// Represents an operator that creates a beta probability distribution parameterized by two concentration parameters (alpha, beta). /// [Combinator] [Description("Creates a Beta distribution with concentration parameters (alpha, beta).")] @@ -17,7 +17,7 @@ namespace Bonsai.ML.Torch.Distributions; public class Beta : IScalarTypeProvider { /// - /// Concentration parameter alpha (> 0). Can be a scalar or tensor; the shape determines the batch/event shape. + /// The first concentration parameter alpha (> 0). Can be a scalar or tensor; the shape determines the batch/event shape. /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] @@ -92,4 +92,4 @@ public string Concentration0Xml { return source.Select(_ => distributions.Beta(Concentration1, Concentration0)); } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/Distributions/Binomial.cs b/src/Bonsai.ML.Torch/Distributions/Binomial.cs index 3a23ad5b..f6456762 100644 --- a/src/Bonsai.ML.Torch/Distributions/Binomial.cs +++ b/src/Bonsai.ML.Torch/Distributions/Binomial.cs @@ -16,11 +16,11 @@ namespace Bonsai.ML.Torch.Distributions; public class Binomial : IScalarTypeProvider { /// - /// Number of trials (non-negative). Can be a scalar or tensor. If tensor, values should be non-negative integers. + /// The number of trials (non-negative). Can be a scalar or tensor. If it is a tensor, values should be non-negative integers. /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - [Description("Number of trials (non-negative). Can be a scalar or tensor.")] + [Description("The number of trials (non-negative). Can be a scalar or tensor. If it is a tensor, values should be non-negative integers.")] public Tensor Count { get; set; } = null; /// @@ -40,7 +40,7 @@ public string CountXml /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - [Description("Probability of success in [0, 1]. Can be a scalar or tensor; shape should broadcastable with Count.")] + [Description("Probability of success in [0, 1]. Can be a scalar or tensor; the shape should be broadcastable to Count.")] public Tensor Probabilities { get; set; } = null; /// @@ -91,4 +91,4 @@ public string ProbabilitiesXml { return source.Select(_ => distributions.Binomial(Count, Probabilities)); } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/Distributions/Categorical.cs b/src/Bonsai.ML.Torch/Distributions/Categorical.cs index f25875d8..167b8fec 100644 --- a/src/Bonsai.ML.Torch/Distributions/Categorical.cs +++ b/src/Bonsai.ML.Torch/Distributions/Categorical.cs @@ -8,20 +8,20 @@ namespace Bonsai.ML.Torch.Distributions; /// -/// Creates a Categorical (discrete) distribution over classes given event probabilities. +/// Creates a categorical (discrete) distribution over classes given event probabilities. /// [Combinator] -[Description("Creates a Categorical distribution with class probabilities and emits a TorchSharp distribution module.")] +[Description("Creates a categorical (discrete) distribution over classes given event probabilities.")] [WorkflowElementCategory(ElementCategory.Source)] [TypeConverter(typeof(TensorOperatorConverter))] public class Categorical : IScalarTypeProvider { /// - /// Class probabilities along the last dimension. Values must be non-negative and typically sum to 1 per row. Can be a 1D vector or higher-rank tensor for batched distributions. + /// The class probabilities. Values must be non-negative and typically sum to 1 per row. Can be a 1D vector or higher-rank tensor for batched distributions. /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - [Description("Class probabilities along the last dimension (non-negative, typically sum to 1 per row). Supports batching.")] + [Description("The class probabilities. Values must be non-negative and typically sum to 1 per row. Can be a 1D vector or higher-rank tensor for batched distributions.")] public Tensor Probabilities { get; set; } = null; /// @@ -72,4 +72,4 @@ public string ProbabilitiesXml { return source.Select(_ => distributions.Categorical(Probabilities)); } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/Distributions/Cauchy.cs b/src/Bonsai.ML.Torch/Distributions/Cauchy.cs index 47426854..76074042 100644 --- a/src/Bonsai.ML.Torch/Distributions/Cauchy.cs +++ b/src/Bonsai.ML.Torch/Distributions/Cauchy.cs @@ -17,11 +17,11 @@ namespace Bonsai.ML.Torch.Distributions; public class Cauchy : IScalarTypeProvider { /// - /// Location parameter. Can be a scalar or tensor; shape determines the batch/event shape. + /// The location parameter. Can be a scalar or tensor; shape determines the batch/event shape. /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - [Description("Location parameter. Can be a scalar or tensor; supports batching.")] + [Description("The location parameter. Can be a scalar or tensor; supports batching.")] public Tensor Locations { get; set; } = null; /// @@ -93,4 +93,4 @@ public string ScalesXml { return source.Select(_ => distributions.Cauchy(Locations, Scales)); } -} \ No newline at end of file +} From 3fcc379b9310c97dac24ad656fc128851c6b012f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 11 Mar 2026 14:33:59 +0000 Subject: [PATCH 16/16] Renamed reparameterized sample to rsample for better alignment with torch --- .../Distributions/{ReparametrizedSample.cs => RSample.cs} | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) rename src/Bonsai.ML.Torch/Distributions/{ReparametrizedSample.cs => RSample.cs} (95%) diff --git a/src/Bonsai.ML.Torch/Distributions/ReparametrizedSample.cs b/src/Bonsai.ML.Torch/Distributions/RSample.cs similarity index 95% rename from src/Bonsai.ML.Torch/Distributions/ReparametrizedSample.cs rename to src/Bonsai.ML.Torch/Distributions/RSample.cs index faff9cbb..1e067164 100644 --- a/src/Bonsai.ML.Torch/Distributions/ReparametrizedSample.cs +++ b/src/Bonsai.ML.Torch/Distributions/RSample.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Reactive.Linq; using System.ComponentModel; using static TorchSharp.torch; @@ -12,7 +12,7 @@ namespace Bonsai.ML.Torch.Distributions; [Combinator] [Description("Generates reparameterized samples from the input distribution.")] [WorkflowElementCategory(ElementCategory.Transform)] -public class ReparametrizedSample +public class RSample { /// /// The shape of the samples to generate. @@ -29,4 +29,4 @@ public IObservable Process(IObservable source) { return source.Select(distribution => distribution.rsample(SampleShape)); } -} \ No newline at end of file +}