Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
f9ea88e
Added Cdf and InverseCdf classes for probability distributions
ncguilbeault Nov 1, 2025
08f73b7
Added a `Generator` class to create a reproducible RNG with specified…
ncguilbeault Jul 31, 2025
09b3488
Added a collection of probability distributions
ncguilbeault Jul 31, 2025
b966165
Added a class for seeding a torch compatible RNG
ncguilbeault Aug 1, 2025
f174c2c
Began adding a number of useful distributions and functionality to th…
ncguilbeault Aug 6, 2025
a16e6ed
Created a dedicated namespace for working with random numbers and add…
ncguilbeault Oct 8, 2025
d106ff6
Added overload to generator to support emitting generators based on e…
ncguilbeault Oct 8, 2025
4227dfc
Add support for using generator from input sequence and ignore in XML…
ncguilbeault Oct 8, 2025
832fe72
Updated `Distributions` with doc strings and container base
ncguilbeault Nov 1, 2025
50137cc
Added doc strings and made to inherit from tensor container base
ncguilbeault Nov 1, 2025
5d885b6
Refactored `Generator` documentation to suggest creating an individua…
ncguilbeault Nov 1, 2025
62dd824
Renamed distributions for better consistency
ncguilbeault Nov 3, 2025
0f223af
Refactored distribution classes with TensorOperatorConverter
ncguilbeault Dec 10, 2025
8c09274
Refactored classes in the random namespace
ncguilbeault Dec 10, 2025
53edb46
Updated XML documentation
ncguilbeault Dec 12, 2025
3fcc379
Renamed reparameterized sample to rsample for better alignment with t…
ncguilbeault Mar 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions src/Bonsai.ML.Torch/Distributions/Bernoulli.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
using System;
using System.ComponentModel;
using System.Reactive.Linq;
using System.Xml.Serialization;
using static TorchSharp.torch;

namespace Bonsai.ML.Torch.Distributions;

/// <summary>
/// Represents an operator that creates a Bernoulli probability distribution parameterized by event probabilities.
/// </summary>
[Combinator]
[Description("Creates a Bernoulli distribution with event probabilities and emits a TorchSharp distribution module.")]
[WorkflowElementCategory(ElementCategory.Source)]
[TypeConverter(typeof(TensorOperatorConverter))]
public class Bernoulli : IScalarTypeProvider
{
/// <summary>
/// The event probabilities in [0, 1]. Can be a scalar or a tensor; the shape determines the batch/event shape.
/// </summary>
[XmlIgnore]
[TypeConverter(typeof(TensorConverter))]
[Description("The event probabilities in [0, 1]. Can be a scalar or a tensor; shape sets the batch/event shape of the distribution.")]
public Tensor Probabilities { get; set; } = null;

/// <summary>
/// The values of the probabilities in XML string format.
/// </summary>
[Browsable(false)]
[XmlElement(nameof(Probabilities))]
[EditorBrowsable(EditorBrowsableState.Never)]
public string ProbabilitiesXml
{
get => TensorConverter.ConvertToString(Probabilities, Type);
set => Probabilities = TensorConverter.ConvertFromString(value, Type);
}

/// <summary>
/// Gets or sets the data type of the tensor elements.
/// </summary>
[Description("The data type of the tensor elements.")]
[TypeConverter(typeof(ScalarTypeConverter))]
public ScalarType Type { get; set; } = ScalarType.Float32;

/// <summary>
/// Creates a Bernoulli distribution.
/// </summary>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Bernoulli> Process()
{
return Observable.Return(distributions.Bernoulli(Probabilities));
}

/// <summary>
/// Creates a Bernoulli distribution using the incoming RNG Generator.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Bernoulli> Process(IObservable<Generator> source)
{
return source.Select(generator => distributions.Bernoulli(Probabilities, generator: generator));
}

/// <summary>
/// For each element of the source stream, emits a Bernoulli distribution.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Bernoulli> Process<T>(IObservable<T> source)
{
return source.Select(_ => distributions.Bernoulli(Probabilities));
}
}
95 changes: 95 additions & 0 deletions src/Bonsai.ML.Torch/Distributions/Beta.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
using System;
using System.ComponentModel;
using System.Reactive.Linq;
using System.Xml.Serialization;
using TorchSharp;
using static TorchSharp.torch;

