diff --git a/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs b/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs new file mode 100644 index 00000000..e048aeea --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Bernoulli.cs @@ -0,0 +1,74 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Distributions; + +/// +/// Represents an operator that creates a Bernoulli probability distribution parameterized by event probabilities. +/// +[Combinator] +[Description("Creates a Bernoulli distribution with event probabilities and emits a TorchSharp distribution module.")] +[WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] +public class Bernoulli : IScalarTypeProvider +{ + /// + /// The event probabilities in [0, 1]. Can be a scalar or a tensor; the shape determines the batch/event shape. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + [Description("The event probabilities in [0, 1]. Can be a scalar or a tensor; shape sets the batch/event shape of the distribution.")] + public Tensor Probabilities { get; set; } = null; + + /// + /// The values of the probabilities in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Probabilities))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ProbabilitiesXml + { + get => TensorConverter.ConvertToString(Probabilities, Type); + set => Probabilities = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// Gets or sets the data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Creates a Bernoulli distribution. + /// + /// + public IObservable Process() + { + return Observable.Return(distributions.Bernoulli(Probabilities)); + } + + /// + /// Creates a Bernoulli distribution using the incoming RNG Generator. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(generator => distributions.Bernoulli(Probabilities, generator: generator)); + } + + /// + /// For each element of the source stream, emits a Bernoulli distribution. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.Bernoulli(Probabilities)); + } +} diff --git a/src/Bonsai.ML.Torch/Distributions/Beta.cs b/src/Bonsai.ML.Torch/Distributions/Beta.cs new file mode 100644 index 00000000..00e2e99f --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Beta.cs @@ -0,0 +1,95 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Distributions; + +/// +/// Represents an operator that creates a beta probability distribution parameterized by two concentration parameters (alpha, beta). +/// +[Combinator] +[Description("Creates a Beta distribution with concentration parameters (alpha, beta).")] +[WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] +public class Beta : IScalarTypeProvider +{ + /// + /// The first concentration parameter alpha (> 0). Can be a scalar or tensor; the shape determines the batch/event shape. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + [Description("Concentration alpha (> 0). Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] + public Tensor Concentration1 { get; set; } = null; + + /// + /// The values of concentration 1 in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Concentration1))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string Concentration1Xml + { + get => TensorConverter.ConvertToString(Concentration1, Type); + set => Concentration1 = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// Concentration parameter beta (> 0). Can be a scalar or tensor; the shape determines the batch/event shape. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + [Description("Concentration beta (> 0). Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] + public Tensor Concentration0 { get; set; } = null; + + /// + /// The values of concentration 0 in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Concentration0))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string Concentration0Xml + { + get => TensorConverter.ConvertToString(Concentration0, Type); + set => Concentration0 = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// Gets or sets the data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Creates a Beta distribution. + /// + /// + public IObservable Process() + { + return Observable.Return(distributions.Beta(Concentration1, Concentration0)); + } + + /// + /// Creates a Beta distribution using the incoming RNG generator. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(generator => distributions.Beta(Concentration1, Concentration0, generator: generator)); + } + + /// + /// For each element of the source stream, emits a Beta distribution. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.Beta(Concentration1, Concentration0)); + } +} diff --git a/src/Bonsai.ML.Torch/Distributions/Binomial.cs b/src/Bonsai.ML.Torch/Distributions/Binomial.cs new file mode 100644 index 00000000..f6456762 --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Binomial.cs @@ -0,0 +1,94 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Distributions; + +/// +/// Creates a Binomial probability distribution with a given number of trials and success probability. +/// +[Combinator] +[Description("Creates a Binomial distribution with count (number of trials) and probability of success.")] +[WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] +public class Binomial : IScalarTypeProvider +{ + /// + /// The number of trials (non-negative). Can be a scalar or tensor. If it is a tensor, values should be non-negative integers. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + [Description("The number of trials (non-negative). Can be a scalar or tensor. If it is a tensor, values should be non-negative integers.")] + public Tensor Count { get; set; } = null; + + /// + /// The values of count in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Count))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string CountXml + { + get => TensorConverter.ConvertToString(Count, Type); + set => Count = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// Probability of success p in [0, 1]. Can be a scalar or tensor; the shape should be broadcastable to . + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + [Description("Probability of success in [0, 1]. Can be a scalar or tensor; the shape should be broadcastable to Count.")] + public Tensor Probabilities { get; set; } = null; + + /// + /// The values of probabilities in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Probabilities))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ProbabilitiesXml + { + get => TensorConverter.ConvertToString(Probabilities, Type); + set => Probabilities = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// Gets or sets the data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Creates a Binomial distribution. + /// + /// + public IObservable Process() + { + return Observable.Return(distributions.Binomial(Count, Probabilities)); + } + + /// + /// Creates a Binomial distribution for each incoming RNG generator. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(generator => distributions.Binomial(Count, Probabilities, generator: generator)); + } + + /// + /// For each element of the source stream, emits a Binomial distribution. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.Binomial(Count, Probabilities)); + } +} diff --git a/src/Bonsai.ML.Torch/Distributions/Categorical.cs b/src/Bonsai.ML.Torch/Distributions/Categorical.cs new file mode 100644 index 00000000..167b8fec --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Categorical.cs @@ -0,0 +1,75 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Distributions; + +/// +/// Creates a categorical (discrete) distribution over classes given event probabilities. +/// +[Combinator] +[Description("Creates a categorical (discrete) distribution over classes given event probabilities.")] +[WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] +public class Categorical : IScalarTypeProvider +{ + /// + /// The class probabilities. Values must be non-negative and typically sum to 1 per row. Can be a 1D vector or higher-rank tensor for batched distributions. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + [Description("The class probabilities. Values must be non-negative and typically sum to 1 per row. Can be a 1D vector or higher-rank tensor for batched distributions.")] + public Tensor Probabilities { get; set; } = null; + + /// + /// The values of probabilities in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Probabilities))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ProbabilitiesXml + { + get => TensorConverter.ConvertToString(Probabilities, Type); + set => Probabilities = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// Gets or sets the data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Creates a categorical distribution. + /// + /// + public IObservable Process() + { + return Observable.Return(distributions.Categorical(Probabilities)); + } + + /// + /// Creates a categorical distribution for each incoming RNG generator. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(generator => distributions.Categorical(Probabilities, generator: generator)); + } + + /// + /// For each element of the source stream, emits a categorical distribution. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.Categorical(Probabilities)); + } +} diff --git a/src/Bonsai.ML.Torch/Distributions/Cauchy.cs b/src/Bonsai.ML.Torch/Distributions/Cauchy.cs new file mode 100644 index 00000000..76074042 --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Cauchy.cs @@ -0,0 +1,96 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Distributions; + +/// +/// Represents an operator that creates a Cauchy (Lorentz) distribution parameterized by location and scale. +/// +[Combinator] +[Description("Creates a Cauchy distribution with the specified location and scale parameters.")] +[WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] +public class Cauchy : IScalarTypeProvider +{ + /// + /// The location parameter. Can be a scalar or tensor; shape determines the batch/event shape. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + [Description("The location parameter. Can be a scalar or tensor; supports batching.")] + public Tensor Locations { get; set; } = null; + + /// + /// The values of the locations in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Locations))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string LocationsXml + { + get => TensorConverter.ConvertToString(Locations, Type); + set => Locations = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// Scale parameter (> 0). Can be a scalar or tensor; must be broadcastable with . + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + [Description("Scale parameter (> 0). Can be a scalar or tensor; must broadcast with Locations.")] + public Tensor Scales { get; set; } = null; + + /// + /// The values of the scales in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Scales))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ScalesXml + { + get => TensorConverter.ConvertToString(Scales, Type); + set => Scales = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// Gets or sets the data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Creates a distribution using the configured parameters. + /// + /// An observable that emits the constructed Cauchy distribution. + public IObservable Process() + { + return Observable.Return(distributions.Cauchy(Locations, Scales)); + } + + /// + /// Creates a distribution for each incoming RNG . + /// + /// Observable sequence of random generators to use. + /// An observable sequence of Cauchy distributions. + public IObservable Process(IObservable source) + { + return source.Select(generator => distributions.Cauchy(Locations, Scales, generator: generator)); + } + + /// + /// For each element of the source stream, emits a distribution + /// constructed from the configured parameters. + /// + /// The type of the triggering source sequence. + /// Trigger sequence; each element causes a new distribution to be emitted. + /// An observable sequence of Cauchy distributions. + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.Cauchy(Locations, Scales)); + } +} diff --git a/src/Bonsai.ML.Torch/Distributions/CumulativeDistributionFunction.cs b/src/Bonsai.ML.Torch/Distributions/CumulativeDistributionFunction.cs new file mode 100644 index 00000000..2831dfab --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/CumulativeDistributionFunction.cs @@ -0,0 +1,54 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; +using static TorchSharp.torch.distributions; + +namespace Bonsai.ML.Torch.Distributions; + +/// +/// Represents an operator that creates a cumulative distribution function (CDF) from the given distribution. +/// +[Combinator] +[Description("Creates a cumulative distribution function (CDF) from the given distribution.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class CumulativeDistributionFunction +{ + /// + /// The input distribution. + /// + [XmlIgnore] + [Description("The input distribution.")] + public Distribution Distribution { get; set; } + + /// + /// Processes the input values to compute the CDF using the specified distribution. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Distribution.cdf); + } + + /// + /// Processes the input tuples of distribution and values to compute the CDF. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select((input) => input.Item1.cdf(input.Item2)); + } + + /// + /// Processes the input tuples of values and distribution to compute the CDF. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select((input) => input.Item2.cdf(input.Item1)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Dirichlet.cs b/src/Bonsai.ML.Torch/Distributions/Dirichlet.cs new file mode 100644 index 00000000..7ad3e131 --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Dirichlet.cs @@ -0,0 +1,74 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Distributions; + +/// +/// Represents an operator that creates a Dirichlet probability distribution parameterized by concentration parameters. +/// +[Combinator] +[Description("Creates a Dirichlet distribution with concentration parameters.")] +[WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] +public class Dirichlet : IScalarTypeProvider +{ + /// + /// Concentration parameters (> 0). Can be a scalar or tensor; the shape determines the batch/event shape. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + [Description("Concentration parameters (> 0). Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] + public Tensor Concentration { get; set; } = null; + + /// + /// The values of the concentration in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Concentration))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ConcentrationXml + { + get => TensorConverter.ConvertToString(Concentration, Type); + set => Concentration = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// Gets or sets the data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Creates a distribution using the configured parameters. + /// + /// An observable that emits the constructed Dirichlet distribution. + public IObservable Process() + { + return Observable.Return(distributions.Dirichlet(Concentration)); + } + + /// + /// Creates a distribution for each incoming RNG . + /// + /// Observable sequence of random generators to use. + /// An observable sequence of Dirichlet distributions. + public IObservable Process(IObservable source) + { + return source.Select(generator => distributions.Dirichlet(Concentration, generator: generator)); + } + + /// + /// For each element of the source stream, emits a distribution constructed from the configured parameters. + /// + /// The type of the triggering source sequence. + /// Trigger sequence; each element causes a new distribution to be emitted. + /// An observable sequence of Dirichlet distributions. + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.Dirichlet(Concentration)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Exponential.cs b/src/Bonsai.ML.Torch/Distributions/Exponential.cs new file mode 100644 index 00000000..09f6f3fd --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Exponential.cs @@ -0,0 +1,75 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Distributions; + +/// +/// Represents an operator that creates an exponential probability distribution parameterized by rate. +/// +[Combinator] +[Description("Creates an exponential distribution with the specified rate parameter.")] +[WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] +public class Exponential : IScalarTypeProvider +{ + /// + /// Rate parameter (> 0). Can be a scalar or tensor; the shape determines the batch/event shape. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + [Description("Rate parameter (> 0). Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] + public Tensor Rate { get; set; } = null; + + /// + /// The values of the rates in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Rate))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string RateXml + { + get => TensorConverter.ConvertToString(Rate, Type); + set => Rate = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// Gets or sets the data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Creates a distribution using the configured parameters. + /// + /// An observable that emits the constructed Exponential distribution. + public IObservable Process() + { + return Observable.Return(distributions.Exponential(Rate)); + } + + /// + /// Creates a distribution for each incoming RNG . + /// + /// Observable sequence of random generators to use. + /// An observable sequence of Exponential distributions. + public IObservable Process(IObservable source) + { + return source.Select(generator => distributions.Exponential(Rate, generator: generator)); + } + + /// + /// For each element of the source stream, emits a distribution constructed from the configured parameters. + /// + /// The type of the triggering source sequence. + /// Trigger sequence; each element causes a new distribution to be emitted. + /// An observable sequence of Exponential distributions. + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.Exponential(Rate)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Gamma.cs b/src/Bonsai.ML.Torch/Distributions/Gamma.cs new file mode 100644 index 00000000..f3142eeb --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Gamma.cs @@ -0,0 +1,96 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Distributions; + +/// +/// Represents an operator that creates a gamma probability distribution parameterized by concentration and rate. +/// +[Combinator] +[Description("Creates a gamma distribution with concentration and rate parameters.")] +[WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] +public class Gamma : IScalarTypeProvider +{ + /// + /// Concentration parameter (> 0). Can be a scalar or tensor; the shape determines the batch/event shape. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + [Description("Concentration parameter (> 0). Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] + public Tensor Concentration { get; set; } = null; + + /// + /// The values of the concentration in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Concentration))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ConcentrationXml + { + get => TensorConverter.ConvertToString(Concentration, Type); + set => Concentration = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// Rate parameter (> 0). Can be a scalar or tensor; must be broadcastable with . + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + [Description("Rate parameter (> 0). Can be a scalar or tensor; must broadcast with Concentration.")] + public Tensor Rate { get; set; } = null; + + /// + /// The values of the rate in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Rate))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string RateXml + { + get => TensorConverter.ConvertToString(Rate, Type); + set => Rate = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// Gets or sets the data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Creates a distribution using the configured parameters. + /// + /// An observable that emits the constructed Gamma distribution. + public IObservable Process() + { + return Observable.Return(distributions.Gamma(Concentration, Rate)); + } + + /// + /// Creates a distribution for each incoming RNG . + /// + /// Observable sequence of random generators to use. + /// An observable sequence of Gamma distributions. + public IObservable Process(IObservable source) + { + return source.Select(generator => distributions.Gamma(Concentration, Rate, generator: generator)); + } + + /// + /// For each element of the source stream, emits a distribution + /// constructed from the configured parameters. + /// + /// The type of the triggering source sequence. + /// Trigger sequence; each element causes a new distribution to be emitted. + /// An observable sequence of Gamma distributions. + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.Gamma(Concentration, Rate)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Geometric.cs b/src/Bonsai.ML.Torch/Distributions/Geometric.cs new file mode 100644 index 00000000..9a910e90 --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Geometric.cs @@ -0,0 +1,76 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Distributions; + +/// +/// Represents an operator that creates a geometric probability distribution parameterized by success probability. +/// +[Combinator] +[Description("Creates a geometric distribution with the specified success probability.")] +[WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] +public class Geometric : IScalarTypeProvider +{ + /// + /// Success probability in [0, 1]. Can be a scalar or tensor; the shape determines the batch/event shape. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + [Description("Success probability in [0, 1]. Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] + public Tensor Probabilities { get; set; } = null; + + /// + /// The values of the probabilities in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Probabilities))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ProbabilitiesXml + { + get => TensorConverter.ConvertToString(Probabilities, Type); + set => Probabilities = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// Gets or sets the data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Creates a distribution using the configured parameters. + /// + /// An observable that emits the constructed Geometric distribution. + public IObservable Process() + { + return Observable.Return(distributions.Geometric(Probabilities)); + } + + /// + /// Creates a distribution for each incoming RNG . + /// + /// Observable sequence of random generators to use. + /// An observable sequence of Geometric distributions. + public IObservable Process(IObservable source) + { + return source.Select(generator => distributions.Geometric(Probabilities, generator: generator)); + } + + /// + /// For each element of the source stream, emits a distribution + /// constructed from the configured parameters. + /// + /// The type of the triggering source sequence. + /// Trigger sequence; each element causes a new distribution to be emitted. + /// An observable sequence of Geometric distributions. + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.Geometric(Probabilities)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/InverseCumulativeDistributionFunction.cs b/src/Bonsai.ML.Torch/Distributions/InverseCumulativeDistributionFunction.cs new file mode 100644 index 00000000..259336ed --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/InverseCumulativeDistributionFunction.cs @@ -0,0 +1,53 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; +using static TorchSharp.torch.distributions; + +namespace Bonsai.ML.Torch.Distributions; + +/// +/// Creates an inverse cumulative distribution function (inverse CDF) from the input distribution. +/// +[Combinator] +[Description("Creates an inverse cumulative distribution function (inverse CDF) from the input distribution.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class InverseCumulativeDistributionFunction +{ + /// + /// The input distribution. + /// + [XmlIgnore] + public Distribution Distribution { get; set; } + + /// + /// Processes the input values to compute the inverse CDF using the specified distribution. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Distribution.icdf); + } + + /// + /// Processes the input tuples of distribution and values to compute the inverse CDF. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select((input) => input.Item1.icdf(input.Item2)); + } + + /// + /// Processes the input tuples of values and distribution to compute the inverse CDF. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select((input) => input.Item2.icdf(input.Item1)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/LogProbability.cs b/src/Bonsai.ML.Torch/Distributions/LogProbability.cs new file mode 100644 index 00000000..272df0d0 --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/LogProbability.cs @@ -0,0 +1,53 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; +using static TorchSharp.torch.distributions; + +namespace Bonsai.ML.Torch.Distributions; + +/// +/// Represents an operator that computes the log probability of the input values under the specified distribution. +/// +[Combinator] +[Description("Computes the log probability of the input values under the specified distribution.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class LogProbability +{ + /// + /// The input distribution. + /// + [XmlIgnore] + public Distribution Distribution { get; set; } + + /// + /// Processes the input values to compute the log probability using the specified distribution. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Distribution.log_prob); + } + + /// + /// Processes the input tuples of distribution and values to compute the log probability. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select((input) => input.Item1.log_prob(input.Item2)); + } + + /// + /// Processes the input tuples of values and distribution to compute the log probability. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select((input) => input.Item2.log_prob(input.Item1)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/MultivariateNormal.cs b/src/Bonsai.ML.Torch/Distributions/MultivariateNormal.cs new file mode 100644 index 00000000..8bc97f45 --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/MultivariateNormal.cs @@ -0,0 +1,94 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Distributions; + +/// +/// Represents an operator that creates a multivariate normal (Gaussian) distribution parameterized by mean vector and covariance matrix. +/// +[Combinator] +[Description("Creates a multivariate normal (Gaussian) distribution with mean vector and covariance matrix.")] +[WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] +public class MultivariateNormal : IScalarTypeProvider +{ + /// + /// Mean vector of the distribution. Can be a 1D vector or higher-rank tensor for batched distributions. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + [Description("Mean vector of the distribution. Can be a 1D vector or higher-rank tensor for batched distributions.")] + public Tensor Mean { get; set; } = null; + + /// + /// The values of the means in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Mean))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string MeanXml + { + get => TensorConverter.ConvertToString(Mean, Type); + set => Mean = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// Covariance matrix of the distribution. Must be positive-definite and square with dimension matching . + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + [Description("Covariance matrix of the distribution. Must be positive-definite and square with dimension matching Mean.")] + public Tensor Covariance { get; set; } = null; + + /// + /// The values of the covariance matrix in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Covariance))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string CovarianceXml + { + get => TensorConverter.ConvertToString(Covariance, Type); + set => Covariance = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// Gets or sets the data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Creates a distribution using the configured parameters. + /// + /// An observable that emits the constructed Multivariate Normal distribution. + public IObservable Process() + { + return Observable.Return(distributions.MultivariateNormal(Mean, Covariance)); + } + + /// + /// Creates a distribution for each incoming RNG . + /// + /// Observable sequence of random generators to use. + /// An observable sequence of Multivariate Normal distributions. + public IObservable Process(IObservable source) + { + return source.Select(generator => distributions.MultivariateNormal(Mean, Covariance, generator: generator)); + } + + /// + /// For each element of the source stream, emits a distribution constructed from the configured parameters. + /// + /// The type of the triggering source sequence. + /// Trigger sequence; each element causes a new distribution to be emitted. + /// An observable sequence of Multivariate Normal distributions. + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.MultivariateNormal(Mean, Covariance)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/Poisson.cs b/src/Bonsai.ML.Torch/Distributions/Poisson.cs new file mode 100644 index 00000000..467e4195 --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Poisson.cs @@ -0,0 +1,74 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Distributions; + +/// +/// Represents an operator that creates a poisson probability distribution parameterized by rate (expected number of events). +/// +[Combinator] +[Description("Creates a poisson distribution with the specified rate parameter.")] +[WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] +public class Poisson : IScalarTypeProvider +{ + /// + /// Rate parameter (> 0), representing the expected number of events. Can be a scalar or tensor; the shape determines the batch/event shape. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + [Description("Rate parameter (> 0), expected number of events. Can be a scalar or tensor; shape sets the batch/event shape of the distribution.")] + public Tensor Rate { get; set; } = null; + + /// + /// The values of the rates in XML string format. + /// + [Browsable(false)] + [XmlElement(nameof(Rate))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string RateXml + { + get => TensorConverter.ConvertToString(Rate, Type); + set => Rate = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// Gets or sets the data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Creates a distribution using the configured parameters. + /// + /// An observable that emits the constructed Poisson distribution. + public IObservable Process() + { + return Observable.Return(distributions.Poisson(Rate)); + } + + /// + /// Creates a distribution for each incoming RNG . + /// + /// Observable sequence of random generators to use. + /// An observable sequence of Poisson distributions. + public IObservable Process(IObservable source) + { + return source.Select(generator => distributions.Poisson(Rate, generator: generator)); + } + + /// + /// For each element of the source stream, emits a distribution constructed from the configured parameters. + /// + /// The type of the triggering source sequence. + /// Trigger sequence; each element causes a new distribution to be emitted. + /// An observable sequence of Poisson distributions. + public IObservable Process(IObservable source) + { + return source.Select(_ => distributions.Poisson(Rate)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Distributions/RSample.cs b/src/Bonsai.ML.Torch/Distributions/RSample.cs new file mode 100644 index 00000000..1e067164 --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/RSample.cs @@ -0,0 +1,32 @@ +using System; +using System.Reactive.Linq; +using System.ComponentModel; +using static TorchSharp.torch; +using static TorchSharp.torch.distributions; + +namespace Bonsai.ML.Torch.Distributions; + +/// +/// Represents an operator that generates reparameterized samples from the input distribution. Reparameterized samples allow gradients to flow through the sampling process. +/// +[Combinator] +[Description("Generates reparameterized samples from the input distribution.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class RSample +{ + /// + /// The shape of the samples to generate. + /// + [TypeConverter(typeof(UnidimensionalArrayConverter))] + public long[] SampleShape { get; set; } + + /// + /// Processes the input distribution to generate reparameterized samples. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(distribution => distribution.rsample(SampleShape)); + } +} diff --git a/src/Bonsai.ML.Torch/Distributions/Sample.cs b/src/Bonsai.ML.Torch/Distributions/Sample.cs new file mode 100644 index 00000000..32ed0f8f --- /dev/null +++ b/src/Bonsai.ML.Torch/Distributions/Sample.cs @@ -0,0 +1,32 @@ +using System; +using System.Reactive.Linq; +using System.ComponentModel; +using static TorchSharp.torch; +using static TorchSharp.torch.distributions; + +namespace Bonsai.ML.Torch.Distributions; + +/// +/// Represents an operator that generates samples from the input distribution. Gradients do not flow through the sampling process. +/// +[Combinator] +[Description("Generates samples from the input distribution.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Sample +{ + /// + /// The shape of the samples to generate. + /// + [TypeConverter(typeof(UnidimensionalArrayConverter))] + public long[] SampleShape { get; set; } + + /// + /// Processes the input distribution to generate samples. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(distribution => distribution.sample(SampleShape)); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch/Random/CreateGenerator.cs b/src/Bonsai.ML.Torch/Random/CreateGenerator.cs new file mode 100644 index 00000000..a6aace8d --- /dev/null +++ b/src/Bonsai.ML.Torch/Random/CreateGenerator.cs @@ -0,0 +1,51 @@ +using Bonsai; +using static TorchSharp.torch; +using TorchSharp; +using System; +using System.Reactive.Linq; +using System.Xml.Serialization; +using System.ComponentModel; + +namespace Bonsai.ML.Torch.Random; + +/// +/// Represents an operator that creates a specific instance of a random number generator (RNG) with the specified seed and device. +/// +[Combinator] +[Description("Creates a specific instance of a random number generator (RNG) with the specified seed and device.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class CreateGenerator +{ + /// + /// The device on which to create the generator. + /// + [XmlIgnore] + [Description("The device on which to create the generator.")] + public Device Device { get; set; } + + /// + /// The seed for the random number generator. + /// + [Description("The seed for the random number generator.")] + public ulong Seed { get; set; } = 0; + + /// + /// Creates a random number generator with the specified seed and device. + /// + /// + public IObservable Process() + { + return Observable.Return(new Generator(Seed, Device)); + } + + /// + /// Generates an observable sequence of random number generators for each element of the input sequence. + /// + /// The type of the elements in the input sequence. + /// The input observable sequence. + /// An observable sequence of random number generators. + public IObservable Process(IObservable source) + { + return source.Select(_ => new Generator(Seed, Device)); + } +} diff --git a/src/Bonsai.ML.Torch/Random/ManualSeed.cs b/src/Bonsai.ML.Torch/Random/ManualSeed.cs new file mode 100644 index 00000000..24eaa930 --- /dev/null +++ b/src/Bonsai.ML.Torch/Random/ManualSeed.cs @@ -0,0 +1,41 @@ +using System; +using System.Reactive.Linq; +using System.ComponentModel; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Random; + +/// +/// Represents an operator that sets the global random seed for TorchSharp and creates a random number generator (RNG) with the specified seed. +/// +[Combinator] +[Description("Sets the global random seed for TorchSharp and creates a random number generator (RNG) with the specified seed.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class ManualSeed +{ + /// + /// The seed for the random number generator. + /// + public long Seed { get; set; } = 0; + + /// + /// Sets the global random seed and creates a random number generator (RNG). + /// + /// + public IObservable Process() + { + return Observable.Return(manual_seed(Seed)); + } + + /// + /// Generates an observable sequence where each element sets the global random seed and creates a random number generator (RNG). + /// + /// The type of the elements in the input sequence. + /// The input observable sequence. + /// An observable sequence of random number generators. + public IObservable Process(IObservable source) + { + return source.Select(_ => manual_seed(Seed)); + } +} diff --git a/src/Bonsai.ML.Torch/Random/Normal.cs b/src/Bonsai.ML.Torch/Random/Normal.cs new file mode 100644 index 00000000..389eebee --- /dev/null +++ b/src/Bonsai.ML.Torch/Random/Normal.cs @@ -0,0 +1,98 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Random; + +/// +/// Represents an operator that creates a tensor filled with random floats sampled from a normal distribution with the specified mean and variance. +/// +[Combinator] +[ResetCombinator] +[Description("Creates a tensor filled with random floats sampled from a normal distribution.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class Normal +{ + /// + /// The size of the tensor. + /// + [Description("The size of the tensor.")] + [TypeConverter(typeof(UnidimensionalArrayConverter))] + public long[] Size { get; set; } + + /// + /// The data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + public ScalarType? Type { get; set; } = null; + + /// + /// The device on which to create the tensor. + /// + [Description("The device on which to create the tensor.")] + [XmlIgnore] + public Device Device { get; set; } = null; + + /// + /// The random number generator to use. + /// + [XmlIgnore] + [Description("The random number generator to use.")] + public Generator Generator { get; set; } = null; + + /// + /// The mean of the normal distribution. + /// + [Description("The mean of the normal distribution.")] + public double Mean { get; set; } = 0; + + /// + /// The variance of the normal distribution. + /// + [Description("The variance of the normal distribution.")] + public double Variance { get; set; } = 1; + + /// + /// Creates a tensor filled with random values sampled from a normal distribution with mean 0 and variance 1. + /// + public IObservable Process() + { + return Observable.Return(randn(Size, dtype: Type, device: Device, generator: Generator) * Variance + Mean); + } + + /// + /// Generates an observable sequence of tensors filled with random values and uses the input generator. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => + { + Generator = value; + return randn(Size, dtype: Type, device: Device, generator: Generator) * Variance + Mean; + }); + } + + /// + /// Generates an observable sequence of tensors filled with random values for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => randn_like(value, dtype: Type, device: Device) * Variance + Mean); + } + + /// + /// Generates an observable sequence of tensors filled with random values for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(_ => randn(Size, dtype: Type, device: Device, generator: Generator) * Variance + Mean); + } +} diff --git a/src/Bonsai.ML.Torch/Random/Permutation.cs b/src/Bonsai.ML.Torch/Random/Permutation.cs new file mode 100644 index 00000000..afc4c6e1 --- /dev/null +++ b/src/Bonsai.ML.Torch/Random/Permutation.cs @@ -0,0 +1,92 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Random; + +/// +/// Represents an operator that creates a 1D tensor of a given size with a random permutation of integers in [0, size). +/// +[Combinator] +[ResetCombinator] +[Description("Creates a 1D tensor of a given size with a random permutation of integers in [0, size).")] +[WorkflowElementCategory(ElementCategory.Source)] +public class Permutation +{ + /// + /// The size of the tensor. + /// + [Description("The size of the tensor.")] + public long Size { get; set; } = 0; + + /// + /// The data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + public ScalarType? Type { get; set; } = null; + + /// + /// The device on which to create the tensor. + /// + [Description("The device on which to create the tensor.")] + [XmlIgnore] + public Device Device { get; set; } = null; + + /// + /// The random number generator to use. + /// + [XmlIgnore] + [Description("The random number generator to use.")] + public Generator Generator { get; set; } = null; + + /// + /// Creates a tensor of a given size with a random permutation of integers from [0, size). + /// + public IObservable Process() + { + return Observable.Return(randperm(Size, dtype: Type, device: Device, generator: Generator)); + } + + /// + /// Creates a tensor with a random permutation of integers in [0, size) and uses the input generator. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => + { + Generator = value; + return randperm(Size, dtype: Type, device: Device, generator: Generator); + }); + } + + /// + /// Randomly permutates tensors from the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => + { + var size = value.numel(); + var shape = value.shape; + var idxs = randperm(size, dtype: Type, device: Device, generator: Generator); + return value.flatten().index_select(0, idxs).reshape(shape); + }); + } + + + /// + /// Generates random permutations for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(_ => randperm(Size, dtype: Type, device: Device, generator: Generator)); + } +} diff --git a/src/Bonsai.ML.Torch/Random/Uniform.cs b/src/Bonsai.ML.Torch/Random/Uniform.cs new file mode 100644 index 00000000..e837bd7d --- /dev/null +++ b/src/Bonsai.ML.Torch/Random/Uniform.cs @@ -0,0 +1,101 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Random; + +/// +/// Represents an operator that creates a tensor filled with random numbers sampled from a uniform distribution over the interval [MinSize, MaxSize). +/// +[Combinator] +[ResetCombinator] +[Description("Creates a tensor filled with random numbers sampled from a uniform distribution over the interval [MinValue, MaxValue).")] +[WorkflowElementCategory(ElementCategory.Source)] +public class Uniform +{ + /// + /// The size of the tensor. + /// + [Description("The size of the tensor.")] + [TypeConverter(typeof(UnidimensionalArrayConverter))] + public long[] Size { get; set; } = []; + + /// + /// The data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + public ScalarType? Type { get; set; } = null; + + /// + /// The device on which to create the tensor. + /// + [Description("The device on which to create the tensor.")] + [XmlIgnore] + public Device Device { get; set; } = null; + + /// + /// The random number generator to use. + /// + [XmlIgnore] + [Description("The random number generator to use.")] + public Generator Generator { get; set; } = null; + + /// + /// The minimum value of the random numbers inclusive. + /// + [Description("The minimum value of the random numbers.")] + public double MinValue { get; set; } = 0; + + /// + /// The maximum value of the random numbers exclusive. + /// + [Description("The maximum value of the random numbers.")] + public double MaxValue { get; set; } = 1; + + + /// + /// Creates a tensor filled with random values sampled from a uniform distribution over the interval [MinValue, MaxValue). + /// + public IObservable Process() + { + return Observable.Return(rand(Size, dtype: Type, device: Device, generator: Generator) * (MaxValue - MinValue) + MinValue); + } + + /// + /// Creates tensors filled with random values and uses the input generator. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => + { + Generator = value; + return rand(Size, dtype: Type, device: Device, generator: Generator) * (MaxValue - MinValue) + MinValue; + }); + } + + /// + /// Generates an observable sequence of tensors filled with random values for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => rand_like(value, dtype: Type, device: Device) * (MaxValue - MinValue) + MinValue); + } + + /// + /// Creates tensors filled with random values for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(_ => rand(Size, dtype: Type, device: Device, generator: Generator) * (MaxValue - MinValue) + MinValue); + } + +} diff --git a/src/Bonsai.ML.Torch/Random/UniformIntegers.cs b/src/Bonsai.ML.Torch/Random/UniformIntegers.cs new file mode 100644 index 00000000..450ce499 --- /dev/null +++ b/src/Bonsai.ML.Torch/Random/UniformIntegers.cs @@ -0,0 +1,98 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.Random; + +/// +/// Creates a tensor filled with random integers sampled from a uniform distribution over the interval [MinValue, MaxValue). +/// +[Combinator] +[ResetCombinator] +[Description("Creates a tensor filled with random integers sampled from a uniform distribution over the interval [MinValue, MaxValue).")] +[WorkflowElementCategory(ElementCategory.Source)] +public class UniformIntegers +{ + /// + /// The size of the tensor. + /// + [Description("The size of the tensor.")] + [TypeConverter(typeof(UnidimensionalArrayConverter))] + public long[] Size { get; set; } = []; + + /// + /// The minimum value of the random integers inclusive. + /// + [Description("The minimum value of the random integers.")] + public int MinValue { get; set; } = 0; + + /// + /// The maximum value of the random integers exclusive. + /// + [Description("The maximum value of the random integers.")] + public int MaxValue { get; set; } = 100; + + /// + /// The data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + public ScalarType? Type { get; set; } = null; + + /// + /// The device on which to create the tensor. + /// + [Description("The device on which to create the tensor.")] + [XmlIgnore] + public Device Device { get; set; } = null; + + /// + /// The random number generator to use. + /// + [XmlIgnore] + [Description("The random number generator to use.")] + public Generator Generator { get; set; } = null; + + /// + /// Creates a tensor filled with random integers sampled from a uniform distribution over the interval [MinValue, MaxValue). + /// + public IObservable Process() + { + return Observable.Return(randint(MinValue, MaxValue, Size, dtype: Type, device: Device, generator: Generator)); + } + + /// + /// Creates tensors filled with random integers and uses the input generator. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => + { + Generator = value; + return randint(MinValue, MaxValue, Size, dtype: Type, device: Device, generator: Generator); + }); + } + + /// + /// Creates tensors filled with random integers for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(value => randint_like(value, MinValue, MaxValue, dtype: Type, device: Device)); + } + + /// + /// Creates tensors filled with random integers for each element of the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(_ => randint(MinValue, MaxValue, Size, dtype: Type, device: Device, generator: Generator)); + } +}