From a26b727fddf20d4c95898351884adcece8e89a68 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 6 Aug 2025 17:21:44 +0100 Subject: [PATCH 01/20] Added new PCA project --- Bonsai.ML.sln | 11 +++++++++++ src/Bonsai.ML.PCA/Bonsai.ML.PCA.csproj | 15 +++++++++++++++ src/Bonsai.ML.PCA/Properties/launchSettings.json | 10 ++++++++++ 3 files changed, 36 insertions(+) create mode 100644 src/Bonsai.ML.PCA/Bonsai.ML.PCA.csproj create mode 100644 src/Bonsai.ML.PCA/Properties/launchSettings.json diff --git a/Bonsai.ML.sln b/Bonsai.ML.sln index fc4e963f..4640a1a3 100644 --- a/Bonsai.ML.sln +++ b/Bonsai.ML.sln @@ -45,6 +45,9 @@ EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Lds.Torch.Tests", "tests\Bonsai.ML.Lds.Torch.Tests\Bonsai.ML.Lds.Torch.Tests.csproj", "{0B258929-0B07-4CE7-BE8D-A86BBC46AAD4}" EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Lds.Torch.Design", "src\Bonsai.ML.Lds.Torch.Design\Bonsai.ML.Lds.Torch.Design.csproj", "{1F52DECD-1B2C-4F6C-996C-14C715283B80}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.PCA", "src\Bonsai.ML.PCA\Bonsai.ML.PCA.csproj", "{1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.PCA.Tests", "tests\Bonsai.ML.PCA.Tests\Bonsai.ML.PCA.Tests.csproj", "{4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -120,6 +123,14 @@ Global {1F52DECD-1B2C-4F6C-996C-14C715283B80}.Debug|Any CPU.Build.0 = Debug|Any CPU {1F52DECD-1B2C-4F6C-996C-14C715283B80}.Release|Any CPU.ActiveCfg = Release|Any CPU {1F52DECD-1B2C-4F6C-996C-14C715283B80}.Release|Any CPU.Build.0 = Release|Any CPU + {1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}.Debug|Any CPU.Build.0 = Debug|Any CPU + {1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}.Release|Any CPU.ActiveCfg = Release|Any CPU + {1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}.Release|Any CPU.Build.0 = Release|Any CPU + {4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}.Debug|Any CPU.Build.0 = Debug|Any CPU + {4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}.Release|Any CPU.ActiveCfg = Release|Any CPU + {4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/Bonsai.ML.PCA/Bonsai.ML.PCA.csproj b/src/Bonsai.ML.PCA/Bonsai.ML.PCA.csproj new file mode 100644 index 00000000..f1e876d2 --- /dev/null +++ b/src/Bonsai.ML.PCA/Bonsai.ML.PCA.csproj @@ -0,0 +1,15 @@ + + + + Bonsai.ML.PCA Bonsai library. + $(PackageTags) Point Process Neural Decoder + net472;netstandard2.0 + enable + + + + + + + + diff --git a/src/Bonsai.ML.PCA/Properties/launchSettings.json b/src/Bonsai.ML.PCA/Properties/launchSettings.json new file mode 100644 index 00000000..4af4f468 --- /dev/null +++ b/src/Bonsai.ML.PCA/Properties/launchSettings.json @@ -0,0 +1,10 @@ +{ + "profiles": { + "Bonsai": { + "commandName": "Executable", + "executablePath": "$(BonsaiExecutablePath)", + "commandLineArgs": "--lib:\"$(TargetDir).\"", + "nativeDebugging": true + } + } +} \ No newline at end of file From 4bfcb4ad990df1f177fc28ac2456b1e972f6aa29 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 7 Aug 2025 14:41:31 +0100 Subject: [PATCH 02/20] Added main components of PCA package --- src/Bonsai.ML.PCA/CreatePCA.cs | 66 ++++++++++++++++++++++ src/Bonsai.ML.PCA/Fit.cs | 34 +++++++++++ src/Bonsai.ML.PCA/IPCABaseModel.cs | 19 +++++++ src/Bonsai.ML.PCA/PCA.cs | 52 +++++++++++++++++ src/Bonsai.ML.PCA/PCABaseModel.cs | 35 ++++++++++++ src/Bonsai.ML.PCA/PCADesciptionProvider.cs | 26 +++++++++ src/Bonsai.ML.PCA/PCADescriptor.cs | 39 +++++++++++++ src/Bonsai.ML.PCA/PCAModelType.cs | 18 ++++++ src/Bonsai.ML.PCA/Transform.cs | 35 ++++++++++++ 9 files changed, 324 insertions(+) create mode 100644 src/Bonsai.ML.PCA/CreatePCA.cs create mode 100644 src/Bonsai.ML.PCA/Fit.cs create mode 100644 src/Bonsai.ML.PCA/IPCABaseModel.cs create mode 100644 src/Bonsai.ML.PCA/PCA.cs create mode 100644 src/Bonsai.ML.PCA/PCABaseModel.cs create mode 100644 src/Bonsai.ML.PCA/PCADesciptionProvider.cs create mode 100644 src/Bonsai.ML.PCA/PCADescriptor.cs create mode 100644 src/Bonsai.ML.PCA/PCAModelType.cs create mode 100644 src/Bonsai.ML.PCA/Transform.cs diff --git a/src/Bonsai.ML.PCA/CreatePCA.cs b/src/Bonsai.ML.PCA/CreatePCA.cs new file mode 100644 index 00000000..34a0232c --- /dev/null +++ b/src/Bonsai.ML.PCA/CreatePCA.cs @@ -0,0 +1,66 @@ +using System; +using System.ComponentModel; +using System.Collections.Generic; +using System.Reactive.Linq; +using System.Linq.Expressions; +using Bonsai.Expressions; +using System.Linq; +using System.Reflection; + +namespace Bonsai.ML.PCA +{ + [Combinator] + [WorkflowElementCategory(ElementCategory.Source)] + [TypeDescriptionProvider(typeof(PCADescriptionProvider))] + public class CreatePCA : ZeroArgumentExpressionBuilder + { + public int NumComponents { get; set; } = 2; + + [RefreshProperties(RefreshProperties.All)] + public PCAModelType ModelType { get; set; } = PCAModelType.PCA; + public double Variance { get; set; } = 1.0; + + internal IEnumerable GetModelProperties() + { + yield return nameof(NumComponents); + yield return nameof(ModelType); + + if (ModelType == PCAModelType.ProbabilisticPCA) + { + yield return "Variance"; + } + } + + private static PCABaseModel CreateModel(CreatePCA instance) + { + return instance.ModelType switch + { + PCAModelType.PCA => new PCA(instance.NumComponents), + _ => throw new NotSupportedException($"Model type {instance.ModelType} is not supported."), + }; + } + + private static Type GetModelType(PCAModelType modelType) + { + return modelType switch + { + PCAModelType.PCA => typeof(PCA), + _ => throw new NotSupportedException($"Model type {modelType} is not supported."), + }; + } + + public override Expression Build(IEnumerable arguments) + { + var processMethod = GetType().GetMethod( + nameof(Process), + BindingFlags.NonPublic | BindingFlags.Static); + processMethod = processMethod.MakeGenericMethod(GetModelType(ModelType)); + return Expression.Call(processMethod, Expression.Constant(this)); + } + + private static IObservable Process(CreatePCA instance) where T : PCABaseModel + { + return Observable.Return((T)CreateModel(instance)); + } + } +} diff --git a/src/Bonsai.ML.PCA/Fit.cs b/src/Bonsai.ML.PCA/Fit.cs new file mode 100644 index 00000000..0e7ca80e --- /dev/null +++ b/src/Bonsai.ML.PCA/Fit.cs @@ -0,0 +1,34 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using Bonsai; +using static TorchSharp.torch; + +namespace Bonsai.ML.PCA +{ + [Combinator] + [Description] + [WorkflowElementCategory(ElementCategory.Sink)] + public class Fit + { + private void FitModel(IPCABaseModel model, Tensor data) + { + model.Fit(data); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item1, value.Item2); + }); + } + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item2, value.Item1); + }); + } + } +} diff --git a/src/Bonsai.ML.PCA/IPCABaseModel.cs b/src/Bonsai.ML.PCA/IPCABaseModel.cs new file mode 100644 index 00000000..17d5d603 --- /dev/null +++ b/src/Bonsai.ML.PCA/IPCABaseModel.cs @@ -0,0 +1,19 @@ +using Bonsai; +using System; +using System.ComponentModel; +using System.Collections.Generic; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.PCA +{ + public interface IPCABaseModel + { + public abstract void Fit(Tensor data); + public abstract Tensor Transform(Tensor data); + public abstract Tensor FitAndTransform(Tensor data); + } +} diff --git a/src/Bonsai.ML.PCA/PCA.cs b/src/Bonsai.ML.PCA/PCA.cs new file mode 100644 index 00000000..7c31d43e --- /dev/null +++ b/src/Bonsai.ML.PCA/PCA.cs @@ -0,0 +1,52 @@ +using Bonsai; +using System; +using System.ComponentModel; +using System.Collections.Generic; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.PCA +{ + public class PCA : PCABaseModel + { + public Tensor Covariance { get; private set; } = empty(0); + public Tensor EigenValues { get; private set; } = empty(0); + public Tensor EigenVectors { get; private set; } = empty(0); + public Tensor Components { get; private set; } = empty(0); + + public PCA(int numComponents) : base(numComponents) { } + + public override void Fit(Tensor data) + { + if (data.NumberOfElements == 0 || data.dim() < 2) + { + throw new ArgumentException("Data must be a non-empty 2D tensor.", nameof(data)); + } + + Covariance = cov(data); + var eigen = eigh(Covariance); + var sortedIndices = argsort(eigen.Item1, dim: -1, descending: true); + EigenValues = eigen.Item1[sortedIndices]; + EigenVectors = eigen.Item2.index_select(1, sortedIndices); + Components = EigenVectors.slice(1, 0, NumComponents, 1); + } + + public override Tensor Transform(Tensor data) + { + if (data.NumberOfElements == 0 || data.dim() < 2) + { + throw new ArgumentException("Data must be a non-empty 2D tensor.", nameof(data)); + } + + if (Components.NumberOfElements == 0) + { + throw new InvalidOperationException("Model has not been fit to data. Call the Fit() method first."); + } + + return data.T.matmul(Components); + } + } +} diff --git a/src/Bonsai.ML.PCA/PCABaseModel.cs b/src/Bonsai.ML.PCA/PCABaseModel.cs new file mode 100644 index 00000000..044a8eb6 --- /dev/null +++ b/src/Bonsai.ML.PCA/PCABaseModel.cs @@ -0,0 +1,35 @@ +using Bonsai; +using System; +using System.ComponentModel; +using System.Collections.Generic; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.PCA +{ + public abstract class PCABaseModel : IPCABaseModel + { + public int NumComponents { get; private set; } + + public PCABaseModel(int numComponents) + { + if (numComponents <= 0) + { + throw new ArgumentException("Number of components must be greater than zero.", nameof(numComponents)); + } + + NumComponents = numComponents; + } + + public abstract void Fit(Tensor data); + public abstract Tensor Transform(Tensor data); + public virtual Tensor FitAndTransform(Tensor data) + { + Fit(data); + return Transform(data); + } + } +} diff --git a/src/Bonsai.ML.PCA/PCADesciptionProvider.cs b/src/Bonsai.ML.PCA/PCADesciptionProvider.cs new file mode 100644 index 00000000..1f8886af --- /dev/null +++ b/src/Bonsai.ML.PCA/PCADesciptionProvider.cs @@ -0,0 +1,26 @@ +using System.ComponentModel; +using System; +using Bonsai; +using Bonsai.Expressions; + +namespace Bonsai.ML.PCA +{ + class PCADescriptionProvider : TypeDescriptionProvider + { + private readonly TypeDescriptionProvider _baseProvider; + + public PCADescriptionProvider() : this(TypeDescriptor.GetProvider(typeof(object))) { } + + public PCADescriptionProvider(TypeDescriptionProvider baseProvider) + : base(baseProvider) + { + _baseProvider = baseProvider; + } + + public override ICustomTypeDescriptor GetTypeDescriptor(Type objectType, object instance) + { + var defaultDescriptor = base.GetTypeDescriptor(objectType, instance); + return new PCADescriptor(defaultDescriptor, instance); + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.PCA/PCADescriptor.cs b/src/Bonsai.ML.PCA/PCADescriptor.cs new file mode 100644 index 00000000..1c19d254 --- /dev/null +++ b/src/Bonsai.ML.PCA/PCADescriptor.cs @@ -0,0 +1,39 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Collections.Generic; +using System.Linq; +using Bonsai; +using static TorchSharp.torch; + +namespace Bonsai.ML.PCA +{ + public class PCADescriptor : CustomTypeDescriptor + { + private readonly object _instance; + + public PCADescriptor(ICustomTypeDescriptor parent, object instance) : base(parent) + { + _instance = instance; + } + + public override PropertyDescriptorCollection GetProperties(Attribute[] attributes) + { + var allProperties = base.GetProperties(attributes); + + if (_instance is CreatePCA createPCA) + { + var modelProperties = new HashSet(createPCA.GetModelProperties()); + var filtered = allProperties.Cast() + .Where(p => modelProperties.Contains(p.Name)) + .ToArray(); + return new PropertyDescriptorCollection(filtered); + } + + return allProperties; + } + + public override PropertyDescriptorCollection GetProperties() + => GetProperties(null); + } +} diff --git a/src/Bonsai.ML.PCA/PCAModelType.cs b/src/Bonsai.ML.PCA/PCAModelType.cs new file mode 100644 index 00000000..726b9081 --- /dev/null +++ b/src/Bonsai.ML.PCA/PCAModelType.cs @@ -0,0 +1,18 @@ +using Bonsai; +using System; +using System.ComponentModel; +using System.Collections.Generic; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.PCA +{ + public enum PCAModelType + { + PCA, + ProbabilisticPCA + } +} diff --git a/src/Bonsai.ML.PCA/Transform.cs b/src/Bonsai.ML.PCA/Transform.cs new file mode 100644 index 00000000..3016f462 --- /dev/null +++ b/src/Bonsai.ML.PCA/Transform.cs @@ -0,0 +1,35 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using Bonsai; +using static TorchSharp.torch; + +namespace Bonsai.ML.PCA +{ + [Combinator] + [Description] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Transform + { + private Tensor TransformData(IPCABaseModel model, Tensor data) + { + return model.Transform(data); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item1, value.Item2); + }); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item2, value.Item1); + }); + } + } +} From 2a941e8c3fc10fcafdb5c2111dc3837d7c8ce1c4 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 11 Aug 2025 10:28:03 +0100 Subject: [PATCH 03/20] Added probabalistic PCA method to package --- src/Bonsai.ML.PCA/CreatePCA.cs | 21 +++- src/Bonsai.ML.PCA/Fit.cs | 17 ++++ src/Bonsai.ML.PCA/PCA.cs | 18 +++- src/Bonsai.ML.PCA/PPCA.cs | 173 +++++++++++++++++++++++++++++++++ src/Bonsai.ML.PCA/Transform.cs | 16 +++ 5 files changed, 239 insertions(+), 6 deletions(-) create mode 100644 src/Bonsai.ML.PCA/PPCA.cs diff --git a/src/Bonsai.ML.PCA/CreatePCA.cs b/src/Bonsai.ML.PCA/CreatePCA.cs index 34a0232c..3524056a 100644 --- a/src/Bonsai.ML.PCA/CreatePCA.cs +++ b/src/Bonsai.ML.PCA/CreatePCA.cs @@ -18,7 +18,13 @@ public class CreatePCA : ZeroArgumentExpressionBuilder [RefreshProperties(RefreshProperties.All)] public PCAModelType ModelType { get; set; } = PCAModelType.PCA; - public double Variance { get; set; } = 1.0; + + public double InitialVariance { get; set; } = 1.0; + public int Iterations { get; set; } = 100; + public double Tolerance { get; set; } = 1e-5; + + [XmlIgnore] + public Generator? Generator { get; set; } = null; internal IEnumerable GetModelProperties() { @@ -27,7 +33,11 @@ internal IEnumerable GetModelProperties() if (ModelType == PCAModelType.ProbabilisticPCA) { - yield return "Variance"; + yield return nameof(InitialVariance); + yield return nameof(Iterations); + yield return nameof(Tolerance); + yield return nameof(Generator); + } } } @@ -36,6 +46,12 @@ private static PCABaseModel CreateModel(CreatePCA instance) return instance.ModelType switch { PCAModelType.PCA => new PCA(instance.NumComponents), + PCAModelType.ProbabilisticPCA => new PPCA( + instance.NumComponents, + instance.InitialVariance, + instance.Generator, + instance.Iterations, + instance.Tolerance), _ => throw new NotSupportedException($"Model type {instance.ModelType} is not supported."), }; } @@ -45,6 +61,7 @@ private static Type GetModelType(PCAModelType modelType) return modelType switch { PCAModelType.PCA => typeof(PCA), + PCAModelType.ProbabilisticPCA => typeof(PPCA), _ => throw new NotSupportedException($"Model type {modelType} is not supported."), }; } diff --git a/src/Bonsai.ML.PCA/Fit.cs b/src/Bonsai.ML.PCA/Fit.cs index 0e7ca80e..4cab50af 100644 --- a/src/Bonsai.ML.PCA/Fit.cs +++ b/src/Bonsai.ML.PCA/Fit.cs @@ -23,6 +23,7 @@ public IObservable> Process(IObservable> s FitModel(value.Item1, value.Item2); }); } + public IObservable> Process(IObservable> source) { return source.Do((value) => @@ -30,5 +31,21 @@ public IObservable> Process(IObservable> s FitModel(value.Item2, value.Item1); }); } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item1, value.Item2); + }); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item2, value.Item1); + }); + } } } diff --git a/src/Bonsai.ML.PCA/PCA.cs b/src/Bonsai.ML.PCA/PCA.cs index 7c31d43e..734306ae 100644 --- a/src/Bonsai.ML.PCA/PCA.cs +++ b/src/Bonsai.ML.PCA/PCA.cs @@ -16,6 +16,7 @@ public class PCA : PCABaseModel public Tensor EigenValues { get; private set; } = empty(0); public Tensor EigenVectors { get; private set; } = empty(0); public Tensor Components { get; private set; } = empty(0); + private bool _isFitted = false; public PCA(int numComponents) : base(numComponents) { } @@ -23,7 +24,15 @@ public override void Fit(Tensor data) { if (data.NumberOfElements == 0 || data.dim() < 2) { - throw new ArgumentException("Data must be a non-empty 2D tensor.", nameof(data)); + throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); + } + + var n = data.size(0); + var d = data.size(1); + + if (NumComponents > d) + { + throw new ArgumentException("Number of components cannot be greater than the number of features.", nameof(data)); } Covariance = cov(data); @@ -32,18 +41,19 @@ public override void Fit(Tensor data) EigenValues = eigen.Item1[sortedIndices]; EigenVectors = eigen.Item2.index_select(1, sortedIndices); Components = EigenVectors.slice(1, 0, NumComponents, 1); + _isFitted = true; } public override Tensor Transform(Tensor data) { if (data.NumberOfElements == 0 || data.dim() < 2) { - throw new ArgumentException("Data must be a non-empty 2D tensor.", nameof(data)); + throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); } - if (Components.NumberOfElements == 0) + if (!_isFitted) { - throw new InvalidOperationException("Model has not been fit to data. Call the Fit() method first."); + throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); } return data.T.matmul(Components); diff --git a/src/Bonsai.ML.PCA/PPCA.cs b/src/Bonsai.ML.PCA/PPCA.cs new file mode 100644 index 00000000..1f713b3c --- /dev/null +++ b/src/Bonsai.ML.PCA/PPCA.cs @@ -0,0 +1,173 @@ +using Bonsai; +using System; +using System.ComponentModel; +using System.Collections.Generic; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.PCA +{ + public class PPCA : PCABaseModel + { + public double Variance { get; private set; } + public Tensor LogLikelihood { get; private set; } = empty(0); + public Tensor Components { get; private set; } = empty(0); + public Generator Generator { get; private set; } + private int _iterations; + private double _tolerance; + private bool _isFitted = false; + + public PPCA(int numComponents, + double initialVariance = 1.0, + Generator? generator = null, + int iterations = 100, + double tolerance = 1e-5 + ) : base(numComponents) + { + if (initialVariance < 0) + { + throw new ArgumentException("Starting variance must be greater than or equal to zero.", nameof(initialVariance)); + } + + if (iterations <= 0) + { + throw new ArgumentException("Number of iterations must be greater than zero.", nameof(iterations)); + } + + if (tolerance <= 0) + { + throw new ArgumentException("Tolerance must be greater than zero.", nameof(tolerance)); + } + + Variance = initialVariance; + Generator = generator ?? manual_seed(0); + _iterations = iterations; + _tolerance = tolerance; + } + + public override void Fit(Tensor data) + { + if (data.NumberOfElements == 0 || data.dim() != 2) + { + throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); + } + + var Xt = data.T; // n x d + + // Initialize variance + var variance = Variance; + + // Initialize log likelihood + LogLikelihood = ones(_iterations) * double.NegativeInfinity; + + // Initialize dimensions for components + var q = NumComponents; + var n = Xt.size(0); + var d = Xt.size(1); + + if (q > d) + { + throw new ArgumentException("Number of components cannot be greater than the number of features.", nameof(data)); + } + + // Initialize W and I + var W = randn(d, q, generator: Generator); // d x q + var MI = eye(q); // q x q + var CI = eye(d); // d x d + + // Calculate the sample mean + var mean = Xt.mean([0], keepdim: true); // 1 x d + + // Center the data and transpose + var X = Xt - mean; // n x d + + // Calculate the sample covariance + var XTX = X.T.matmul(X); // d x d + var sampleCov = XTX / n; // d x d + + // Calculate term 1 for variance update + var term1 = trace(XTX); + + // Compute log likelihood constant + var logLikelihoodConst = d * log(2 * Math.PI); + + double diffW; + double diffVariance; + + // Repeat until convergence + for (int i = 0; i < _iterations; i++) + { + using (var _ = NewDisposeScope()) + { + // E-step: Compute the posterior distribution of the latent variables + var M = W.T.matmul(W) + MI * variance; // q x q + var MInv = inv(M); // q x q + var mu = MInv.matmul(W.T).matmul(X.T).T; // n x q + var SSum = n * MInv * variance; // q x q + var cov = mu.T.matmul(mu) + SSum; // q x q + + // M-step: Compute new W and new variance + var XMu = X.T.matmul(mu); // d x q + var WNew = XMu.matmul(inv(cov)); // d x q + + var term2 = 2 * XMu.mul(WNew).sum(); + var mumu = mu.T.matmul(mu); + var WNewWNew = WNew.T.matmul(WNew); + var term3 = trace(WNewWNew.matmul(mumu + SSum)); + var varianceNew = (term1 - term2 + term3) / (n * d); // scalar + + // Compute the log likelihood + var C = W.matmul(W.T) + CI * variance; // d x d + var CInv = inv(C); // d x d + var logLikelihood = -0.5 * n * (logLikelihoodConst + logdet(C) + trace(CInv.matmul(sampleCov))); // scalar + + // Check for convergence + diffW = linalg.norm(WNew - W).to_type(ScalarType.Float64).cpu().ReadCpuDouble(0); + diffVariance = abs(varianceNew - variance).to_type(ScalarType.Float64).cpu().ReadCpuDouble(0); + + // Update loglikelihood, W and variance + LogLikelihood[i] = logLikelihood.MoveToOuterDisposeScope(); + W = WNew.MoveToOuterDisposeScope(); + variance = varianceNew.to_type(ScalarType.Float64).cpu().ReadCpuDouble(0); + } + + + if (diffW < _tolerance && diffVariance < _tolerance) + { + LogLikelihood = LogLikelihood.slice(0, 0, i + 1, 1); + break; + } + } + + // Finalize model parameters + LogLikelihood = LogLikelihood.DetachFromDisposeScope(); + Components = W.DetachFromDisposeScope(); + Variance = variance; + _isFitted = true; + } + + public override Tensor Transform(Tensor data) + { + if (data.NumberOfElements == 0 || data.dim() < 2) + { + throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); + } + + if (!_isFitted) + { + throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); + } + + var Xt = data.T; + var mean = Xt.mean([ 0 ], keepdim: true); // 1 x d + var X = Xt - mean; // n x d + var W = Components; // d x q + var M = W.T.matmul(W) + eye(NumComponents) * Variance; // q x q + var MInv = inverse(M); // q x q + return X.matmul(W).matmul(MInv); // n x q + } + } +} diff --git a/src/Bonsai.ML.PCA/Transform.cs b/src/Bonsai.ML.PCA/Transform.cs index 3016f462..744c1948 100644 --- a/src/Bonsai.ML.PCA/Transform.cs +++ b/src/Bonsai.ML.PCA/Transform.cs @@ -31,5 +31,21 @@ public IObservable Process(IObservable> source) return TransformData(value.Item2, value.Item1); }); } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item1, value.Item2); + }); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item2, value.Item1); + }); + } } } From 5be19998ebaf63ec832836b6c8171c675f51f00c Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 11 Aug 2025 10:29:03 +0100 Subject: [PATCH 04/20] Added online probabalistic PCA method --- src/Bonsai.ML.PCA/CreatePCA.cs | 23 ++++ src/Bonsai.ML.PCA/Fit.cs | 16 +++ src/Bonsai.ML.PCA/OnlinePPCA.cs | 191 ++++++++++++++++++++++++++++++ src/Bonsai.ML.PCA/PCAModelType.cs | 3 +- src/Bonsai.ML.PCA/Transform.cs | 16 +++ 5 files changed, 248 insertions(+), 1 deletion(-) create mode 100644 src/Bonsai.ML.PCA/OnlinePPCA.cs diff --git a/src/Bonsai.ML.PCA/CreatePCA.cs b/src/Bonsai.ML.PCA/CreatePCA.cs index 3524056a..e729fcfc 100644 --- a/src/Bonsai.ML.PCA/CreatePCA.cs +++ b/src/Bonsai.ML.PCA/CreatePCA.cs @@ -6,10 +6,13 @@ using Bonsai.Expressions; using System.Linq; using System.Reflection; +using static TorchSharp.torch; +using System.Xml.Serialization; namespace Bonsai.ML.PCA { [Combinator] + [ResetCombinator] [WorkflowElementCategory(ElementCategory.Source)] [TypeDescriptionProvider(typeof(PCADescriptionProvider))] public class CreatePCA : ZeroArgumentExpressionBuilder @@ -23,6 +26,10 @@ public class CreatePCA : ZeroArgumentExpressionBuilder public int Iterations { get; set; } = 100; public double Tolerance { get; set; } = 1e-5; + public double? Rho { get; set; } = 0.1; + public double? Kappa { get; set; } = 0.9; + public int? BurnInCount { get; set; } = null; + [XmlIgnore] public Generator? Generator { get; set; } = null; @@ -38,6 +45,14 @@ internal IEnumerable GetModelProperties() yield return nameof(Tolerance); yield return nameof(Generator); } + + if (ModelType == PCAModelType.OnlinePPCA) + { + yield return nameof(InitialVariance); + yield return nameof(Rho); + yield return nameof(Kappa); + yield return nameof(BurnInCount); + yield return nameof(Generator); } } @@ -52,6 +67,13 @@ private static PCABaseModel CreateModel(CreatePCA instance) instance.Generator, instance.Iterations, instance.Tolerance), + PCAModelType.OnlinePPCA => new OnlinePPCA( + instance.NumComponents, + instance.InitialVariance, + instance.Generator, + instance.Rho, + instance.Kappa, + instance.BurnInCount), _ => throw new NotSupportedException($"Model type {instance.ModelType} is not supported."), }; } @@ -62,6 +84,7 @@ private static Type GetModelType(PCAModelType modelType) { PCAModelType.PCA => typeof(PCA), PCAModelType.ProbabilisticPCA => typeof(PPCA), + PCAModelType.OnlinePPCA => typeof(OnlinePPCA), _ => throw new NotSupportedException($"Model type {modelType} is not supported."), }; } diff --git a/src/Bonsai.ML.PCA/Fit.cs b/src/Bonsai.ML.PCA/Fit.cs index 4cab50af..27eb0d60 100644 --- a/src/Bonsai.ML.PCA/Fit.cs +++ b/src/Bonsai.ML.PCA/Fit.cs @@ -47,5 +47,21 @@ public IObservable> Process(IObservable> FitModel(value.Item2, value.Item1); }); } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item1, value.Item2); + }); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item2, value.Item1); + }); + } } } diff --git a/src/Bonsai.ML.PCA/OnlinePPCA.cs b/src/Bonsai.ML.PCA/OnlinePPCA.cs new file mode 100644 index 00000000..7c1daea6 --- /dev/null +++ b/src/Bonsai.ML.PCA/OnlinePPCA.cs @@ -0,0 +1,191 @@ +using Bonsai; +using System; +using System.ComponentModel; +using System.Collections.Generic; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.PCA +{ + public class OnlinePPCA : PCABaseModel + { + public double? Rho { get; private set; } + public double? Kappa { get; private set; } + public int? BurnInCount { get; private set; } + public double Variance { get; private set; } + public Tensor Components => _W; + public Generator Generator { get; private set; } + private Tensor _mu; + private Tensor _W; + private Tensor _Iq; + private Tensor _Id; + private Tensor _Sx; + private double _Sxxtrace; + private Tensor _Sxz; + private Tensor _Szz; + private bool _initializedParameters = false; + private int _count = 0; + private int _burnInCount = 0; + private Func UpdateRho; + + public OnlinePPCA(int numComponents, + double initialVariance = 1.0, + Generator? generator = null, + double? rho = 0.1, + double? kappa = null, + int? burnInCount = null + ) : base(numComponents) + { + if (initialVariance <= 0) + { + throw new ArgumentException("Starting variance must be greater than or equal to zero.", nameof(initialVariance)); + } + + if (kappa.HasValue && rho.HasValue) + { + throw new ArgumentException("Only one of rho or kappa should be specified, not both.", nameof(rho)); + } + + if (!kappa.HasValue && !rho.HasValue) + { + throw new ArgumentException("Either rho or kappa must be specified.", nameof(rho)); + } + + if (rho.HasValue) + { + UpdateRho = () => rho.Value; + if (rho <= 0 || rho >= 1) + { + throw new ArgumentException("Rho must be in the range (0, 1).", nameof(rho)); + } + } + + if (kappa.HasValue) + { + UpdateRho = () => Math.Pow(_count + 10, -kappa.Value); + if (kappa <= 0.5 || kappa > 1) + { + throw new ArgumentException("Kappa must be in the range (0.5, 1].", nameof(kappa)); + } + } + + if (burnInCount.HasValue && burnInCount <= 0) + { + throw new ArgumentException("Warmup iterations must be greater than zero.", nameof(burnInCount)); + } + + + Variance = initialVariance; + Generator = generator ?? manual_seed(0); + Rho = rho; + Kappa = kappa; + BurnInCount = burnInCount; + _burnInCount = burnInCount ?? 0; + } + + public override void Fit(Tensor data) + { + // throw new NotImplementedException(); + if (data.NumberOfElements == 0 || data.dim() != 2) + { + throw new ArgumentException("Input data must be a 2D tensor."); + } + + var Xt = data.T; // n x d + + // Initialize dimensions + var q = NumComponents; + var n = Xt.size(0); + var d = Xt.size(1); + + // Initialize parameters + if (!_initializedParameters) + { + _mu = zeros(1, d); // 1 x d + var wScale = Math.Sqrt(Variance / d); + _W = randn(d, q, generator: Generator) * wScale; // d x q + _Iq = eye(q); // q x q + _Id = eye(d); // d x d + _Sx = zeros(1, d); // d + _Sxz = zeros(d, q); // d x q + _Szz = zeros(q, q); // q x q + _Sxxtrace = 0.0; // scalar + + _burnInCount = Math.Max(_burnInCount, (int)d); + _initializedParameters = true; + } + + // Update rho + var rho = UpdateRho(); + + using (var _ = NewDisposeScope()) + { + // E-step using current parameters + var M = _W.T.matmul(_W) + _Iq * Variance; // q x q + M = 0.5 * (M + M.T); // Ensure symmetry + var Xc = Xt - _mu; // n x d + var Lm = linalg.cholesky(M); // q x q + var XcW = Xc.matmul(_W); // n x q + var Ez = cholesky_solve(XcW.T, Lm).T; // n x q + var EzInv = cholesky_inverse(Lm); // q x q + var Ezzmu = (Variance * EzInv) + Ez.T.matmul(Ez); + + // Batch statistics + var Xmu = Xt.mean([0], keepdim: true); // 1 x d + var Xzmu = Xc.T.matmul(Ez); // d x q + var Xcmu = Xc.pow(2).sum(1).mean(); + + // Update parameters with new statistics + _Sx = (1 - rho) * _Sx + rho * Xmu; + _Sxz = (1 - rho) * _Sxz + rho * Xzmu; + + var SzzNew = (1 - rho) * _Szz + rho * Ezzmu; + _Szz = 0.5 * (SzzNew + SzzNew.T); // Ensure symmetry + _Sxxtrace = (1 - rho) * _Sxxtrace + rho * Xcmu.to_type(ScalarType.Float64).cpu().ReadCpuDouble(0); + + _Sx = _Sx.MoveToOuterDisposeScope(); + _Sxz = _Sxz.MoveToOuterDisposeScope(); + _Szz = _Szz.MoveToOuterDisposeScope(); + _Iq.MoveToOuterDisposeScope(); + _Id.MoveToOuterDisposeScope(); + + // During burn-in, we do not update W or variance + if (_count <= _burnInCount) + { + _count++; + return; + } + + // M-step: Update W and variance + _mu = _Sx.MoveToOuterDisposeScope(); + var Lzz = linalg.cholesky(_Szz); // q x q + var WNew = cholesky_solve(_Sxz.T, Lzz).T; // d x q + // _W = WNew.MoveToOuterDisposeScope(); + var WUpdated = (1 - rho) * _W + rho * WNew; + _W = WUpdated.MoveToOuterDisposeScope(); + + var trWSt = WNew.mul(_Sxz).sum().to_type(ScalarType.Float64).cpu().ReadCpuDouble(0); + var newVar = (_Sxxtrace - trWSt) / d; + Variance = !double.IsNaN(newVar) && newVar > 0 ? newVar : Variance; + } + } + + public override Tensor Transform(Tensor data) + { + if (data.NumberOfElements == 0 || data.dim() < 2) + { + throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); + } + + var Xt = data.T; // n x d + var X = Xt - _mu; // n x d + var W = Components; // d x q + var M = W.T.matmul(W) + eye(NumComponents) * Variance; // q x q + var MInv = inverse(M); // q x q + return X.matmul(W).matmul(MInv); // n x q + } + } +} diff --git a/src/Bonsai.ML.PCA/PCAModelType.cs b/src/Bonsai.ML.PCA/PCAModelType.cs index 726b9081..30c6ed1b 100644 --- a/src/Bonsai.ML.PCA/PCAModelType.cs +++ b/src/Bonsai.ML.PCA/PCAModelType.cs @@ -13,6 +13,7 @@ namespace Bonsai.ML.PCA public enum PCAModelType { PCA, - ProbabilisticPCA + ProbabilisticPCA, + OnlinePPCA } } diff --git a/src/Bonsai.ML.PCA/Transform.cs b/src/Bonsai.ML.PCA/Transform.cs index 744c1948..2217c74c 100644 --- a/src/Bonsai.ML.PCA/Transform.cs +++ b/src/Bonsai.ML.PCA/Transform.cs @@ -47,5 +47,21 @@ public IObservable Process(IObservable> source) return TransformData(value.Item2, value.Item1); }); } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item1, value.Item2); + }); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item2, value.Item1); + }); + } } } From 1afd121ca29b5940bbd077bd6512e3435aac9a75 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 11 Aug 2025 10:29:37 +0100 Subject: [PATCH 05/20] Added synthetic data tests for new PCA package --- .../Bonsai.ML.PCA.Tests.csproj | 21 + .../Bonsai.ML.PCA.Tests/CreatePCATest.bonsai | 31 + tests/Bonsai.ML.PCA.Tests/FitPCATest.bonsai | 235 ++++ .../TransformPCATest.bonsai | 1015 +++++++++++++++++ 4 files changed, 1302 insertions(+) create mode 100644 tests/Bonsai.ML.PCA.Tests/Bonsai.ML.PCA.Tests.csproj create mode 100644 tests/Bonsai.ML.PCA.Tests/CreatePCATest.bonsai create mode 100644 tests/Bonsai.ML.PCA.Tests/FitPCATest.bonsai create mode 100644 tests/Bonsai.ML.PCA.Tests/TransformPCATest.bonsai diff --git a/tests/Bonsai.ML.PCA.Tests/Bonsai.ML.PCA.Tests.csproj b/tests/Bonsai.ML.PCA.Tests/Bonsai.ML.PCA.Tests.csproj new file mode 100644 index 00000000..e628dd82 --- /dev/null +++ b/tests/Bonsai.ML.PCA.Tests/Bonsai.ML.PCA.Tests.csproj @@ -0,0 +1,21 @@ + + + net8.0 + enable + enable + false + true + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/Bonsai.ML.PCA.Tests/CreatePCATest.bonsai b/tests/Bonsai.ML.PCA.Tests/CreatePCATest.bonsai new file mode 100644 index 00000000..ca948fc8 --- /dev/null +++ b/tests/Bonsai.ML.PCA.Tests/CreatePCATest.bonsai @@ -0,0 +1,31 @@ + + + + + + + 2 + PCA + 1 + 100 + 1E-05 + + + + PCA + + + + + + + + + + + \ No newline at end of file diff --git a/tests/Bonsai.ML.PCA.Tests/FitPCATest.bonsai b/tests/Bonsai.ML.PCA.Tests/FitPCATest.bonsai new file mode 100644 index 00000000..ca915de8 --- /dev/null +++ b/tests/Bonsai.ML.PCA.Tests/FitPCATest.bonsai @@ -0,0 +1,235 @@ + + + + + + LoadTrainingData + + + + + + + + Z:/home/nicholas/Downloads/xfads-2/workflows/data/y_train.bin + 0 + 0 + 50 + 1 + F64 + RowMajor + + + + + + 1000 + + + + + + + + + + + + + + Float32 + + + + + + + + + + + + + + + + + + + + 1 + + + + it.squeeze(null).T + + + + + 0 + + true + + + + + + + + + it.T + + + + + + + + + + + + + + + + + YTrain + + + LoadValidationData + + + + + + + + Z:/home/nicholas/Downloads/xfads-2/workflows/data/y_valid.bin + 0 + 0 + 50 + 1 + F64 + RowMajor + + + + + + 1000 + + + + + + + + + + + + + + Float32 + + + + + + + + + + + + + + + + + + + + 1 + + + + it.squeeze(null).T + + + + + 0 + + true + + + + + + + + + it.T + + + + + + + + + + + + + + + + + YValid + + + YTrain + + + + 2 + ProbabilisticPCA + 1 + 100 + 1E-05 + + + + + + + + + + Item2 + + + Components + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/Bonsai.ML.PCA.Tests/TransformPCATest.bonsai b/tests/Bonsai.ML.PCA.Tests/TransformPCATest.bonsai new file mode 100644 index 00000000..bf402169 --- /dev/null +++ b/tests/Bonsai.ML.PCA.Tests/TransformPCATest.bonsai @@ -0,0 +1,1015 @@ + + + + + + LoadTrainingData + + + + + + + + Z:/home/nicholas/Downloads/xfads-2/workflows/data/y_train.bin + 0 + 0 + 50 + 1 + F64 + RowMajor + + + + + + 1000 + + + + + + + + + + + + + + Float32 + + + + + + + + + + + + + + + + + + + + 1 + + + + it.squeeze(null).T + + + + + 0 + + true + + + + + + + + + it.T + + + + + + + + + + + + + + + + + YTrain + + + LoadValidationData + + + + + + + + Z:/home/nicholas/Downloads/xfads-2/workflows/data/y_valid.bin + 0 + 0 + 50 + 1 + F64 + RowMajor + + + + + + 1000 + + + + + + + + + + + + + + Float32 + + + + + + + + + + + + + + + + + + + + 1 + + + + it.squeeze(null).T + + + + + 0 + + true + + + + + + + + + it.T + + + + + + + + + + + + + + + + + YValid + + + + 0 + 50 + 1 + + + + + Float32 + + + YValid + + + YTrain + + + + 2 + PCA + 1 + 100 + 1E-05 + 0.1 + 0.9 + + + + + + + + + + + Item2 + + + + + + + + + + + 333 + 50 + 2 + + + + + + 0,:,0 + + + + Float32 + + + + + + DataStruct + + + + Source1 + + + Item1 + + + + + + Item2 + + + + + + + + + new (Item1 as X, Item2 as Y) + + + + + + + + + + + + + + + + + Circle + 1 + + + + + + + + + Components + + + + 0 + + + + Float32 + + + + + + + + + + 0 + 50 + 1 + + + + + Float32 + + + YValid + + + YTrain + + + + 2 + ProbabilisticPCA + 1 + 100 + 1E-05 + 0.1 + 0.9 + + + + + + + + + + + Item2 + + + + + + + + + + + 333 + 50 + 2 + + + + + + 0,:,0 + + + + Float32 + + + + + + DataStruct + + + + Source1 + + + Item1 + + + + + + Item2 + + + + + + + + + new (Item1 as X, Item2 as Y) + + + + + + + + + + + + + + + + + Circle + 1 + + + + + + + + + Components + + + + 0 + + + + Float32 + + + + + + + + + + 0 + 50 + 1 + + + + + Float32 + + + YValid + + + + 1 + + + + it.shape[1] + + + + + + + + + 0 + 16650 + 1 + + + + + Int32 + + + + + + Source1 + + + + + + + + + YValid + + + + 1 + + + + + + + Source1 + + + + + + + + + + + + + + + + it.unsqueeze(1) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + YTrain + + + + 1 + + + + it.shape[1] + + + + + + + + + 0 + 33350 + 1 + + + + + Int32 + + + + + + Source1 + + + + + + + + + YTrain + + + + 1 + + + + + + + Source1 + + + + + + + + + + + + + + + + it.unsqueeze(1) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 2 + OnlinePPCA + 1 + 100 + 1E-05 + 0.1 + + + + + + + + + + + + + 1 + + + + Item2 + + + + + + + + + + + + + + + Source1 + + + + 50 + + + + + + + + 0 + + + + + + + + + + + + + + Float32 + + + + + + DataStruct + + + + Source1 + + + Item1 + + + + + + Item2 + + + + + + + + + new (Item1 as X, Item2 as Y) + + + + + + + + + + + + + + + + + Circle + 1 + + + + + + + + + + YTrain + + + + + + 2 + ProbabilisticPCA + 1 + 10 + 0.0001 + 0.1 + 0.9 + + + + + + + + + + + + + + + + + Item2 + + + + + LogLikelihood + + + + + + + + + + + + + 0 + 55 + 1 + + + + + + + Float32 + + + + + Float32 + + + + + + + + + + + + + Source1 + + + Item1 + + + + + + Item2 + + + + + + + + + new (Item1 as X, Item2 as Y) + + + + + + + + + + + + + + + + + + + Circle + 1 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file From e9b7587a47fad65aa1fcc5878bb87977a2fc1ec5 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 11 Aug 2025 15:50:44 +0100 Subject: [PATCH 06/20] Updated online ppca method to return closely what is expected from batch pca and ppca --- src/Bonsai.ML.PCA/CreatePCA.cs | 44 ++- src/Bonsai.ML.PCA/IPCABaseModel.cs | 2 + src/Bonsai.ML.PCA/OnlinePPCA.cs | 257 +++++++++------ src/Bonsai.ML.PCA/PCA.cs | 8 +- src/Bonsai.ML.PCA/PCABaseModel.cs | 9 +- src/Bonsai.ML.PCA/PPCA.cs | 6 +- .../TransformPCATest.bonsai | 312 ++++++------------ 7 files changed, 303 insertions(+), 335 deletions(-) diff --git a/src/Bonsai.ML.PCA/CreatePCA.cs b/src/Bonsai.ML.PCA/CreatePCA.cs index e729fcfc..4c3ddf82 100644 --- a/src/Bonsai.ML.PCA/CreatePCA.cs +++ b/src/Bonsai.ML.PCA/CreatePCA.cs @@ -19,6 +19,10 @@ public class CreatePCA : ZeroArgumentExpressionBuilder { public int NumComponents { get; set; } = 2; + [XmlIgnore] + public Device Device { get; set; } + public ScalarType? ScalarType { get; set; } + [RefreshProperties(RefreshProperties.All)] public PCAModelType ModelType { get; set; } = PCAModelType.PCA; @@ -28,7 +32,8 @@ public class CreatePCA : ZeroArgumentExpressionBuilder public double? Rho { get; set; } = 0.1; public double? Kappa { get; set; } = 0.9; - public int? BurnInCount { get; set; } = null; + public int? TimeOffset { get; set; } = null; + public int? ReorthogonalizePeriod { get; set; } = null; [XmlIgnore] public Generator? Generator { get; set; } = null; @@ -36,6 +41,8 @@ public class CreatePCA : ZeroArgumentExpressionBuilder internal IEnumerable GetModelProperties() { yield return nameof(NumComponents); + yield return nameof(Device); + yield return nameof(ScalarType); yield return nameof(ModelType); if (ModelType == PCAModelType.ProbabilisticPCA) @@ -51,7 +58,8 @@ internal IEnumerable GetModelProperties() yield return nameof(InitialVariance); yield return nameof(Rho); yield return nameof(Kappa); - yield return nameof(BurnInCount); + yield return nameof(TimeOffset); + yield return nameof(ReorthogonalizePeriod); yield return nameof(Generator); } } @@ -60,20 +68,28 @@ private static PCABaseModel CreateModel(CreatePCA instance) { return instance.ModelType switch { - PCAModelType.PCA => new PCA(instance.NumComponents), + PCAModelType.PCA => new PCA( + numComponents: instance.NumComponents, + device: instance.Device, + scalarType: instance.ScalarType), PCAModelType.ProbabilisticPCA => new PPCA( - instance.NumComponents, - instance.InitialVariance, - instance.Generator, - instance.Iterations, - instance.Tolerance), + numComponents: instance.NumComponents, + device: instance.Device, + scalarType: instance.ScalarType, + initialVariance: instance.InitialVariance, + generator: instance.Generator, + iterations: instance.Iterations, + tolerance: instance.Tolerance), PCAModelType.OnlinePPCA => new OnlinePPCA( - instance.NumComponents, - instance.InitialVariance, - instance.Generator, - instance.Rho, - instance.Kappa, - instance.BurnInCount), + numComponents: instance.NumComponents, + device: instance.Device, + scalarType: instance.ScalarType, + initialVariance: instance.InitialVariance, + generator: instance.Generator, + rho: instance.Rho, + kappa: instance.Kappa, + timeOffset: instance.TimeOffset, + reorthogonalizePeriod: instance.ReorthogonalizePeriod), _ => throw new NotSupportedException($"Model type {instance.ModelType} is not supported."), }; } diff --git a/src/Bonsai.ML.PCA/IPCABaseModel.cs b/src/Bonsai.ML.PCA/IPCABaseModel.cs index 17d5d603..84d7364c 100644 --- a/src/Bonsai.ML.PCA/IPCABaseModel.cs +++ b/src/Bonsai.ML.PCA/IPCABaseModel.cs @@ -12,6 +12,8 @@ namespace Bonsai.ML.PCA { public interface IPCABaseModel { + public Device Device { get; } + public ScalarType ScalarType { get; } public abstract void Fit(Tensor data); public abstract Tensor Transform(Tensor data); public abstract Tensor FitAndTransform(Tensor data); diff --git a/src/Bonsai.ML.PCA/OnlinePPCA.cs b/src/Bonsai.ML.PCA/OnlinePPCA.cs index 7c1daea6..a8193465 100644 --- a/src/Bonsai.ML.PCA/OnlinePPCA.cs +++ b/src/Bonsai.ML.PCA/OnlinePPCA.cs @@ -14,30 +14,39 @@ public class OnlinePPCA : PCABaseModel { public double? Rho { get; private set; } public double? Kappa { get; private set; } - public int? BurnInCount { get; private set; } - public double Variance { get; private set; } + public double Variance => _sigma2.to_type(ScalarType.Float64).item(); + public int? ReorthogonalizePeriod { get; private set; } + public int? TimeOffset { get; private set; } public Tensor Components => _W; public Generator Generator { get; private set; } + private Tensor _mu; private Tensor _W; private Tensor _Iq; - private Tensor _Id; - private Tensor _Sx; - private double _Sxxtrace; - private Tensor _Sxz; - private Tensor _Szz; + private Tensor _mx; // E[x] + private Tensor _Cxz; // E[xz^T] + private Tensor _mz; // E[z] + private Tensor _Czz; // E[zz^T] + private Tensor _sxx; // E[||x||^2] + private Tensor _sigma2; // Variance + private bool _initializedParameters = false; - private int _count = 0; - private int _burnInCount = 0; - private Func UpdateRho; + private readonly Func UpdateSchedule; + private readonly Action Reorthogonalize; + private int _stepCount = 0; public OnlinePPCA(int numComponents, + Device? device = null, + ScalarType? scalarType = ScalarType.Float32, double initialVariance = 1.0, Generator? generator = null, double? rho = 0.1, double? kappa = null, - int? burnInCount = null - ) : base(numComponents) + int? timeOffset = 3000, + int? reorthogonalizePeriod = null + ) : base(numComponents, + device, + scalarType) { if (initialVariance <= 0) { @@ -56,7 +65,7 @@ public OnlinePPCA(int numComponents, if (rho.HasValue) { - UpdateRho = () => rho.Value; + UpdateSchedule = () => rho.Value; if (rho <= 0 || rho >= 1) { throw new ArgumentException("Rho must be in the range (0, 1).", nameof(rho)); @@ -65,25 +74,60 @@ public OnlinePPCA(int numComponents, if (kappa.HasValue) { - UpdateRho = () => Math.Pow(_count + 10, -kappa.Value); + if (timeOffset is null or <= 0) + { + throw new ArgumentException("Time offset must be a positive integer.", nameof(timeOffset)); + } + + UpdateSchedule = () => Math.Pow(_stepCount + timeOffset.Value, -kappa.Value); if (kappa <= 0.5 || kappa > 1) { throw new ArgumentException("Kappa must be in the range (0.5, 1].", nameof(kappa)); } } - if (burnInCount.HasValue && burnInCount <= 0) + if (reorthogonalizePeriod.HasValue) { - throw new ArgumentException("Warmup iterations must be greater than zero.", nameof(burnInCount)); + Reorthogonalize = () => + { + if (_initializedParameters && _stepCount % ReorthogonalizePeriod == 0) + { + var (U, S, Vh) = svd(_W, fullMatrices: false); // W = U S V^T + var R = Vh.T; + _W = U.matmul(diag(S)); // keep per-component scale + _Cxz = _Cxz.matmul(R.T); + _Czz = R.matmul(_Czz).matmul(R.T); + _mz = R.matmul(_mz); + } + }; + } + else + { + Reorthogonalize = () => { }; } - - Variance = initialVariance; Generator = generator ?? manual_seed(0); Rho = rho; Kappa = kappa; - BurnInCount = burnInCount; - _burnInCount = burnInCount ?? 0; + ReorthogonalizePeriod = reorthogonalizePeriod; + TimeOffset = timeOffset; + _sigma2 = initialVariance; + } + + private Tensor InvertSPD(Tensor spdMatrix, Tensor rhs, double regularization = 1e-6) + { + var diagShape = spdMatrix.size(-1); + Tensor L; + try + { + L = linalg.cholesky(spdMatrix); + } + catch (Exception) + { + var regularizer = eye(diagShape, device: Device, dtype: ScalarType) * regularization; + L = linalg.cholesky(spdMatrix + regularizer); + } + return cholesky_solve(rhs, L); } public override void Fit(Tensor data) @@ -94,84 +138,100 @@ public override void Fit(Tensor data) throw new ArgumentException("Input data must be a 2D tensor."); } - var Xt = data.T; // n x d - - // Initialize dimensions - var q = NumComponents; - var n = Xt.size(0); - var d = Xt.size(1); - - // Initialize parameters - if (!_initializedParameters) + using (no_grad()) { - _mu = zeros(1, d); // 1 x d - var wScale = Math.Sqrt(Variance / d); - _W = randn(d, q, generator: Generator) * wScale; // d x q - _Iq = eye(q); // q x q - _Id = eye(d); // d x d - _Sx = zeros(1, d); // d - _Sxz = zeros(d, q); // d x q - _Szz = zeros(q, q); // q x q - _Sxxtrace = 0.0; // scalar - - _burnInCount = Math.Max(_burnInCount, (int)d); - _initializedParameters = true; - } - - // Update rho - var rho = UpdateRho(); - - using (var _ = NewDisposeScope()) - { - // E-step using current parameters - var M = _W.T.matmul(_W) + _Iq * Variance; // q x q - M = 0.5 * (M + M.T); // Ensure symmetry - var Xc = Xt - _mu; // n x d - var Lm = linalg.cholesky(M); // q x q - var XcW = Xc.matmul(_W); // n x q - var Ez = cholesky_solve(XcW.T, Lm).T; // n x q - var EzInv = cholesky_inverse(Lm); // q x q - var Ezzmu = (Variance * EzInv) + Ez.T.matmul(Ez); - - // Batch statistics - var Xmu = Xt.mean([0], keepdim: true); // 1 x d - var Xzmu = Xc.T.matmul(Ez); // d x q - var Xcmu = Xc.pow(2).sum(1).mean(); - - // Update parameters with new statistics - _Sx = (1 - rho) * _Sx + rho * Xmu; - _Sxz = (1 - rho) * _Sxz + rho * Xzmu; - - var SzzNew = (1 - rho) * _Szz + rho * Ezzmu; - _Szz = 0.5 * (SzzNew + SzzNew.T); // Ensure symmetry - _Sxxtrace = (1 - rho) * _Sxxtrace + rho * Xcmu.to_type(ScalarType.Float64).cpu().ReadCpuDouble(0); - - _Sx = _Sx.MoveToOuterDisposeScope(); - _Sxz = _Sxz.MoveToOuterDisposeScope(); - _Szz = _Szz.MoveToOuterDisposeScope(); - _Iq.MoveToOuterDisposeScope(); - _Id.MoveToOuterDisposeScope(); - - // During burn-in, we do not update W or variance - if (_count <= _burnInCount) + using (NewDisposeScope()) { - _count++; - return; - } - // M-step: Update W and variance - _mu = _Sx.MoveToOuterDisposeScope(); - var Lzz = linalg.cholesky(_Szz); // q x q - var WNew = cholesky_solve(_Sxz.T, Lzz).T; // d x q - // _W = WNew.MoveToOuterDisposeScope(); - var WUpdated = (1 - rho) * _W + rho * WNew; - _W = WUpdated.MoveToOuterDisposeScope(); - - var trWSt = WNew.mul(_Sxz).sum().to_type(ScalarType.Float64).cpu().ReadCpuDouble(0); - var newVar = (_Sxxtrace - trWSt) / d; - Variance = !double.IsNaN(newVar) && newVar > 0 ? newVar : Variance; + _stepCount++; + var rho = UpdateSchedule(); + + var Xt = data.T; // n x d + + // Initialize dimensions + var q = NumComponents; + var n = Xt.size(0); + var d = Xt.size(1); + + // Initialize parameters + if (!_initializedParameters) + { + _mu = zeros(d, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); + var randW = randn(d, q, generator: Generator, device: Device, dtype: ScalarType); + var orthonormalBases = linalg.qr(randW).Q; + _W = (orthonormalBases * _sigma2).MoveToOuterDisposeScope(); // d x q + _Iq = eye(q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q x q + + _mx = zeros(d, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // d + _Cxz = zeros(d, q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // d x q + _mz = zeros(q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q + _Czz = zeros(q, q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q x q + _sxx = zeros(1, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // scalar + + _initializedParameters = true; + } + + // Covariance matrix + var cov = _Iq * _sigma2; + + // Center data using current mean + var Xc = Xt - _mu; + + // E-step + var M = _W.T.matmul(_W) + cov; + var MInv = InvertSPD(M, _Iq); + + var XcW = Xc.matmul(_W); + var EzT = InvertSPD(M, XcW.T); + var Ez = EzT.T; + + // Update statistics + var mx = Xt.mean([0]); + var sxx = Xt.pow(2).sum(1).mean(); + var Cxz = Xt.T.matmul(Ez) / n; + var mz = Ez.mean([0]); + var Czz = EzT.matmul(Ez) / n + _sigma2 * MInv; + + // Update parameters + var rhoFactor = 1 - rho; + _mx = (rhoFactor * _mx + rho * mx).MoveToOuterDisposeScope(); + _Cxz = rhoFactor * _Cxz + rho * Cxz; + _mz = rhoFactor * _mz + rho * mz; + _sxx = (rhoFactor * _sxx + rho * sxx).MoveToOuterDisposeScope(); + _Czz = rhoFactor * _Czz + rho * Czz; + + // Update mean + _mu = _mx.MoveToOuterDisposeScope(); + + // Centered statistics + var Sxz = _Cxz - outer(_mu, _mz); + var Szz = _Czz; + var Sxx = _sxx - _mu.dot(_mu); + + // M-step + _W = InvertSPD(Szz, Sxz.T).T; + + // Reorthogonalize W + Reorthogonalize(); + + // Reorder components based on the strength of the components + var strength = sum(_W * _W, dim: 0); + var indices = argsort(strength, descending: true); + _W = _W.index_select(1, indices).MoveToOuterDisposeScope(); + _Cxz = _Cxz.index_select(1, indices).MoveToOuterDisposeScope(); + _mz = _mz.index_select(0, indices).MoveToOuterDisposeScope(); + _Czz = _Czz.index_select(0, indices).index_select(1, indices).MoveToOuterDisposeScope(); + + Sxz = _Cxz - outer(_mu, _mz); + Szz = _Czz; + + // Update variance + _sigma2 = ((Sxx - 2 * trace(_W.T.matmul(Sxz)) + trace(_W.T.matmul(_W).matmul(Szz))) / (double)d) + .clamp_min(0.0) + .MoveToOuterDisposeScope(); + } } - } + } public override Tensor Transform(Tensor data) { @@ -181,11 +241,10 @@ public override Tensor Transform(Tensor data) } var Xt = data.T; // n x d - var X = Xt - _mu; // n x d - var W = Components; // d x q - var M = W.T.matmul(W) + eye(NumComponents) * Variance; // q x q - var MInv = inverse(M); // q x q - return X.matmul(W).matmul(MInv); // n x q + var Xc = Xt - _mu; // n x d + var M = _W.T.matmul(_W) + _Iq * _sigma2; // q x q + var XcW = Xc.matmul(_W); + return InvertSPD(M, XcW.T).T; // n x q } } } diff --git a/src/Bonsai.ML.PCA/PCA.cs b/src/Bonsai.ML.PCA/PCA.cs index 734306ae..7b8883b4 100644 --- a/src/Bonsai.ML.PCA/PCA.cs +++ b/src/Bonsai.ML.PCA/PCA.cs @@ -18,7 +18,13 @@ public class PCA : PCABaseModel public Tensor Components { get; private set; } = empty(0); private bool _isFitted = false; - public PCA(int numComponents) : base(numComponents) { } + public PCA(int numComponents, + Device? device = null, + ScalarType? scalarType = ScalarType.Float32) + : base(numComponents, + device, + scalarType) + { } public override void Fit(Tensor data) { diff --git a/src/Bonsai.ML.PCA/PCABaseModel.cs b/src/Bonsai.ML.PCA/PCABaseModel.cs index 044a8eb6..ed77dd41 100644 --- a/src/Bonsai.ML.PCA/PCABaseModel.cs +++ b/src/Bonsai.ML.PCA/PCABaseModel.cs @@ -14,7 +14,12 @@ public abstract class PCABaseModel : IPCABaseModel { public int NumComponents { get; private set; } - public PCABaseModel(int numComponents) + public Device Device { get; private set; } + public ScalarType ScalarType { get; private set; } + + public PCABaseModel(int numComponents, + Device? device = null, + ScalarType? scalarType = null) { if (numComponents <= 0) { @@ -22,6 +27,8 @@ public PCABaseModel(int numComponents) } NumComponents = numComponents; + Device = device ?? CPU; + ScalarType = scalarType ?? ScalarType.Float32; } public abstract void Fit(Tensor data); diff --git a/src/Bonsai.ML.PCA/PPCA.cs b/src/Bonsai.ML.PCA/PPCA.cs index 1f713b3c..5c1e2109 100644 --- a/src/Bonsai.ML.PCA/PPCA.cs +++ b/src/Bonsai.ML.PCA/PPCA.cs @@ -21,11 +21,15 @@ public class PPCA : PCABaseModel private bool _isFitted = false; public PPCA(int numComponents, + Device? device = null, + ScalarType? scalarType = ScalarType.Float32, double initialVariance = 1.0, Generator? generator = null, int iterations = 100, double tolerance = 1e-5 - ) : base(numComponents) + ) : base(numComponents, + device, + scalarType) { if (initialVariance < 0) { diff --git a/tests/Bonsai.ML.PCA.Tests/TransformPCATest.bonsai b/tests/Bonsai.ML.PCA.Tests/TransformPCATest.bonsai index bf402169..e6fa55b9 100644 --- a/tests/Bonsai.ML.PCA.Tests/TransformPCATest.bonsai +++ b/tests/Bonsai.ML.PCA.Tests/TransformPCATest.bonsai @@ -219,13 +219,15 @@ 2 + PCA 1 100 1E-05 0.1 0.9 - + + @@ -312,23 +314,6 @@ - - Components - - - - 0 - - - - Float32 - - - - - - - 0 @@ -349,13 +334,15 @@ 2 + ProbabilisticPCA 1 100 1E-05 0.1 0.9 - + + @@ -442,34 +429,6 @@ - - Components - - - - 0 - - - - Float32 - - - - - - - - - - 0 - 50 - 1 - - - - - Float32 - YValid @@ -575,6 +534,10 @@ + + YValidStream + + YTrain @@ -683,13 +646,15 @@ 2 + OnlinePPCA - 1 + 0.1 100 1E-05 - 0.1 - - + + 0.7 + 3000 + 500 @@ -706,20 +671,54 @@ Item2 + + FittedOnlinePPCA + + + Components + - + + 0 + + + + Float32 - + - + + + + + 0 + 50 + 1 + + + + + Float32 + + + FittedOnlinePPCA - - Source1 + + YValidStream + + + FittedOnlinePPCA + + + + + + @@ -737,13 +736,30 @@ - - + + + + + + + + + 1 + 50 + 2 + + + + + + 0,:,0 + + Float32 @@ -799,129 +815,6 @@ - - - YTrain - - - - - - 2 - ProbabilisticPCA - 1 - 10 - 0.0001 - 0.1 - 0.9 - - - - - - - - - - - - - - - - - Item2 - - - - - LogLikelihood - - - - - - - - - - - - - 0 - 55 - 1 - - - - - - - Float32 - - - - - Float32 - - - - - - - - - - - - - Source1 - - - Item1 - - - - - - Item2 - - - - - - - - - new (Item1 as X, Item2 as Y) - - - - - - - - - - - - - - - - - - - Circle - 1 - - - - - - - - @@ -934,7 +827,6 @@ - @@ -943,39 +835,39 @@ - - - + + + + - - - - + + + + - - - + + - + - + - - + + - + @@ -983,33 +875,15 @@ - - - + + - - + - + - - - - - - - - - - - - - - - - \ No newline at end of file From 84d9a94ef8be3725e7700f77f4373c98b241f19f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 11 Aug 2025 16:34:42 +0100 Subject: [PATCH 07/20] Refactored `OnlinePPCA` to enforce non-nullable ReorthogonalizePeriod and move reorthogonalization logic directly inside the `Fit` method --- src/Bonsai.ML.PCA/OnlinePPCA.cs | 207 +++++++++++++++----------------- 1 file changed, 99 insertions(+), 108 deletions(-) diff --git a/src/Bonsai.ML.PCA/OnlinePPCA.cs b/src/Bonsai.ML.PCA/OnlinePPCA.cs index a8193465..954ae382 100644 --- a/src/Bonsai.ML.PCA/OnlinePPCA.cs +++ b/src/Bonsai.ML.PCA/OnlinePPCA.cs @@ -15,7 +15,7 @@ public class OnlinePPCA : PCABaseModel public double? Rho { get; private set; } public double? Kappa { get; private set; } public double Variance => _sigma2.to_type(ScalarType.Float64).item(); - public int? ReorthogonalizePeriod { get; private set; } + public int ReorthogonalizePeriod { get; private set; } public int? TimeOffset { get; private set; } public Tensor Components => _W; public Generator Generator { get; private set; } @@ -32,9 +32,9 @@ public class OnlinePPCA : PCABaseModel private bool _initializedParameters = false; private readonly Func UpdateSchedule; - private readonly Action Reorthogonalize; private int _stepCount = 0; - + private readonly bool _reorthogonalize = false; + public OnlinePPCA(int numComponents, Device? device = null, ScalarType? scalarType = ScalarType.Float32, @@ -88,28 +88,13 @@ public OnlinePPCA(int numComponents, if (reorthogonalizePeriod.HasValue) { - Reorthogonalize = () => - { - if (_initializedParameters && _stepCount % ReorthogonalizePeriod == 0) - { - var (U, S, Vh) = svd(_W, fullMatrices: false); // W = U S V^T - var R = Vh.T; - _W = U.matmul(diag(S)); // keep per-component scale - _Cxz = _Cxz.matmul(R.T); - _Czz = R.matmul(_Czz).matmul(R.T); - _mz = R.matmul(_mz); - } - }; - } - else - { - Reorthogonalize = () => { }; + _reorthogonalize = true; + ReorthogonalizePeriod = reorthogonalizePeriod.Value; } Generator = generator ?? manual_seed(0); Rho = rho; Kappa = kappa; - ReorthogonalizePeriod = reorthogonalizePeriod; TimeOffset = timeOffset; _sigma2 = initialVariance; } @@ -139,99 +124,105 @@ public override void Fit(Tensor data) } using (no_grad()) + using (NewDisposeScope()) { - using (NewDisposeScope()) + + _stepCount++; + var rho = UpdateSchedule(); + + var Xt = data.T; // n x d + + // Initialize dimensions + var q = NumComponents; + var n = Xt.size(0); + var d = Xt.size(1); + + // Initialize parameters + if (!_initializedParameters) { + _mu = zeros(d, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); + var randW = randn(d, q, generator: Generator, device: Device, dtype: ScalarType); + var orthonormalBases = linalg.qr(randW).Q; + _W = (orthonormalBases * _sigma2).MoveToOuterDisposeScope(); // d x q + _Iq = eye(q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q x q + + _mx = zeros(d, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // d + _Cxz = zeros(d, q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // d x q + _mz = zeros(q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q + _Czz = zeros(q, q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q x q + _sxx = zeros(1, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // scalar + + _initializedParameters = true; + } + + // Covariance matrix + var cov = _Iq * _sigma2; + + // Center data using current mean + var Xc = Xt - _mu; - _stepCount++; - var rho = UpdateSchedule(); - - var Xt = data.T; // n x d - - // Initialize dimensions - var q = NumComponents; - var n = Xt.size(0); - var d = Xt.size(1); - - // Initialize parameters - if (!_initializedParameters) - { - _mu = zeros(d, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); - var randW = randn(d, q, generator: Generator, device: Device, dtype: ScalarType); - var orthonormalBases = linalg.qr(randW).Q; - _W = (orthonormalBases * _sigma2).MoveToOuterDisposeScope(); // d x q - _Iq = eye(q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q x q - - _mx = zeros(d, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // d - _Cxz = zeros(d, q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // d x q - _mz = zeros(q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q - _Czz = zeros(q, q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q x q - _sxx = zeros(1, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // scalar - - _initializedParameters = true; - } - - // Covariance matrix - var cov = _Iq * _sigma2; - - // Center data using current mean - var Xc = Xt - _mu; - - // E-step - var M = _W.T.matmul(_W) + cov; - var MInv = InvertSPD(M, _Iq); - - var XcW = Xc.matmul(_W); - var EzT = InvertSPD(M, XcW.T); - var Ez = EzT.T; - - // Update statistics - var mx = Xt.mean([0]); - var sxx = Xt.pow(2).sum(1).mean(); - var Cxz = Xt.T.matmul(Ez) / n; - var mz = Ez.mean([0]); - var Czz = EzT.matmul(Ez) / n + _sigma2 * MInv; - - // Update parameters - var rhoFactor = 1 - rho; - _mx = (rhoFactor * _mx + rho * mx).MoveToOuterDisposeScope(); - _Cxz = rhoFactor * _Cxz + rho * Cxz; - _mz = rhoFactor * _mz + rho * mz; - _sxx = (rhoFactor * _sxx + rho * sxx).MoveToOuterDisposeScope(); - _Czz = rhoFactor * _Czz + rho * Czz; - - // Update mean - _mu = _mx.MoveToOuterDisposeScope(); - - // Centered statistics - var Sxz = _Cxz - outer(_mu, _mz); - var Szz = _Czz; - var Sxx = _sxx - _mu.dot(_mu); - - // M-step - _W = InvertSPD(Szz, Sxz.T).T; - - // Reorthogonalize W - Reorthogonalize(); - - // Reorder components based on the strength of the components - var strength = sum(_W * _W, dim: 0); - var indices = argsort(strength, descending: true); - _W = _W.index_select(1, indices).MoveToOuterDisposeScope(); - _Cxz = _Cxz.index_select(1, indices).MoveToOuterDisposeScope(); - _mz = _mz.index_select(0, indices).MoveToOuterDisposeScope(); - _Czz = _Czz.index_select(0, indices).index_select(1, indices).MoveToOuterDisposeScope(); - - Sxz = _Cxz - outer(_mu, _mz); - Szz = _Czz; - - // Update variance - _sigma2 = ((Sxx - 2 * trace(_W.T.matmul(Sxz)) + trace(_W.T.matmul(_W).matmul(Szz))) / (double)d) - .clamp_min(0.0) - .MoveToOuterDisposeScope(); + // E-step + var M = _W.T.matmul(_W) + cov; + var MInv = InvertSPD(M, _Iq); + + var XcW = Xc.matmul(_W); + var EzT = InvertSPD(M, XcW.T); + var Ez = EzT.T; + + // Update statistics + var mx = Xt.mean([0]); + var sxx = Xt.pow(2).sum(dim: 1).mean(); + var Cxz = Xt.T.matmul(Ez) / n; + var mz = Ez.mean([0]); + var Czz = EzT.matmul(Ez) / n + _sigma2 * MInv; + + // Update parameters + var rhoFactor = 1 - rho; + _mx = (rhoFactor * _mx + rho * mx).MoveToOuterDisposeScope(); + _Cxz = (rhoFactor * _Cxz + rho * Cxz).MoveToOuterDisposeScope(); + _mz = (rhoFactor * _mz + rho * mz).MoveToOuterDisposeScope(); + _sxx = (rhoFactor * _sxx + rho * sxx).MoveToOuterDisposeScope(); + _Czz = (rhoFactor * _Czz + rho * Czz).MoveToOuterDisposeScope(); + + // Update mean + _mu = _mx.MoveToOuterDisposeScope(); + + // Centered statistics + var Sxz = _Cxz - outer(_mu, _mz); + var Szz = _Czz; + var Sxx = _sxx - _mu.dot(_mu); + + // M-step + var WNew = InvertSPD(Szz, Sxz.T).T; + + if (_reorthogonalize && + _stepCount % ReorthogonalizePeriod == 0) + { + var (U, S, Vh) = svd(WNew, fullMatrices: false); + var R = Vh.T; + WNew = U.matmul(diag(S)); + _Cxz = _Cxz.matmul(R.T); + _Czz = R.matmul(_Czz).matmul(R.T); + _mz = R.matmul(_mz); } + + // Reorder components based on the strength of the components + var strength = sum(WNew * WNew, dim: 0); + var indices = argsort(strength, descending: true); + _W = WNew.index_select(1, indices).MoveToOuterDisposeScope(); + _Cxz = _Cxz.index_select(1, indices).MoveToOuterDisposeScope(); + _mz = _mz.index_select(0, indices).MoveToOuterDisposeScope(); + _Czz = _Czz.index_select(0, indices).index_select(1, indices).MoveToOuterDisposeScope(); + + Sxz = _Cxz - outer(_mu, _mz); + Szz = _Czz; + + // Update variance + _sigma2 = ((Sxx - 2 * trace(_W.T.matmul(Sxz)) + trace(_W.T.matmul(_W).matmul(Szz))) / (double)d) + .clamp_min(0.0) + .MoveToOuterDisposeScope(); } - } + } public override Tensor Transform(Tensor data) { From bfd8857f9e1f586486cb567c85957ddc7c7c352a Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 11 Aug 2025 18:22:05 +0100 Subject: [PATCH 08/20] Added PCA data reconstruction functionality --- src/Bonsai.ML.PCA/IPCABaseModel.cs | 1 + src/Bonsai.ML.PCA/OnlinePPCA.cs | 24 ++++++++++- src/Bonsai.ML.PCA/PCA.cs | 24 ++++++++++- src/Bonsai.ML.PCA/PCABaseModel.cs | 2 + src/Bonsai.ML.PCA/PPCA.cs | 19 +++++++++ src/Bonsai.ML.PCA/Reconstruct.cs | 67 ++++++++++++++++++++++++++++++ 6 files changed, 135 insertions(+), 2 deletions(-) create mode 100644 src/Bonsai.ML.PCA/Reconstruct.cs diff --git a/src/Bonsai.ML.PCA/IPCABaseModel.cs b/src/Bonsai.ML.PCA/IPCABaseModel.cs index 84d7364c..3e7bbad5 100644 --- a/src/Bonsai.ML.PCA/IPCABaseModel.cs +++ b/src/Bonsai.ML.PCA/IPCABaseModel.cs @@ -17,5 +17,6 @@ public interface IPCABaseModel public abstract void Fit(Tensor data); public abstract Tensor Transform(Tensor data); public abstract Tensor FitAndTransform(Tensor data); + public abstract Tensor Reconstruct(Tensor data); } } diff --git a/src/Bonsai.ML.PCA/OnlinePPCA.cs b/src/Bonsai.ML.PCA/OnlinePPCA.cs index 954ae382..dc27d2f9 100644 --- a/src/Bonsai.ML.PCA/OnlinePPCA.cs +++ b/src/Bonsai.ML.PCA/OnlinePPCA.cs @@ -235,7 +235,29 @@ public override Tensor Transform(Tensor data) var Xc = Xt - _mu; // n x d var M = _W.T.matmul(_W) + _Iq * _sigma2; // q x q var XcW = Xc.matmul(_W); - return InvertSPD(M, XcW.T).T; // n x q + return Utils.InvertSPD(M, XcW.T).T; // n x q + } + + public override Tensor Reconstruct(Tensor data) + { + if (data.NumberOfElements == 0 || data.dim() < 2) + { + throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); + } + + if (!_initializedParameters) + { + throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); + } + + var Xt = data.T; // n x d + var Xc = Xt - _mu; // n x d + var M = _W.T.matmul(_W) + _Iq * _sigma2; // q x q + var XcW = Xc.matmul(_W); + var EzT = Utils.InvertSPD(M, XcW.T); + var Ez = EzT.T; + + return Ez.matmul(_W.T) + _mu.T; // n x d } } } diff --git a/src/Bonsai.ML.PCA/PCA.cs b/src/Bonsai.ML.PCA/PCA.cs index 7b8883b4..12799f08 100644 --- a/src/Bonsai.ML.PCA/PCA.cs +++ b/src/Bonsai.ML.PCA/PCA.cs @@ -62,7 +62,29 @@ public override Tensor Transform(Tensor data) throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); } - return data.T.matmul(Components); + var X = data.T; + var mean = X.mean([0], keepdim: true); // 1 x d + var Xc = X - mean; + return Xc.matmul(Components); // n x q + } + + public override Tensor Reconstruct(Tensor data) + { + if (data.NumberOfElements == 0 || data.dim() < 2) + { + throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); + } + + if (!_isFitted) + { + throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); + } + + var X = data.T; + var mean = X.mean([0], keepdim: true); // 1 x d + var Xc = X - mean; + var reconstructed = Transform(Xc); + return reconstructed.matmul(Components.T) + mean.T; } } } diff --git a/src/Bonsai.ML.PCA/PCABaseModel.cs b/src/Bonsai.ML.PCA/PCABaseModel.cs index ed77dd41..f6617ef8 100644 --- a/src/Bonsai.ML.PCA/PCABaseModel.cs +++ b/src/Bonsai.ML.PCA/PCABaseModel.cs @@ -38,5 +38,7 @@ public virtual Tensor FitAndTransform(Tensor data) Fit(data); return Transform(data); } + + public abstract Tensor Reconstruct(Tensor data); } } diff --git a/src/Bonsai.ML.PCA/PPCA.cs b/src/Bonsai.ML.PCA/PPCA.cs index 5c1e2109..8d096af0 100644 --- a/src/Bonsai.ML.PCA/PPCA.cs +++ b/src/Bonsai.ML.PCA/PPCA.cs @@ -173,5 +173,24 @@ public override Tensor Transform(Tensor data) var MInv = inverse(M); // q x q return X.matmul(W).matmul(MInv); // n x q } + + public override Tensor Reconstruct(Tensor data) + { + if (data.NumberOfElements == 0 || data.dim() < 2) + { + throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); + } + + if (!_isFitted) + { + throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); + } + + var Xt = data.T; + var mean = Xt.mean([0], keepdim: true); // 1 x d + var Xc = Xt - mean; // n x d + var W = Components; // d x q + return Xc.matmul(W).matmul(W.T) + mean.T; // n x d + } } } diff --git a/src/Bonsai.ML.PCA/Reconstruct.cs b/src/Bonsai.ML.PCA/Reconstruct.cs new file mode 100644 index 00000000..f098cf30 --- /dev/null +++ b/src/Bonsai.ML.PCA/Reconstruct.cs @@ -0,0 +1,67 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using Bonsai; +using static TorchSharp.torch; + +namespace Bonsai.ML.PCA +{ + [Combinator] + [Description] + [WorkflowElementCategory(ElementCategory.Transform)] + public class Reconstruct + { + private Tensor ReconstructData(IPCABaseModel model, Tensor data) + { + return model.Reconstruct(data); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item1, value.Item2); + }); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item2, value.Item1); + }); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item1, value.Item2); + }); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item2, value.Item1); + }); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item1, value.Item2); + }); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item2, value.Item1); + }); + } + } +} From 3d4756f0824d641b23f4a69b63d6f74067ec0d1b Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 11 Aug 2025 18:22:22 +0100 Subject: [PATCH 09/20] Added method to fit model and transform data --- src/Bonsai.ML.PCA/FitAndTransform.cs | 67 ++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 src/Bonsai.ML.PCA/FitAndTransform.cs diff --git a/src/Bonsai.ML.PCA/FitAndTransform.cs b/src/Bonsai.ML.PCA/FitAndTransform.cs new file mode 100644 index 00000000..6dcc20dd --- /dev/null +++ b/src/Bonsai.ML.PCA/FitAndTransform.cs @@ -0,0 +1,67 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using Bonsai; +using static TorchSharp.torch; + +namespace Bonsai.ML.PCA +{ + [Combinator] + [Description] + [WorkflowElementCategory(ElementCategory.Transform)] + public class FitAndTransform + { + private void FitModelAndTransformData(IPCABaseModel model, Tensor data) + { + model.FitAndTransform(data); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item1, value.Item2); + }); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item2, value.Item1); + }); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item1, value.Item2); + }); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item2, value.Item1); + }); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item1, value.Item2); + }); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item2, value.Item1); + }); + } + } +} From dd03327d5b66c523d626554a5e5b7945e6451dc3 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 11 Aug 2025 18:23:49 +0100 Subject: [PATCH 10/20] Moved `InvertSPD` method to seperate static `Utils` class and updated PCA models to use this --- src/Bonsai.ML.PCA/OnlinePPCA.cs | 24 ++++-------------------- src/Bonsai.ML.PCA/PPCA.cs | 2 +- src/Bonsai.ML.PCA/Utils.cs | 30 ++++++++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 21 deletions(-) create mode 100644 src/Bonsai.ML.PCA/Utils.cs diff --git a/src/Bonsai.ML.PCA/OnlinePPCA.cs b/src/Bonsai.ML.PCA/OnlinePPCA.cs index dc27d2f9..a3065da6 100644 --- a/src/Bonsai.ML.PCA/OnlinePPCA.cs +++ b/src/Bonsai.ML.PCA/OnlinePPCA.cs @@ -34,7 +34,7 @@ public class OnlinePPCA : PCABaseModel private readonly Func UpdateSchedule; private int _stepCount = 0; private readonly bool _reorthogonalize = false; - + public OnlinePPCA(int numComponents, Device? device = null, ScalarType? scalarType = ScalarType.Float32, @@ -99,22 +99,6 @@ public OnlinePPCA(int numComponents, _sigma2 = initialVariance; } - private Tensor InvertSPD(Tensor spdMatrix, Tensor rhs, double regularization = 1e-6) - { - var diagShape = spdMatrix.size(-1); - Tensor L; - try - { - L = linalg.cholesky(spdMatrix); - } - catch (Exception) - { - var regularizer = eye(diagShape, device: Device, dtype: ScalarType) * regularization; - L = linalg.cholesky(spdMatrix + regularizer); - } - return cholesky_solve(rhs, L); - } - public override void Fit(Tensor data) { // throw new NotImplementedException(); @@ -163,10 +147,10 @@ public override void Fit(Tensor data) // E-step var M = _W.T.matmul(_W) + cov; - var MInv = InvertSPD(M, _Iq); + var MInv = Utils.InvertSPD(M, _Iq); var XcW = Xc.matmul(_W); - var EzT = InvertSPD(M, XcW.T); + var EzT = Utils.InvertSPD(M, XcW.T); var Ez = EzT.T; // Update statistics @@ -193,7 +177,7 @@ public override void Fit(Tensor data) var Sxx = _sxx - _mu.dot(_mu); // M-step - var WNew = InvertSPD(Szz, Sxz.T).T; + var WNew = Utils.InvertSPD(Szz, Sxz.T).T; if (_reorthogonalize && _stepCount % ReorthogonalizePeriod == 0) diff --git a/src/Bonsai.ML.PCA/PPCA.cs b/src/Bonsai.ML.PCA/PPCA.cs index 8d096af0..53eea527 100644 --- a/src/Bonsai.ML.PCA/PPCA.cs +++ b/src/Bonsai.ML.PCA/PPCA.cs @@ -170,7 +170,7 @@ public override Tensor Transform(Tensor data) var X = Xt - mean; // n x d var W = Components; // d x q var M = W.T.matmul(W) + eye(NumComponents) * Variance; // q x q - var MInv = inverse(M); // q x q + var MInv = Utils.InvertSPD(M, eye(NumComponents)); // q x q return X.matmul(W).matmul(MInv); // n x q } diff --git a/src/Bonsai.ML.PCA/Utils.cs b/src/Bonsai.ML.PCA/Utils.cs new file mode 100644 index 00000000..bc3ab7e5 --- /dev/null +++ b/src/Bonsai.ML.PCA/Utils.cs @@ -0,0 +1,30 @@ +using System; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.PCA; + +internal static class Utils +{ + internal static Tensor InvertSPD( + Tensor spdMatrix, + Tensor rhs, + double regularization = 1e-6, + Device? device = null, + ScalarType? scalarType = null + ) + { + var diagShape = spdMatrix.size(-1); + Tensor L; + try + { + L = linalg.cholesky(spdMatrix); + } + catch (Exception) + { + var regularizer = eye(diagShape, device: device, dtype: scalarType) * regularization; + L = linalg.cholesky(spdMatrix + regularizer); + } + return cholesky_solve(rhs, L); + } +} \ No newline at end of file From fb6c741ebb191a79788caec50dc3358d9cae706b Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 11 Aug 2025 18:24:22 +0100 Subject: [PATCH 11/20] Added implementation of device and data type properties of base class in model subclasses --- src/Bonsai.ML.PCA/PCA.cs | 5 ++++- src/Bonsai.ML.PCA/PPCA.cs | 18 +++++++++--------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/Bonsai.ML.PCA/PCA.cs b/src/Bonsai.ML.PCA/PCA.cs index 12799f08..162a58c4 100644 --- a/src/Bonsai.ML.PCA/PCA.cs +++ b/src/Bonsai.ML.PCA/PCA.cs @@ -41,7 +41,10 @@ public override void Fit(Tensor data) throw new ArgumentException("Number of components cannot be greater than the number of features.", nameof(data)); } - Covariance = cov(data); + var Xt = data.T; + var mean = Xt.mean([0], keepdim: true); + var Xc = Xt - mean; + Covariance = cov(Xc.T); var eigen = eigh(Covariance); var sortedIndices = argsort(eigen.Item1, dim: -1, descending: true); EigenValues = eigen.Item1[sortedIndices]; diff --git a/src/Bonsai.ML.PCA/PPCA.cs b/src/Bonsai.ML.PCA/PPCA.cs index 53eea527..ce2532cc 100644 --- a/src/Bonsai.ML.PCA/PPCA.cs +++ b/src/Bonsai.ML.PCA/PPCA.cs @@ -65,7 +65,7 @@ public override void Fit(Tensor data) var variance = Variance; // Initialize log likelihood - LogLikelihood = ones(_iterations) * double.NegativeInfinity; + LogLikelihood = ones(_iterations, device: Device, dtype: ScalarType) * double.NegativeInfinity; // Initialize dimensions for components var q = NumComponents; @@ -78,9 +78,9 @@ public override void Fit(Tensor data) } // Initialize W and I - var W = randn(d, q, generator: Generator); // d x q - var MI = eye(q); // q x q - var CI = eye(d); // d x d + var W = randn(d, q, generator: Generator, device: Device, dtype: ScalarType); // d x q + var Iq = eye(q, device: Device, dtype: ScalarType); // q x q + var Id = eye(d, device: Device, dtype: ScalarType); // d x d // Calculate the sample mean var mean = Xt.mean([0], keepdim: true); // 1 x d @@ -96,7 +96,7 @@ public override void Fit(Tensor data) var term1 = trace(XTX); // Compute log likelihood constant - var logLikelihoodConst = d * log(2 * Math.PI); + var logLikelihoodConst = d * log(2 * Math.PI).to(Device).to_type(ScalarType); double diffW; double diffVariance; @@ -104,10 +104,10 @@ public override void Fit(Tensor data) // Repeat until convergence for (int i = 0; i < _iterations; i++) { - using (var _ = NewDisposeScope()) + using (NewDisposeScope()) { // E-step: Compute the posterior distribution of the latent variables - var M = W.T.matmul(W) + MI * variance; // q x q + var M = W.T.matmul(W) + Iq * variance; // q x q var MInv = inv(M); // q x q var mu = MInv.matmul(W.T).matmul(X.T).T; // n x q var SSum = n * MInv * variance; // q x q @@ -124,7 +124,7 @@ public override void Fit(Tensor data) var varianceNew = (term1 - term2 + term3) / (n * d); // scalar // Compute the log likelihood - var C = W.matmul(W.T) + CI * variance; // d x d + var C = W.matmul(W.T) + Id * variance; // d x d var CInv = inv(C); // d x d var logLikelihood = -0.5 * n * (logLikelihoodConst + logdet(C) + trace(CInv.matmul(sampleCov))); // scalar @@ -166,7 +166,7 @@ public override Tensor Transform(Tensor data) } var Xt = data.T; - var mean = Xt.mean([ 0 ], keepdim: true); // 1 x d + var mean = Xt.mean([0], keepdim: true); // 1 x d var X = Xt - mean; // n x d var W = Components; // d x q var M = W.T.matmul(W) + eye(NumComponents) * Variance; // q x q From 885c7a0557777e1c20865204c53bd8a519809ce1 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 3 Nov 2025 15:04:19 +0000 Subject: [PATCH 12/20] Refactored to use improved naming conventions --- Bonsai.ML.sln | 5 +- src/Bonsai.ML.PCA/CreatePCA.cs | 122 --------- src/Bonsai.ML.PCA/Fit.cs | 67 ----- src/Bonsai.ML.PCA/FitAndTransform.cs | 67 ----- src/Bonsai.ML.PCA/IPCABaseModel.cs | 22 -- src/Bonsai.ML.PCA/OnlinePPCA.cs | 247 ------------------ src/Bonsai.ML.PCA/PCA.cs | 93 ------- src/Bonsai.ML.PCA/PCABaseModel.cs | 44 ---- src/Bonsai.ML.PCA/PCADesciptionProvider.cs | 26 -- src/Bonsai.ML.PCA/PCADescriptor.cs | 39 --- src/Bonsai.ML.PCA/PCAModelType.cs | 19 -- src/Bonsai.ML.PCA/PPCA.cs | 196 -------------- src/Bonsai.ML.PCA/Reconstruct.cs | 67 ----- src/Bonsai.ML.PCA/Transform.cs | 67 ----- .../Bonsai.ML.Pca.Torch.csproj} | 4 +- src/Bonsai.ML.Pca.Torch/CreatePca.cs | 123 +++++++++ src/Bonsai.ML.Pca.Torch/Fit.cs | 66 +++++ src/Bonsai.ML.Pca.Torch/FitAndTransform.cs | 66 +++++ src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs | 13 + .../OnlineProbabilisticPca.cs | 246 +++++++++++++++++ src/Bonsai.ML.Pca.Torch/Pca.cs | 92 +++++++ src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs | 43 +++ .../PcaDesciptionProvider.cs | 23 ++ src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs | 36 +++ src/Bonsai.ML.Pca.Torch/PcaModelType.cs | 8 + src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs | 195 ++++++++++++++ .../Properties/launchSettings.json | 0 src/Bonsai.ML.Pca.Torch/Reconstruct.cs | 66 +++++ src/Bonsai.ML.Pca.Torch/Transform.cs | 66 +++++ .../Utils.cs | 2 +- .../Bonsai.ML.Pca.Torch.Tests.csproj} | 2 +- .../CreatePCATest.bonsai | 0 .../FitPCATest.bonsai | 0 .../TransformPCATest.bonsai | 0 34 files changed, 1050 insertions(+), 1082 deletions(-) delete mode 100644 src/Bonsai.ML.PCA/CreatePCA.cs delete mode 100644 src/Bonsai.ML.PCA/Fit.cs delete mode 100644 src/Bonsai.ML.PCA/FitAndTransform.cs delete mode 100644 src/Bonsai.ML.PCA/IPCABaseModel.cs delete mode 100644 src/Bonsai.ML.PCA/OnlinePPCA.cs delete mode 100644 src/Bonsai.ML.PCA/PCA.cs delete mode 100644 src/Bonsai.ML.PCA/PCABaseModel.cs delete mode 100644 src/Bonsai.ML.PCA/PCADesciptionProvider.cs delete mode 100644 src/Bonsai.ML.PCA/PCADescriptor.cs delete mode 100644 src/Bonsai.ML.PCA/PCAModelType.cs delete mode 100644 src/Bonsai.ML.PCA/PPCA.cs delete mode 100644 src/Bonsai.ML.PCA/Reconstruct.cs delete mode 100644 src/Bonsai.ML.PCA/Transform.cs rename src/{Bonsai.ML.PCA/Bonsai.ML.PCA.csproj => Bonsai.ML.Pca.Torch/Bonsai.ML.Pca.Torch.csproj} (70%) create mode 100644 src/Bonsai.ML.Pca.Torch/CreatePca.cs create mode 100644 src/Bonsai.ML.Pca.Torch/Fit.cs create mode 100644 src/Bonsai.ML.Pca.Torch/FitAndTransform.cs create mode 100644 src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs create mode 100644 src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs create mode 100644 src/Bonsai.ML.Pca.Torch/Pca.cs create mode 100644 src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs create mode 100644 src/Bonsai.ML.Pca.Torch/PcaDesciptionProvider.cs create mode 100644 src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs create mode 100644 src/Bonsai.ML.Pca.Torch/PcaModelType.cs create mode 100644 src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs rename src/{Bonsai.ML.PCA => Bonsai.ML.Pca.Torch}/Properties/launchSettings.json (100%) create mode 100644 src/Bonsai.ML.Pca.Torch/Reconstruct.cs create mode 100644 src/Bonsai.ML.Pca.Torch/Transform.cs rename src/{Bonsai.ML.PCA => Bonsai.ML.Pca.Torch}/Utils.cs (95%) rename tests/{Bonsai.ML.PCA.Tests/Bonsai.ML.PCA.Tests.csproj => Bonsai.ML.Pca.Torch.Tests/Bonsai.ML.Pca.Torch.Tests.csproj} (88%) rename tests/{Bonsai.ML.PCA.Tests => Bonsai.ML.Pca.Torch.Tests}/CreatePCATest.bonsai (100%) rename tests/{Bonsai.ML.PCA.Tests => Bonsai.ML.Pca.Torch.Tests}/FitPCATest.bonsai (100%) rename tests/{Bonsai.ML.PCA.Tests => Bonsai.ML.Pca.Torch.Tests}/TransformPCATest.bonsai (100%) diff --git a/Bonsai.ML.sln b/Bonsai.ML.sln index 4640a1a3..56e66f13 100644 --- a/Bonsai.ML.sln +++ b/Bonsai.ML.sln @@ -45,9 +45,10 @@ EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Lds.Torch.Tests", "tests\Bonsai.ML.Lds.Torch.Tests\Bonsai.ML.Lds.Torch.Tests.csproj", "{0B258929-0B07-4CE7-BE8D-A86BBC46AAD4}" EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Lds.Torch.Design", "src\Bonsai.ML.Lds.Torch.Design\Bonsai.ML.Lds.Torch.Design.csproj", "{1F52DECD-1B2C-4F6C-996C-14C715283B80}" -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.PCA", "src\Bonsai.ML.PCA\Bonsai.ML.PCA.csproj", "{1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.PCA.Tests", "tests\Bonsai.ML.PCA.Tests\Bonsai.ML.PCA.Tests.csproj", "{4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Pca.Torch", "src\Bonsai.ML.Pca.Torch\Bonsai.ML.Pca.Torch.csproj", "{1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Pca.Torch.Tests", "tests\Bonsai.ML.Pca.Torch.Tests\Bonsai.ML.Pca.Torch.Tests.csproj", "{4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution diff --git a/src/Bonsai.ML.PCA/CreatePCA.cs b/src/Bonsai.ML.PCA/CreatePCA.cs deleted file mode 100644 index 4c3ddf82..00000000 --- a/src/Bonsai.ML.PCA/CreatePCA.cs +++ /dev/null @@ -1,122 +0,0 @@ -using System; -using System.ComponentModel; -using System.Collections.Generic; -using System.Reactive.Linq; -using System.Linq.Expressions; -using Bonsai.Expressions; -using System.Linq; -using System.Reflection; -using static TorchSharp.torch; -using System.Xml.Serialization; - -namespace Bonsai.ML.PCA -{ - [Combinator] - [ResetCombinator] - [WorkflowElementCategory(ElementCategory.Source)] - [TypeDescriptionProvider(typeof(PCADescriptionProvider))] - public class CreatePCA : ZeroArgumentExpressionBuilder - { - public int NumComponents { get; set; } = 2; - - [XmlIgnore] - public Device Device { get; set; } - public ScalarType? ScalarType { get; set; } - - [RefreshProperties(RefreshProperties.All)] - public PCAModelType ModelType { get; set; } = PCAModelType.PCA; - - public double InitialVariance { get; set; } = 1.0; - public int Iterations { get; set; } = 100; - public double Tolerance { get; set; } = 1e-5; - - public double? Rho { get; set; } = 0.1; - public double? Kappa { get; set; } = 0.9; - public int? TimeOffset { get; set; } = null; - public int? ReorthogonalizePeriod { get; set; } = null; - - [XmlIgnore] - public Generator? Generator { get; set; } = null; - - internal IEnumerable GetModelProperties() - { - yield return nameof(NumComponents); - yield return nameof(Device); - yield return nameof(ScalarType); - yield return nameof(ModelType); - - if (ModelType == PCAModelType.ProbabilisticPCA) - { - yield return nameof(InitialVariance); - yield return nameof(Iterations); - yield return nameof(Tolerance); - yield return nameof(Generator); - } - - if (ModelType == PCAModelType.OnlinePPCA) - { - yield return nameof(InitialVariance); - yield return nameof(Rho); - yield return nameof(Kappa); - yield return nameof(TimeOffset); - yield return nameof(ReorthogonalizePeriod); - yield return nameof(Generator); - } - } - - private static PCABaseModel CreateModel(CreatePCA instance) - { - return instance.ModelType switch - { - PCAModelType.PCA => new PCA( - numComponents: instance.NumComponents, - device: instance.Device, - scalarType: instance.ScalarType), - PCAModelType.ProbabilisticPCA => new PPCA( - numComponents: instance.NumComponents, - device: instance.Device, - scalarType: instance.ScalarType, - initialVariance: instance.InitialVariance, - generator: instance.Generator, - iterations: instance.Iterations, - tolerance: instance.Tolerance), - PCAModelType.OnlinePPCA => new OnlinePPCA( - numComponents: instance.NumComponents, - device: instance.Device, - scalarType: instance.ScalarType, - initialVariance: instance.InitialVariance, - generator: instance.Generator, - rho: instance.Rho, - kappa: instance.Kappa, - timeOffset: instance.TimeOffset, - reorthogonalizePeriod: instance.ReorthogonalizePeriod), - _ => throw new NotSupportedException($"Model type {instance.ModelType} is not supported."), - }; - } - - private static Type GetModelType(PCAModelType modelType) - { - return modelType switch - { - PCAModelType.PCA => typeof(PCA), - PCAModelType.ProbabilisticPCA => typeof(PPCA), - PCAModelType.OnlinePPCA => typeof(OnlinePPCA), - _ => throw new NotSupportedException($"Model type {modelType} is not supported."), - }; - } - - public override Expression Build(IEnumerable arguments) - { - var processMethod = GetType().GetMethod( - nameof(Process), - BindingFlags.NonPublic | BindingFlags.Static); - processMethod = processMethod.MakeGenericMethod(GetModelType(ModelType)); - return Expression.Call(processMethod, Expression.Constant(this)); - } - - private static IObservable Process(CreatePCA instance) where T : PCABaseModel - { - return Observable.Return((T)CreateModel(instance)); - } - } -} diff --git a/src/Bonsai.ML.PCA/Fit.cs b/src/Bonsai.ML.PCA/Fit.cs deleted file mode 100644 index 27eb0d60..00000000 --- a/src/Bonsai.ML.PCA/Fit.cs +++ /dev/null @@ -1,67 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using Bonsai; -using static TorchSharp.torch; - -namespace Bonsai.ML.PCA -{ - [Combinator] - [Description] - [WorkflowElementCategory(ElementCategory.Sink)] - public class Fit - { - private void FitModel(IPCABaseModel model, Tensor data) - { - model.Fit(data); - } - - public IObservable> Process(IObservable> source) - { - return source.Do((value) => - { - FitModel(value.Item1, value.Item2); - }); - } - - public IObservable> Process(IObservable> source) - { - return source.Do((value) => - { - FitModel(value.Item2, value.Item1); - }); - } - - public IObservable> Process(IObservable> source) - { - return source.Do((value) => - { - FitModel(value.Item1, value.Item2); - }); - } - - public IObservable> Process(IObservable> source) - { - return source.Do((value) => - { - FitModel(value.Item2, value.Item1); - }); - } - - public IObservable> Process(IObservable> source) - { - return source.Do((value) => - { - FitModel(value.Item1, value.Item2); - }); - } - - public IObservable> Process(IObservable> source) - { - return source.Do((value) => - { - FitModel(value.Item2, value.Item1); - }); - } - } -} diff --git a/src/Bonsai.ML.PCA/FitAndTransform.cs b/src/Bonsai.ML.PCA/FitAndTransform.cs deleted file mode 100644 index 6dcc20dd..00000000 --- a/src/Bonsai.ML.PCA/FitAndTransform.cs +++ /dev/null @@ -1,67 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using Bonsai; -using static TorchSharp.torch; - -namespace Bonsai.ML.PCA -{ - [Combinator] - [Description] - [WorkflowElementCategory(ElementCategory.Transform)] - public class FitAndTransform - { - private void FitModelAndTransformData(IPCABaseModel model, Tensor data) - { - model.FitAndTransform(data); - } - - public IObservable> Process(IObservable> source) - { - return source.Do((value) => - { - FitModelAndTransformData(value.Item1, value.Item2); - }); - } - - public IObservable> Process(IObservable> source) - { - return source.Do((value) => - { - FitModelAndTransformData(value.Item2, value.Item1); - }); - } - - public IObservable> Process(IObservable> source) - { - return source.Do((value) => - { - FitModelAndTransformData(value.Item1, value.Item2); - }); - } - - public IObservable> Process(IObservable> source) - { - return source.Do((value) => - { - FitModelAndTransformData(value.Item2, value.Item1); - }); - } - - public IObservable> Process(IObservable> source) - { - return source.Do((value) => - { - FitModelAndTransformData(value.Item1, value.Item2); - }); - } - - public IObservable> Process(IObservable> source) - { - return source.Do((value) => - { - FitModelAndTransformData(value.Item2, value.Item1); - }); - } - } -} diff --git a/src/Bonsai.ML.PCA/IPCABaseModel.cs b/src/Bonsai.ML.PCA/IPCABaseModel.cs deleted file mode 100644 index 3e7bbad5..00000000 --- a/src/Bonsai.ML.PCA/IPCABaseModel.cs +++ /dev/null @@ -1,22 +0,0 @@ -using Bonsai; -using System; -using System.ComponentModel; -using System.Collections.Generic; -using System.Linq; -using System.Reactive.Linq; -using TorchSharp; -using static TorchSharp.torch; -using static TorchSharp.torch.linalg; - -namespace Bonsai.ML.PCA -{ - public interface IPCABaseModel - { - public Device Device { get; } - public ScalarType ScalarType { get; } - public abstract void Fit(Tensor data); - public abstract Tensor Transform(Tensor data); - public abstract Tensor FitAndTransform(Tensor data); - public abstract Tensor Reconstruct(Tensor data); - } -} diff --git a/src/Bonsai.ML.PCA/OnlinePPCA.cs b/src/Bonsai.ML.PCA/OnlinePPCA.cs deleted file mode 100644 index a3065da6..00000000 --- a/src/Bonsai.ML.PCA/OnlinePPCA.cs +++ /dev/null @@ -1,247 +0,0 @@ -using Bonsai; -using System; -using System.ComponentModel; -using System.Collections.Generic; -using System.Linq; -using System.Reactive.Linq; -using TorchSharp; -using static TorchSharp.torch; -using static TorchSharp.torch.linalg; - -namespace Bonsai.ML.PCA -{ - public class OnlinePPCA : PCABaseModel - { - public double? Rho { get; private set; } - public double? Kappa { get; private set; } - public double Variance => _sigma2.to_type(ScalarType.Float64).item(); - public int ReorthogonalizePeriod { get; private set; } - public int? TimeOffset { get; private set; } - public Tensor Components => _W; - public Generator Generator { get; private set; } - - private Tensor _mu; - private Tensor _W; - private Tensor _Iq; - private Tensor _mx; // E[x] - private Tensor _Cxz; // E[xz^T] - private Tensor _mz; // E[z] - private Tensor _Czz; // E[zz^T] - private Tensor _sxx; // E[||x||^2] - private Tensor _sigma2; // Variance - - private bool _initializedParameters = false; - private readonly Func UpdateSchedule; - private int _stepCount = 0; - private readonly bool _reorthogonalize = false; - - public OnlinePPCA(int numComponents, - Device? device = null, - ScalarType? scalarType = ScalarType.Float32, - double initialVariance = 1.0, - Generator? generator = null, - double? rho = 0.1, - double? kappa = null, - int? timeOffset = 3000, - int? reorthogonalizePeriod = null - ) : base(numComponents, - device, - scalarType) - { - if (initialVariance <= 0) - { - throw new ArgumentException("Starting variance must be greater than or equal to zero.", nameof(initialVariance)); - } - - if (kappa.HasValue && rho.HasValue) - { - throw new ArgumentException("Only one of rho or kappa should be specified, not both.", nameof(rho)); - } - - if (!kappa.HasValue && !rho.HasValue) - { - throw new ArgumentException("Either rho or kappa must be specified.", nameof(rho)); - } - - if (rho.HasValue) - { - UpdateSchedule = () => rho.Value; - if (rho <= 0 || rho >= 1) - { - throw new ArgumentException("Rho must be in the range (0, 1).", nameof(rho)); - } - } - - if (kappa.HasValue) - { - if (timeOffset is null or <= 0) - { - throw new ArgumentException("Time offset must be a positive integer.", nameof(timeOffset)); - } - - UpdateSchedule = () => Math.Pow(_stepCount + timeOffset.Value, -kappa.Value); - if (kappa <= 0.5 || kappa > 1) - { - throw new ArgumentException("Kappa must be in the range (0.5, 1].", nameof(kappa)); - } - } - - if (reorthogonalizePeriod.HasValue) - { - _reorthogonalize = true; - ReorthogonalizePeriod = reorthogonalizePeriod.Value; - } - - Generator = generator ?? manual_seed(0); - Rho = rho; - Kappa = kappa; - TimeOffset = timeOffset; - _sigma2 = initialVariance; - } - - public override void Fit(Tensor data) - { - // throw new NotImplementedException(); - if (data.NumberOfElements == 0 || data.dim() != 2) - { - throw new ArgumentException("Input data must be a 2D tensor."); - } - - using (no_grad()) - using (NewDisposeScope()) - { - - _stepCount++; - var rho = UpdateSchedule(); - - var Xt = data.T; // n x d - - // Initialize dimensions - var q = NumComponents; - var n = Xt.size(0); - var d = Xt.size(1); - - // Initialize parameters - if (!_initializedParameters) - { - _mu = zeros(d, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); - var randW = randn(d, q, generator: Generator, device: Device, dtype: ScalarType); - var orthonormalBases = linalg.qr(randW).Q; - _W = (orthonormalBases * _sigma2).MoveToOuterDisposeScope(); // d x q - _Iq = eye(q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q x q - - _mx = zeros(d, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // d - _Cxz = zeros(d, q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // d x q - _mz = zeros(q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q - _Czz = zeros(q, q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q x q - _sxx = zeros(1, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // scalar - - _initializedParameters = true; - } - - // Covariance matrix - var cov = _Iq * _sigma2; - - // Center data using current mean - var Xc = Xt - _mu; - - // E-step - var M = _W.T.matmul(_W) + cov; - var MInv = Utils.InvertSPD(M, _Iq); - - var XcW = Xc.matmul(_W); - var EzT = Utils.InvertSPD(M, XcW.T); - var Ez = EzT.T; - - // Update statistics - var mx = Xt.mean([0]); - var sxx = Xt.pow(2).sum(dim: 1).mean(); - var Cxz = Xt.T.matmul(Ez) / n; - var mz = Ez.mean([0]); - var Czz = EzT.matmul(Ez) / n + _sigma2 * MInv; - - // Update parameters - var rhoFactor = 1 - rho; - _mx = (rhoFactor * _mx + rho * mx).MoveToOuterDisposeScope(); - _Cxz = (rhoFactor * _Cxz + rho * Cxz).MoveToOuterDisposeScope(); - _mz = (rhoFactor * _mz + rho * mz).MoveToOuterDisposeScope(); - _sxx = (rhoFactor * _sxx + rho * sxx).MoveToOuterDisposeScope(); - _Czz = (rhoFactor * _Czz + rho * Czz).MoveToOuterDisposeScope(); - - // Update mean - _mu = _mx.MoveToOuterDisposeScope(); - - // Centered statistics - var Sxz = _Cxz - outer(_mu, _mz); - var Szz = _Czz; - var Sxx = _sxx - _mu.dot(_mu); - - // M-step - var WNew = Utils.InvertSPD(Szz, Sxz.T).T; - - if (_reorthogonalize && - _stepCount % ReorthogonalizePeriod == 0) - { - var (U, S, Vh) = svd(WNew, fullMatrices: false); - var R = Vh.T; - WNew = U.matmul(diag(S)); - _Cxz = _Cxz.matmul(R.T); - _Czz = R.matmul(_Czz).matmul(R.T); - _mz = R.matmul(_mz); - } - - // Reorder components based on the strength of the components - var strength = sum(WNew * WNew, dim: 0); - var indices = argsort(strength, descending: true); - _W = WNew.index_select(1, indices).MoveToOuterDisposeScope(); - _Cxz = _Cxz.index_select(1, indices).MoveToOuterDisposeScope(); - _mz = _mz.index_select(0, indices).MoveToOuterDisposeScope(); - _Czz = _Czz.index_select(0, indices).index_select(1, indices).MoveToOuterDisposeScope(); - - Sxz = _Cxz - outer(_mu, _mz); - Szz = _Czz; - - // Update variance - _sigma2 = ((Sxx - 2 * trace(_W.T.matmul(Sxz)) + trace(_W.T.matmul(_W).matmul(Szz))) / (double)d) - .clamp_min(0.0) - .MoveToOuterDisposeScope(); - } - } - - public override Tensor Transform(Tensor data) - { - if (data.NumberOfElements == 0 || data.dim() < 2) - { - throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); - } - - var Xt = data.T; // n x d - var Xc = Xt - _mu; // n x d - var M = _W.T.matmul(_W) + _Iq * _sigma2; // q x q - var XcW = Xc.matmul(_W); - return Utils.InvertSPD(M, XcW.T).T; // n x q - } - - public override Tensor Reconstruct(Tensor data) - { - if (data.NumberOfElements == 0 || data.dim() < 2) - { - throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); - } - - if (!_initializedParameters) - { - throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); - } - - var Xt = data.T; // n x d - var Xc = Xt - _mu; // n x d - var M = _W.T.matmul(_W) + _Iq * _sigma2; // q x q - var XcW = Xc.matmul(_W); - var EzT = Utils.InvertSPD(M, XcW.T); - var Ez = EzT.T; - - return Ez.matmul(_W.T) + _mu.T; // n x d - } - } -} diff --git a/src/Bonsai.ML.PCA/PCA.cs b/src/Bonsai.ML.PCA/PCA.cs deleted file mode 100644 index 162a58c4..00000000 --- a/src/Bonsai.ML.PCA/PCA.cs +++ /dev/null @@ -1,93 +0,0 @@ -using Bonsai; -using System; -using System.ComponentModel; -using System.Collections.Generic; -using System.Linq; -using System.Reactive.Linq; -using TorchSharp; -using static TorchSharp.torch; -using static TorchSharp.torch.linalg; - -namespace Bonsai.ML.PCA -{ - public class PCA : PCABaseModel - { - public Tensor Covariance { get; private set; } = empty(0); - public Tensor EigenValues { get; private set; } = empty(0); - public Tensor EigenVectors { get; private set; } = empty(0); - public Tensor Components { get; private set; } = empty(0); - private bool _isFitted = false; - - public PCA(int numComponents, - Device? device = null, - ScalarType? scalarType = ScalarType.Float32) - : base(numComponents, - device, - scalarType) - { } - - public override void Fit(Tensor data) - { - if (data.NumberOfElements == 0 || data.dim() < 2) - { - throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); - } - - var n = data.size(0); - var d = data.size(1); - - if (NumComponents > d) - { - throw new ArgumentException("Number of components cannot be greater than the number of features.", nameof(data)); - } - - var Xt = data.T; - var mean = Xt.mean([0], keepdim: true); - var Xc = Xt - mean; - Covariance = cov(Xc.T); - var eigen = eigh(Covariance); - var sortedIndices = argsort(eigen.Item1, dim: -1, descending: true); - EigenValues = eigen.Item1[sortedIndices]; - EigenVectors = eigen.Item2.index_select(1, sortedIndices); - Components = EigenVectors.slice(1, 0, NumComponents, 1); - _isFitted = true; - } - - public override Tensor Transform(Tensor data) - { - if (data.NumberOfElements == 0 || data.dim() < 2) - { - throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); - } - - if (!_isFitted) - { - throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); - } - - var X = data.T; - var mean = X.mean([0], keepdim: true); // 1 x d - var Xc = X - mean; - return Xc.matmul(Components); // n x q - } - - public override Tensor Reconstruct(Tensor data) - { - if (data.NumberOfElements == 0 || data.dim() < 2) - { - throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); - } - - if (!_isFitted) - { - throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); - } - - var X = data.T; - var mean = X.mean([0], keepdim: true); // 1 x d - var Xc = X - mean; - var reconstructed = Transform(Xc); - return reconstructed.matmul(Components.T) + mean.T; - } - } -} diff --git a/src/Bonsai.ML.PCA/PCABaseModel.cs b/src/Bonsai.ML.PCA/PCABaseModel.cs deleted file mode 100644 index f6617ef8..00000000 --- a/src/Bonsai.ML.PCA/PCABaseModel.cs +++ /dev/null @@ -1,44 +0,0 @@ -using Bonsai; -using System; -using System.ComponentModel; -using System.Collections.Generic; -using System.Linq; -using System.Reactive.Linq; -using TorchSharp; -using static TorchSharp.torch; -using static TorchSharp.torch.linalg; - -namespace Bonsai.ML.PCA -{ - public abstract class PCABaseModel : IPCABaseModel - { - public int NumComponents { get; private set; } - - public Device Device { get; private set; } - public ScalarType ScalarType { get; private set; } - - public PCABaseModel(int numComponents, - Device? device = null, - ScalarType? scalarType = null) - { - if (numComponents <= 0) - { - throw new ArgumentException("Number of components must be greater than zero.", nameof(numComponents)); - } - - NumComponents = numComponents; - Device = device ?? CPU; - ScalarType = scalarType ?? ScalarType.Float32; - } - - public abstract void Fit(Tensor data); - public abstract Tensor Transform(Tensor data); - public virtual Tensor FitAndTransform(Tensor data) - { - Fit(data); - return Transform(data); - } - - public abstract Tensor Reconstruct(Tensor data); - } -} diff --git a/src/Bonsai.ML.PCA/PCADesciptionProvider.cs b/src/Bonsai.ML.PCA/PCADesciptionProvider.cs deleted file mode 100644 index 1f8886af..00000000 --- a/src/Bonsai.ML.PCA/PCADesciptionProvider.cs +++ /dev/null @@ -1,26 +0,0 @@ -using System.ComponentModel; -using System; -using Bonsai; -using Bonsai.Expressions; - -namespace Bonsai.ML.PCA -{ - class PCADescriptionProvider : TypeDescriptionProvider - { - private readonly TypeDescriptionProvider _baseProvider; - - public PCADescriptionProvider() : this(TypeDescriptor.GetProvider(typeof(object))) { } - - public PCADescriptionProvider(TypeDescriptionProvider baseProvider) - : base(baseProvider) - { - _baseProvider = baseProvider; - } - - public override ICustomTypeDescriptor GetTypeDescriptor(Type objectType, object instance) - { - var defaultDescriptor = base.GetTypeDescriptor(objectType, instance); - return new PCADescriptor(defaultDescriptor, instance); - } - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.PCA/PCADescriptor.cs b/src/Bonsai.ML.PCA/PCADescriptor.cs deleted file mode 100644 index 1c19d254..00000000 --- a/src/Bonsai.ML.PCA/PCADescriptor.cs +++ /dev/null @@ -1,39 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using System.Collections.Generic; -using System.Linq; -using Bonsai; -using static TorchSharp.torch; - -namespace Bonsai.ML.PCA -{ - public class PCADescriptor : CustomTypeDescriptor - { - private readonly object _instance; - - public PCADescriptor(ICustomTypeDescriptor parent, object instance) : base(parent) - { - _instance = instance; - } - - public override PropertyDescriptorCollection GetProperties(Attribute[] attributes) - { - var allProperties = base.GetProperties(attributes); - - if (_instance is CreatePCA createPCA) - { - var modelProperties = new HashSet(createPCA.GetModelProperties()); - var filtered = allProperties.Cast() - .Where(p => modelProperties.Contains(p.Name)) - .ToArray(); - return new PropertyDescriptorCollection(filtered); - } - - return allProperties; - } - - public override PropertyDescriptorCollection GetProperties() - => GetProperties(null); - } -} diff --git a/src/Bonsai.ML.PCA/PCAModelType.cs b/src/Bonsai.ML.PCA/PCAModelType.cs deleted file mode 100644 index 30c6ed1b..00000000 --- a/src/Bonsai.ML.PCA/PCAModelType.cs +++ /dev/null @@ -1,19 +0,0 @@ -using Bonsai; -using System; -using System.ComponentModel; -using System.Collections.Generic; -using System.Linq; -using System.Reactive.Linq; -using TorchSharp; -using static TorchSharp.torch; -using static TorchSharp.torch.linalg; - -namespace Bonsai.ML.PCA -{ - public enum PCAModelType - { - PCA, - ProbabilisticPCA, - OnlinePPCA - } -} diff --git a/src/Bonsai.ML.PCA/PPCA.cs b/src/Bonsai.ML.PCA/PPCA.cs deleted file mode 100644 index ce2532cc..00000000 --- a/src/Bonsai.ML.PCA/PPCA.cs +++ /dev/null @@ -1,196 +0,0 @@ -using Bonsai; -using System; -using System.ComponentModel; -using System.Collections.Generic; -using System.Linq; -using System.Reactive.Linq; -using TorchSharp; -using static TorchSharp.torch; -using static TorchSharp.torch.linalg; - -namespace Bonsai.ML.PCA -{ - public class PPCA : PCABaseModel - { - public double Variance { get; private set; } - public Tensor LogLikelihood { get; private set; } = empty(0); - public Tensor Components { get; private set; } = empty(0); - public Generator Generator { get; private set; } - private int _iterations; - private double _tolerance; - private bool _isFitted = false; - - public PPCA(int numComponents, - Device? device = null, - ScalarType? scalarType = ScalarType.Float32, - double initialVariance = 1.0, - Generator? generator = null, - int iterations = 100, - double tolerance = 1e-5 - ) : base(numComponents, - device, - scalarType) - { - if (initialVariance < 0) - { - throw new ArgumentException("Starting variance must be greater than or equal to zero.", nameof(initialVariance)); - } - - if (iterations <= 0) - { - throw new ArgumentException("Number of iterations must be greater than zero.", nameof(iterations)); - } - - if (tolerance <= 0) - { - throw new ArgumentException("Tolerance must be greater than zero.", nameof(tolerance)); - } - - Variance = initialVariance; - Generator = generator ?? manual_seed(0); - _iterations = iterations; - _tolerance = tolerance; - } - - public override void Fit(Tensor data) - { - if (data.NumberOfElements == 0 || data.dim() != 2) - { - throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); - } - - var Xt = data.T; // n x d - - // Initialize variance - var variance = Variance; - - // Initialize log likelihood - LogLikelihood = ones(_iterations, device: Device, dtype: ScalarType) * double.NegativeInfinity; - - // Initialize dimensions for components - var q = NumComponents; - var n = Xt.size(0); - var d = Xt.size(1); - - if (q > d) - { - throw new ArgumentException("Number of components cannot be greater than the number of features.", nameof(data)); - } - - // Initialize W and I - var W = randn(d, q, generator: Generator, device: Device, dtype: ScalarType); // d x q - var Iq = eye(q, device: Device, dtype: ScalarType); // q x q - var Id = eye(d, device: Device, dtype: ScalarType); // d x d - - // Calculate the sample mean - var mean = Xt.mean([0], keepdim: true); // 1 x d - - // Center the data and transpose - var X = Xt - mean; // n x d - - // Calculate the sample covariance - var XTX = X.T.matmul(X); // d x d - var sampleCov = XTX / n; // d x d - - // Calculate term 1 for variance update - var term1 = trace(XTX); - - // Compute log likelihood constant - var logLikelihoodConst = d * log(2 * Math.PI).to(Device).to_type(ScalarType); - - double diffW; - double diffVariance; - - // Repeat until convergence - for (int i = 0; i < _iterations; i++) - { - using (NewDisposeScope()) - { - // E-step: Compute the posterior distribution of the latent variables - var M = W.T.matmul(W) + Iq * variance; // q x q - var MInv = inv(M); // q x q - var mu = MInv.matmul(W.T).matmul(X.T).T; // n x q - var SSum = n * MInv * variance; // q x q - var cov = mu.T.matmul(mu) + SSum; // q x q - - // M-step: Compute new W and new variance - var XMu = X.T.matmul(mu); // d x q - var WNew = XMu.matmul(inv(cov)); // d x q - - var term2 = 2 * XMu.mul(WNew).sum(); - var mumu = mu.T.matmul(mu); - var WNewWNew = WNew.T.matmul(WNew); - var term3 = trace(WNewWNew.matmul(mumu + SSum)); - var varianceNew = (term1 - term2 + term3) / (n * d); // scalar - - // Compute the log likelihood - var C = W.matmul(W.T) + Id * variance; // d x d - var CInv = inv(C); // d x d - var logLikelihood = -0.5 * n * (logLikelihoodConst + logdet(C) + trace(CInv.matmul(sampleCov))); // scalar - - // Check for convergence - diffW = linalg.norm(WNew - W).to_type(ScalarType.Float64).cpu().ReadCpuDouble(0); - diffVariance = abs(varianceNew - variance).to_type(ScalarType.Float64).cpu().ReadCpuDouble(0); - - // Update loglikelihood, W and variance - LogLikelihood[i] = logLikelihood.MoveToOuterDisposeScope(); - W = WNew.MoveToOuterDisposeScope(); - variance = varianceNew.to_type(ScalarType.Float64).cpu().ReadCpuDouble(0); - } - - - if (diffW < _tolerance && diffVariance < _tolerance) - { - LogLikelihood = LogLikelihood.slice(0, 0, i + 1, 1); - break; - } - } - - // Finalize model parameters - LogLikelihood = LogLikelihood.DetachFromDisposeScope(); - Components = W.DetachFromDisposeScope(); - Variance = variance; - _isFitted = true; - } - - public override Tensor Transform(Tensor data) - { - if (data.NumberOfElements == 0 || data.dim() < 2) - { - throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); - } - - if (!_isFitted) - { - throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); - } - - var Xt = data.T; - var mean = Xt.mean([0], keepdim: true); // 1 x d - var X = Xt - mean; // n x d - var W = Components; // d x q - var M = W.T.matmul(W) + eye(NumComponents) * Variance; // q x q - var MInv = Utils.InvertSPD(M, eye(NumComponents)); // q x q - return X.matmul(W).matmul(MInv); // n x q - } - - public override Tensor Reconstruct(Tensor data) - { - if (data.NumberOfElements == 0 || data.dim() < 2) - { - throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); - } - - if (!_isFitted) - { - throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); - } - - var Xt = data.T; - var mean = Xt.mean([0], keepdim: true); // 1 x d - var Xc = Xt - mean; // n x d - var W = Components; // d x q - return Xc.matmul(W).matmul(W.T) + mean.T; // n x d - } - } -} diff --git a/src/Bonsai.ML.PCA/Reconstruct.cs b/src/Bonsai.ML.PCA/Reconstruct.cs deleted file mode 100644 index f098cf30..00000000 --- a/src/Bonsai.ML.PCA/Reconstruct.cs +++ /dev/null @@ -1,67 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using Bonsai; -using static TorchSharp.torch; - -namespace Bonsai.ML.PCA -{ - [Combinator] - [Description] - [WorkflowElementCategory(ElementCategory.Transform)] - public class Reconstruct - { - private Tensor ReconstructData(IPCABaseModel model, Tensor data) - { - return model.Reconstruct(data); - } - - public IObservable Process(IObservable> source) - { - return source.Select(value => - { - return ReconstructData(value.Item1, value.Item2); - }); - } - - public IObservable Process(IObservable> source) - { - return source.Select(value => - { - return ReconstructData(value.Item2, value.Item1); - }); - } - - public IObservable Process(IObservable> source) - { - return source.Select(value => - { - return ReconstructData(value.Item1, value.Item2); - }); - } - - public IObservable Process(IObservable> source) - { - return source.Select(value => - { - return ReconstructData(value.Item2, value.Item1); - }); - } - - public IObservable Process(IObservable> source) - { - return source.Select(value => - { - return ReconstructData(value.Item1, value.Item2); - }); - } - - public IObservable Process(IObservable> source) - { - return source.Select(value => - { - return ReconstructData(value.Item2, value.Item1); - }); - } - } -} diff --git a/src/Bonsai.ML.PCA/Transform.cs b/src/Bonsai.ML.PCA/Transform.cs deleted file mode 100644 index 2217c74c..00000000 --- a/src/Bonsai.ML.PCA/Transform.cs +++ /dev/null @@ -1,67 +0,0 @@ -using System; -using System.ComponentModel; -using System.Reactive.Linq; -using Bonsai; -using static TorchSharp.torch; - -namespace Bonsai.ML.PCA -{ - [Combinator] - [Description] - [WorkflowElementCategory(ElementCategory.Transform)] - public class Transform - { - private Tensor TransformData(IPCABaseModel model, Tensor data) - { - return model.Transform(data); - } - - public IObservable Process(IObservable> source) - { - return source.Select(value => - { - return TransformData(value.Item1, value.Item2); - }); - } - - public IObservable Process(IObservable> source) - { - return source.Select(value => - { - return TransformData(value.Item2, value.Item1); - }); - } - - public IObservable Process(IObservable> source) - { - return source.Select(value => - { - return TransformData(value.Item1, value.Item2); - }); - } - - public IObservable Process(IObservable> source) - { - return source.Select(value => - { - return TransformData(value.Item2, value.Item1); - }); - } - - public IObservable Process(IObservable> source) - { - return source.Select(value => - { - return TransformData(value.Item1, value.Item2); - }); - } - - public IObservable Process(IObservable> source) - { - return source.Select(value => - { - return TransformData(value.Item2, value.Item1); - }); - } - } -} diff --git a/src/Bonsai.ML.PCA/Bonsai.ML.PCA.csproj b/src/Bonsai.ML.Pca.Torch/Bonsai.ML.Pca.Torch.csproj similarity index 70% rename from src/Bonsai.ML.PCA/Bonsai.ML.PCA.csproj rename to src/Bonsai.ML.Pca.Torch/Bonsai.ML.Pca.Torch.csproj index f1e876d2..b97c06fd 100644 --- a/src/Bonsai.ML.PCA/Bonsai.ML.PCA.csproj +++ b/src/Bonsai.ML.Pca.Torch/Bonsai.ML.Pca.Torch.csproj @@ -1,8 +1,8 @@  - Bonsai.ML.PCA Bonsai library. - $(PackageTags) Point Process Neural Decoder + Bonsai.ML.Pca.Torch Bonsai library. + $(PackageTags) PCA Principal Component Analysis net472;netstandard2.0 enable diff --git a/src/Bonsai.ML.Pca.Torch/CreatePca.cs b/src/Bonsai.ML.Pca.Torch/CreatePca.cs new file mode 100644 index 00000000..1914184b --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/CreatePca.cs @@ -0,0 +1,123 @@ +using System; +using System.ComponentModel; +using System.Collections.Generic; +using System.Reactive.Linq; +using System.Linq.Expressions; +using Bonsai.Expressions; +using System.Linq; +using System.Reflection; +using static TorchSharp.torch; +using System.Xml.Serialization; + +namespace Bonsai.ML.Pca.Torch; + +[Combinator] +[ResetCombinator] +[WorkflowElementCategory(ElementCategory.Source)] +[TypeDescriptionProvider(typeof(PcaDescriptionProvider))] +public class CreatePca : ZeroArgumentExpressionBuilder, INamedElement +{ + public string Name => ModelType.ToString(); + + public int NumComponents { get; set; } = 2; + + [XmlIgnore] + public Device Device { get; set; } + public ScalarType? ScalarType { get; set; } + + [RefreshProperties(RefreshProperties.All)] + public PcaModelType ModelType { get; set; } = PcaModelType.Pca; + + public double InitialVariance { get; set; } = 1.0; + public int Iterations { get; set; } = 100; + public double Tolerance { get; set; } = 1e-5; + + public double? Rho { get; set; } = 0.1; + public double? Kappa { get; set; } = 0.9; + public int? TimeOffset { get; set; } = null; + public int? ReorthogonalizePeriod { get; set; } = null; + + [XmlIgnore] + public Generator? Generator { get; set; } = null; + + internal IEnumerable GetModelProperties() + { + yield return nameof(NumComponents); + yield return nameof(Device); + yield return nameof(ScalarType); + yield return nameof(ModelType); + + if (ModelType == PcaModelType.ProbabilisticPca) + { + yield return nameof(InitialVariance); + yield return nameof(Iterations); + yield return nameof(Tolerance); + yield return nameof(Generator); + } + + if (ModelType == PcaModelType.OnlineProbabilisticPca) + { + yield return nameof(InitialVariance); + yield return nameof(Rho); + yield return nameof(Kappa); + yield return nameof(TimeOffset); + yield return nameof(ReorthogonalizePeriod); + yield return nameof(Generator); + } + } + + private static PcaBaseModel CreateModel(CreatePca instance) + { + return instance.ModelType switch + { + PcaModelType.Pca => new Pca( + numComponents: instance.NumComponents, + device: instance.Device, + scalarType: instance.ScalarType), + PcaModelType.ProbabilisticPca => new ProbabilisticPca( + numComponents: instance.NumComponents, + device: instance.Device, + scalarType: instance.ScalarType, + initialVariance: instance.InitialVariance, + generator: instance.Generator, + iterations: instance.Iterations, + tolerance: instance.Tolerance), + PcaModelType.OnlineProbabilisticPca => new OnlineProbabilisticPca( + numComponents: instance.NumComponents, + device: instance.Device, + scalarType: instance.ScalarType, + initialVariance: instance.InitialVariance, + generator: instance.Generator, + rho: instance.Rho, + kappa: instance.Kappa, + timeOffset: instance.TimeOffset, + reorthogonalizePeriod: instance.ReorthogonalizePeriod), + _ => throw new NotSupportedException($"Model type {instance.ModelType} is not supported."), + }; + } + + private static Type GetModelType(PcaModelType modelType) + { + return modelType switch + { + PcaModelType.Pca => typeof(Pca), + PcaModelType.ProbabilisticPca => typeof(ProbabilisticPca), + PcaModelType.OnlineProbabilisticPca => typeof(OnlineProbabilisticPca), + _ => throw new NotSupportedException($"Model type {modelType} is not supported."), + }; + } + + public override Expression Build(IEnumerable arguments) + { + var processMethod = GetType().GetMethod( + nameof(Process), + BindingFlags.NonPublic | BindingFlags.Static); + processMethod = processMethod.MakeGenericMethod(GetModelType(ModelType)); + return Expression.Call(processMethod, Expression.Constant(this)); + } + + private static IObservable Process(CreatePca instance) where T : PcaBaseModel + { + return Observable.Return((T)CreateModel(instance)); + } +} diff --git a/src/Bonsai.ML.Pca.Torch/Fit.cs b/src/Bonsai.ML.Pca.Torch/Fit.cs new file mode 100644 index 00000000..57fee736 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/Fit.cs @@ -0,0 +1,66 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using Bonsai; +using static TorchSharp.torch; + +namespace Bonsai.ML.Pca.Torch; + +[Combinator] +[Description] +[WorkflowElementCategory(ElementCategory.Sink)] +public class Fit +{ + private void FitModel(IPcaBaseModel model, Tensor data) + { + model.Fit(data); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item1, value.Item2); + }); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item2, value.Item1); + }); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item1, value.Item2); + }); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item2, value.Item1); + }); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item1, value.Item2); + }); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item2, value.Item1); + }); + } +} diff --git a/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs b/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs new file mode 100644 index 00000000..5d4d3290 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs @@ -0,0 +1,66 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using Bonsai; +using static TorchSharp.torch; + +namespace Bonsai.ML.Pca.Torch; + +[Combinator] +[Description] +[WorkflowElementCategory(ElementCategory.Transform)] +public class FitAndTransform +{ + private void FitModelAndTransformData(IPcaBaseModel model, Tensor data) + { + model.FitAndTransform(data); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item1, value.Item2); + }); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item2, value.Item1); + }); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item1, value.Item2); + }); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item2, value.Item1); + }); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item1, value.Item2); + }); + } + + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item2, value.Item1); + }); + } +} diff --git a/src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs b/src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs new file mode 100644 index 00000000..495e4b30 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs @@ -0,0 +1,13 @@ +using static TorchSharp.torch; + +namespace Bonsai.ML.Pca.Torch; + +public interface IPcaBaseModel +{ + public Device Device { get; } + public ScalarType ScalarType { get; } + public abstract void Fit(Tensor data); + public abstract Tensor Transform(Tensor data); + public abstract Tensor FitAndTransform(Tensor data); + public abstract Tensor Reconstruct(Tensor data); +} diff --git a/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs b/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs new file mode 100644 index 00000000..bafa2fcb --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs @@ -0,0 +1,246 @@ +using Bonsai; +using System; +using System.ComponentModel; +using System.Collections.Generic; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.Pca.Torch; + +public class OnlineProbabilisticPca : PcaBaseModel +{ + public double? Rho { get; private set; } + public double? Kappa { get; private set; } + public double Variance => _sigma2.to_type(ScalarType.Float64).item(); + public int ReorthogonalizePeriod { get; private set; } + public int? TimeOffset { get; private set; } + public Tensor Components => _W; + public Generator Generator { get; private set; } + + private Tensor _mu; + private Tensor _W; + private Tensor _Iq; + private Tensor _mx; // E[x] + private Tensor _Cxz; // E[xz^T] + private Tensor _mz; // E[z] + private Tensor _Czz; // E[zz^T] + private Tensor _sxx; // E[||x||^2] + private Tensor _sigma2; // Variance + + private bool _initializedParameters = false; + private readonly Func UpdateSchedule; + private int _stepCount = 0; + private readonly bool _reorthogonalize = false; + + public OnlineProbabilisticPca(int numComponents, + Device? device = null, + ScalarType? scalarType = ScalarType.Float32, + double initialVariance = 1.0, + Generator? generator = null, + double? rho = 0.1, + double? kappa = null, + int? timeOffset = 3000, + int? reorthogonalizePeriod = null + ) : base(numComponents, + device, + scalarType) + { + if (initialVariance <= 0) + { + throw new ArgumentException("Starting variance must be greater than or equal to zero.", nameof(initialVariance)); + } + + if (kappa.HasValue && rho.HasValue) + { + throw new ArgumentException("Only one of rho or kappa should be specified, not both.", nameof(rho)); + } + + if (!kappa.HasValue && !rho.HasValue) + { + throw new ArgumentException("Either rho or kappa must be specified.", nameof(rho)); + } + + if (rho.HasValue) + { + UpdateSchedule = () => rho.Value; + if (rho <= 0 || rho >= 1) + { + throw new ArgumentException("Rho must be in the range (0, 1).", nameof(rho)); + } + } + + if (kappa.HasValue) + { + if (timeOffset is null or <= 0) + { + throw new ArgumentException("Time offset must be a positive integer.", nameof(timeOffset)); + } + + UpdateSchedule = () => Math.Pow(_stepCount + timeOffset.Value, -kappa.Value); + if (kappa <= 0.5 || kappa > 1) + { + throw new ArgumentException("Kappa must be in the range (0.5, 1].", nameof(kappa)); + } + } + + if (reorthogonalizePeriod.HasValue) + { + _reorthogonalize = true; + ReorthogonalizePeriod = reorthogonalizePeriod.Value; + } + + Generator = generator ?? manual_seed(0); + Rho = rho; + Kappa = kappa; + TimeOffset = timeOffset; + _sigma2 = initialVariance; + } + + public override void Fit(Tensor data) + { + // throw new NotImplementedException(); + if (data.NumberOfElements == 0 || data.dim() != 2) + { + throw new ArgumentException("Input data must be a 2D tensor."); + } + + using (no_grad()) + using (NewDisposeScope()) + { + + _stepCount++; + var rho = UpdateSchedule(); + + var Xt = data.T; // n x d + + // Initialize dimensions + var q = NumComponents; + var n = Xt.size(0); + var d = Xt.size(1); + + // Initialize parameters + if (!_initializedParameters) + { + _mu = zeros(d, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); + var randW = randn(d, q, generator: Generator, device: Device, dtype: ScalarType); + var orthonormalBases = linalg.qr(randW).Q; + _W = (orthonormalBases * _sigma2).MoveToOuterDisposeScope(); // d x q + _Iq = eye(q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q x q + + _mx = zeros(d, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // d + _Cxz = zeros(d, q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // d x q + _mz = zeros(q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q + _Czz = zeros(q, q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q x q + _sxx = zeros(1, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // scalar + + _initializedParameters = true; + } + + // Covariance matrix + var cov = _Iq * _sigma2; + + // Center data using current mean + var Xc = Xt - _mu; + + // E-step + var M = _W.T.matmul(_W) + cov; + var MInv = Utils.InvertSPD(M, _Iq); + + var XcW = Xc.matmul(_W); + var EzT = Utils.InvertSPD(M, XcW.T); + var Ez = EzT.T; + + // Update statistics + var mx = Xt.mean([0]); + var sxx = Xt.pow(2).sum(dim: 1).mean(); + var Cxz = Xt.T.matmul(Ez) / n; + var mz = Ez.mean([0]); + var Czz = EzT.matmul(Ez) / n + _sigma2 * MInv; + + // Update parameters + var rhoFactor = 1 - rho; + _mx = (rhoFactor * _mx + rho * mx).MoveToOuterDisposeScope(); + _Cxz = (rhoFactor * _Cxz + rho * Cxz).MoveToOuterDisposeScope(); + _mz = (rhoFactor * _mz + rho * mz).MoveToOuterDisposeScope(); + _sxx = (rhoFactor * _sxx + rho * sxx).MoveToOuterDisposeScope(); + _Czz = (rhoFactor * _Czz + rho * Czz).MoveToOuterDisposeScope(); + + // Update mean + _mu = _mx.MoveToOuterDisposeScope(); + + // Centered statistics + var Sxz = _Cxz - outer(_mu, _mz); + var Szz = _Czz; + var Sxx = _sxx - _mu.dot(_mu); + + // M-step + var WNew = Utils.InvertSPD(Szz, Sxz.T).T; + + if (_reorthogonalize && + _stepCount % ReorthogonalizePeriod == 0) + { + var (U, S, Vh) = svd(WNew, fullMatrices: false); + var R = Vh.T; + WNew = U.matmul(diag(S)); + _Cxz = _Cxz.matmul(R.T); + _Czz = R.matmul(_Czz).matmul(R.T); + _mz = R.matmul(_mz); + } + + // Reorder components based on the strength of the components + var strength = sum(WNew * WNew, dim: 0); + var indices = argsort(strength, descending: true); + _W = WNew.index_select(1, indices).MoveToOuterDisposeScope(); + _Cxz = _Cxz.index_select(1, indices).MoveToOuterDisposeScope(); + _mz = _mz.index_select(0, indices).MoveToOuterDisposeScope(); + _Czz = _Czz.index_select(0, indices).index_select(1, indices).MoveToOuterDisposeScope(); + + Sxz = _Cxz - outer(_mu, _mz); + Szz = _Czz; + + // Update variance + _sigma2 = ((Sxx - 2 * trace(_W.T.matmul(Sxz)) + trace(_W.T.matmul(_W).matmul(Szz))) / (double)d) + .clamp_min(0.0) + .MoveToOuterDisposeScope(); + } + } + + public override Tensor Transform(Tensor data) + { + if (data.NumberOfElements == 0 || data.dim() < 2) + { + throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); + } + + var Xt = data.T; // n x d + var Xc = Xt - _mu; // n x d + var M = _W.T.matmul(_W) + _Iq * _sigma2; // q x q + var XcW = Xc.matmul(_W); + return Utils.InvertSPD(M, XcW.T).T; // n x q + } + + public override Tensor Reconstruct(Tensor data) + { + if (data.NumberOfElements == 0 || data.dim() < 2) + { + throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); + } + + if (!_initializedParameters) + { + throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); + } + + var Xt = data.T; // n x d + var Xc = Xt - _mu; // n x d + var M = _W.T.matmul(_W) + _Iq * _sigma2; // q x q + var XcW = Xc.matmul(_W); + var EzT = Utils.InvertSPD(M, XcW.T); + var Ez = EzT.T; + + return Ez.matmul(_W.T) + _mu.T; // n x d + } +} diff --git a/src/Bonsai.ML.Pca.Torch/Pca.cs b/src/Bonsai.ML.Pca.Torch/Pca.cs new file mode 100644 index 00000000..3eb75868 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/Pca.cs @@ -0,0 +1,92 @@ +using Bonsai; +using System; +using System.ComponentModel; +using System.Collections.Generic; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.Pca.Torch; + +public class Pca : PcaBaseModel +{ + public Tensor Covariance { get; private set; } = empty(0); + public Tensor EigenValues { get; private set; } = empty(0); + public Tensor EigenVectors { get; private set; } = empty(0); + public Tensor Components { get; private set; } = empty(0); + private bool _isFitted = false; + + public Pca(int numComponents, + Device? device = null, + ScalarType? scalarType = ScalarType.Float32) + : base(numComponents, + device, + scalarType) + { } + + public override void Fit(Tensor data) + { + if (data.NumberOfElements == 0 || data.dim() < 2) + { + throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); + } + + var n = data.size(0); + var d = data.size(1); + + if (NumComponents > d) + { + throw new ArgumentException("Number of components cannot be greater than the number of features.", nameof(data)); + } + + var Xt = data.T; + var mean = Xt.mean([0], keepdim: true); + var Xc = Xt - mean; + Covariance = cov(Xc.T); + var eigen = eigh(Covariance); + var sortedIndices = argsort(eigen.Item1, dim: -1, descending: true); + EigenValues = eigen.Item1[sortedIndices]; + EigenVectors = eigen.Item2.index_select(1, sortedIndices); + Components = EigenVectors.slice(1, 0, NumComponents, 1); + _isFitted = true; + } + + public override Tensor Transform(Tensor data) + { + if (data.NumberOfElements == 0 || data.dim() < 2) + { + throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); + } + + if (!_isFitted) + { + throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); + } + + var X = data.T; + var mean = X.mean([0], keepdim: true); // 1 x d + var Xc = X - mean; + return Xc.matmul(Components); // n x q + } + + public override Tensor Reconstruct(Tensor data) + { + if (data.NumberOfElements == 0 || data.dim() < 2) + { + throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); + } + + if (!_isFitted) + { + throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); + } + + var X = data.T; + var mean = X.mean([0], keepdim: true); // 1 x d + var Xc = X - mean; + var reconstructed = Transform(Xc); + return reconstructed.matmul(Components.T) + mean.T; + } +} diff --git a/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs b/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs new file mode 100644 index 00000000..c24d35eb --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs @@ -0,0 +1,43 @@ +using Bonsai; +using System; +using System.ComponentModel; +using System.Collections.Generic; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.Pca.Torch; + +public abstract class PcaBaseModel : IPcaBaseModel +{ + public int NumComponents { get; private set; } + + public Device Device { get; private set; } + public ScalarType ScalarType { get; private set; } + + public PcaBaseModel(int numComponents, + Device? device = null, + ScalarType? scalarType = null) + { + if (numComponents <= 0) + { + throw new ArgumentException("Number of components must be greater than zero.", nameof(numComponents)); + } + + NumComponents = numComponents; + Device = device ?? CPU; + ScalarType = scalarType ?? ScalarType.Float32; + } + + public abstract void Fit(Tensor data); + public abstract Tensor Transform(Tensor data); + public virtual Tensor FitAndTransform(Tensor data) + { + Fit(data); + return Transform(data); + } + + public abstract Tensor Reconstruct(Tensor data); +} diff --git a/src/Bonsai.ML.Pca.Torch/PcaDesciptionProvider.cs b/src/Bonsai.ML.Pca.Torch/PcaDesciptionProvider.cs new file mode 100644 index 00000000..7ddcaba2 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/PcaDesciptionProvider.cs @@ -0,0 +1,23 @@ +using System; +using System.ComponentModel; + +namespace Bonsai.ML.Pca.Torch; + +class PcaDescriptionProvider : TypeDescriptionProvider +{ + private readonly TypeDescriptionProvider _baseProvider; + + public PcaDescriptionProvider() : this(TypeDescriptor.GetProvider(typeof(object))) { } + + public PcaDescriptionProvider(TypeDescriptionProvider baseProvider) + : base(baseProvider) + { + _baseProvider = baseProvider; + } + + public override ICustomTypeDescriptor GetTypeDescriptor(Type objectType, object instance) + { + var defaultDescriptor = base.GetTypeDescriptor(objectType, instance); + return new PcaDescriptor(defaultDescriptor, instance); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs b/src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs new file mode 100644 index 00000000..4758008e --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs @@ -0,0 +1,36 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Collections.Generic; +using System.Linq; + +namespace Bonsai.ML.Pca.Torch; + +public class PcaDescriptor : CustomTypeDescriptor +{ + private readonly object _instance; + + public PcaDescriptor(ICustomTypeDescriptor parent, object instance) : base(parent) + { + _instance = instance; + } + + public override PropertyDescriptorCollection GetProperties(Attribute[] attributes) + { + var allProperties = base.GetProperties(attributes); + + if (_instance is CreatePca createPca) + { + var modelProperties = new HashSet(createPca.GetModelProperties()); + var filtered = allProperties.Cast() + .Where(p => modelProperties.Contains(p.Name)) + .ToArray(); + return new PropertyDescriptorCollection(filtered); + } + + return allProperties; + } + + public override PropertyDescriptorCollection GetProperties() + => GetProperties(null); +} diff --git a/src/Bonsai.ML.Pca.Torch/PcaModelType.cs b/src/Bonsai.ML.Pca.Torch/PcaModelType.cs new file mode 100644 index 00000000..1d422a36 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/PcaModelType.cs @@ -0,0 +1,8 @@ +namespace Bonsai.ML.Pca.Torch; + +public enum PcaModelType +{ + Pca, + ProbabilisticPca, + OnlineProbabilisticPca +} diff --git a/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs b/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs new file mode 100644 index 00000000..8f1029c4 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs @@ -0,0 +1,195 @@ +using Bonsai; +using System; +using System.ComponentModel; +using System.Collections.Generic; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.Pca.Torch; + +public class ProbabilisticPca : PcaBaseModel +{ + public double Variance { get; private set; } + public Tensor LogLikelihood { get; private set; } = empty(0); + public Tensor Components { get; private set; } = empty(0); + public Generator Generator { get; private set; } + private int _iterations; + private double _tolerance; + private bool _isFitted = false; + + public ProbabilisticPca(int numComponents, + Device? device = null, + ScalarType? scalarType = ScalarType.Float32, + double initialVariance = 1.0, + Generator? generator = null, + int iterations = 100, + double tolerance = 1e-5 + ) : base(numComponents, + device, + scalarType) + { + if (initialVariance < 0) + { + throw new ArgumentException("Starting variance must be greater than or equal to zero.", nameof(initialVariance)); + } + + if (iterations <= 0) + { + throw new ArgumentException("Number of iterations must be greater than zero.", nameof(iterations)); + } + + if (tolerance <= 0) + { + throw new ArgumentException("Tolerance must be greater than zero.", nameof(tolerance)); + } + + Variance = initialVariance; + Generator = generator ?? manual_seed(0); + _iterations = iterations; + _tolerance = tolerance; + } + + public override void Fit(Tensor data) + { + if (data.NumberOfElements == 0 || data.dim() != 2) + { + throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); + } + + var Xt = data.T; // n x d + + // Initialize variance + var variance = Variance; + + // Initialize log likelihood + LogLikelihood = ones(_iterations, device: Device, dtype: ScalarType) * double.NegativeInfinity; + + // Initialize dimensions for components + var q = NumComponents; + var n = Xt.size(0); + var d = Xt.size(1); + + if (q > d) + { + throw new ArgumentException("Number of components cannot be greater than the number of features.", nameof(data)); + } + + // Initialize W and I + var W = randn(d, q, generator: Generator, device: Device, dtype: ScalarType); // d x q + var Iq = eye(q, device: Device, dtype: ScalarType); // q x q + var Id = eye(d, device: Device, dtype: ScalarType); // d x d + + // Calculate the sample mean + var mean = Xt.mean([0], keepdim: true); // 1 x d + + // Center the data and transpose + var X = Xt - mean; // n x d + + // Calculate the sample covariance + var XTX = X.T.matmul(X); // d x d + var sampleCov = XTX / n; // d x d + + // Calculate term 1 for variance update + var term1 = trace(XTX); + + // Compute log likelihood constant + var logLikelihoodConst = d * log(2 * Math.PI).to(Device).to_type(ScalarType); + + double diffW; + double diffVariance; + + // Repeat until convergence + for (int i = 0; i < _iterations; i++) + { + using (NewDisposeScope()) + { + // E-step: Compute the posterior distribution of the latent variables + var M = W.T.matmul(W) + Iq * variance; // q x q + var MInv = inv(M); // q x q + var mu = MInv.matmul(W.T).matmul(X.T).T; // n x q + var SSum = n * MInv * variance; // q x q + var cov = mu.T.matmul(mu) + SSum; // q x q + + // M-step: Compute new W and new variance + var XMu = X.T.matmul(mu); // d x q + var WNew = XMu.matmul(inv(cov)); // d x q + + var term2 = 2 * XMu.mul(WNew).sum(); + var mumu = mu.T.matmul(mu); + var WNewWNew = WNew.T.matmul(WNew); + var term3 = trace(WNewWNew.matmul(mumu + SSum)); + var varianceNew = (term1 - term2 + term3) / (n * d); // scalar + + // Compute the log likelihood + var C = W.matmul(W.T) + Id * variance; // d x d + var CInv = inv(C); // d x d + var logLikelihood = -0.5 * n * (logLikelihoodConst + logdet(C) + trace(CInv.matmul(sampleCov))); // scalar + + // Check for convergence + diffW = linalg.norm(WNew - W).to_type(ScalarType.Float64).cpu().ReadCpuDouble(0); + diffVariance = abs(varianceNew - variance).to_type(ScalarType.Float64).cpu().ReadCpuDouble(0); + + // Update loglikelihood, W and variance + LogLikelihood[i] = logLikelihood.MoveToOuterDisposeScope(); + W = WNew.MoveToOuterDisposeScope(); + variance = varianceNew.to_type(ScalarType.Float64).cpu().ReadCpuDouble(0); + } + + + if (diffW < _tolerance && diffVariance < _tolerance) + { + LogLikelihood = LogLikelihood.slice(0, 0, i + 1, 1); + break; + } + } + + // Finalize model parameters + LogLikelihood = LogLikelihood.DetachFromDisposeScope(); + Components = W.DetachFromDisposeScope(); + Variance = variance; + _isFitted = true; + } + + public override Tensor Transform(Tensor data) + { + if (data.NumberOfElements == 0 || data.dim() < 2) + { + throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); + } + + if (!_isFitted) + { + throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); + } + + var Xt = data.T; + var mean = Xt.mean([0], keepdim: true); // 1 x d + var X = Xt - mean; // n x d + var W = Components; // d x q + var M = W.T.matmul(W) + eye(NumComponents) * Variance; // q x q + var MInv = Utils.InvertSPD(M, eye(NumComponents)); // q x q + return X.matmul(W).matmul(MInv); // n x q + } + + public override Tensor Reconstruct(Tensor data) + { + if (data.NumberOfElements == 0 || data.dim() < 2) + { + throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); + } + + if (!_isFitted) + { + throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); + } + + var Xt = data.T; + var mean = Xt.mean([0], keepdim: true); // 1 x d + var Xc = Xt - mean; // n x d + var W = Components; // d x q + return Xc.matmul(W).matmul(W.T) + mean.T; // n x d + } +} diff --git a/src/Bonsai.ML.PCA/Properties/launchSettings.json b/src/Bonsai.ML.Pca.Torch/Properties/launchSettings.json similarity index 100% rename from src/Bonsai.ML.PCA/Properties/launchSettings.json rename to src/Bonsai.ML.Pca.Torch/Properties/launchSettings.json diff --git a/src/Bonsai.ML.Pca.Torch/Reconstruct.cs b/src/Bonsai.ML.Pca.Torch/Reconstruct.cs new file mode 100644 index 00000000..64f0d242 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/Reconstruct.cs @@ -0,0 +1,66 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using Bonsai; +using static TorchSharp.torch; + +namespace Bonsai.ML.Pca.Torch; + +[Combinator] +[Description] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Reconstruct +{ + private Tensor ReconstructData(IPcaBaseModel model, Tensor data) + { + return model.Reconstruct(data); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item1, value.Item2); + }); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item2, value.Item1); + }); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item1, value.Item2); + }); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item2, value.Item1); + }); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item1, value.Item2); + }); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item2, value.Item1); + }); + } +} diff --git a/src/Bonsai.ML.Pca.Torch/Transform.cs b/src/Bonsai.ML.Pca.Torch/Transform.cs new file mode 100644 index 00000000..5cc94769 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/Transform.cs @@ -0,0 +1,66 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using Bonsai; +using static TorchSharp.torch; + +namespace Bonsai.ML.Pca.Torch; + +[Combinator] +[Description] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Transform +{ + private Tensor TransformData(IPcaBaseModel model, Tensor data) + { + return model.Transform(data); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item1, value.Item2); + }); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item2, value.Item1); + }); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item1, value.Item2); + }); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item2, value.Item1); + }); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item1, value.Item2); + }); + } + + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item2, value.Item1); + }); + } +} diff --git a/src/Bonsai.ML.PCA/Utils.cs b/src/Bonsai.ML.Pca.Torch/Utils.cs similarity index 95% rename from src/Bonsai.ML.PCA/Utils.cs rename to src/Bonsai.ML.Pca.Torch/Utils.cs index bc3ab7e5..c22bc775 100644 --- a/src/Bonsai.ML.PCA/Utils.cs +++ b/src/Bonsai.ML.Pca.Torch/Utils.cs @@ -2,7 +2,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.PCA; +namespace Bonsai.ML.Pca.Torch; internal static class Utils { diff --git a/tests/Bonsai.ML.PCA.Tests/Bonsai.ML.PCA.Tests.csproj b/tests/Bonsai.ML.Pca.Torch.Tests/Bonsai.ML.Pca.Torch.Tests.csproj similarity index 88% rename from tests/Bonsai.ML.PCA.Tests/Bonsai.ML.PCA.Tests.csproj rename to tests/Bonsai.ML.Pca.Torch.Tests/Bonsai.ML.Pca.Torch.Tests.csproj index e628dd82..5c8c20ab 100644 --- a/tests/Bonsai.ML.PCA.Tests/Bonsai.ML.PCA.Tests.csproj +++ b/tests/Bonsai.ML.Pca.Torch.Tests/Bonsai.ML.Pca.Torch.Tests.csproj @@ -16,6 +16,6 @@ - + \ No newline at end of file diff --git a/tests/Bonsai.ML.PCA.Tests/CreatePCATest.bonsai b/tests/Bonsai.ML.Pca.Torch.Tests/CreatePCATest.bonsai similarity index 100% rename from tests/Bonsai.ML.PCA.Tests/CreatePCATest.bonsai rename to tests/Bonsai.ML.Pca.Torch.Tests/CreatePCATest.bonsai diff --git a/tests/Bonsai.ML.PCA.Tests/FitPCATest.bonsai b/tests/Bonsai.ML.Pca.Torch.Tests/FitPCATest.bonsai similarity index 100% rename from tests/Bonsai.ML.PCA.Tests/FitPCATest.bonsai rename to tests/Bonsai.ML.Pca.Torch.Tests/FitPCATest.bonsai diff --git a/tests/Bonsai.ML.PCA.Tests/TransformPCATest.bonsai b/tests/Bonsai.ML.Pca.Torch.Tests/TransformPCATest.bonsai similarity index 100% rename from tests/Bonsai.ML.PCA.Tests/TransformPCATest.bonsai rename to tests/Bonsai.ML.Pca.Torch.Tests/TransformPCATest.bonsai From 7268686a6b6f6bbe4b50232d4865cd0976cd8479 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 10 Nov 2025 17:10:45 +0000 Subject: [PATCH 13/20] Added XML documentation to code --- src/Bonsai.ML.Pca.Torch/CreatePca.cs | 104 ++++++++++++++---- src/Bonsai.ML.Pca.Torch/Fit.cs | 65 ++++++++++- src/Bonsai.ML.Pca.Torch/FitAndTransform.cs | 63 ++++++++++- src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs | 55 +++++++++ .../OnlineProbabilisticPca.cs | 102 +++++++++++++---- src/Bonsai.ML.Pca.Torch/Pca.cs | 37 +++++-- src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs | 23 ++++ .../PcaDesciptionProvider.cs | 21 ++-- src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs | 21 ++-- src/Bonsai.ML.Pca.Torch/PcaModelType.cs | 14 +++ src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs | 55 ++++++++- src/Bonsai.ML.Pca.Torch/Reconstruct.cs | 68 +++++++++++- src/Bonsai.ML.Pca.Torch/Transform.cs | 68 +++++++++++- 13 files changed, 610 insertions(+), 86 deletions(-) diff --git a/src/Bonsai.ML.Pca.Torch/CreatePca.cs b/src/Bonsai.ML.Pca.Torch/CreatePca.cs index 1914184b..394c35b4 100644 --- a/src/Bonsai.ML.Pca.Torch/CreatePca.cs +++ b/src/Bonsai.ML.Pca.Torch/CreatePca.cs @@ -11,32 +11,89 @@ namespace Bonsai.ML.Pca.Torch; +/// +/// Creates a PCA model. +/// [Combinator] [ResetCombinator] [WorkflowElementCategory(ElementCategory.Source)] [TypeDescriptionProvider(typeof(PcaDescriptionProvider))] +[Description("Creates a PCA model.")] public class CreatePca : ZeroArgumentExpressionBuilder, INamedElement { + /// public string Name => ModelType.ToString(); + /// + /// The number of principal components to compute. + /// public int NumComponents { get; set; } = 2; + /// + /// The device on which to create the PCA model. + /// [XmlIgnore] - public Device Device { get; set; } + [Description("The device on which to create the PCA model.")] + public Device? Device { get; set; } + + /// + /// The scalar type of the PCA model. + /// + [Description("The scalar type of the PCA model.")] public ScalarType? ScalarType { get; set; } + /// + /// The type of PCA model to create. + /// [RefreshProperties(RefreshProperties.All)] + [Description("The type of PCA model to create.")] public PcaModelType ModelType { get; set; } = PcaModelType.Pca; + /// + /// The initial variance for probabilistic PCA models. + /// + [Description("The initial variance for probabilistic PCA models.")] public double InitialVariance { get; set; } = 1.0; + + /// + /// The number of iterations for fitting probabilistic PCA models. + /// + [Description("The number of iterations for fitting probabilistic PCA models.")] public int Iterations { get; set; } = 100; + + /// + /// The tolerance for convergence in probabilistic PCA models. + /// + [Description("The tolerance for convergence in probabilistic PCA models.")] public double Tolerance { get; set; } = 1e-5; + /// + /// The constant learning rate parameter for the online probabilistic PCA model. + /// + [Description("The constant learning rate parameter for the online probabilistic PCA model. Only one of Rho or Kappa must be specified.")] public double? Rho { get; set; } = 0.1; + + /// + /// The forgetting factor for the online probabilistic PCA model. + /// + [Description("The forgetting factor for the online probabilistic PCA model. Only one of Rho or Kappa must be specified.")] public double? Kappa { get; set; } = 0.9; + + /// + /// The time offset for the online probabilistic PCA model. + /// + [Description("The time offset for the online probabilistic PCA model. If null, decaying learning rate starts from the first sample.")] public int? TimeOffset { get; set; } = null; + + /// + /// The period for reorthogonalizing the components in the online probabilistic PCA model. + /// + [Description("The period for reorthogonalizing the components in the online probabilistic PCA model. If null, reorthogonalization is not performed.")] public int? ReorthogonalizePeriod { get; set; } = null; + /// + /// The random number generator used for initializing probabilistic PCA models. + /// [XmlIgnore] public Generator? Generator { get; set; } = null; @@ -66,33 +123,33 @@ internal IEnumerable GetModelProperties() } } - private static PcaBaseModel CreateModel(CreatePca instance) + private static PcaBaseModel CreateModel(CreatePca pcaBuilder) { - return instance.ModelType switch + return pcaBuilder.ModelType switch { PcaModelType.Pca => new Pca( - numComponents: instance.NumComponents, - device: instance.Device, - scalarType: instance.ScalarType), + numComponents: pcaBuilder.NumComponents, + device: pcaBuilder.Device, + scalarType: pcaBuilder.ScalarType), PcaModelType.ProbabilisticPca => new ProbabilisticPca( - numComponents: instance.NumComponents, - device: instance.Device, - scalarType: instance.ScalarType, - initialVariance: instance.InitialVariance, - generator: instance.Generator, - iterations: instance.Iterations, - tolerance: instance.Tolerance), + numComponents: pcaBuilder.NumComponents, + device: pcaBuilder.Device, + scalarType: pcaBuilder.ScalarType, + initialVariance: pcaBuilder.InitialVariance, + generator: pcaBuilder.Generator, + iterations: pcaBuilder.Iterations, + tolerance: pcaBuilder.Tolerance), PcaModelType.OnlineProbabilisticPca => new OnlineProbabilisticPca( - numComponents: instance.NumComponents, - device: instance.Device, - scalarType: instance.ScalarType, - initialVariance: instance.InitialVariance, - generator: instance.Generator, - rho: instance.Rho, - kappa: instance.Kappa, - timeOffset: instance.TimeOffset, - reorthogonalizePeriod: instance.ReorthogonalizePeriod), - _ => throw new NotSupportedException($"Model type {instance.ModelType} is not supported."), + numComponents: pcaBuilder.NumComponents, + device: pcaBuilder.Device, + scalarType: pcaBuilder.ScalarType, + initialVariance: pcaBuilder.InitialVariance, + generator: pcaBuilder.Generator, + rho: pcaBuilder.Rho, + kappa: pcaBuilder.Kappa, + timeOffset: pcaBuilder.TimeOffset, + reorthogonalizePeriod: pcaBuilder.ReorthogonalizePeriod), + _ => throw new NotSupportedException($"Model type {pcaBuilder.ModelType} is not supported."), }; } @@ -107,6 +164,7 @@ private static Type GetModelType(PcaModelType modelType) }; } + /// public override Expression Build(IEnumerable arguments) { var processMethod = GetType().GetMethod( diff --git a/src/Bonsai.ML.Pca.Torch/Fit.cs b/src/Bonsai.ML.Pca.Torch/Fit.cs index 57fee736..a566e659 100644 --- a/src/Bonsai.ML.Pca.Torch/Fit.cs +++ b/src/Bonsai.ML.Pca.Torch/Fit.cs @@ -1,21 +1,55 @@ using System; using System.ComponentModel; using System.Reactive.Linq; +using System.Xml.Serialization; using Bonsai; using static TorchSharp.torch; namespace Bonsai.ML.Pca.Torch; +/// +/// Fits a PCA model to the input data. +/// [Combinator] -[Description] +[Description("Fits a PCA model to the input data.")] [WorkflowElementCategory(ElementCategory.Sink)] public class Fit { + /// + /// The PCA model used to fit the input data. + /// + [Description("The PCA model used to fit the input data.")] + [XmlIgnore] + public IPcaBaseModel? Model { get; set; } + private void FitModel(IPcaBaseModel model, Tensor data) { model.Fit(data); } + /// + /// Fits the PCA model to the input data. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + if (Model == null) + { + throw new InvalidOperationException("The PCA model has not been specified."); + } + return source.Do(value => + { + FitModel(Model, value); + }); + } + + /// + /// Fits the PCA model to the input data. + /// + /// + /// public IObservable> Process(IObservable> source) { return source.Do((value) => @@ -24,7 +58,12 @@ public IObservable> Process(IObservable> s }); } - public IObservable> Process(IObservable> source) + /// + /// Fits the PCA model to the input data. + /// + /// + /// + public IObservable> Process(IObservable> source) { return source.Do((value) => { @@ -32,6 +71,11 @@ public IObservable> Process(IObservable> sou }); } + /// + /// Fits the PCA model to the input data. + /// + /// + /// public IObservable> Process(IObservable> source) { return source.Do((value) => @@ -40,6 +84,11 @@ public IObservable> Process(IObservable + /// Fits the PCA model to the input data. + /// + /// + /// public IObservable> Process(IObservable> source) { return source.Do((value) => @@ -48,6 +97,11 @@ public IObservable> Process(IObservable + /// Fits the PCA model to the input data. + /// + /// + /// public IObservable> Process(IObservable> source) { return source.Do((value) => @@ -55,7 +109,12 @@ public IObservable> Process(IObservable + /// Fits the PCA model to the input data. + /// + /// + /// public IObservable> Process(IObservable> source) { return source.Do((value) => diff --git a/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs b/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs index 5d4d3290..d8ecbe3a 100644 --- a/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs +++ b/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs @@ -1,21 +1,55 @@ using System; using System.ComponentModel; using System.Reactive.Linq; +using System.Xml.Serialization; using Bonsai; using static TorchSharp.torch; namespace Bonsai.ML.Pca.Torch; +/// +/// Fits the PCA model to the input data and transforms it. +/// [Combinator] -[Description] +[Description("Fits the PCA model to the input data and transforms it.")] [WorkflowElementCategory(ElementCategory.Transform)] public class FitAndTransform { - private void FitModelAndTransformData(IPcaBaseModel model, Tensor data) + /// + /// The PCA model used to fit and transform the input data. + /// + [Description("The PCA model used to fit and transform the input data.")] + [XmlIgnore] + public IPcaBaseModel? Model { get; set; } + + private static void FitModelAndTransformData(IPcaBaseModel model, Tensor data) { model.FitAndTransform(data); } + + /// + /// Fits the PCA model to the input data and transforms it. + /// + /// + /// + public IObservable Process(IObservable source) + { + if (Model == null) + { + throw new InvalidOperationException("The PCA model has not been specified."); + } + return source.Do(value => + { + FitModelAndTransformData(Model, value); + }); + } + + /// + /// Fits the PCA model to the input data and transforms it. + /// + /// + /// public IObservable> Process(IObservable> source) { return source.Do((value) => @@ -24,6 +58,11 @@ public IObservable> Process(IObservable> s }); } + /// + /// Fits the PCA model to the input data and transforms it. + /// + /// + /// public IObservable> Process(IObservable> source) { return source.Do((value) => @@ -32,6 +71,11 @@ public IObservable> Process(IObservable> s }); } + /// + /// Fits the PCA model to the input data and transforms it. + /// + /// + /// public IObservable> Process(IObservable> source) { return source.Do((value) => @@ -40,6 +84,11 @@ public IObservable> Process(IObservable + /// Fits the PCA model to the input data and transforms it. + /// + /// + /// public IObservable> Process(IObservable> source) { return source.Do((value) => @@ -48,6 +97,11 @@ public IObservable> Process(IObservable + /// Fits the PCA model to the input data and transforms it. + /// + /// + /// public IObservable> Process(IObservable> source) { return source.Do((value) => @@ -56,6 +110,11 @@ public IObservable> Process(IObservable + /// Fits the PCA model to the input data and transforms it. + /// + /// + /// public IObservable> Process(IObservable> source) { return source.Do((value) => diff --git a/src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs b/src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs index 495e4b30..7d70bb46 100644 --- a/src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs +++ b/src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs @@ -2,12 +2,67 @@ namespace Bonsai.ML.Pca.Torch; +/// +/// Defines the interface for PCA models. +/// public interface IPcaBaseModel { + /// + /// Gets the principal components of the model. + /// + public abstract Tensor Components { get; } + + /// + /// Gets the number of principal components kept by the model. + /// + public int NumComponents { get; } + + /// + /// Gets the device on which the model operates. + /// public Device Device { get; } + + /// + /// Gets the data type used by the model. + /// public ScalarType ScalarType { get; } + + /// + /// Fits the PCA model to the given data. + /// + /// + /// The input data should be a 2D tensor with shape (samples x features). + /// + /// public abstract void Fit(Tensor data); + + /// + /// Transforms the input data using the PCA model. + /// + /// + /// The input data should be a 2D tensor with shape (samples x features). + /// + /// + /// public abstract Tensor Transform(Tensor data); + + /// + /// Fits the PCA model to the given data and transforms it. + /// + /// + /// The input data should be a 2D tensor with shape (samples x features). + /// + /// + /// public abstract Tensor FitAndTransform(Tensor data); + + /// + /// Reconstructs the input data using the PCA model. + /// + /// + /// The input data should be a 2D tensor with shape (samples x features). + /// + /// + /// public abstract Tensor Reconstruct(Tensor data); } diff --git a/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs b/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs index bafa2fcb..073805c8 100644 --- a/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs +++ b/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs @@ -10,24 +10,62 @@ namespace Bonsai.ML.Pca.Torch; +/// +/// Implements an online Probabilistic Principal Component Analysis (PPCA) model using stochastic online EM. +/// public class OnlineProbabilisticPca : PcaBaseModel { + /// + /// Rho is a constant learning rate parameter. + /// + /// + /// Rho must be in the range (0, 1). Only one of Rho or Kappa should be specified. + /// public double? Rho { get; private set; } + + /// + /// Kappa is the exponent in the learning rate schedule. + /// + /// + /// Kappa must be in the range (0.5, 1]. Only one + /// of Rho or Kappa should be specified. + /// public double? Kappa { get; private set; } + + /// + /// Gets the variance of the isotropic Gaussian noise model. + /// public double Variance => _sigma2.to_type(ScalarType.Float64).item(); + + /// + /// Gets the period for reorthogonalizing the principal components. + /// + /// + /// Represented as the number of update steps between reorthogonalization operations. + /// If not specified, reorthogonalization is not performed. + /// public int ReorthogonalizePeriod { get; private set; } + + /// + /// Gets the time offset used in the learning rate schedule when Kappa is specified. + /// public int? TimeOffset { get; private set; } - public Tensor Components => _W; + + /// + public override Tensor Components { get; protected set; } = empty(0); + + /// + /// Gets the random number generator used for initializing the model. + /// public Generator Generator { get; private set; } - private Tensor _mu; - private Tensor _W; - private Tensor _Iq; - private Tensor _mx; // E[x] - private Tensor _Cxz; // E[xz^T] - private Tensor _mz; // E[z] - private Tensor _Czz; // E[zz^T] - private Tensor _sxx; // E[||x||^2] + private Tensor _mu = empty(0); + private Tensor _Iq = empty(0); + private Tensor _mx = empty(0); // E[x] + private Tensor _Cxz = empty(0); // E[xz^T] + private Tensor _mz = empty(0); // E[z] + private Tensor _Czz = empty(0); // E[zz^T] + private Tensor _sxx = empty(0); // E[||x||^2] private Tensor _sigma2; // Variance private bool _initializedParameters = false; @@ -35,6 +73,19 @@ public class OnlineProbabilisticPca : PcaBaseModel private int _stepCount = 0; private readonly bool _reorthogonalize = false; + /// + /// Initializes a new instance of the class. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// public OnlineProbabilisticPca(int numComponents, Device? device = null, ScalarType? scalarType = ScalarType.Float32, @@ -71,14 +122,18 @@ public OnlineProbabilisticPca(int numComponents, throw new ArgumentException("Rho must be in the range (0, 1).", nameof(rho)); } } - - if (kappa.HasValue) + else { if (timeOffset is null or <= 0) { throw new ArgumentException("Time offset must be a positive integer.", nameof(timeOffset)); } + if (!kappa.HasValue) + { + throw new ArgumentException("Kappa must be specified when using a learning rate schedule.", nameof(kappa)); + } + UpdateSchedule = () => Math.Pow(_stepCount + timeOffset.Value, -kappa.Value); if (kappa <= 0.5 || kappa > 1) { @@ -99,6 +154,7 @@ public OnlineProbabilisticPca(int numComponents, _sigma2 = initialVariance; } + /// public override void Fit(Tensor data) { // throw new NotImplementedException(); @@ -127,7 +183,7 @@ public override void Fit(Tensor data) _mu = zeros(d, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); var randW = randn(d, q, generator: Generator, device: Device, dtype: ScalarType); var orthonormalBases = linalg.qr(randW).Q; - _W = (orthonormalBases * _sigma2).MoveToOuterDisposeScope(); // d x q + Components = (orthonormalBases * _sigma2).MoveToOuterDisposeScope(); // d x q _Iq = eye(q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q x q _mx = zeros(d, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // d @@ -146,10 +202,10 @@ public override void Fit(Tensor data) var Xc = Xt - _mu; // E-step - var M = _W.T.matmul(_W) + cov; + var M = Components.T.matmul(Components) + cov; var MInv = Utils.InvertSPD(M, _Iq); - var XcW = Xc.matmul(_W); + var XcW = Xc.matmul(Components); var EzT = Utils.InvertSPD(M, XcW.T); var Ez = EzT.T; @@ -193,7 +249,7 @@ public override void Fit(Tensor data) // Reorder components based on the strength of the components var strength = sum(WNew * WNew, dim: 0); var indices = argsort(strength, descending: true); - _W = WNew.index_select(1, indices).MoveToOuterDisposeScope(); + Components = WNew.index_select(1, indices).MoveToOuterDisposeScope(); _Cxz = _Cxz.index_select(1, indices).MoveToOuterDisposeScope(); _mz = _mz.index_select(0, indices).MoveToOuterDisposeScope(); _Czz = _Czz.index_select(0, indices).index_select(1, indices).MoveToOuterDisposeScope(); @@ -202,12 +258,13 @@ public override void Fit(Tensor data) Szz = _Czz; // Update variance - _sigma2 = ((Sxx - 2 * trace(_W.T.matmul(Sxz)) + trace(_W.T.matmul(_W).matmul(Szz))) / (double)d) + _sigma2 = ((Sxx - 2 * trace(Components.T.matmul(Sxz)) + trace(Components.T.matmul(Components).matmul(Szz))) / (double)d) .clamp_min(0.0) .MoveToOuterDisposeScope(); } } + /// public override Tensor Transform(Tensor data) { if (data.NumberOfElements == 0 || data.dim() < 2) @@ -217,11 +274,12 @@ public override Tensor Transform(Tensor data) var Xt = data.T; // n x d var Xc = Xt - _mu; // n x d - var M = _W.T.matmul(_W) + _Iq * _sigma2; // q x q - var XcW = Xc.matmul(_W); + var M = Components.T.matmul(Components) + _Iq * _sigma2; // q x q + var XcW = Xc.matmul(Components); return Utils.InvertSPD(M, XcW.T).T; // n x q } - + + /// public override Tensor Reconstruct(Tensor data) { if (data.NumberOfElements == 0 || data.dim() < 2) @@ -236,11 +294,11 @@ public override Tensor Reconstruct(Tensor data) var Xt = data.T; // n x d var Xc = Xt - _mu; // n x d - var M = _W.T.matmul(_W) + _Iq * _sigma2; // q x q - var XcW = Xc.matmul(_W); + var M = Components.T.matmul(Components) + _Iq * _sigma2; // q x q + var XcW = Xc.matmul(Components); var EzT = Utils.InvertSPD(M, XcW.T); var Ez = EzT.T; - return Ez.matmul(_W.T) + _mu.T; // n x d + return Ez.matmul(Components.T) + _mu.T; // n x d } } diff --git a/src/Bonsai.ML.Pca.Torch/Pca.cs b/src/Bonsai.ML.Pca.Torch/Pca.cs index 3eb75868..e6400673 100644 --- a/src/Bonsai.ML.Pca.Torch/Pca.cs +++ b/src/Bonsai.ML.Pca.Torch/Pca.cs @@ -10,22 +10,36 @@ namespace Bonsai.ML.Pca.Torch; -public class Pca : PcaBaseModel +/// +/// Implements a standard Principal Component Analysis (PCA) model. +/// +public class Pca(int numComponents, + Device? device = null, + ScalarType? scalarType = ScalarType.Float32) : PcaBaseModel(numComponents, + device, + scalarType) { + /// + /// Gets the covariance matrix of the fitted data. + /// public Tensor Covariance { get; private set; } = empty(0); + + /// + /// Gets the eigenvalues of the covariance matrix. + /// public Tensor EigenValues { get; private set; } = empty(0); + + /// + /// Gets the eigenvectors of the covariance matrix. + /// public Tensor EigenVectors { get; private set; } = empty(0); - public Tensor Components { get; private set; } = empty(0); - private bool _isFitted = false; - public Pca(int numComponents, - Device? device = null, - ScalarType? scalarType = ScalarType.Float32) - : base(numComponents, - device, - scalarType) - { } + /// + public override Tensor Components { get; protected set; } = empty(0); + + private bool _isFitted = false; + /// public override void Fit(Tensor data) { if (data.NumberOfElements == 0 || data.dim() < 2) @@ -33,7 +47,6 @@ public override void Fit(Tensor data) throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); } - var n = data.size(0); var d = data.size(1); if (NumComponents > d) @@ -53,6 +66,7 @@ public override void Fit(Tensor data) _isFitted = true; } + /// public override Tensor Transform(Tensor data) { if (data.NumberOfElements == 0 || data.dim() < 2) @@ -71,6 +85,7 @@ public override Tensor Transform(Tensor data) return Xc.matmul(Components); // n x q } + /// public override Tensor Reconstruct(Tensor data) { if (data.NumberOfElements == 0 || data.dim() < 2) diff --git a/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs b/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs index c24d35eb..262ad066 100644 --- a/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs +++ b/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs @@ -10,13 +10,30 @@ namespace Bonsai.ML.Pca.Torch; +/// +/// Provides an abstract base class for PCA models. +/// public abstract class PcaBaseModel : IPcaBaseModel { + /// + public abstract Tensor Components { get; protected set; } + + /// public int NumComponents { get; private set; } + /// public Device Device { get; private set; } + + /// public ScalarType ScalarType { get; private set; } + /// + /// Initializes a new instance of the class. + /// + /// + /// + /// + /// public PcaBaseModel(int numComponents, Device? device = null, ScalarType? scalarType = null) @@ -31,13 +48,19 @@ public PcaBaseModel(int numComponents, ScalarType = scalarType ?? ScalarType.Float32; } + /// public abstract void Fit(Tensor data); + + /// public abstract Tensor Transform(Tensor data); + + /// public virtual Tensor FitAndTransform(Tensor data) { Fit(data); return Transform(data); } + /// public abstract Tensor Reconstruct(Tensor data); } diff --git a/src/Bonsai.ML.Pca.Torch/PcaDesciptionProvider.cs b/src/Bonsai.ML.Pca.Torch/PcaDesciptionProvider.cs index 7ddcaba2..39b50eb0 100644 --- a/src/Bonsai.ML.Pca.Torch/PcaDesciptionProvider.cs +++ b/src/Bonsai.ML.Pca.Torch/PcaDesciptionProvider.cs @@ -3,18 +3,21 @@ namespace Bonsai.ML.Pca.Torch; -class PcaDescriptionProvider : TypeDescriptionProvider +/// +/// Provides a custom type description provider for PCA models. +/// +/// +/// Initializes a new instance of the class. +/// +/// +public class PcaDescriptionProvider(TypeDescriptionProvider baseProvider) : TypeDescriptionProvider(baseProvider) { - private readonly TypeDescriptionProvider _baseProvider; - + /// + /// Initializes a new instance of the class. + /// public PcaDescriptionProvider() : this(TypeDescriptor.GetProvider(typeof(object))) { } - public PcaDescriptionProvider(TypeDescriptionProvider baseProvider) - : base(baseProvider) - { - _baseProvider = baseProvider; - } - + /// public override ICustomTypeDescriptor GetTypeDescriptor(Type objectType, object instance) { var defaultDescriptor = base.GetTypeDescriptor(objectType, instance); diff --git a/src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs b/src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs index 4758008e..b66fd47d 100644 --- a/src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs +++ b/src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs @@ -6,15 +6,19 @@ namespace Bonsai.ML.Pca.Torch; -public class PcaDescriptor : CustomTypeDescriptor +/// +/// Provides a custom type descriptor for PCA model creation. +/// +/// +/// Initializes a new instance of the class. +/// +/// +/// +public class PcaDescriptor(ICustomTypeDescriptor parent, object instance) : CustomTypeDescriptor(parent) { - private readonly object _instance; - - public PcaDescriptor(ICustomTypeDescriptor parent, object instance) : base(parent) - { - _instance = instance; - } + private readonly object _instance = instance; + /// public override PropertyDescriptorCollection GetProperties(Attribute[] attributes) { var allProperties = base.GetProperties(attributes); @@ -31,6 +35,7 @@ public override PropertyDescriptorCollection GetProperties(Attribute[] attribute return allProperties; } + /// public override PropertyDescriptorCollection GetProperties() - => GetProperties(null); + => GetProperties([]); } diff --git a/src/Bonsai.ML.Pca.Torch/PcaModelType.cs b/src/Bonsai.ML.Pca.Torch/PcaModelType.cs index 1d422a36..83a4fabd 100644 --- a/src/Bonsai.ML.Pca.Torch/PcaModelType.cs +++ b/src/Bonsai.ML.Pca.Torch/PcaModelType.cs @@ -1,8 +1,22 @@ namespace Bonsai.ML.Pca.Torch; +/// +/// Specifies the type of PCA model. +/// public enum PcaModelType { + /// + /// Standard PCA model. + /// Pca, + + /// + /// Probabilistic PCA model. + /// ProbabilisticPca, + + /// + /// Online Probabilistic PCA model. + /// OnlineProbabilisticPca } diff --git a/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs b/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs index 8f1029c4..e9f3b8bc 100644 --- a/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs +++ b/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs @@ -10,16 +10,44 @@ namespace Bonsai.ML.Pca.Torch; +/// +/// Probabilistic Principal Component Analysis (PPCA) model. +/// public class ProbabilisticPca : PcaBaseModel { + /// + /// Gets the variance of the isotropic Gaussian noise model. + /// public double Variance { get; private set; } + + /// + /// Gets the log likelihood of the fitted model. + /// public Tensor LogLikelihood { get; private set; } = empty(0); - public Tensor Components { get; private set; } = empty(0); + + /// + public override Tensor Components { get; protected set; } = empty(0); + + /// + /// Gets the random number generator used for initializing the model. + /// public Generator Generator { get; private set; } - private int _iterations; - private double _tolerance; + + private readonly int _iterations; + private readonly double _tolerance; private bool _isFitted = false; + /// + /// Initializes a new instance of the class. + /// + /// + /// + /// + /// + /// + /// + /// + /// public ProbabilisticPca(int numComponents, Device? device = null, ScalarType? scalarType = ScalarType.Float32, @@ -52,6 +80,11 @@ public ProbabilisticPca(int numComponents, _tolerance = tolerance; } + /// + /// Fits the PPCA model to the input data. + /// + /// + /// public override void Fit(Tensor data) { if (data.NumberOfElements == 0 || data.dim() != 2) @@ -153,6 +186,13 @@ public override void Fit(Tensor data) _isFitted = true; } + /// + /// Transforms the input data using the fitted PPCA model. + /// + /// + /// + /// + /// public override Tensor Transform(Tensor data) { if (data.NumberOfElements == 0 || data.dim() < 2) @@ -173,7 +213,14 @@ public override Tensor Transform(Tensor data) var MInv = Utils.InvertSPD(M, eye(NumComponents)); // q x q return X.matmul(W).matmul(MInv); // n x q } - + + /// + /// Reconstructs the input data using the fitted PPCA model. + /// + /// + /// + /// + /// public override Tensor Reconstruct(Tensor data) { if (data.NumberOfElements == 0 || data.dim() < 2) diff --git a/src/Bonsai.ML.Pca.Torch/Reconstruct.cs b/src/Bonsai.ML.Pca.Torch/Reconstruct.cs index 64f0d242..bbeb57a9 100644 --- a/src/Bonsai.ML.Pca.Torch/Reconstruct.cs +++ b/src/Bonsai.ML.Pca.Torch/Reconstruct.cs @@ -1,21 +1,60 @@ using System; using System.ComponentModel; using System.Reactive.Linq; +using System.Xml.Serialization; using Bonsai; using static TorchSharp.torch; namespace Bonsai.ML.Pca.Torch; +/// +/// Reconstructs the input data using a PCA model. +/// [Combinator] -[Description] +[Description("Reconstructs the input data using a PCA model.")] [WorkflowElementCategory(ElementCategory.Transform)] public class Reconstruct { - private Tensor ReconstructData(IPcaBaseModel model, Tensor data) + /// + /// The PCA model used to reconstruct the input data. + /// + [Description("The PCA model used to reconstruct the input data.")] + [XmlIgnore] + public IPcaBaseModel? Model { get; set; } + + /// + /// Reconstructs the input data using the specified PCA model. + /// + /// + /// + /// + private static Tensor ReconstructData(IPcaBaseModel model, Tensor data) { return model.Reconstruct(data); } + /// + /// Reconstructs the input data using the specified PCA model. + /// + /// + /// + public IObservable Process(IObservable source) + { + if (Model == null) + { + throw new InvalidOperationException("The PCA model has not been specified."); + } + return source.Select(value => + { + return ReconstructData(Model, value); + }); + } + + /// + /// Reconstructs the input data using the specified PCA model. + /// + /// + /// public IObservable Process(IObservable> source) { return source.Select(value => @@ -24,6 +63,11 @@ public IObservable Process(IObservable> source) }); } + /// + /// Reconstructs the input data using the specified PCA model. + /// + /// + /// public IObservable Process(IObservable> source) { return source.Select(value => @@ -32,6 +76,11 @@ public IObservable Process(IObservable> source) }); } + /// + /// Reconstructs the input data using the specified PCA model. + /// + /// + /// public IObservable Process(IObservable> source) { return source.Select(value => @@ -40,6 +89,11 @@ public IObservable Process(IObservable> }); } + /// + /// Reconstructs the input data using the specified PCA model. + /// + /// + /// public IObservable Process(IObservable> source) { return source.Select(value => @@ -48,6 +102,11 @@ public IObservable Process(IObservable> }); } + /// + /// Reconstructs the input data using the specified PCA model. + /// + /// + /// public IObservable Process(IObservable> source) { return source.Select(value => @@ -56,6 +115,11 @@ public IObservable Process(IObservable + /// Reconstructs the input data using the specified PCA model. + /// + /// + /// public IObservable Process(IObservable> source) { return source.Select(value => diff --git a/src/Bonsai.ML.Pca.Torch/Transform.cs b/src/Bonsai.ML.Pca.Torch/Transform.cs index 5cc94769..39867793 100644 --- a/src/Bonsai.ML.Pca.Torch/Transform.cs +++ b/src/Bonsai.ML.Pca.Torch/Transform.cs @@ -1,21 +1,60 @@ using System; using System.ComponentModel; using System.Reactive.Linq; +using System.Xml.Serialization; using Bonsai; using static TorchSharp.torch; namespace Bonsai.ML.Pca.Torch; +/// +/// Transforms the input data using a PCA model. +/// [Combinator] -[Description] +[Description("Transforms the input data using a PCA model.")] [WorkflowElementCategory(ElementCategory.Transform)] public class Transform { - private Tensor TransformData(IPcaBaseModel model, Tensor data) + /// + /// The PCA model used to transform the input data. + /// + [Description("The PCA model used to transform the input data.")] + [XmlIgnore] + public IPcaBaseModel? Model { get; set; } + + /// + /// Transforms the input data using the specified PCA model. + /// + /// + /// + /// + private static Tensor TransformData(IPcaBaseModel model, Tensor data) { return model.Transform(data); } + /// + /// Transforms the input data. + /// + /// + /// + public IObservable Process(IObservable source) + { + if (Model == null) + { + throw new InvalidOperationException("The PCA model has not been specified."); + } + return source.Select(value => + { + return TransformData(Model, value); + }); + } + + /// + /// Transforms the input data using the specified PCA model. + /// + /// + /// public IObservable Process(IObservable> source) { return source.Select(value => @@ -24,6 +63,11 @@ public IObservable Process(IObservable> source) }); } + /// + /// Transforms the input data using the specified PCA model. + /// + /// + /// public IObservable Process(IObservable> source) { return source.Select(value => @@ -32,6 +76,11 @@ public IObservable Process(IObservable> source) }); } + /// + /// Transforms the input data using the specified PCA model. + /// + /// + /// public IObservable Process(IObservable> source) { return source.Select(value => @@ -40,6 +89,11 @@ public IObservable Process(IObservable> }); } + /// + /// Transforms the input data using the specified PCA model. + /// + /// + /// public IObservable Process(IObservable> source) { return source.Select(value => @@ -48,6 +102,11 @@ public IObservable Process(IObservable> }); } + /// + /// Transforms the input data using the specified PCA model. + /// + /// + /// public IObservable Process(IObservable> source) { return source.Select(value => @@ -56,6 +115,11 @@ public IObservable Process(IObservable + /// Transforms the input data using the specified PCA model. + /// + /// + /// public IObservable Process(IObservable> source) { return source.Select(value => From 4d3d5dcdd5b1e4d5934d6129055df579d66b6394 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 13 Jan 2026 16:02:42 +0000 Subject: [PATCH 14/20] Added method for online PCA using the generalized hebbian rule --- src/Bonsai.ML.Pca.Torch/CreatePca.cs | 19 +++- src/Bonsai.ML.Pca.Torch/Fit.cs | 39 ++++++-- src/Bonsai.ML.Pca.Torch/FitAndTransform.cs | 39 ++++++-- src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs | 98 +++++++++++++++++++ .../OnlineProbabilisticPca.cs | 22 ++--- src/Bonsai.ML.Pca.Torch/Pca.cs | 15 +-- src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs | 9 +- src/Bonsai.ML.Pca.Torch/PcaModelType.cs | 7 +- src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs | 24 ++--- src/Bonsai.ML.Pca.Torch/Reconstruct.cs | 38 +++++-- src/Bonsai.ML.Pca.Torch/Transform.cs | 38 +++++-- 11 files changed, 271 insertions(+), 77 deletions(-) create mode 100644 src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs diff --git a/src/Bonsai.ML.Pca.Torch/CreatePca.cs b/src/Bonsai.ML.Pca.Torch/CreatePca.cs index 394c35b4..2dbaebc7 100644 --- a/src/Bonsai.ML.Pca.Torch/CreatePca.cs +++ b/src/Bonsai.ML.Pca.Torch/CreatePca.cs @@ -4,7 +4,6 @@ using System.Reactive.Linq; using System.Linq.Expressions; using Bonsai.Expressions; -using System.Linq; using System.Reflection; using static TorchSharp.torch; using System.Xml.Serialization; @@ -97,6 +96,11 @@ public class CreatePca : ZeroArgumentExpressionBuilder, INamedElement [XmlIgnore] public Generator? Generator { get; set; } = null; + /// + /// The learning rate for the Online PCA GHA model. + /// + public double LearningRate { get; set; } = 0.1; + internal IEnumerable GetModelProperties() { yield return nameof(NumComponents); @@ -121,6 +125,12 @@ internal IEnumerable GetModelProperties() yield return nameof(ReorthogonalizePeriod); yield return nameof(Generator); } + + if (ModelType == PcaModelType.OnlinePcaGha) + { + yield return nameof(LearningRate); + yield return nameof(Generator); + } } private static PcaBaseModel CreateModel(CreatePca pcaBuilder) @@ -149,6 +159,12 @@ private static PcaBaseModel CreateModel(CreatePca pcaBuilder) kappa: pcaBuilder.Kappa, timeOffset: pcaBuilder.TimeOffset, reorthogonalizePeriod: pcaBuilder.ReorthogonalizePeriod), + PcaModelType.OnlinePcaGha => new OnlinePcaGha( + numComponents: pcaBuilder.NumComponents, + learningRate: pcaBuilder.LearningRate, + device: pcaBuilder.Device, + scalarType: pcaBuilder.ScalarType, + generator: pcaBuilder.Generator), _ => throw new NotSupportedException($"Model type {pcaBuilder.ModelType} is not supported."), }; } @@ -160,6 +176,7 @@ private static Type GetModelType(PcaModelType modelType) PcaModelType.Pca => typeof(Pca), PcaModelType.ProbabilisticPca => typeof(ProbabilisticPca), PcaModelType.OnlineProbabilisticPca => typeof(OnlineProbabilisticPca), + PcaModelType.OnlinePcaGha => typeof(OnlinePcaGha), _ => throw new NotSupportedException($"Model type {modelType} is not supported."), }; } diff --git a/src/Bonsai.ML.Pca.Torch/Fit.cs b/src/Bonsai.ML.Pca.Torch/Fit.cs index a566e659..b5f758f0 100644 --- a/src/Bonsai.ML.Pca.Torch/Fit.cs +++ b/src/Bonsai.ML.Pca.Torch/Fit.cs @@ -2,7 +2,6 @@ using System.ComponentModel; using System.Reactive.Linq; using System.Xml.Serialization; -using Bonsai; using static TorchSharp.torch; namespace Bonsai.ML.Pca.Torch; @@ -46,7 +45,7 @@ public IObservable Process(IObservable source) } /// - /// Fits the PCA model to the input data. + /// Fits a standard PCA model to the input data. /// /// /// @@ -59,7 +58,7 @@ public IObservable> Process(IObservable> s } /// - /// Fits the PCA model to the input data. + /// Fits a standard PCA model to the input data. /// /// /// @@ -72,7 +71,7 @@ public IObservable> Process(IObservable> s } /// - /// Fits the PCA model to the input data. + /// Fits a probabilistic PCA model to the input data. /// /// /// @@ -85,7 +84,7 @@ public IObservable> Process(IObservable - /// Fits the PCA model to the input data. + /// Fits a probabilistic PCA model to the input data. /// /// /// @@ -98,7 +97,7 @@ public IObservable> Process(IObservable - /// Fits the PCA model to the input data. + /// Fits an online probabilistic PCA model to the input data. /// /// /// @@ -111,7 +110,7 @@ public IObservable> Process(IObservable - /// Fits the PCA model to the input data. + /// Fits an online probabilistic PCA model to the input data. /// /// /// @@ -122,4 +121,30 @@ public IObservable> Process(IObservable + /// Fits an online PCA model using the Generalized Hebbian Algorithm to the input data. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item1, value.Item2); + }); + } + + /// + /// Fits an online PCA model using the Generalized Hebbian Algorithm to the input data. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item2, value.Item1); + }); + } } diff --git a/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs b/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs index d8ecbe3a..87c8aefb 100644 --- a/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs +++ b/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs @@ -2,7 +2,6 @@ using System.ComponentModel; using System.Reactive.Linq; using System.Xml.Serialization; -using Bonsai; using static TorchSharp.torch; namespace Bonsai.ML.Pca.Torch; @@ -46,7 +45,7 @@ public IObservable Process(IObservable source) } /// - /// Fits the PCA model to the input data and transforms it. + /// Fits a standard PCA model to the input data and transforms it. /// /// /// @@ -59,7 +58,7 @@ public IObservable> Process(IObservable> s } /// - /// Fits the PCA model to the input data and transforms it. + /// Fits a standard PCA model to the input data and transforms it. /// /// /// @@ -72,7 +71,7 @@ public IObservable> Process(IObservable> s } /// - /// Fits the PCA model to the input data and transforms it. + /// Fits a probabilistic PCA model to the input data and transforms it. /// /// /// @@ -85,7 +84,7 @@ public IObservable> Process(IObservable - /// Fits the PCA model to the input data and transforms it. + /// Fits a probabilistic PCA model to the input data and transforms it. /// /// /// @@ -98,7 +97,7 @@ public IObservable> Process(IObservable - /// Fits the PCA model to the input data and transforms it. + /// Fits an online probabilistic PCA model to the input data and transforms it. /// /// /// @@ -111,7 +110,7 @@ public IObservable> Process(IObservable - /// Fits the PCA model to the input data and transforms it. + /// Fits an online probabilistic PCA model to the input data and transforms it. /// /// /// @@ -122,4 +121,30 @@ public IObservable> Process(IObservable + /// Fits an online PCA model using the Generalized Hebbian Algorithm to the input data and transforms it. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item1, value.Item2); + }); + } + + /// + /// Fits an online PCA model using the Generalized Hebbian Algorithm to the input data and transforms it. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item2, value.Item1); + }); + } } diff --git a/src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs b/src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs new file mode 100644 index 00000000..79c92447 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs @@ -0,0 +1,98 @@ +using System; +using static TorchSharp.torch; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Implements Online PCA using the Generalized Hebbian Algorithm (GHA). +/// +/// +/// +/// +/// +/// +public class OnlinePcaGha( + int numComponents, + double learningRate = 0.1, + Device? device = null, + ScalarType? scalarType = null, + Generator? generator = null +) : PcaBaseModel(numComponents, device, scalarType) +{ + + private Tensor _mean = empty(0); + private int _sampleCount = 0; + + /// + /// Gets or sets the learning rate for the GHA algorithm. + /// + public double LearningRate { get; set; } = learningRate; + + /// + public override Tensor Components { get; protected set; } = empty(0); + + /// + /// Gets the random number generator used for initializing the model. + /// + public Generator Generator { get; private set; } = generator ?? manual_seed(0); + + /// + public override void Fit(Tensor data) + { + // throw new NotImplementedException(); + if (data.NumberOfElements == 0 || data.dim() != 2) + { + throw new ArgumentException("Input data must be a 2D tensor."); + } + + var q = NumComponents; + + // Data is shaped (number of samples x number of features) + var n = data.shape[0]; + var d = data.shape[1]; + + if (q > d) + { + throw new ArgumentException("Number of components cannot be greater than the number of features.", nameof(data)); + } + + // Initialize components randomly + if (Components.numel() == 0) + Components = linalg.qr(randn([d, q], ScalarType, Device), mode: linalg.QRMode.Reduced).Q; + + if (_mean.numel() == 0) + _mean = data.mean([0], keepdim: true); + else + _mean.mul_(_sampleCount / (double)(_sampleCount + n)).add_(data.mean([0], keepdim: true), alpha: n / (double)(_sampleCount + n)); + + _sampleCount += (int)n; + var dataCentered = data - _mean; + + var Y = dataCentered.matmul(Components); // n x q + var hebbianTerm = dataCentered.T.matmul(Y); // d x q + var crossTerm = Y.T.matmul(Y); // q x q + var lowerTriangular = crossTerm.tril(0); // q x q + var correlation = lowerTriangular.matmul(Components.T); // q x d + + // Update components + Components.add_(hebbianTerm - correlation.T, alpha: LearningRate); + } + + /// + public override Tensor Transform(Tensor data) + { + if (data.NumberOfElements == 0 || data.dim() != 2) + { + throw new ArgumentException("Input data must be a 2D tensor."); + } + var dataCentered = data - _mean; + return dataCentered.matmul(Components.T); + } + + /// + public override Tensor Reconstruct(Tensor data) + { + var transformed = Transform(data); + return transformed.matmul(Components.T) + _mean; + } +} diff --git a/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs b/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs index 073805c8..b3c6fa5a 100644 --- a/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs +++ b/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs @@ -1,10 +1,4 @@ -using Bonsai; -using System; -using System.ComponentModel; -using System.Collections.Generic; -using System.Linq; -using System.Reactive.Linq; -using TorchSharp; +using System; using static TorchSharp.torch; using static TorchSharp.torch.linalg; @@ -170,12 +164,10 @@ public override void Fit(Tensor data) _stepCount++; var rho = UpdateSchedule(); - var Xt = data.T; // n x d - // Initialize dimensions var q = NumComponents; - var n = Xt.size(0); - var d = Xt.size(1); + var n = data.size(0); + var d = data.size(1); // Initialize parameters if (!_initializedParameters) @@ -199,7 +191,7 @@ public override void Fit(Tensor data) var cov = _Iq * _sigma2; // Center data using current mean - var Xc = Xt - _mu; + var Xc = data - _mu; // E-step var M = Components.T.matmul(Components) + cov; @@ -210,9 +202,9 @@ public override void Fit(Tensor data) var Ez = EzT.T; // Update statistics - var mx = Xt.mean([0]); - var sxx = Xt.pow(2).sum(dim: 1).mean(); - var Cxz = Xt.T.matmul(Ez) / n; + var mx = data.mean([0]); + var sxx = data.pow(2).sum(dim: 1).mean(); + var Cxz = data.T.matmul(Ez) / n; var mz = Ez.mean([0]); var Czz = EzT.matmul(Ez) / n + _sigma2 * MInv; diff --git a/src/Bonsai.ML.Pca.Torch/Pca.cs b/src/Bonsai.ML.Pca.Torch/Pca.cs index e6400673..d69e598e 100644 --- a/src/Bonsai.ML.Pca.Torch/Pca.cs +++ b/src/Bonsai.ML.Pca.Torch/Pca.cs @@ -1,10 +1,4 @@ -using Bonsai; -using System; -using System.ComponentModel; -using System.Collections.Generic; -using System.Linq; -using System.Reactive.Linq; -using TorchSharp; +using System; using static TorchSharp.torch; using static TorchSharp.torch.linalg; @@ -15,10 +9,13 @@ namespace Bonsai.ML.Pca.Torch; /// public class Pca(int numComponents, Device? device = null, - ScalarType? scalarType = ScalarType.Float32) : PcaBaseModel(numComponents, + ScalarType? scalarType = ScalarType.Float32 +) : PcaBaseModel(numComponents, device, scalarType) { + private bool _isFitted = false; + /// /// Gets the covariance matrix of the fitted data. /// @@ -37,8 +34,6 @@ public class Pca(int numComponents, /// public override Tensor Components { get; protected set; } = empty(0); - private bool _isFitted = false; - /// public override void Fit(Tensor data) { diff --git a/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs b/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs index 262ad066..dd041abd 100644 --- a/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs +++ b/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs @@ -1,12 +1,5 @@ -using Bonsai; -using System; -using System.ComponentModel; -using System.Collections.Generic; -using System.Linq; -using System.Reactive.Linq; -using TorchSharp; +using System; using static TorchSharp.torch; -using static TorchSharp.torch.linalg; namespace Bonsai.ML.Pca.Torch; diff --git a/src/Bonsai.ML.Pca.Torch/PcaModelType.cs b/src/Bonsai.ML.Pca.Torch/PcaModelType.cs index 83a4fabd..d8194720 100644 --- a/src/Bonsai.ML.Pca.Torch/PcaModelType.cs +++ b/src/Bonsai.ML.Pca.Torch/PcaModelType.cs @@ -18,5 +18,10 @@ public enum PcaModelType /// /// Online Probabilistic PCA model. /// - OnlineProbabilisticPca + OnlineProbabilisticPca, + + /// + /// Online PCA model using the Generalized Hebbian Algorithm. + /// + OnlinePcaGha } diff --git a/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs b/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs index e9f3b8bc..452a63a8 100644 --- a/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs +++ b/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs @@ -1,10 +1,4 @@ -using Bonsai; -using System; -using System.ComponentModel; -using System.Collections.Generic; -using System.Linq; -using System.Reactive.Linq; -using TorchSharp; +using System; using static TorchSharp.torch; using static TorchSharp.torch.linalg; @@ -92,8 +86,6 @@ public override void Fit(Tensor data) throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); } - var Xt = data.T; // n x d - // Initialize variance var variance = Variance; @@ -102,8 +94,8 @@ public override void Fit(Tensor data) // Initialize dimensions for components var q = NumComponents; - var n = Xt.size(0); - var d = Xt.size(1); + var n = data.size(0); + var d = data.size(1); if (q > d) { @@ -116,13 +108,13 @@ public override void Fit(Tensor data) var Id = eye(d, device: Device, dtype: ScalarType); // d x d // Calculate the sample mean - var mean = Xt.mean([0], keepdim: true); // 1 x d + var mean = data.mean([0], keepdim: true); // 1 x d // Center the data and transpose - var X = Xt - mean; // n x d + var dataCentered = data - mean; // n x d // Calculate the sample covariance - var XTX = X.T.matmul(X); // d x d + var XTX = dataCentered.T.matmul(dataCentered); // d x d var sampleCov = XTX / n; // d x d // Calculate term 1 for variance update @@ -142,12 +134,12 @@ public override void Fit(Tensor data) // E-step: Compute the posterior distribution of the latent variables var M = W.T.matmul(W) + Iq * variance; // q x q var MInv = inv(M); // q x q - var mu = MInv.matmul(W.T).matmul(X.T).T; // n x q + var mu = MInv.matmul(W.T).matmul(dataCentered.T).T; // n x q var SSum = n * MInv * variance; // q x q var cov = mu.T.matmul(mu) + SSum; // q x q // M-step: Compute new W and new variance - var XMu = X.T.matmul(mu); // d x q + var XMu = dataCentered.T.matmul(mu); // d x q var WNew = XMu.matmul(inv(cov)); // d x q var term2 = 2 * XMu.mul(WNew).sum(); diff --git a/src/Bonsai.ML.Pca.Torch/Reconstruct.cs b/src/Bonsai.ML.Pca.Torch/Reconstruct.cs index bbeb57a9..043f631a 100644 --- a/src/Bonsai.ML.Pca.Torch/Reconstruct.cs +++ b/src/Bonsai.ML.Pca.Torch/Reconstruct.cs @@ -51,7 +51,7 @@ public IObservable Process(IObservable source) } /// - /// Reconstructs the input data using the specified PCA model. + /// Reconstructs the input data using a standard PCA model. /// /// /// @@ -64,7 +64,7 @@ public IObservable Process(IObservable> source) } /// - /// Reconstructs the input data using the specified PCA model. + /// Reconstructs the input data using a standard PCA model. /// /// /// @@ -77,7 +77,7 @@ public IObservable Process(IObservable> source) } /// - /// Reconstructs the input data using the specified PCA model. + /// Reconstructs the input data using a probabilistic PCA model. /// /// /// @@ -90,7 +90,7 @@ public IObservable Process(IObservable> } /// - /// Reconstructs the input data using the specified PCA model. + /// Reconstructs the input data using a probabilistic PCA model. /// /// /// @@ -103,7 +103,7 @@ public IObservable Process(IObservable> } /// - /// Reconstructs the input data using the specified PCA model. + /// Reconstructs the input data using an online probabilistic PCA model based on stochastic online EM. /// /// /// @@ -116,7 +116,7 @@ public IObservable Process(IObservable - /// Reconstructs the input data using the specified PCA model. + /// Reconstructs the input data using an online probabilistic PCA model based on stochastic online EM. /// /// /// @@ -127,4 +127,30 @@ public IObservable Process(IObservable + /// Reconstructs the input data using an online PCA model based on the Generalized Hebbian Algorithm. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item1, value.Item2); + }); + } + + /// + /// Reconstructs the input data using an online PCA model based on the Generalized Hebbian Algorithm + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item2, value.Item1); + }); + } } diff --git a/src/Bonsai.ML.Pca.Torch/Transform.cs b/src/Bonsai.ML.Pca.Torch/Transform.cs index 39867793..95c39826 100644 --- a/src/Bonsai.ML.Pca.Torch/Transform.cs +++ b/src/Bonsai.ML.Pca.Torch/Transform.cs @@ -51,7 +51,7 @@ public IObservable Process(IObservable source) } /// - /// Transforms the input data using the specified PCA model. + /// Transforms the input data using a standard PCA model. /// /// /// @@ -64,7 +64,7 @@ public IObservable Process(IObservable> source) } /// - /// Transforms the input data using the specified PCA model. + /// Transforms the input data using a standard PCA model. /// /// /// @@ -77,7 +77,7 @@ public IObservable Process(IObservable> source) } /// - /// Transforms the input data using the specified PCA model. + /// Transforms the input data using a standard PCA model. /// /// /// @@ -90,7 +90,7 @@ public IObservable Process(IObservable> } /// - /// Transforms the input data using the specified PCA model. + /// Transforms the input data using a probabilistic PCA model. /// /// /// @@ -103,7 +103,7 @@ public IObservable Process(IObservable> } /// - /// Transforms the input data using the specified PCA model. + /// Transforms the input data using an online probabilistic PCA model. /// /// /// @@ -116,7 +116,7 @@ public IObservable Process(IObservable - /// Transforms the input data using the specified PCA model. + /// Transforms the input data using an online probabilistic PCA model. /// /// /// @@ -127,4 +127,30 @@ public IObservable Process(IObservable + /// Transforms the input data using an online PCA model based on the Generalized Hebbian Algorithm. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item1, value.Item2); + }); + } + + /// + /// Transforms the input data using an online PCA model based on the Generalized Hebbian Algorithm. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item2, value.Item1); + }); + } } From d7fe674e6fa14b9bba308264f5855b674f1e4b1c Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 14 Jan 2026 14:18:52 +0000 Subject: [PATCH 15/20] Refactored PCA model to use more convenient shape (samples x features) and improved support for validation --- src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs | 22 +- .../OnlineProbabilisticPca.cs | 49 ++-- src/Bonsai.ML.Pca.Torch/Pca.cs | 96 +++----- src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs | 65 +++++- src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs | 216 +++++++----------- 5 files changed, 216 insertions(+), 232 deletions(-) diff --git a/src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs b/src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs index 7d70bb46..7759f109 100644 --- a/src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs +++ b/src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs @@ -7,10 +7,20 @@ namespace Bonsai.ML.Pca.Torch; /// public interface IPcaBaseModel { + /// + /// Gets a value indicating whether the model has been fitted to data. + /// + public bool IsFitted { get; } + + /// + /// Gets the number of features in the fitted data. + /// + public int NumFeatures { get; } + /// /// Gets the principal components of the model. /// - public abstract Tensor Components { get; } + public Tensor Components { get; } /// /// Gets the number of principal components kept by the model. @@ -25,7 +35,7 @@ public interface IPcaBaseModel /// /// Gets the data type used by the model. /// - public ScalarType ScalarType { get; } + public ScalarType? ScalarType { get; } /// /// Fits the PCA model to the given data. @@ -34,7 +44,7 @@ public interface IPcaBaseModel /// The input data should be a 2D tensor with shape (samples x features). /// /// - public abstract void Fit(Tensor data); + public void Fit(Tensor data); /// /// Transforms the input data using the PCA model. @@ -44,7 +54,7 @@ public interface IPcaBaseModel /// /// /// - public abstract Tensor Transform(Tensor data); + public Tensor Transform(Tensor data); /// /// Fits the PCA model to the given data and transforms it. @@ -54,7 +64,7 @@ public interface IPcaBaseModel /// /// /// - public abstract Tensor FitAndTransform(Tensor data); + public Tensor FitAndTransform(Tensor data); /// /// Reconstructs the input data using the PCA model. @@ -64,5 +74,5 @@ public interface IPcaBaseModel /// /// /// - public abstract Tensor Reconstruct(Tensor data); + public Tensor Reconstruct(Tensor data); } diff --git a/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs b/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs index b3c6fa5a..ee40d716 100644 --- a/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs +++ b/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs @@ -5,10 +5,22 @@ namespace Bonsai.ML.Pca.Torch; /// -/// Implements an online Probabilistic Principal Component Analysis (PPCA) model using stochastic online EM. +/// Implements an online probabilistic PCA model using the stochastic online EM algorithm. /// public class OnlineProbabilisticPca : PcaBaseModel { + private Tensor _mu = empty(0); + private Tensor _Iq = empty(0); + private Tensor _mx = empty(0); // E[x] + private Tensor _Cxz = empty(0); // E[xz^T] + private Tensor _mz = empty(0); // E[z] + private Tensor _Czz = empty(0); // E[zz^T] + private Tensor _sxx = empty(0); // E[||x||^2] + private bool _initializedParameters = false; + private readonly Func UpdateSchedule; + private int _stepCount = 0; + private readonly bool _reorthogonalize = false; + /// /// Rho is a constant learning rate parameter. /// @@ -29,7 +41,7 @@ public class OnlineProbabilisticPca : PcaBaseModel /// /// Gets the variance of the isotropic Gaussian noise model. /// - public double Variance => _sigma2.to_type(ScalarType.Float64).item(); + public double Variance { get; private set; } /// /// Gets the period for reorthogonalizing the principal components. @@ -53,20 +65,6 @@ public class OnlineProbabilisticPca : PcaBaseModel /// public Generator Generator { get; private set; } - private Tensor _mu = empty(0); - private Tensor _Iq = empty(0); - private Tensor _mx = empty(0); // E[x] - private Tensor _Cxz = empty(0); // E[xz^T] - private Tensor _mz = empty(0); // E[z] - private Tensor _Czz = empty(0); // E[zz^T] - private Tensor _sxx = empty(0); // E[||x||^2] - private Tensor _sigma2; // Variance - - private bool _initializedParameters = false; - private readonly Func UpdateSchedule; - private int _stepCount = 0; - private readonly bool _reorthogonalize = false; - /// /// Initializes a new instance of the class. /// @@ -82,7 +80,7 @@ public class OnlineProbabilisticPca : PcaBaseModel /// public OnlineProbabilisticPca(int numComponents, Device? device = null, - ScalarType? scalarType = ScalarType.Float32, + ScalarType? scalarType = null, double initialVariance = 1.0, Generator? generator = null, double? rho = 0.1, @@ -145,7 +143,7 @@ public OnlineProbabilisticPca(int numComponents, Rho = rho; Kappa = kappa; TimeOffset = timeOffset; - _sigma2 = initialVariance; + Variance = initialVariance; } /// @@ -175,7 +173,7 @@ public override void Fit(Tensor data) _mu = zeros(d, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); var randW = randn(d, q, generator: Generator, device: Device, dtype: ScalarType); var orthonormalBases = linalg.qr(randW).Q; - Components = (orthonormalBases * _sigma2).MoveToOuterDisposeScope(); // d x q + Components = (orthonormalBases * Variance).MoveToOuterDisposeScope(); // d x q _Iq = eye(q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q x q _mx = zeros(d, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // d @@ -188,7 +186,7 @@ public override void Fit(Tensor data) } // Covariance matrix - var cov = _Iq * _sigma2; + var cov = _Iq * Variance; // Center data using current mean var Xc = data - _mu; @@ -206,7 +204,7 @@ public override void Fit(Tensor data) var sxx = data.pow(2).sum(dim: 1).mean(); var Cxz = data.T.matmul(Ez) / n; var mz = Ez.mean([0]); - var Czz = EzT.matmul(Ez) / n + _sigma2 * MInv; + var Czz = EzT.matmul(Ez) / n + Variance * MInv; // Update parameters var rhoFactor = 1 - rho; @@ -250,9 +248,10 @@ public override void Fit(Tensor data) Szz = _Czz; // Update variance - _sigma2 = ((Sxx - 2 * trace(Components.T.matmul(Sxz)) + trace(Components.T.matmul(Components).matmul(Szz))) / (double)d) + Variance = ((Sxx - 2 * trace(Components.T.matmul(Sxz)) + trace(Components.T.matmul(Components).matmul(Szz))) / (double)d) .clamp_min(0.0) - .MoveToOuterDisposeScope(); + .to_type(TorchSharp.torch.ScalarType.Float64) + .item(); } } @@ -266,7 +265,7 @@ public override Tensor Transform(Tensor data) var Xt = data.T; // n x d var Xc = Xt - _mu; // n x d - var M = Components.T.matmul(Components) + _Iq * _sigma2; // q x q + var M = Components.T.matmul(Components) + _Iq * Variance; // q x q var XcW = Xc.matmul(Components); return Utils.InvertSPD(M, XcW.T).T; // n x q } @@ -286,7 +285,7 @@ public override Tensor Reconstruct(Tensor data) var Xt = data.T; // n x d var Xc = Xt - _mu; // n x d - var M = Components.T.matmul(Components) + _Iq * _sigma2; // q x q + var M = Components.T.matmul(Components) + _Iq * Variance; // q x q var XcW = Xc.matmul(Components); var EzT = Utils.InvertSPD(M, XcW.T); var Ez = EzT.T; diff --git a/src/Bonsai.ML.Pca.Torch/Pca.cs b/src/Bonsai.ML.Pca.Torch/Pca.cs index d69e598e..5ee84e95 100644 --- a/src/Bonsai.ML.Pca.Torch/Pca.cs +++ b/src/Bonsai.ML.Pca.Torch/Pca.cs @@ -5,98 +5,70 @@ namespace Bonsai.ML.Pca.Torch; /// -/// Implements a standard Principal Component Analysis (PCA) model. +/// Represents a standard Principal Component Analysis (PCA) model. /// public class Pca(int numComponents, Device? device = null, - ScalarType? scalarType = ScalarType.Float32 + ScalarType? scalarType = null ) : PcaBaseModel(numComponents, device, scalarType) { - private bool _isFitted = false; - /// - /// Gets the covariance matrix of the fitted data. + /// Gets the mean of the fitted data. /// - public Tensor Covariance { get; private set; } = empty(0); + public Tensor Mean { get; private set; } = empty(0); - /// - /// Gets the eigenvalues of the covariance matrix. - /// - public Tensor EigenValues { get; private set; } = empty(0); + /// + public override Tensor Components { get; protected set; } = empty(0); /// - /// Gets the eigenvectors of the covariance matrix. + /// The singular values of the fitted data. /// - public Tensor EigenVectors { get; private set; } = empty(0); - - /// - public override Tensor Components { get; protected set; } = empty(0); + public Tensor SingularValues { get; private set; } = empty(0); /// public override void Fit(Tensor data) { - if (data.NumberOfElements == 0 || data.dim() < 2) - { - throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); - } - - var d = data.size(1); + base.Fit(data); - if (NumComponents > d) + using (no_grad()) + using (NewDisposeScope()) { - throw new ArgumentException("Number of components cannot be greater than the number of features.", nameof(data)); + var mean = data.mean([0], keepdim: true); + var dataCentered = data - mean; + var (U, S, Vh) = svd(dataCentered, fullMatrices: false); + var components = Vh.slice(0, 0, NumComponents, 1).T; + var singularValues = S.slice(0, 0, NumComponents, 1); + + if (ScalarType is not null && ScalarType != data.dtype) + { + var scalarType = ScalarType.Value; + mean = mean.to(scalarType); + components = components.to(scalarType); + singularValues = singularValues.to(scalarType); + } + + Mean = mean.MoveToOuterDisposeScope(); + Components = components.MoveToOuterDisposeScope(); + SingularValues = singularValues.MoveToOuterDisposeScope(); } - var Xt = data.T; - var mean = Xt.mean([0], keepdim: true); - var Xc = Xt - mean; - Covariance = cov(Xc.T); - var eigen = eigh(Covariance); - var sortedIndices = argsort(eigen.Item1, dim: -1, descending: true); - EigenValues = eigen.Item1[sortedIndices]; - EigenVectors = eigen.Item2.index_select(1, sortedIndices); - Components = EigenVectors.slice(1, 0, NumComponents, 1); - _isFitted = true; + IsFitted = true; } /// public override Tensor Transform(Tensor data) { - if (data.NumberOfElements == 0 || data.dim() < 2) - { - throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); - } - - if (!_isFitted) - { - throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); - } - - var X = data.T; - var mean = X.mean([0], keepdim: true); // 1 x d - var Xc = X - mean; - return Xc.matmul(Components); // n x q + base.Transform(data); + var dataCentered = data - Mean; + return dataCentered.matmul(Components); } /// public override Tensor Reconstruct(Tensor data) { - if (data.NumberOfElements == 0 || data.dim() < 2) - { - throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); - } - - if (!_isFitted) - { - throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); - } - - var X = data.T; - var mean = X.mean([0], keepdim: true); // 1 x d - var Xc = X - mean; - var reconstructed = Transform(Xc); - return reconstructed.matmul(Components.T) + mean.T; + base.Reconstruct(data); + return data.matmul(Components.T) + Mean; } } diff --git a/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs b/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs index dd041abd..8e7c3ee3 100644 --- a/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs +++ b/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs @@ -8,6 +8,12 @@ namespace Bonsai.ML.Pca.Torch; /// public abstract class PcaBaseModel : IPcaBaseModel { + /// + public bool IsFitted { get; protected set; } + + /// + public int NumFeatures { get; protected set; } = -1; + /// public abstract Tensor Components { get; protected set; } @@ -18,7 +24,7 @@ public abstract class PcaBaseModel : IPcaBaseModel public Device Device { get; private set; } /// - public ScalarType ScalarType { get; private set; } + public ScalarType? ScalarType { get; private set; } /// /// Initializes a new instance of the class. @@ -38,14 +44,34 @@ public PcaBaseModel(int numComponents, NumComponents = numComponents; Device = device ?? CPU; - ScalarType = scalarType ?? ScalarType.Float32; + ScalarType = scalarType; } /// - public abstract void Fit(Tensor data); + public virtual void Fit(Tensor data) + { + CheckDataCompatibility(data); + + var n = data.size(0); + var d = data.size(1); + + if (NumComponents > d) + throw new ArgumentException($"Number of components cannot be greater than the number of features. Number of components: {NumComponents}, number of features: {d}.", nameof(data)); + + if (n < 2) + throw new ArgumentException($"Need at least 2 samples to fit PCA. Number of samples: {n}.", nameof(data)); + + NumFeatures = (int)d; + } /// - public abstract Tensor Transform(Tensor data); + public virtual Tensor Transform(Tensor data) + { + CheckFitted(); + CheckDataCompatibility(data); + CheckDataFeatures(data); + return data; + } /// public virtual Tensor FitAndTransform(Tensor data) @@ -55,5 +81,34 @@ public virtual Tensor FitAndTransform(Tensor data) } /// - public abstract Tensor Reconstruct(Tensor data); + public virtual Tensor Reconstruct(Tensor data) + { + CheckFitted(); + CheckDataCompatibility(data); + CheckDataFeatures(data); + return data; + } + + private void CheckFitted() + { + if (!IsFitted) + throw new InvalidOperationException("Model has not yet been fitted. You should call one of the Fit() or the FitAndTransform() methods first."); + } + + private void CheckDataCompatibility(Tensor data) + { + if (data.NumberOfElements == 0) + throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); + + if (data.dim() != 2) + throw new ArgumentException($"Data must be a 2D tensor with shape (samples x features). Data shape: {data.shape}.", nameof(data)); + } + + private void CheckDataFeatures(Tensor data) + { + var d = data.size(1); + + if (d != NumFeatures) + throw new ArgumentException("The number of features in the data does not match the number of features in the fitted model.", nameof(data)); + } } diff --git a/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs b/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs index 452a63a8..cfddc321 100644 --- a/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs +++ b/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs @@ -1,14 +1,23 @@ using System; +using Bonsai.ML.Torch; using static TorchSharp.torch; using static TorchSharp.torch.linalg; namespace Bonsai.ML.Pca.Torch; /// -/// Probabilistic Principal Component Analysis (PPCA) model. +/// Represents a probabilistic PCA model. /// public class ProbabilisticPca : PcaBaseModel { + private readonly int _iterations; + private readonly double _tolerance; + + /// + /// Gets the mean of the fitted data. + /// + public Tensor Mean { get; private set; } = empty(0); + /// /// Gets the variance of the isotropic Gaussian noise model. /// @@ -25,11 +34,7 @@ public class ProbabilisticPca : PcaBaseModel /// /// Gets the random number generator used for initializing the model. /// - public Generator Generator { get; private set; } - - private readonly int _iterations; - private readonly double _tolerance; - private bool _isFitted = false; + public Generator? Generator { get; private set; } /// /// Initializes a new instance of the class. @@ -44,7 +49,7 @@ public class ProbabilisticPca : PcaBaseModel /// public ProbabilisticPca(int numComponents, Device? device = null, - ScalarType? scalarType = ScalarType.Float32, + ScalarType? scalarType = null, double initialVariance = 1.0, Generator? generator = null, int iterations = 100, @@ -69,166 +74,109 @@ public ProbabilisticPca(int numComponents, } Variance = initialVariance; - Generator = generator ?? manual_seed(0); + Generator = generator; _iterations = iterations; _tolerance = tolerance; } - /// - /// Fits the PPCA model to the input data. - /// - /// - /// + /// public override void Fit(Tensor data) { - if (data.NumberOfElements == 0 || data.dim() != 2) - { - throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); - } - - // Initialize variance - var variance = Variance; - - // Initialize log likelihood - LogLikelihood = ones(_iterations, device: Device, dtype: ScalarType) * double.NegativeInfinity; - - // Initialize dimensions for components - var q = NumComponents; - var n = data.size(0); - var d = data.size(1); + base.Fit(data); - if (q > d) + using (no_grad()) + using (NewDisposeScope()) { - throw new ArgumentException("Number of components cannot be greater than the number of features.", nameof(data)); - } + var numSamples = data.size(0); - // Initialize W and I - var W = randn(d, q, generator: Generator, device: Device, dtype: ScalarType); // d x q - var Iq = eye(q, device: Device, dtype: ScalarType); // q x q - var Id = eye(d, device: Device, dtype: ScalarType); // d x d + // Initialize log likelihood + LogLikelihood = ones(_iterations, device: Device, dtype: ScalarType) * double.NegativeInfinity; - // Calculate the sample mean - var mean = data.mean([0], keepdim: true); // 1 x d + var weights = randn(NumFeatures, NumComponents, generator: Generator, device: Device, dtype: ScalarType); + var identityComponents = eye(NumComponents, device: Device, dtype: ScalarType); + var identityFeatures = eye(NumFeatures, device: Device, dtype: ScalarType); - // Center the data and transpose - var dataCentered = data - mean; // n x d + var mean = data.mean([0], keepdim: true); + var dataCentered = data - mean; - // Calculate the sample covariance - var XTX = dataCentered.T.matmul(dataCentered); // d x d - var sampleCov = XTX / n; // d x d + // Calculate the sample covariance + var covarianceTerm = dataCentered.T.matmul(dataCentered); + var sampleCov = covarianceTerm / numSamples; - // Calculate term 1 for variance update - var term1 = trace(XTX); + // Calculate term 1 for variance update + var term1 = trace(covarianceTerm); - // Compute log likelihood constant - var logLikelihoodConst = d * log(2 * Math.PI).to(Device).to_type(ScalarType); + // Compute log likelihood constant + var logLikelihoodConst = NumFeatures * log(2 * Math.PI).to(Device); - double diffW; - double diffVariance; + double diffWeights; + double diffVariance; - // Repeat until convergence - for (int i = 0; i < _iterations; i++) - { - using (NewDisposeScope()) + // Repeat until convergence + for (int i = 0; i < _iterations; i++) { // E-step: Compute the posterior distribution of the latent variables - var M = W.T.matmul(W) + Iq * variance; // q x q - var MInv = inv(M); // q x q - var mu = MInv.matmul(W.T).matmul(dataCentered.T).T; // n x q - var SSum = n * MInv * variance; // q x q - var cov = mu.T.matmul(mu) + SSum; // q x q - - // M-step: Compute new W and new variance - var XMu = dataCentered.T.matmul(mu); // d x q - var WNew = XMu.matmul(inv(cov)); // d x q - - var term2 = 2 * XMu.mul(WNew).sum(); - var mumu = mu.T.matmul(mu); - var WNewWNew = WNew.T.matmul(WNew); - var term3 = trace(WNewWNew.matmul(mumu + SSum)); - var varianceNew = (term1 - term2 + term3) / (n * d); // scalar + var M = weights.T.matmul(weights) + identityComponents * Variance; + var MInv = inv(M); + var mu = MInv.matmul(weights.T).matmul(dataCentered.T).T; + var SSum = numSamples * MInv * Variance; + var cov = mu.T.matmul(mu) + SSum; + + // M-step: Compute new weights and new variance + var dataMu = dataCentered.T.matmul(mu); + var weightsNew = dataMu.matmul(inv(cov)); + + var term2 = 2 * dataMu.mul(weightsNew).sum(); + var mu2 = mu.T.matmul(mu); + var weightsNew2 = weightsNew.T.matmul(weightsNew); + var term3 = trace(weightsNew2.matmul(mu2 + SSum)); + var varianceNew = (term1 - term2 + term3) / (numSamples * NumFeatures); // Compute the log likelihood - var C = W.matmul(W.T) + Id * variance; // d x d - var CInv = inv(C); // d x d - var logLikelihood = -0.5 * n * (logLikelihoodConst + logdet(C) + trace(CInv.matmul(sampleCov))); // scalar + var logLikelihoodTerm = weightsNew.matmul(weightsNew.T) + eye(NumFeatures) * varianceNew; + var logLikelihoodTermInv = inv(logLikelihoodTerm); + var logLikelihood = -0.5 * numSamples * (logLikelihoodConst + logdet(logLikelihoodTerm) + trace(logLikelihoodTermInv.matmul(sampleCov))); - // Check for convergence - diffW = linalg.norm(WNew - W).to_type(ScalarType.Float64).cpu().ReadCpuDouble(0); - diffVariance = abs(varianceNew - variance).to_type(ScalarType.Float64).cpu().ReadCpuDouble(0); - - // Update loglikelihood, W and variance - LogLikelihood[i] = logLikelihood.MoveToOuterDisposeScope(); - W = WNew.MoveToOuterDisposeScope(); - variance = varianceNew.to_type(ScalarType.Float64).cpu().ReadCpuDouble(0); - } + // Compare previous and new parameters for convergence + diffWeights = linalg.norm(weightsNew - weights).to_type(TorchSharp.torch.ScalarType.Float64).item(); + diffVariance = abs(varianceNew - Variance).to_type(TorchSharp.torch.ScalarType.Float64).item(); + // Update loglikelihood, weights and variance + LogLikelihood[i] = logLikelihood; + weights = weightsNew; + Variance = varianceNew.to_type(TorchSharp.torch.ScalarType.Float64).item(); - if (diffW < _tolerance && diffVariance < _tolerance) - { - LogLikelihood = LogLikelihood.slice(0, 0, i + 1, 1); - break; + // Check for convergence + if (diffWeights < _tolerance && diffVariance < _tolerance) + { + LogLikelihood = LogLikelihood.slice(0, 0, i + 1, 1); + break; + } } + + // Finalize model parameters + LogLikelihood = LogLikelihood.MoveToOuterDisposeScope(); + Components = weights.MoveToOuterDisposeScope(); + Mean = mean.MoveToOuterDisposeScope(); } - // Finalize model parameters - LogLikelihood = LogLikelihood.DetachFromDisposeScope(); - Components = W.DetachFromDisposeScope(); - Variance = variance; - _isFitted = true; + IsFitted = true; } - /// - /// Transforms the input data using the fitted PPCA model. - /// - /// - /// - /// - /// + /// public override Tensor Transform(Tensor data) { - if (data.NumberOfElements == 0 || data.dim() < 2) - { - throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); - } - - if (!_isFitted) - { - throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); - } - - var Xt = data.T; - var mean = Xt.mean([0], keepdim: true); // 1 x d - var X = Xt - mean; // n x d - var W = Components; // d x q - var M = W.T.matmul(W) + eye(NumComponents) * Variance; // q x q - var MInv = Utils.InvertSPD(M, eye(NumComponents)); // q x q - return X.matmul(W).matmul(MInv); // n x q + base.Transform(data); + var dataCentered = data - Mean; + var M = Components.T.matmul(Components) + eye(NumComponents) * Variance; + var MInv = Utils.InvertSPD(M, eye(NumComponents)); + return dataCentered.matmul(Components).matmul(MInv); } - /// - /// Reconstructs the input data using the fitted PPCA model. - /// - /// - /// - /// - /// + /// public override Tensor Reconstruct(Tensor data) { - if (data.NumberOfElements == 0 || data.dim() < 2) - { - throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); - } - - if (!_isFitted) - { - throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); - } - - var Xt = data.T; - var mean = Xt.mean([0], keepdim: true); // 1 x d - var Xc = Xt - mean; // n x d - var W = Components; // d x q - return Xc.matmul(W).matmul(W.T) + mean.T; // n x d + base.Reconstruct(data); + return data.matmul(Components.T) + Mean; } } From 16b8f057520c72e5c57972754265da4482e9d515 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 15 Jan 2026 13:55:33 +0000 Subject: [PATCH 16/20] Updated online PCA models --- src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs | 93 ++++++------- .../OnlineProbabilisticPca.cs | 125 +++++++----------- src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs | 10 +- 3 files changed, 103 insertions(+), 125 deletions(-) diff --git a/src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs b/src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs index 79c92447..6c73813a 100644 --- a/src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs +++ b/src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs @@ -1,10 +1,11 @@ using System; +using System.ComponentModel; using static TorchSharp.torch; namespace Bonsai.ML.Pca.Torch; /// -/// Implements Online PCA using the Generalized Hebbian Algorithm (GHA). +/// Implements online PCA using the Generalized Hebbian Algorithm (GHA). /// /// /// @@ -19,12 +20,18 @@ public class OnlinePcaGha( Generator? generator = null ) : PcaBaseModel(numComponents, device, scalarType) { + /// + /// Gets the number of samples that have been used to fit the model. + /// + public int SampleCount { get; private set; } = 0; - private Tensor _mean = empty(0); - private int _sampleCount = 0; + /// + /// Gets the mean of the fitted data. + /// + public Tensor Mean { get; private set; } = empty(0); /// - /// Gets or sets the learning rate for the GHA algorithm. + /// Gets or sets the learning rate. /// public double LearningRate { get; set; } = learningRate; @@ -34,65 +41,61 @@ public class OnlinePcaGha( /// /// Gets the random number generator used for initializing the model. /// - public Generator Generator { get; private set; } = generator ?? manual_seed(0); + public Generator? Generator { get; private set; } = generator; /// public override void Fit(Tensor data) { - // throw new NotImplementedException(); - if (data.NumberOfElements == 0 || data.dim() != 2) - { - throw new ArgumentException("Input data must be a 2D tensor."); - } - - var q = NumComponents; + base.Fit(data); - // Data is shaped (number of samples x number of features) - var n = data.shape[0]; - var d = data.shape[1]; + var numSamples = data.size(0); - if (q > d) + using (no_grad()) + using (NewDisposeScope()) { - throw new ArgumentException("Number of components cannot be greater than the number of features.", nameof(data)); + // Initialize components randomly + if (Components.numel() == 0) + Components = randn([NumFeatures, NumComponents], dtype: ScalarType, device: Device, generator: Generator); + + if (Mean.numel() == 0) + Mean = data.mean([0], keepdim: true); + else + { + Mean *= SampleCount / (SampleCount + numSamples); + Mean += data.mean([0], keepdim: true) * numSamples / (SampleCount + numSamples); + } + + SampleCount += (int)numSamples; + var dataCentered = data - Mean; + + var projection = dataCentered.matmul(Components); + var hebbianTerm = dataCentered.T.matmul(projection); + var crossTerm = projection.T.matmul(projection); + var lowerTriangular = crossTerm.tril(0); + var correlation = Components.matmul(lowerTriangular); + var componentsUpdate = (hebbianTerm - correlation) * (LearningRate / numSamples); + var weights = Components + componentsUpdate; + var norms = weights.norm(dim: 0, keepdim: true, p: 2).clamp_min(1e-12); + + Components = linalg.qr(weights / norms, mode: linalg.QRMode.Reduced).Q.MoveToOuterDisposeScope(); + Mean = Mean.MoveToOuterDisposeScope(); } - // Initialize components randomly - if (Components.numel() == 0) - Components = linalg.qr(randn([d, q], ScalarType, Device), mode: linalg.QRMode.Reduced).Q; - - if (_mean.numel() == 0) - _mean = data.mean([0], keepdim: true); - else - _mean.mul_(_sampleCount / (double)(_sampleCount + n)).add_(data.mean([0], keepdim: true), alpha: n / (double)(_sampleCount + n)); - - _sampleCount += (int)n; - var dataCentered = data - _mean; - - var Y = dataCentered.matmul(Components); // n x q - var hebbianTerm = dataCentered.T.matmul(Y); // d x q - var crossTerm = Y.T.matmul(Y); // q x q - var lowerTriangular = crossTerm.tril(0); // q x q - var correlation = lowerTriangular.matmul(Components.T); // q x d - - // Update components - Components.add_(hebbianTerm - correlation.T, alpha: LearningRate); + IsFitted = true; } /// public override Tensor Transform(Tensor data) { - if (data.NumberOfElements == 0 || data.dim() != 2) - { - throw new ArgumentException("Input data must be a 2D tensor."); - } - var dataCentered = data - _mean; - return dataCentered.matmul(Components.T); + base.Transform(data); + var dataCentered = data - Mean; + return dataCentered.matmul(Components); } /// public override Tensor Reconstruct(Tensor data) { - var transformed = Transform(data); - return transformed.matmul(Components.T) + _mean; + base.Reconstruct(data); + return data.matmul(Components.T) + Mean; } } diff --git a/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs b/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs index ee40d716..15a3afd8 100644 --- a/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs +++ b/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs @@ -9,14 +9,12 @@ namespace Bonsai.ML.Pca.Torch; /// public class OnlineProbabilisticPca : PcaBaseModel { - private Tensor _mu = empty(0); - private Tensor _Iq = empty(0); + private Tensor _identityComponents = empty(0); private Tensor _mx = empty(0); // E[x] private Tensor _Cxz = empty(0); // E[xz^T] private Tensor _mz = empty(0); // E[z] private Tensor _Czz = empty(0); // E[zz^T] private Tensor _sxx = empty(0); // E[||x||^2] - private bool _initializedParameters = false; private readonly Func UpdateSchedule; private int _stepCount = 0; private readonly bool _reorthogonalize = false; @@ -38,6 +36,11 @@ public class OnlineProbabilisticPca : PcaBaseModel /// public double? Kappa { get; private set; } + /// + /// Gets the mean of the fitted data. + /// + public Tensor Means { get; private set; } = empty(0); + /// /// Gets the variance of the isotropic Gaussian noise model. /// @@ -63,7 +66,7 @@ public class OnlineProbabilisticPca : PcaBaseModel /// /// Gets the random number generator used for initializing the model. /// - public Generator Generator { get; private set; } + public Generator? Generator { get; private set; } /// /// Initializes a new instance of the class. @@ -108,11 +111,12 @@ public OnlineProbabilisticPca(int numComponents, if (rho.HasValue) { - UpdateSchedule = () => rho.Value; - if (rho <= 0 || rho >= 1) + if (rho.Value <= 0 || rho.Value >= 1) { throw new ArgumentException("Rho must be in the range (0, 1).", nameof(rho)); } + + UpdateSchedule = () => rho.Value; } else { @@ -126,11 +130,12 @@ public OnlineProbabilisticPca(int numComponents, throw new ArgumentException("Kappa must be specified when using a learning rate schedule.", nameof(kappa)); } - UpdateSchedule = () => Math.Pow(_stepCount + timeOffset.Value, -kappa.Value); if (kappa <= 0.5 || kappa > 1) { throw new ArgumentException("Kappa must be in the range (0.5, 1].", nameof(kappa)); } + + UpdateSchedule = () => Math.Pow(_stepCount + timeOffset.Value, -kappa.Value); } if (reorthogonalizePeriod.HasValue) @@ -139,7 +144,7 @@ public OnlineProbabilisticPca(int numComponents, ReorthogonalizePeriod = reorthogonalizePeriod.Value; } - Generator = generator ?? manual_seed(0); + Generator = generator; Rho = rho; Kappa = kappa; TimeOffset = timeOffset; @@ -149,11 +154,7 @@ public OnlineProbabilisticPca(int numComponents, /// public override void Fit(Tensor data) { - // throw new NotImplementedException(); - if (data.NumberOfElements == 0 || data.dim() != 2) - { - throw new ArgumentException("Input data must be a 2D tensor."); - } + base.Fit(data); using (no_grad()) using (NewDisposeScope()) @@ -163,48 +164,41 @@ public override void Fit(Tensor data) var rho = UpdateSchedule(); // Initialize dimensions - var q = NumComponents; - var n = data.size(0); - var d = data.size(1); + var numSamples = data.size(0); // Initialize parameters - if (!_initializedParameters) + if (Means.numel() == 0) { - _mu = zeros(d, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); - var randW = randn(d, q, generator: Generator, device: Device, dtype: ScalarType); - var orthonormalBases = linalg.qr(randW).Q; - Components = (orthonormalBases * Variance).MoveToOuterDisposeScope(); // d x q - _Iq = eye(q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q x q - - _mx = zeros(d, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // d - _Cxz = zeros(d, q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // d x q - _mz = zeros(q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q - _Czz = zeros(q, q, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q x q + Means = zeros(NumFeatures, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); + var weights = qr(randn(NumFeatures, NumComponents, generator: Generator, device: Device, dtype: ScalarType), mode: QRMode.Reduced).Q; + Components = (weights * Variance).MoveToOuterDisposeScope(); + _identityComponents = eye(NumComponents, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); + _mx = zeros(NumFeatures, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // d + _Cxz = zeros(NumFeatures, NumComponents, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // d x q + _mz = zeros(NumComponents, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q + _Czz = zeros(NumComponents, NumComponents, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q x q _sxx = zeros(1, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // scalar - - _initializedParameters = true; } // Covariance matrix - var cov = _Iq * Variance; + var cov = _identityComponents * Variance; // Center data using current mean - var Xc = data - _mu; + var dataCentered = data - Means; // E-step var M = Components.T.matmul(Components) + cov; - var MInv = Utils.InvertSPD(M, _Iq); - - var XcW = Xc.matmul(Components); - var EzT = Utils.InvertSPD(M, XcW.T); + var MInv = Utils.InvertSPD(M, _identityComponents); + var projection = dataCentered.matmul(Components); + var EzT = Utils.InvertSPD(M, projection.T); var Ez = EzT.T; // Update statistics var mx = data.mean([0]); var sxx = data.pow(2).sum(dim: 1).mean(); - var Cxz = data.T.matmul(Ez) / n; + var Cxz = data.T.matmul(Ez) / numSamples; var mz = Ez.mean([0]); - var Czz = EzT.matmul(Ez) / n + Variance * MInv; + var Czz = EzT.matmul(Ez) / numSamples + Variance * MInv; // Update parameters var rhoFactor = 1 - rho; @@ -215,81 +209,62 @@ public override void Fit(Tensor data) _Czz = (rhoFactor * _Czz + rho * Czz).MoveToOuterDisposeScope(); // Update mean - _mu = _mx.MoveToOuterDisposeScope(); + Means = _mx.MoveToOuterDisposeScope(); // Centered statistics - var Sxz = _Cxz - outer(_mu, _mz); + var Sxz = _Cxz - outer(Means, _mz); var Szz = _Czz; - var Sxx = _sxx - _mu.dot(_mu); + var Sxx = _sxx - Means.dot(Means); // M-step - var WNew = Utils.InvertSPD(Szz, Sxz.T).T; + var weightsUpdated = Utils.InvertSPD(Szz, Sxz.T).T; if (_reorthogonalize && _stepCount % ReorthogonalizePeriod == 0) { - var (U, S, Vh) = svd(WNew, fullMatrices: false); + var (U, S, Vh) = svd(weightsUpdated, fullMatrices: false); var R = Vh.T; - WNew = U.matmul(diag(S)); + weightsUpdated = U.matmul(diag(S)); _Cxz = _Cxz.matmul(R.T); _Czz = R.matmul(_Czz).matmul(R.T); _mz = R.matmul(_mz); } // Reorder components based on the strength of the components - var strength = sum(WNew * WNew, dim: 0); + var strength = sum(weightsUpdated * weightsUpdated, dim: 0); var indices = argsort(strength, descending: true); - Components = WNew.index_select(1, indices).MoveToOuterDisposeScope(); + Components = weightsUpdated.index_select(1, indices).MoveToOuterDisposeScope(); _Cxz = _Cxz.index_select(1, indices).MoveToOuterDisposeScope(); _mz = _mz.index_select(0, indices).MoveToOuterDisposeScope(); _Czz = _Czz.index_select(0, indices).index_select(1, indices).MoveToOuterDisposeScope(); - Sxz = _Cxz - outer(_mu, _mz); + Sxz = _Cxz - outer(Means, _mz); Szz = _Czz; // Update variance - Variance = ((Sxx - 2 * trace(Components.T.matmul(Sxz)) + trace(Components.T.matmul(Components).matmul(Szz))) / (double)d) + Variance = ((Sxx - 2 * trace(Components.T.matmul(Sxz)) + trace(Components.T.matmul(Components).matmul(Szz))) / (double)NumFeatures) .clamp_min(0.0) .to_type(TorchSharp.torch.ScalarType.Float64) .item(); } + + IsFitted = true; } /// public override Tensor Transform(Tensor data) { - if (data.NumberOfElements == 0 || data.dim() < 2) - { - throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); - } - - var Xt = data.T; // n x d - var Xc = Xt - _mu; // n x d - var M = Components.T.matmul(Components) + _Iq * Variance; // q x q - var XcW = Xc.matmul(Components); - return Utils.InvertSPD(M, XcW.T).T; // n x q + base.Transform(data); + var dataCentered = data - Means; + var M = Components.T.matmul(Components) + _identityComponents * Variance; + var projection = dataCentered.matmul(Components); + return Utils.InvertSPD(M, projection.T).T; } /// public override Tensor Reconstruct(Tensor data) { - if (data.NumberOfElements == 0 || data.dim() < 2) - { - throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); - } - - if (!_initializedParameters) - { - throw new InvalidOperationException("Model has not yet been fitted. You should call the Fit() or the FitAndTransform() methods first."); - } - - var Xt = data.T; // n x d - var Xc = Xt - _mu; // n x d - var M = Components.T.matmul(Components) + _Iq * Variance; // q x q - var XcW = Xc.matmul(Components); - var EzT = Utils.InvertSPD(M, XcW.T); - var Ez = EzT.T; - - return Ez.matmul(Components.T) + _mu.T; // n x d + base.Reconstruct(data); + return data.matmul(Components.T) + Means; } } diff --git a/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs b/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs index 8e7c3ee3..99ba852a 100644 --- a/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs +++ b/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs @@ -1,4 +1,5 @@ using System; +using System.Linq; using static TorchSharp.torch; namespace Bonsai.ML.Pca.Torch; @@ -52,15 +53,11 @@ public virtual void Fit(Tensor data) { CheckDataCompatibility(data); - var n = data.size(0); var d = data.size(1); if (NumComponents > d) throw new ArgumentException($"Number of components cannot be greater than the number of features. Number of components: {NumComponents}, number of features: {d}.", nameof(data)); - if (n < 2) - throw new ArgumentException($"Need at least 2 samples to fit PCA. Number of samples: {n}.", nameof(data)); - NumFeatures = (int)d; } @@ -101,7 +98,10 @@ private void CheckDataCompatibility(Tensor data) throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); if (data.dim() != 2) - throw new ArgumentException($"Data must be a 2D tensor with shape (samples x features). Data shape: {data.shape}.", nameof(data)); + { + var shapeStr = string.Join(",", data.shape.Select(x => x.ToString()).ToArray()); + throw new ArgumentException($"Data must be a 2D tensor with shape (samples x features). Data shape: {shapeStr}.", nameof(data)); + } } private void CheckDataFeatures(Tensor data) From 5621afd182a26817a26e050feacce1a19728d873 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 26 Jan 2026 18:42:23 +0000 Subject: [PATCH 17/20] Updated to support "hidding" model property if provided by the input --- src/Bonsai.ML.Pca.Torch/CreatePca.cs | 4 +- src/Bonsai.ML.Pca.Torch/Fit.cs | 19 ++-- src/Bonsai.ML.Pca.Torch/FitAndTransform.cs | 13 +-- .../FitAndTransformBuilder.cs | 12 +++ src/Bonsai.ML.Pca.Torch/FitBuilder.cs | 12 +++ src/Bonsai.ML.Pca.Torch/IPcaModelProvider.cs | 12 +++ src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs | 6 +- src/Bonsai.ML.Pca.Torch/Pca.cs | 3 +- src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs | 5 +- src/Bonsai.ML.Pca.Torch/PcaModelBuilder.cs | 94 +++++++++++++++++++ src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs | 1 - src/Bonsai.ML.Pca.Torch/Reconstruct.cs | 20 +--- src/Bonsai.ML.Pca.Torch/ReconstructBuilder.cs | 12 +++ src/Bonsai.ML.Pca.Torch/Transform.cs | 24 +---- src/Bonsai.ML.Pca.Torch/TransformBuilder.cs | 12 +++ src/Bonsai.ML.Pca.Torch/Utils.cs | 5 +- 16 files changed, 181 insertions(+), 73 deletions(-) create mode 100644 src/Bonsai.ML.Pca.Torch/FitAndTransformBuilder.cs create mode 100644 src/Bonsai.ML.Pca.Torch/FitBuilder.cs create mode 100644 src/Bonsai.ML.Pca.Torch/IPcaModelProvider.cs create mode 100644 src/Bonsai.ML.Pca.Torch/PcaModelBuilder.cs create mode 100644 src/Bonsai.ML.Pca.Torch/ReconstructBuilder.cs create mode 100644 src/Bonsai.ML.Pca.Torch/TransformBuilder.cs diff --git a/src/Bonsai.ML.Pca.Torch/CreatePca.cs b/src/Bonsai.ML.Pca.Torch/CreatePca.cs index 2dbaebc7..e77bf986 100644 --- a/src/Bonsai.ML.Pca.Torch/CreatePca.cs +++ b/src/Bonsai.ML.Pca.Torch/CreatePca.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.ComponentModel; using System.Collections.Generic; using System.Reactive.Linq; @@ -21,7 +21,7 @@ namespace Bonsai.ML.Pca.Torch; public class CreatePca : ZeroArgumentExpressionBuilder, INamedElement { /// - public string Name => ModelType.ToString(); + public string Name => $"CreatePca.{ModelType}"; /// /// The number of principal components to compute. diff --git a/src/Bonsai.ML.Pca.Torch/Fit.cs b/src/Bonsai.ML.Pca.Torch/Fit.cs index b5f758f0..eb762f93 100644 --- a/src/Bonsai.ML.Pca.Torch/Fit.cs +++ b/src/Bonsai.ML.Pca.Torch/Fit.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.ComponentModel; using System.Reactive.Linq; using System.Xml.Serialization; @@ -7,19 +7,12 @@ namespace Bonsai.ML.Pca.Torch; /// -/// Fits a PCA model to the input data. +/// Fits the PCA model to the input data. /// -[Combinator] -[Description("Fits a PCA model to the input data.")] -[WorkflowElementCategory(ElementCategory.Sink)] -public class Fit +public class Fit : IPcaModelProvider { - /// - /// The PCA model used to fit the input data. - /// - [Description("The PCA model used to fit the input data.")] - [XmlIgnore] - public IPcaBaseModel? Model { get; set; } + /// + public IPcaBaseModel? Model { get; set; } = null; private void FitModel(IPcaBaseModel model, Tensor data) { @@ -34,7 +27,7 @@ private void FitModel(IPcaBaseModel model, Tensor data) /// public IObservable Process(IObservable source) { - if (Model == null) + if (Model is null) { throw new InvalidOperationException("The PCA model has not been specified."); } diff --git a/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs b/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs index 87c8aefb..b7634711 100644 --- a/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs +++ b/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.ComponentModel; using System.Reactive.Linq; using System.Xml.Serialization; @@ -9,16 +9,9 @@ namespace Bonsai.ML.Pca.Torch; /// /// Fits the PCA model to the input data and transforms it. /// -[Combinator] -[Description("Fits the PCA model to the input data and transforms it.")] -[WorkflowElementCategory(ElementCategory.Transform)] -public class FitAndTransform +public class FitAndTransform : IPcaModelProvider { - /// - /// The PCA model used to fit and transform the input data. - /// - [Description("The PCA model used to fit and transform the input data.")] - [XmlIgnore] + /// public IPcaBaseModel? Model { get; set; } private static void FitModelAndTransformData(IPcaBaseModel model, Tensor data) diff --git a/src/Bonsai.ML.Pca.Torch/FitAndTransformBuilder.cs b/src/Bonsai.ML.Pca.Torch/FitAndTransformBuilder.cs new file mode 100644 index 00000000..2cb35f18 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/FitAndTransformBuilder.cs @@ -0,0 +1,12 @@ +using System.ComponentModel; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Represents an operator that fits a PCA model and transforms the input data. +/// +[ResetCombinator] +[Combinator] +[Description("Fits a PCA model and transforms the input data.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class FitAndTransformBuilder() : PcaModelBuilder(new FitAndTransform()) { } diff --git a/src/Bonsai.ML.Pca.Torch/FitBuilder.cs b/src/Bonsai.ML.Pca.Torch/FitBuilder.cs new file mode 100644 index 00000000..5955fb2a --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/FitBuilder.cs @@ -0,0 +1,12 @@ +using System.ComponentModel; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Represents an operator that fits a PCA model to the input data. +/// +[ResetCombinator] +[Combinator] +[Description("Fits a PCA model to the input data.")] +[WorkflowElementCategory(ElementCategory.Sink)] +public class FitBuilder() : PcaModelBuilder(new Fit()) { } diff --git a/src/Bonsai.ML.Pca.Torch/IPcaModelProvider.cs b/src/Bonsai.ML.Pca.Torch/IPcaModelProvider.cs new file mode 100644 index 00000000..512e15da --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/IPcaModelProvider.cs @@ -0,0 +1,12 @@ +namespace Bonsai.ML.Pca.Torch; + +/// +/// Defines an interface for PCA model providers. +/// +public interface IPcaModelProvider +{ + /// + /// Gets or sets the PCA model. + /// + public IPcaBaseModel? Model { get; set; } +} diff --git a/src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs b/src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs index 6c73813a..2240701a 100644 --- a/src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs +++ b/src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs @@ -1,11 +1,9 @@ -using System; -using System.ComponentModel; -using static TorchSharp.torch; +using static TorchSharp.torch; namespace Bonsai.ML.Pca.Torch; /// -/// Implements online PCA using the Generalized Hebbian Algorithm (GHA). +/// Implements streaming/online PCA using the Generalized Hebbian Algorithm (GHA). /// /// /// diff --git a/src/Bonsai.ML.Pca.Torch/Pca.cs b/src/Bonsai.ML.Pca.Torch/Pca.cs index 5ee84e95..a33c8ce4 100644 --- a/src/Bonsai.ML.Pca.Torch/Pca.cs +++ b/src/Bonsai.ML.Pca.Torch/Pca.cs @@ -1,5 +1,4 @@ -using System; -using static TorchSharp.torch; +using static TorchSharp.torch; using static TorchSharp.torch.linalg; namespace Bonsai.ML.Pca.Torch; diff --git a/src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs b/src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs index b66fd47d..307c4d8f 100644 --- a/src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs +++ b/src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.ComponentModel; using System.Reactive.Linq; using System.Collections.Generic; @@ -38,4 +38,7 @@ public override PropertyDescriptorCollection GetProperties(Attribute[] attribute /// public override PropertyDescriptorCollection GetProperties() => GetProperties([]); + + /// + public override object GetPropertyOwner(PropertyDescriptor pd) => _instance; } diff --git a/src/Bonsai.ML.Pca.Torch/PcaModelBuilder.cs b/src/Bonsai.ML.Pca.Torch/PcaModelBuilder.cs new file mode 100644 index 00000000..8ef69b25 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/PcaModelBuilder.cs @@ -0,0 +1,94 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Xml.Serialization; +using Bonsai.Expressions; +using static TorchSharp.torch; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Represents an abstract PCA model builder. +/// +public abstract class PcaModelBuilder(IPcaModelProvider operatorInstance) : SingleArgumentExpressionBuilder, ICustomTypeDescriptor +{ + private readonly IPcaModelProvider _operator = operatorInstance; + + /// + /// The PCA model. + /// + [XmlIgnore] + [Description("The PCA model.")] + public IPcaBaseModel? Model + { + get => _operator.Model; + set => _operator.Model = value; + } + + /// + /// Determines whether the model is available in the input expression (and thus Model property should not be exposed) or if it should be explicitly set as a property of the operator. + /// + [Browsable(false)] + [XmlIgnore] + [RefreshProperties(RefreshProperties.All)] + public bool HasModel { get; private set; } = true; + + /// + public override Expression Build(IEnumerable arguments) + { + var input = arguments?.FirstOrDefault(); + MethodInfo processMethod; + if (input is null) + { + HasModel = true; + processMethod = typeof(T).GetMethod("Process", [typeof(IObservable)]); + return Expression.Call(Expression.Constant(_operator), processMethod!, [input]); + } + + var obsType = input.Type; + var t = obsType.GetGenericArguments().Single(); + + HasModel = !(t.IsGenericType && t.FullName?.StartsWith("System.Tuple`") == true); + + if (!HasModel) + { + var args = t.GetGenericArguments(); + if (args.Length != 2 || (!typeof(IPcaBaseModel).IsAssignableFrom(args[0]) && !typeof(IPcaBaseModel).IsAssignableFrom(args[1]))) + throw new InvalidOperationException("The input type is not valid. Expected an observable sequence of tuples containing a PCA model and a tensor."); + } + + processMethod = HasModel + ? typeof(T).GetMethod("Process", [typeof(IObservable)]) + : typeof(T).GetMethod("Process", [typeof(IObservable<>).MakeGenericType(t)]); + + return Expression.Call(Expression.Constant(_operator), processMethod!, [input]); + } + + + PropertyDescriptorCollection ICustomTypeDescriptor.GetProperties(Attribute[]? attributes) + { + var props = TypeDescriptor.GetProperties(this, attributes, true); + if (HasModel) return props; + + var filtered = props.Cast() + .Where(p => p.Name != nameof(Model)) + .ToArray(); + + return new PropertyDescriptorCollection(filtered); + } + + AttributeCollection ICustomTypeDescriptor.GetAttributes() => TypeDescriptor.GetAttributes(this, true); + string? ICustomTypeDescriptor.GetClassName() => TypeDescriptor.GetClassName(this, true); + string? ICustomTypeDescriptor.GetComponentName() => TypeDescriptor.GetComponentName(this, true); + TypeConverter ICustomTypeDescriptor.GetConverter() => TypeDescriptor.GetConverter(this, true); + EventDescriptor? ICustomTypeDescriptor.GetDefaultEvent() => TypeDescriptor.GetDefaultEvent(this, true); + PropertyDescriptor? ICustomTypeDescriptor.GetDefaultProperty() => TypeDescriptor.GetDefaultProperty(this, true); + object? ICustomTypeDescriptor.GetEditor(Type editorBaseType) => TypeDescriptor.GetEditor(this, editorBaseType, true); + EventDescriptorCollection ICustomTypeDescriptor.GetEvents(Attribute[]? attributes) => TypeDescriptor.GetEvents(this, attributes, true); + EventDescriptorCollection ICustomTypeDescriptor.GetEvents() => TypeDescriptor.GetEvents(this, true); + PropertyDescriptorCollection ICustomTypeDescriptor.GetProperties() => ((ICustomTypeDescriptor)this).GetProperties(Array.Empty()); + object ICustomTypeDescriptor.GetPropertyOwner(PropertyDescriptor? pd) => this; +} diff --git a/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs b/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs index cfddc321..843298f2 100644 --- a/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs +++ b/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs @@ -1,5 +1,4 @@ using System; -using Bonsai.ML.Torch; using static TorchSharp.torch; using static TorchSharp.torch.linalg; diff --git a/src/Bonsai.ML.Pca.Torch/Reconstruct.cs b/src/Bonsai.ML.Pca.Torch/Reconstruct.cs index 043f631a..2701da28 100644 --- a/src/Bonsai.ML.Pca.Torch/Reconstruct.cs +++ b/src/Bonsai.ML.Pca.Torch/Reconstruct.cs @@ -1,8 +1,7 @@ -using System; +using System; using System.ComponentModel; using System.Reactive.Linq; using System.Xml.Serialization; -using Bonsai; using static TorchSharp.torch; namespace Bonsai.ML.Pca.Torch; @@ -10,24 +9,11 @@ namespace Bonsai.ML.Pca.Torch; /// /// Reconstructs the input data using a PCA model. /// -[Combinator] -[Description("Reconstructs the input data using a PCA model.")] -[WorkflowElementCategory(ElementCategory.Transform)] -public class Reconstruct +public class Reconstruct : IPcaModelProvider { - /// - /// The PCA model used to reconstruct the input data. - /// - [Description("The PCA model used to reconstruct the input data.")] - [XmlIgnore] + /// public IPcaBaseModel? Model { get; set; } - /// - /// Reconstructs the input data using the specified PCA model. - /// - /// - /// - /// private static Tensor ReconstructData(IPcaBaseModel model, Tensor data) { return model.Reconstruct(data); diff --git a/src/Bonsai.ML.Pca.Torch/ReconstructBuilder.cs b/src/Bonsai.ML.Pca.Torch/ReconstructBuilder.cs new file mode 100644 index 00000000..4938864f --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/ReconstructBuilder.cs @@ -0,0 +1,12 @@ +using System.ComponentModel; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Represents an operator that reconstructs the input data using a PCA model. +/// +[ResetCombinator] +[Combinator] +[Description("Reconstructs the input data using a PCA model.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class ReconstructBuilder() : PcaModelBuilder(new Reconstruct()) { } diff --git a/src/Bonsai.ML.Pca.Torch/Transform.cs b/src/Bonsai.ML.Pca.Torch/Transform.cs index 95c39826..016b1346 100644 --- a/src/Bonsai.ML.Pca.Torch/Transform.cs +++ b/src/Bonsai.ML.Pca.Torch/Transform.cs @@ -1,8 +1,5 @@ -using System; -using System.ComponentModel; +using System; using System.Reactive.Linq; -using System.Xml.Serialization; -using Bonsai; using static TorchSharp.torch; namespace Bonsai.ML.Pca.Torch; @@ -10,24 +7,11 @@ namespace Bonsai.ML.Pca.Torch; /// /// Transforms the input data using a PCA model. /// -[Combinator] -[Description("Transforms the input data using a PCA model.")] -[WorkflowElementCategory(ElementCategory.Transform)] -public class Transform +public class Transform : IPcaModelProvider { - /// - /// The PCA model used to transform the input data. - /// - [Description("The PCA model used to transform the input data.")] - [XmlIgnore] - public IPcaBaseModel? Model { get; set; } + /// + public IPcaBaseModel? Model { get; set; } = null; - /// - /// Transforms the input data using the specified PCA model. - /// - /// - /// - /// private static Tensor TransformData(IPcaBaseModel model, Tensor data) { return model.Transform(data); diff --git a/src/Bonsai.ML.Pca.Torch/TransformBuilder.cs b/src/Bonsai.ML.Pca.Torch/TransformBuilder.cs new file mode 100644 index 00000000..ee4ca8ed --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/TransformBuilder.cs @@ -0,0 +1,12 @@ +using System.ComponentModel; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Represents an operator that transforms the input data using a PCA model. +/// +[ResetCombinator] +[Combinator] +[Description("Transforms the input data using a PCA model.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class TransformBuilder() : PcaModelBuilder(new Transform()) { } diff --git a/src/Bonsai.ML.Pca.Torch/Utils.cs b/src/Bonsai.ML.Pca.Torch/Utils.cs index c22bc775..b89b73e5 100644 --- a/src/Bonsai.ML.Pca.Torch/Utils.cs +++ b/src/Bonsai.ML.Pca.Torch/Utils.cs @@ -1,5 +1,4 @@ -using System; -using System.Reactive.Linq; +using System; using static TorchSharp.torch; namespace Bonsai.ML.Pca.Torch; @@ -27,4 +26,4 @@ internal static Tensor InvertSPD( } return cholesky_solve(rhs, L); } -} \ No newline at end of file +} From 5b60a4f8c272e70b75545802b01501d51b4d86ec Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 27 Jan 2026 11:45:11 +0000 Subject: [PATCH 18/20] Changed time offset to sample offset for better consistency --- src/Bonsai.ML.Pca.Torch/CreatePca.cs | 10 +++++----- .../OnlineProbabilisticPca.cs | 17 +++++++++-------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/Bonsai.ML.Pca.Torch/CreatePca.cs b/src/Bonsai.ML.Pca.Torch/CreatePca.cs index e77bf986..bf7f6567 100644 --- a/src/Bonsai.ML.Pca.Torch/CreatePca.cs +++ b/src/Bonsai.ML.Pca.Torch/CreatePca.cs @@ -79,10 +79,10 @@ public class CreatePca : ZeroArgumentExpressionBuilder, INamedElement public double? Kappa { get; set; } = 0.9; /// - /// The time offset for the online probabilistic PCA model. + /// The sample offset for the online probabilistic PCA model. /// - [Description("The time offset for the online probabilistic PCA model. If null, decaying learning rate starts from the first sample.")] - public int? TimeOffset { get; set; } = null; + [Description("The sample offset for the online probabilistic PCA model. If null, decaying learning rate starts from the first sample.")] + public int? SampleOffset { get; set; } = null; /// /// The period for reorthogonalizing the components in the online probabilistic PCA model. @@ -121,7 +121,7 @@ internal IEnumerable GetModelProperties() yield return nameof(InitialVariance); yield return nameof(Rho); yield return nameof(Kappa); - yield return nameof(TimeOffset); + yield return nameof(SampleOffset); yield return nameof(ReorthogonalizePeriod); yield return nameof(Generator); } @@ -157,7 +157,7 @@ private static PcaBaseModel CreateModel(CreatePca pcaBuilder) generator: pcaBuilder.Generator, rho: pcaBuilder.Rho, kappa: pcaBuilder.Kappa, - timeOffset: pcaBuilder.TimeOffset, + sampleOffset: pcaBuilder.SampleOffset, reorthogonalizePeriod: pcaBuilder.ReorthogonalizePeriod), PcaModelType.OnlinePcaGha => new OnlinePcaGha( numComponents: pcaBuilder.NumComponents, diff --git a/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs b/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs index 15a3afd8..b5e251bb 100644 --- a/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs +++ b/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs @@ -56,9 +56,9 @@ public class OnlineProbabilisticPca : PcaBaseModel public int ReorthogonalizePeriod { get; private set; } /// - /// Gets the time offset used in the learning rate schedule when Kappa is specified. + /// Gets the sample offset used in the learning rate schedule when Kappa is specified. /// - public int? TimeOffset { get; private set; } + public int SampleOffset { get; private set; } /// public override Tensor Components { get; protected set; } = empty(0); @@ -78,7 +78,7 @@ public class OnlineProbabilisticPca : PcaBaseModel /// /// /// - /// + /// /// /// public OnlineProbabilisticPca(int numComponents, @@ -88,7 +88,7 @@ public OnlineProbabilisticPca(int numComponents, Generator? generator = null, double? rho = 0.1, double? kappa = null, - int? timeOffset = 3000, + int? sampleOffset = null, int? reorthogonalizePeriod = null ) : base(numComponents, device, @@ -120,9 +120,10 @@ public OnlineProbabilisticPca(int numComponents, } else { - if (timeOffset is null or <= 0) + sampleOffset ??= 0; + if (sampleOffset < 0) { - throw new ArgumentException("Time offset must be a positive integer.", nameof(timeOffset)); + throw new ArgumentException("Sample offset must be a positive integer.", nameof(sampleOffset)); } if (!kappa.HasValue) @@ -135,7 +136,7 @@ public OnlineProbabilisticPca(int numComponents, throw new ArgumentException("Kappa must be in the range (0.5, 1].", nameof(kappa)); } - UpdateSchedule = () => Math.Pow(_stepCount + timeOffset.Value, -kappa.Value); + UpdateSchedule = () => Math.Pow(_stepCount + SampleOffset, -kappa.Value); } if (reorthogonalizePeriod.HasValue) @@ -147,7 +148,7 @@ public OnlineProbabilisticPca(int numComponents, Generator = generator; Rho = rho; Kappa = kappa; - TimeOffset = timeOffset; + SampleOffset = sampleOffset ?? 0; Variance = initialVariance; } From 7aa2efd22cc632fe2cb591c201cc6c11a5443940 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 29 Jan 2026 12:55:46 +0000 Subject: [PATCH 19/20] Added standard PCA unit tests --- .../Bonsai.ML.Pca.Torch.Tests.csproj | 1 + .../CreatePCATest.bonsai | 31 --- .../FitPCATest.bonsai | 235 ------------------ .../StandardPcaTests.cs | 130 ++++++++++ 4 files changed, 131 insertions(+), 266 deletions(-) delete mode 100644 tests/Bonsai.ML.Pca.Torch.Tests/CreatePCATest.bonsai delete mode 100644 tests/Bonsai.ML.Pca.Torch.Tests/FitPCATest.bonsai create mode 100644 tests/Bonsai.ML.Pca.Torch.Tests/StandardPcaTests.cs diff --git a/tests/Bonsai.ML.Pca.Torch.Tests/Bonsai.ML.Pca.Torch.Tests.csproj b/tests/Bonsai.ML.Pca.Torch.Tests/Bonsai.ML.Pca.Torch.Tests.csproj index 5c8c20ab..e480f9b3 100644 --- a/tests/Bonsai.ML.Pca.Torch.Tests/Bonsai.ML.Pca.Torch.Tests.csproj +++ b/tests/Bonsai.ML.Pca.Torch.Tests/Bonsai.ML.Pca.Torch.Tests.csproj @@ -17,5 +17,6 @@ + \ No newline at end of file diff --git a/tests/Bonsai.ML.Pca.Torch.Tests/CreatePCATest.bonsai b/tests/Bonsai.ML.Pca.Torch.Tests/CreatePCATest.bonsai deleted file mode 100644 index ca948fc8..00000000 --- a/tests/Bonsai.ML.Pca.Torch.Tests/CreatePCATest.bonsai +++ /dev/null @@ -1,31 +0,0 @@ - - - - - - - 2 - PCA - 1 - 100 - 1E-05 - - - - PCA - - - - - - - - - - - \ No newline at end of file diff --git a/tests/Bonsai.ML.Pca.Torch.Tests/FitPCATest.bonsai b/tests/Bonsai.ML.Pca.Torch.Tests/FitPCATest.bonsai deleted file mode 100644 index ca915de8..00000000 --- a/tests/Bonsai.ML.Pca.Torch.Tests/FitPCATest.bonsai +++ /dev/null @@ -1,235 +0,0 @@ - - - - - - LoadTrainingData - - - - - - - - Z:/home/nicholas/Downloads/xfads-2/workflows/data/y_train.bin - 0 - 0 - 50 - 1 - F64 - RowMajor - - - - - - 1000 - - - - - - - - - - - - - - Float32 - - - - - - - - - - - - - - - - - - - - 1 - - - - it.squeeze(null).T - - - - - 0 - - true - - - - - - - - - it.T - - - - - - - - - - - - - - - - - YTrain - - - LoadValidationData - - - - - - - - Z:/home/nicholas/Downloads/xfads-2/workflows/data/y_valid.bin - 0 - 0 - 50 - 1 - F64 - RowMajor - - - - - - 1000 - - - - - - - - - - - - - - Float32 - - - - - - - - - - - - - - - - - - - - 1 - - - - it.squeeze(null).T - - - - - 0 - - true - - - - - - - - - it.T - - - - - - - - - - - - - - - - - YValid - - - YTrain - - - - 2 - ProbabilisticPCA - 1 - 100 - 1E-05 - - - - - - - - - - Item2 - - - Components - - - - - - - - - - - - - \ No newline at end of file diff --git a/tests/Bonsai.ML.Pca.Torch.Tests/StandardPcaTests.cs b/tests/Bonsai.ML.Pca.Torch.Tests/StandardPcaTests.cs new file mode 100644 index 00000000..03af21e5 --- /dev/null +++ b/tests/Bonsai.ML.Pca.Torch.Tests/StandardPcaTests.cs @@ -0,0 +1,130 @@ +using System.Diagnostics; +using Bonsai.ML.Pca.Torch; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Pca.Torch.Tests; + +/// +/// Tests for the Bonsai.ML.Pca.Torch package. +/// +[TestClass] +public class StandardPcaTests +{ + public StandardPcaTests() + { + manual_seed(0); + set_printoptions(style: TorchSharp.TensorStringStyle.Numpy); + } + + private static readonly Tensor _expectedFirstComponent = tensor(new float[] { 1f, 0f }); + + private static Tensor Generate2dRotationMatrix(double angleDegrees) + { + var angleRad = angleDegrees * Math.PI / 180.0; + var cosA = Math.Cos(angleRad); + var sinA = Math.Sin(angleRad); + return tensor(new float[,] + { + { (float)cosA, (float)-sinA }, + { (float)sinA, (float)cosA } + }); + } + + private static Tensor Generate2dDataset(int numSamples, double scaleX = 3.0, double scaleY = 0.1, double offsetX = 0.0, double offsetY = 0.0, double rotationAngle = 0.0) + { + var x = randn(numSamples) * scaleX + offsetX; + var y = randn(numSamples) * scaleY + offsetY; + var data = stack([x, y], 1); + if (rotationAngle == 0) + return data; + var rotationMatrix = Generate2dRotationMatrix(rotationAngle); + return data.matmul(rotationMatrix); + } + + private static float Similarity(Tensor a, Tensor b) + { + if (a.dim() != 1 || b.dim() != 1) + throw new ArgumentException("Input tensors must be 1-dimensional. Instead got dimension: " + a.dim()); + if (a.size(0) != b.size(0)) + throw new ArgumentException($"Input tensors must have the same shape. Instead got {string.Join(", ", a.shape)} and {string.Join(", ", b.shape)}."); + var aNorm = a.norm(-1, keepdim: true); + var bNorm = b.norm(-1, keepdim: true); + var dotProduct = a.dot(b); + return abs(dotProduct / (aNorm * bNorm)).item(); + } + + private static void TestBasic(PcaBaseModel model) + { + // Generate a simple dataset. + var data = Generate2dDataset(1000); + Debug.WriteLine($"Data shape: {string.Join(", ", data.shape)}"); + + // Fit the model. + model.Fit(data); + + // Verify components. + Debug.WriteLine($"Components: {model.Components.str()}"); + Assert.IsTrue(model.Components.shape[0] == 2 && model.Components.shape[1] == 2); + + // Verify similarity with expected first component. + var similarity = Similarity(model.Components[0], _expectedFirstComponent); + Debug.WriteLine($"Similarity with expected first component: {similarity}"); + Assert.IsTrue(similarity > 0.99); + + // Compare reconstructed data. + var transformed = model.Transform(data); + var reconstructed = model.Reconstruct(transformed); + + var reconstructionError = mean((data - reconstructed).pow(2)).item(); + Debug.WriteLine($"Reconstruction error: {reconstructionError}"); + Assert.IsTrue(reconstructionError < 1e-10); + } + + private static void TestRotation(PcaBaseModel model) + { + // Compare rotated dataset + var data = Generate2dDataset(1000, rotationAngle: 30.0); + model.Fit(data); + + var rotationMatrix = Generate2dRotationMatrix(30.0); + var rotatedExpectedFirstComponent = _expectedFirstComponent.matmul(rotationMatrix); + Debug.WriteLine($"Rotated expected first component: {rotatedExpectedFirstComponent.str()}"); + var rotatedSimilarity = Similarity(model.Components[0], rotatedExpectedFirstComponent); + Debug.WriteLine($"Similarity with expected first component (rotated data): {rotatedSimilarity}"); + Assert.IsTrue(rotatedSimilarity > 0.99); + } + + private static void TestOffset(PcaBaseModel model) + { + // Test offset centering + var dataOffset = Generate2dDataset(1000, offsetX: 5.0, offsetY: -3.0); + model.Fit(dataOffset); + + var offsetExpectedFirstComponent = _expectedFirstComponent; + var offsetSimilarity = Similarity(model.Components[0], offsetExpectedFirstComponent); + Debug.WriteLine($"Similarity with expected first component (offset data): {offsetSimilarity}"); + Assert.IsTrue(offsetSimilarity > 0.99); + + var transformed = model.Transform(dataOffset); + var reconstructed = model.Reconstruct(transformed); + + var reconstructedMeans = reconstructed.mean([0]); + Debug.WriteLine($"Reconstructed means (offset data): {reconstructedMeans.str()}"); + Assert.IsTrue(abs(reconstructedMeans[0] - 5.0).item() < 0.1); + Assert.IsTrue(abs(reconstructedMeans[1] + 3.0).item() < 0.1); + } + + [TestMethod] + public static void TestStandardPca() + { + var pca = new Pca(numComponents: 2); + TestBasic(pca); + + pca = new Pca(numComponents: 2); + TestRotation(pca); + + pca = new Pca(numComponents: 2); + TestOffset(pca); + } +} From d76c2c3870635d41da6082d9c18820383842f9c5 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 25 Mar 2026 14:05:10 +0000 Subject: [PATCH 20/20] Removed `static` from test method to fix ms test warning --- tests/Bonsai.ML.Pca.Torch.Tests/StandardPcaTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/Bonsai.ML.Pca.Torch.Tests/StandardPcaTests.cs b/tests/Bonsai.ML.Pca.Torch.Tests/StandardPcaTests.cs index 03af21e5..3a323bc1 100644 --- a/tests/Bonsai.ML.Pca.Torch.Tests/StandardPcaTests.cs +++ b/tests/Bonsai.ML.Pca.Torch.Tests/StandardPcaTests.cs @@ -116,7 +116,7 @@ private static void TestOffset(PcaBaseModel model) } [TestMethod] - public static void TestStandardPca() + public void TestStandardPca() { var pca = new Pca(numComponents: 2); TestBasic(pca);