namespace Bonsai.ML.Torch.Distributions;

/// <summary>
/// Represents an operator that creates a beta probability distribution parameterized by two concentration parameters (alpha, beta).
/// </summary>
[Combinator]
[Description("Creates a Beta distribution with concentration parameters (alpha, beta).")]
[WorkflowElementCategory(ElementCategory.Source)]
[TypeConverter(typeof(TensorOperatorConverter))]
public class Beta : IScalarTypeProvider
{
/// <summary>
/// The first concentration parameter alpha (> 0). Can be a scalar or tensor; the shape determines the batch/event shape.
/// </summary>
[XmlIgnore]
[TypeConverter(typeof(TensorConverter))]
[Description("Concentration alpha (> 0). Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")]
public Tensor Concentration1 { get; set; } = null;

/// <summary>
/// The values of concentration 1 in XML string format.
/// </summary>
[Browsable(false)]
[XmlElement(nameof(Concentration1))]
[EditorBrowsable(EditorBrowsableState.Never)]
public string Concentration1Xml
{
get => TensorConverter.ConvertToString(Concentration1, Type);
set => Concentration1 = TensorConverter.ConvertFromString(value, Type);
}

/// <summary>
/// Concentration parameter beta (> 0). Can be a scalar or tensor; the shape determines the batch/event shape.
/// </summary>
[XmlIgnore]
[TypeConverter(typeof(TensorConverter))]
[Description("Concentration beta (> 0). Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")]
public Tensor Concentration0 { get; set; } = null;

/// <summary>
/// The values of concentration 0 in XML string format.
/// </summary>
[Browsable(false)]
[XmlElement(nameof(Concentration0))]
[EditorBrowsable(EditorBrowsableState.Never)]
public string Concentration0Xml
{
get => TensorConverter.ConvertToString(Concentration0, Type);
set => Concentration0 = TensorConverter.ConvertFromString(value, Type);
}

/// <summary>
/// Gets or sets the data type of the tensor elements.
/// </summary>
[Description("The data type of the tensor elements.")]
[TypeConverter(typeof(ScalarTypeConverter))]
public ScalarType Type { get; set; } = ScalarType.Float32;

/// <summary>
/// Creates a Beta distribution.
/// </summary>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Beta> Process()
{
return Observable.Return(distributions.Beta(Concentration1, Concentration0));
}

/// <summary>
/// Creates a Beta distribution using the incoming RNG generator.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Beta> Process(IObservable<Generator> source)
{
return source.Select(generator => distributions.Beta(Concentration1, Concentration0, generator: generator));
}

/// <summary>
/// For each element of the source stream, emits a Beta distribution.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Beta> Process<T>(IObservable<T> source)
{
return source.Select(_ => distributions.Beta(Concentration1, Concentration0));
}
}
94 changes: 94 additions & 0 deletions src/Bonsai.ML.Torch/Distributions/Binomial.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
using System;
using System.ComponentModel;
using System.Reactive.Linq;
using System.Xml.Serialization;
using static TorchSharp.torch;

namespace Bonsai.ML.Torch.Distributions;

/// <summary>
/// Creates a Binomial probability distribution with a given number of trials and success probability.
/// </summary>
[Combinator]
[Description("Creates a Binomial distribution with count (number of trials) and probability of success.")]
[WorkflowElementCategory(ElementCategory.Source)]
[TypeConverter(typeof(TensorOperatorConverter))]
public class Binomial : IScalarTypeProvider
{
/// <summary>
/// The number of trials (non-negative). Can be a scalar or tensor. If it is a tensor, values should be non-negative integers.
/// </summary>
[XmlIgnore]
[TypeConverter(typeof(TensorConverter))]
[Description("The number of trials (non-negative). Can be a scalar or tensor. If it is a tensor, values should be non-negative integers.")]
public Tensor Count { get; set; } = null;

/// <summary>
/// The values of count in XML string format.
/// </summary>
[Browsable(false)]
[XmlElement(nameof(Count))]
[EditorBrowsable(EditorBrowsableState.Never)]
public string CountXml
{
get => TensorConverter.ConvertToString(Count, Type);
set => Count = TensorConverter.ConvertFromString(value, Type);
}

/// <summary>
/// Probability of success p in [0, 1]. Can be a scalar or tensor; the shape should be broadcastable to <see cref="Count"/>.
/// </summary>
[XmlIgnore]
[TypeConverter(typeof(TensorConverter))]
[Description("Probability of success in [0, 1]. Can be a scalar or tensor; the shape should be broadcastable to Count.")]
public Tensor Probabilities { get; set; } = null;

/// <summary>
/// The values of probabilities in XML string format.
/// </summary>
[Browsable(false)]
[XmlElement(nameof(Probabilities))]
[EditorBrowsable(EditorBrowsableState.Never)]
public string ProbabilitiesXml
{
get => TensorConverter.ConvertToString(Probabilities, Type);
set => Probabilities = TensorConverter.ConvertFromString(value, Type);
}

/// <summary>
/// Gets or sets the data type of the tensor elements.
/// </summary>
[Description("The data type of the tensor elements.")]
[TypeConverter(typeof(ScalarTypeConverter))]
public ScalarType Type { get; set; } = ScalarType.Float32;

/// <summary>
/// Creates a Binomial distribution.
/// </summary>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Binomial> Process()
{
return Observable.Return(distributions.Binomial(Count, Probabilities));
}

/// <summary>
/// Creates a Binomial distribution for each incoming RNG generator.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Binomial> Process(IObservable<Generator> source)
{
return source.Select(generator => distributions.Binomial(Count, Probabilities, generator: generator));
}

/// <summary>
/// For each element of the source stream, emits a Binomial distribution.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Binomial> Process<T>(IObservable<T> source)
{
return source.Select(_ => distributions.Binomial(Count, Probabilities));
}
}
75 changes: 75 additions & 0 deletions src/Bonsai.ML.Torch/Distributions/Categorical.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
using System;
using System.ComponentModel;
using System.Reactive.Linq;
using System.Xml.Serialization;
using TorchSharp;
using static TorchSharp.torch;

namespace Bonsai.ML.Torch.Distributions;

/// <summary>
/// Creates a categorical (discrete) distribution over classes given event probabilities.
/// </summary>
[Combinator]
[Description("Creates a categorical (discrete) distribution over classes given event probabilities.")]
[WorkflowElementCategory(ElementCategory.Source)]
[TypeConverter(typeof(TensorOperatorConverter))]
public class Categorical : IScalarTypeProvider
{
/// <summary>
/// The class probabilities. Values must be non-negative and typically sum to 1 per row. Can be a 1D vector or higher-rank tensor for batched distributions.
/// </summary>
[XmlIgnore]
[TypeConverter(typeof(TensorConverter))]
[Description("The class probabilities. Values must be non-negative and typically sum to 1 per row. Can be a 1D vector or higher-rank tensor for batched distributions.")]
public Tensor Probabilities { get; set; } = null;

/// <summary>
/// The values of probabilities in XML string format.
/// </summary>
[Browsable(false)]
[XmlElement(nameof(Probabilities))]
[EditorBrowsable(EditorBrowsableState.Never)]
public string ProbabilitiesXml
{
get => TensorConverter.ConvertToString(Probabilities, Type);
set => Probabilities = TensorConverter.ConvertFromString(value, Type);
}

/// <summary>
/// Gets or sets the data type of the tensor elements.
/// </summary>
[Description("The data type of the tensor elements.")]
[TypeConverter(typeof(ScalarTypeConverter))]
public ScalarType Type { get; set; } = ScalarType.Float32;

/// <summary>
/// Creates a categorical distribution.
/// </summary>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Categorical> Process()
{
return Observable.Return(distributions.Categorical(Probabilities));
}

/// <summary>
/// Creates a categorical distribution for each incoming RNG generator.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Categorical> Process(IObservable<Generator> source)
{
return source.Select(generator => distributions.Categorical(Probabilities, generator: generator));
}

/// <summary>
/// For each element of the source stream, emits a categorical distribution.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<TorchSharp.Modules.Categorical> Process<T>(IObservable<T> source)
{
return source.Select(_ => distributions.Categorical(Probabilities));
}
}
Loading
Loading