diff --git a/Bonsai.ML.sln b/Bonsai.ML.sln index fc4e963f..56e66f13 100644 --- a/Bonsai.ML.sln +++ b/Bonsai.ML.sln @@ -46,6 +46,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Lds.Torch.Tests", EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Lds.Torch.Design", "src\Bonsai.ML.Lds.Torch.Design\Bonsai.ML.Lds.Torch.Design.csproj", "{1F52DECD-1B2C-4F6C-996C-14C715283B80}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Pca.Torch", "src\Bonsai.ML.Pca.Torch\Bonsai.ML.Pca.Torch.csproj", "{1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Pca.Torch.Tests", "tests\Bonsai.ML.Pca.Torch.Tests\Bonsai.ML.Pca.Torch.Tests.csproj", "{4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -120,6 +124,14 @@ Global {1F52DECD-1B2C-4F6C-996C-14C715283B80}.Debug|Any CPU.Build.0 = Debug|Any CPU {1F52DECD-1B2C-4F6C-996C-14C715283B80}.Release|Any CPU.ActiveCfg = Release|Any CPU {1F52DECD-1B2C-4F6C-996C-14C715283B80}.Release|Any CPU.Build.0 = Release|Any CPU + {1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}.Debug|Any CPU.Build.0 = Debug|Any CPU + {1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}.Release|Any CPU.ActiveCfg = Release|Any CPU + {1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}.Release|Any CPU.Build.0 = Release|Any CPU + {4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}.Debug|Any CPU.Build.0 = Debug|Any CPU + {4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}.Release|Any CPU.ActiveCfg = Release|Any CPU + {4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/Bonsai.ML.Pca.Torch/Bonsai.ML.Pca.Torch.csproj b/src/Bonsai.ML.Pca.Torch/Bonsai.ML.Pca.Torch.csproj new file mode 100644 index 00000000..b97c06fd --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/Bonsai.ML.Pca.Torch.csproj @@ -0,0 +1,15 @@ + + + + Bonsai.ML.Pca.Torch Bonsai library. + $(PackageTags) PCA Principal Component Analysis + net472;netstandard2.0 + enable + + + + + + + + diff --git a/src/Bonsai.ML.Pca.Torch/CreatePca.cs b/src/Bonsai.ML.Pca.Torch/CreatePca.cs new file mode 100644 index 00000000..bf7f6567 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/CreatePca.cs @@ -0,0 +1,198 @@ +using System; +using System.ComponentModel; +using System.Collections.Generic; +using System.Reactive.Linq; +using System.Linq.Expressions; +using Bonsai.Expressions; +using System.Reflection; +using static TorchSharp.torch; +using System.Xml.Serialization; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Creates a PCA model. +/// +[Combinator] +[ResetCombinator] +[WorkflowElementCategory(ElementCategory.Source)] +[TypeDescriptionProvider(typeof(PcaDescriptionProvider))] +[Description("Creates a PCA model.")] +public class CreatePca : ZeroArgumentExpressionBuilder, INamedElement +{ + /// + public string Name => $"CreatePca.{ModelType}"; + + /// + /// The number of principal components to compute. + /// + public int NumComponents { get; set; } = 2; + + /// + /// The device on which to create the PCA model. + /// + [XmlIgnore] + [Description("The device on which to create the PCA model.")] + public Device? Device { get; set; } + + /// + /// The scalar type of the PCA model. + /// + [Description("The scalar type of the PCA model.")] + public ScalarType? ScalarType { get; set; } + + /// + /// The type of PCA model to create. + /// + [RefreshProperties(RefreshProperties.All)] + [Description("The type of PCA model to create.")] + public PcaModelType ModelType { get; set; } = PcaModelType.Pca; + + /// + /// The initial variance for probabilistic PCA models. + /// + [Description("The initial variance for probabilistic PCA models.")] + public double InitialVariance { get; set; } = 1.0; + + /// + /// The number of iterations for fitting probabilistic PCA models. + /// + [Description("The number of iterations for fitting probabilistic PCA models.")] + public int Iterations { get; set; } = 100; + + /// + /// The tolerance for convergence in probabilistic PCA models. + /// + [Description("The tolerance for convergence in probabilistic PCA models.")] + public double Tolerance { get; set; } = 1e-5; + + /// + /// The constant learning rate parameter for the online probabilistic PCA model. + /// + [Description("The constant learning rate parameter for the online probabilistic PCA model. Only one of Rho or Kappa must be specified.")] + public double? Rho { get; set; } = 0.1; + + /// + /// The forgetting factor for the online probabilistic PCA model. + /// + [Description("The forgetting factor for the online probabilistic PCA model. Only one of Rho or Kappa must be specified.")] + public double? Kappa { get; set; } = 0.9; + + /// + /// The sample offset for the online probabilistic PCA model. + /// + [Description("The sample offset for the online probabilistic PCA model. If null, decaying learning rate starts from the first sample.")] + public int? SampleOffset { get; set; } = null; + + /// + /// The period for reorthogonalizing the components in the online probabilistic PCA model. + /// + [Description("The period for reorthogonalizing the components in the online probabilistic PCA model. If null, reorthogonalization is not performed.")] + public int? ReorthogonalizePeriod { get; set; } = null; + + /// + /// The random number generator used for initializing probabilistic PCA models. + /// + [XmlIgnore] + public Generator? Generator { get; set; } = null; + + /// + /// The learning rate for the Online PCA GHA model. + /// + public double LearningRate { get; set; } = 0.1; + + internal IEnumerable GetModelProperties() + { + yield return nameof(NumComponents); + yield return nameof(Device); + yield return nameof(ScalarType); + yield return nameof(ModelType); + + if (ModelType == PcaModelType.ProbabilisticPca) + { + yield return nameof(InitialVariance); + yield return nameof(Iterations); + yield return nameof(Tolerance); + yield return nameof(Generator); + } + + if (ModelType == PcaModelType.OnlineProbabilisticPca) + { + yield return nameof(InitialVariance); + yield return nameof(Rho); + yield return nameof(Kappa); + yield return nameof(SampleOffset); + yield return nameof(ReorthogonalizePeriod); + yield return nameof(Generator); + } + + if (ModelType == PcaModelType.OnlinePcaGha) + { + yield return nameof(LearningRate); + yield return nameof(Generator); + } + } + + private static PcaBaseModel CreateModel(CreatePca pcaBuilder) + { + return pcaBuilder.ModelType switch + { + PcaModelType.Pca => new Pca( + numComponents: pcaBuilder.NumComponents, + device: pcaBuilder.Device, + scalarType: pcaBuilder.ScalarType), + PcaModelType.ProbabilisticPca => new ProbabilisticPca( + numComponents: pcaBuilder.NumComponents, + device: pcaBuilder.Device, + scalarType: pcaBuilder.ScalarType, + initialVariance: pcaBuilder.InitialVariance, + generator: pcaBuilder.Generator, + iterations: pcaBuilder.Iterations, + tolerance: pcaBuilder.Tolerance), + PcaModelType.OnlineProbabilisticPca => new OnlineProbabilisticPca( + numComponents: pcaBuilder.NumComponents, + device: pcaBuilder.Device, + scalarType: pcaBuilder.ScalarType, + initialVariance: pcaBuilder.InitialVariance, + generator: pcaBuilder.Generator, + rho: pcaBuilder.Rho, + kappa: pcaBuilder.Kappa, + sampleOffset: pcaBuilder.SampleOffset, + reorthogonalizePeriod: pcaBuilder.ReorthogonalizePeriod), + PcaModelType.OnlinePcaGha => new OnlinePcaGha( + numComponents: pcaBuilder.NumComponents, + learningRate: pcaBuilder.LearningRate, + device: pcaBuilder.Device, + scalarType: pcaBuilder.ScalarType, + generator: pcaBuilder.Generator), + _ => throw new NotSupportedException($"Model type {pcaBuilder.ModelType} is not supported."), + }; + } + + private static Type GetModelType(PcaModelType modelType) + { + return modelType switch + { + PcaModelType.Pca => typeof(Pca), + PcaModelType.ProbabilisticPca => typeof(ProbabilisticPca), + PcaModelType.OnlineProbabilisticPca => typeof(OnlineProbabilisticPca), + PcaModelType.OnlinePcaGha => typeof(OnlinePcaGha), + _ => throw new NotSupportedException($"Model type {modelType} is not supported."), + }; + } + + /// + public override Expression Build(IEnumerable arguments) + { + var processMethod = GetType().GetMethod( + nameof(Process), + BindingFlags.NonPublic | BindingFlags.Static); + processMethod = processMethod.MakeGenericMethod(GetModelType(ModelType)); + return Expression.Call(processMethod, Expression.Constant(this)); + } + + private static IObservable Process(CreatePca instance) where T : PcaBaseModel + { + return Observable.Return((T)CreateModel(instance)); + } +} diff --git a/src/Bonsai.ML.Pca.Torch/Fit.cs b/src/Bonsai.ML.Pca.Torch/Fit.cs new file mode 100644 index 00000000..eb762f93 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/Fit.cs @@ -0,0 +1,143 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Fits the PCA model to the input data. +/// +public class Fit : IPcaModelProvider +{ + /// + public IPcaBaseModel? Model { get; set; } = null; + + private void FitModel(IPcaBaseModel model, Tensor data) + { + model.Fit(data); + } + + /// + /// Fits the PCA model to the input data. + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + if (Model is null) + { + throw new InvalidOperationException("The PCA model has not been specified."); + } + return source.Do(value => + { + FitModel(Model, value); + }); + } + + /// + /// Fits a standard PCA model to the input data. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item1, value.Item2); + }); + } + + /// + /// Fits a standard PCA model to the input data. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item2, value.Item1); + }); + } + + /// + /// Fits a probabilistic PCA model to the input data. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item1, value.Item2); + }); + } + + /// + /// Fits a probabilistic PCA model to the input data. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item2, value.Item1); + }); + } + + /// + /// Fits an online probabilistic PCA model to the input data. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item1, value.Item2); + }); + } + + /// + /// Fits an online probabilistic PCA model to the input data. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item2, value.Item1); + }); + } + + /// + /// Fits an online PCA model using the Generalized Hebbian Algorithm to the input data. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item1, value.Item2); + }); + } + + /// + /// Fits an online PCA model using the Generalized Hebbian Algorithm to the input data. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModel(value.Item2, value.Item1); + }); + } +} diff --git a/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs b/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs new file mode 100644 index 00000000..b7634711 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs @@ -0,0 +1,143 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Fits the PCA model to the input data and transforms it. +/// +public class FitAndTransform : IPcaModelProvider +{ + /// + public IPcaBaseModel? Model { get; set; } + + private static void FitModelAndTransformData(IPcaBaseModel model, Tensor data) + { + model.FitAndTransform(data); + } + + + /// + /// Fits the PCA model to the input data and transforms it. + /// + /// + /// + public IObservable Process(IObservable source) + { + if (Model == null) + { + throw new InvalidOperationException("The PCA model has not been specified."); + } + return source.Do(value => + { + FitModelAndTransformData(Model, value); + }); + } + + /// + /// Fits a standard PCA model to the input data and transforms it. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item1, value.Item2); + }); + } + + /// + /// Fits a standard PCA model to the input data and transforms it. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item2, value.Item1); + }); + } + + /// + /// Fits a probabilistic PCA model to the input data and transforms it. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item1, value.Item2); + }); + } + + /// + /// Fits a probabilistic PCA model to the input data and transforms it. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item2, value.Item1); + }); + } + + /// + /// Fits an online probabilistic PCA model to the input data and transforms it. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item1, value.Item2); + }); + } + + /// + /// Fits an online probabilistic PCA model to the input data and transforms it. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item2, value.Item1); + }); + } + + /// + /// Fits an online PCA model using the Generalized Hebbian Algorithm to the input data and transforms it. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item1, value.Item2); + }); + } + + /// + /// Fits an online PCA model using the Generalized Hebbian Algorithm to the input data and transforms it. + /// + /// + /// + public IObservable> Process(IObservable> source) + { + return source.Do((value) => + { + FitModelAndTransformData(value.Item2, value.Item1); + }); + } +} diff --git a/src/Bonsai.ML.Pca.Torch/FitAndTransformBuilder.cs b/src/Bonsai.ML.Pca.Torch/FitAndTransformBuilder.cs new file mode 100644 index 00000000..2cb35f18 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/FitAndTransformBuilder.cs @@ -0,0 +1,12 @@ +using System.ComponentModel; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Represents an operator that fits a PCA model and transforms the input data. +/// +[ResetCombinator] +[Combinator] +[Description("Fits a PCA model and transforms the input data.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class FitAndTransformBuilder() : PcaModelBuilder(new FitAndTransform()) { } diff --git a/src/Bonsai.ML.Pca.Torch/FitBuilder.cs b/src/Bonsai.ML.Pca.Torch/FitBuilder.cs new file mode 100644 index 00000000..5955fb2a --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/FitBuilder.cs @@ -0,0 +1,12 @@ +using System.ComponentModel; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Represents an operator that fits a PCA model to the input data. +/// +[ResetCombinator] +[Combinator] +[Description("Fits a PCA model to the input data.")] +[WorkflowElementCategory(ElementCategory.Sink)] +public class FitBuilder() : PcaModelBuilder(new Fit()) { } diff --git a/src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs b/src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs new file mode 100644 index 00000000..7759f109 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs @@ -0,0 +1,78 @@ +using static TorchSharp.torch; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Defines the interface for PCA models. +/// +public interface IPcaBaseModel +{ + /// + /// Gets a value indicating whether the model has been fitted to data. + /// + public bool IsFitted { get; } + + /// + /// Gets the number of features in the fitted data. + /// + public int NumFeatures { get; } + + /// + /// Gets the principal components of the model. + /// + public Tensor Components { get; } + + /// + /// Gets the number of principal components kept by the model. + /// + public int NumComponents { get; } + + /// + /// Gets the device on which the model operates. + /// + public Device Device { get; } + + /// + /// Gets the data type used by the model. + /// + public ScalarType? ScalarType { get; } + + /// + /// Fits the PCA model to the given data. + /// + /// + /// The input data should be a 2D tensor with shape (samples x features). + /// + /// + public void Fit(Tensor data); + + /// + /// Transforms the input data using the PCA model. + /// + /// + /// The input data should be a 2D tensor with shape (samples x features). + /// + /// + /// + public Tensor Transform(Tensor data); + + /// + /// Fits the PCA model to the given data and transforms it. + /// + /// + /// The input data should be a 2D tensor with shape (samples x features). + /// + /// + /// + public Tensor FitAndTransform(Tensor data); + + /// + /// Reconstructs the input data using the PCA model. + /// + /// + /// The input data should be a 2D tensor with shape (samples x features). + /// + /// + /// + public Tensor Reconstruct(Tensor data); +} diff --git a/src/Bonsai.ML.Pca.Torch/IPcaModelProvider.cs b/src/Bonsai.ML.Pca.Torch/IPcaModelProvider.cs new file mode 100644 index 00000000..512e15da --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/IPcaModelProvider.cs @@ -0,0 +1,12 @@ +namespace Bonsai.ML.Pca.Torch; + +/// +/// Defines an interface for PCA model providers. +/// +public interface IPcaModelProvider +{ + /// + /// Gets or sets the PCA model. + /// + public IPcaBaseModel? Model { get; set; } +} diff --git a/src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs b/src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs new file mode 100644 index 00000000..2240701a --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs @@ -0,0 +1,99 @@ +using static TorchSharp.torch; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Implements streaming/online PCA using the Generalized Hebbian Algorithm (GHA). +/// +/// +/// +/// +/// +/// +public class OnlinePcaGha( + int numComponents, + double learningRate = 0.1, + Device? device = null, + ScalarType? scalarType = null, + Generator? generator = null +) : PcaBaseModel(numComponents, device, scalarType) +{ + /// + /// Gets the number of samples that have been used to fit the model. + /// + public int SampleCount { get; private set; } = 0; + + /// + /// Gets the mean of the fitted data. + /// + public Tensor Mean { get; private set; } = empty(0); + + /// + /// Gets or sets the learning rate. + /// + public double LearningRate { get; set; } = learningRate; + + /// + public override Tensor Components { get; protected set; } = empty(0); + + /// + /// Gets the random number generator used for initializing the model. + /// + public Generator? Generator { get; private set; } = generator; + + /// + public override void Fit(Tensor data) + { + base.Fit(data); + + var numSamples = data.size(0); + + using (no_grad()) + using (NewDisposeScope()) + { + // Initialize components randomly + if (Components.numel() == 0) + Components = randn([NumFeatures, NumComponents], dtype: ScalarType, device: Device, generator: Generator); + + if (Mean.numel() == 0) + Mean = data.mean([0], keepdim: true); + else + { + Mean *= SampleCount / (SampleCount + numSamples); + Mean += data.mean([0], keepdim: true) * numSamples / (SampleCount + numSamples); + } + + SampleCount += (int)numSamples; + var dataCentered = data - Mean; + + var projection = dataCentered.matmul(Components); + var hebbianTerm = dataCentered.T.matmul(projection); + var crossTerm = projection.T.matmul(projection); + var lowerTriangular = crossTerm.tril(0); + var correlation = Components.matmul(lowerTriangular); + var componentsUpdate = (hebbianTerm - correlation) * (LearningRate / numSamples); + var weights = Components + componentsUpdate; + var norms = weights.norm(dim: 0, keepdim: true, p: 2).clamp_min(1e-12); + + Components = linalg.qr(weights / norms, mode: linalg.QRMode.Reduced).Q.MoveToOuterDisposeScope(); + Mean = Mean.MoveToOuterDisposeScope(); + } + + IsFitted = true; + } + + /// + public override Tensor Transform(Tensor data) + { + base.Transform(data); + var dataCentered = data - Mean; + return dataCentered.matmul(Components); + } + + /// + public override Tensor Reconstruct(Tensor data) + { + base.Reconstruct(data); + return data.matmul(Components.T) + Mean; + } +} diff --git a/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs b/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs new file mode 100644 index 00000000..b5e251bb --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs @@ -0,0 +1,271 @@ +using System; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Implements an online probabilistic PCA model using the stochastic online EM algorithm. +/// +public class OnlineProbabilisticPca : PcaBaseModel +{ + private Tensor _identityComponents = empty(0); + private Tensor _mx = empty(0); // E[x] + private Tensor _Cxz = empty(0); // E[xz^T] + private Tensor _mz = empty(0); // E[z] + private Tensor _Czz = empty(0); // E[zz^T] + private Tensor _sxx = empty(0); // E[||x||^2] + private readonly Func UpdateSchedule; + private int _stepCount = 0; + private readonly bool _reorthogonalize = false; + + /// + /// Rho is a constant learning rate parameter. + /// + /// + /// Rho must be in the range (0, 1). Only one of Rho or Kappa should be specified. + /// + public double? Rho { get; private set; } + + /// + /// Kappa is the exponent in the learning rate schedule. + /// + /// + /// Kappa must be in the range (0.5, 1]. Only one + /// of Rho or Kappa should be specified. + /// + public double? Kappa { get; private set; } + + /// + /// Gets the mean of the fitted data. + /// + public Tensor Means { get; private set; } = empty(0); + + /// + /// Gets the variance of the isotropic Gaussian noise model. + /// + public double Variance { get; private set; } + + /// + /// Gets the period for reorthogonalizing the principal components. + /// + /// + /// Represented as the number of update steps between reorthogonalization operations. + /// If not specified, reorthogonalization is not performed. + /// + public int ReorthogonalizePeriod { get; private set; } + + /// + /// Gets the sample offset used in the learning rate schedule when Kappa is specified. + /// + public int SampleOffset { get; private set; } + + /// + public override Tensor Components { get; protected set; } = empty(0); + + /// + /// Gets the random number generator used for initializing the model. + /// + public Generator? Generator { get; private set; } + + /// + /// Initializes a new instance of the class. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public OnlineProbabilisticPca(int numComponents, + Device? device = null, + ScalarType? scalarType = null, + double initialVariance = 1.0, + Generator? generator = null, + double? rho = 0.1, + double? kappa = null, + int? sampleOffset = null, + int? reorthogonalizePeriod = null + ) : base(numComponents, + device, + scalarType) + { + if (initialVariance <= 0) + { + throw new ArgumentException("Starting variance must be greater than or equal to zero.", nameof(initialVariance)); + } + + if (kappa.HasValue && rho.HasValue) + { + throw new ArgumentException("Only one of rho or kappa should be specified, not both.", nameof(rho)); + } + + if (!kappa.HasValue && !rho.HasValue) + { + throw new ArgumentException("Either rho or kappa must be specified.", nameof(rho)); + } + + if (rho.HasValue) + { + if (rho.Value <= 0 || rho.Value >= 1) + { + throw new ArgumentException("Rho must be in the range (0, 1).", nameof(rho)); + } + + UpdateSchedule = () => rho.Value; + } + else + { + sampleOffset ??= 0; + if (sampleOffset < 0) + { + throw new ArgumentException("Sample offset must be a positive integer.", nameof(sampleOffset)); + } + + if (!kappa.HasValue) + { + throw new ArgumentException("Kappa must be specified when using a learning rate schedule.", nameof(kappa)); + } + + if (kappa <= 0.5 || kappa > 1) + { + throw new ArgumentException("Kappa must be in the range (0.5, 1].", nameof(kappa)); + } + + UpdateSchedule = () => Math.Pow(_stepCount + SampleOffset, -kappa.Value); + } + + if (reorthogonalizePeriod.HasValue) + { + _reorthogonalize = true; + ReorthogonalizePeriod = reorthogonalizePeriod.Value; + } + + Generator = generator; + Rho = rho; + Kappa = kappa; + SampleOffset = sampleOffset ?? 0; + Variance = initialVariance; + } + + /// + public override void Fit(Tensor data) + { + base.Fit(data); + + using (no_grad()) + using (NewDisposeScope()) + { + + _stepCount++; + var rho = UpdateSchedule(); + + // Initialize dimensions + var numSamples = data.size(0); + + // Initialize parameters + if (Means.numel() == 0) + { + Means = zeros(NumFeatures, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); + var weights = qr(randn(NumFeatures, NumComponents, generator: Generator, device: Device, dtype: ScalarType), mode: QRMode.Reduced).Q; + Components = (weights * Variance).MoveToOuterDisposeScope(); + _identityComponents = eye(NumComponents, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); + _mx = zeros(NumFeatures, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // d + _Cxz = zeros(NumFeatures, NumComponents, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // d x q + _mz = zeros(NumComponents, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q + _Czz = zeros(NumComponents, NumComponents, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q x q + _sxx = zeros(1, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // scalar + } + + // Covariance matrix + var cov = _identityComponents * Variance; + + // Center data using current mean + var dataCentered = data - Means; + + // E-step + var M = Components.T.matmul(Components) + cov; + var MInv = Utils.InvertSPD(M, _identityComponents); + var projection = dataCentered.matmul(Components); + var EzT = Utils.InvertSPD(M, projection.T); + var Ez = EzT.T; + + // Update statistics + var mx = data.mean([0]); + var sxx = data.pow(2).sum(dim: 1).mean(); + var Cxz = data.T.matmul(Ez) / numSamples; + var mz = Ez.mean([0]); + var Czz = EzT.matmul(Ez) / numSamples + Variance * MInv; + + // Update parameters + var rhoFactor = 1 - rho; + _mx = (rhoFactor * _mx + rho * mx).MoveToOuterDisposeScope(); + _Cxz = (rhoFactor * _Cxz + rho * Cxz).MoveToOuterDisposeScope(); + _mz = (rhoFactor * _mz + rho * mz).MoveToOuterDisposeScope(); + _sxx = (rhoFactor * _sxx + rho * sxx).MoveToOuterDisposeScope(); + _Czz = (rhoFactor * _Czz + rho * Czz).MoveToOuterDisposeScope(); + + // Update mean + Means = _mx.MoveToOuterDisposeScope(); + + // Centered statistics + var Sxz = _Cxz - outer(Means, _mz); + var Szz = _Czz; + var Sxx = _sxx - Means.dot(Means); + + // M-step + var weightsUpdated = Utils.InvertSPD(Szz, Sxz.T).T; + + if (_reorthogonalize && + _stepCount % ReorthogonalizePeriod == 0) + { + var (U, S, Vh) = svd(weightsUpdated, fullMatrices: false); + var R = Vh.T; + weightsUpdated = U.matmul(diag(S)); + _Cxz = _Cxz.matmul(R.T); + _Czz = R.matmul(_Czz).matmul(R.T); + _mz = R.matmul(_mz); + } + + // Reorder components based on the strength of the components + var strength = sum(weightsUpdated * weightsUpdated, dim: 0); + var indices = argsort(strength, descending: true); + Components = weightsUpdated.index_select(1, indices).MoveToOuterDisposeScope(); + _Cxz = _Cxz.index_select(1, indices).MoveToOuterDisposeScope(); + _mz = _mz.index_select(0, indices).MoveToOuterDisposeScope(); + _Czz = _Czz.index_select(0, indices).index_select(1, indices).MoveToOuterDisposeScope(); + + Sxz = _Cxz - outer(Means, _mz); + Szz = _Czz; + + // Update variance + Variance = ((Sxx - 2 * trace(Components.T.matmul(Sxz)) + trace(Components.T.matmul(Components).matmul(Szz))) / (double)NumFeatures) + .clamp_min(0.0) + .to_type(TorchSharp.torch.ScalarType.Float64) + .item(); + } + + IsFitted = true; + } + + /// + public override Tensor Transform(Tensor data) + { + base.Transform(data); + var dataCentered = data - Means; + var M = Components.T.matmul(Components) + _identityComponents * Variance; + var projection = dataCentered.matmul(Components); + return Utils.InvertSPD(M, projection.T).T; + } + + /// + public override Tensor Reconstruct(Tensor data) + { + base.Reconstruct(data); + return data.matmul(Components.T) + Means; + } +} diff --git a/src/Bonsai.ML.Pca.Torch/Pca.cs b/src/Bonsai.ML.Pca.Torch/Pca.cs new file mode 100644 index 00000000..a33c8ce4 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/Pca.cs @@ -0,0 +1,73 @@ +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Represents a standard Principal Component Analysis (PCA) model. +/// +public class Pca(int numComponents, + Device? device = null, + ScalarType? scalarType = null +) : PcaBaseModel(numComponents, + device, + scalarType) +{ + /// + /// Gets the mean of the fitted data. + /// + public Tensor Mean { get; private set; } = empty(0); + + /// + public override Tensor Components { get; protected set; } = empty(0); + + /// + /// The singular values of the fitted data. + /// + public Tensor SingularValues { get; private set; } = empty(0); + + /// + public override void Fit(Tensor data) + { + base.Fit(data); + + using (no_grad()) + using (NewDisposeScope()) + { + var mean = data.mean([0], keepdim: true); + var dataCentered = data - mean; + var (U, S, Vh) = svd(dataCentered, fullMatrices: false); + var components = Vh.slice(0, 0, NumComponents, 1).T; + var singularValues = S.slice(0, 0, NumComponents, 1); + + if (ScalarType is not null && ScalarType != data.dtype) + { + var scalarType = ScalarType.Value; + mean = mean.to(scalarType); + components = components.to(scalarType); + singularValues = singularValues.to(scalarType); + } + + Mean = mean.MoveToOuterDisposeScope(); + Components = components.MoveToOuterDisposeScope(); + SingularValues = singularValues.MoveToOuterDisposeScope(); + } + + IsFitted = true; + } + + /// + public override Tensor Transform(Tensor data) + { + base.Transform(data); + var dataCentered = data - Mean; + return dataCentered.matmul(Components); + } + + /// + public override Tensor Reconstruct(Tensor data) + { + base.Reconstruct(data); + return data.matmul(Components.T) + Mean; + } +} diff --git a/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs b/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs new file mode 100644 index 00000000..99ba852a --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs @@ -0,0 +1,114 @@ +using System; +using System.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Provides an abstract base class for PCA models. +/// +public abstract class PcaBaseModel : IPcaBaseModel +{ + /// + public bool IsFitted { get; protected set; } + + /// + public int NumFeatures { get; protected set; } = -1; + + /// + public abstract Tensor Components { get; protected set; } + + /// + public int NumComponents { get; private set; } + + /// + public Device Device { get; private set; } + + /// + public ScalarType? ScalarType { get; private set; } + + /// + /// Initializes a new instance of the class. + /// + /// + /// + /// + /// + public PcaBaseModel(int numComponents, + Device? device = null, + ScalarType? scalarType = null) + { + if (numComponents <= 0) + { + throw new ArgumentException("Number of components must be greater than zero.", nameof(numComponents)); + } + + NumComponents = numComponents; + Device = device ?? CPU; + ScalarType = scalarType; + } + + /// + public virtual void Fit(Tensor data) + { + CheckDataCompatibility(data); + + var d = data.size(1); + + if (NumComponents > d) + throw new ArgumentException($"Number of components cannot be greater than the number of features. Number of components: {NumComponents}, number of features: {d}.", nameof(data)); + + NumFeatures = (int)d; + } + + /// + public virtual Tensor Transform(Tensor data) + { + CheckFitted(); + CheckDataCompatibility(data); + CheckDataFeatures(data); + return data; + } + + /// + public virtual Tensor FitAndTransform(Tensor data) + { + Fit(data); + return Transform(data); + } + + /// + public virtual Tensor Reconstruct(Tensor data) + { + CheckFitted(); + CheckDataCompatibility(data); + CheckDataFeatures(data); + return data; + } + + private void CheckFitted() + { + if (!IsFitted) + throw new InvalidOperationException("Model has not yet been fitted. You should call one of the Fit() or the FitAndTransform() methods first."); + } + + private void CheckDataCompatibility(Tensor data) + { + if (data.NumberOfElements == 0) + throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data)); + + if (data.dim() != 2) + { + var shapeStr = string.Join(",", data.shape.Select(x => x.ToString()).ToArray()); + throw new ArgumentException($"Data must be a 2D tensor with shape (samples x features). Data shape: {shapeStr}.", nameof(data)); + } + } + + private void CheckDataFeatures(Tensor data) + { + var d = data.size(1); + + if (d != NumFeatures) + throw new ArgumentException("The number of features in the data does not match the number of features in the fitted model.", nameof(data)); + } +} diff --git a/src/Bonsai.ML.Pca.Torch/PcaDesciptionProvider.cs b/src/Bonsai.ML.Pca.Torch/PcaDesciptionProvider.cs new file mode 100644 index 00000000..39b50eb0 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/PcaDesciptionProvider.cs @@ -0,0 +1,26 @@ +using System; +using System.ComponentModel; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Provides a custom type description provider for PCA models. +/// +/// +/// Initializes a new instance of the class. +/// +/// +public class PcaDescriptionProvider(TypeDescriptionProvider baseProvider) : TypeDescriptionProvider(baseProvider) +{ + /// + /// Initializes a new instance of the class. + /// + public PcaDescriptionProvider() : this(TypeDescriptor.GetProvider(typeof(object))) { } + + /// + public override ICustomTypeDescriptor GetTypeDescriptor(Type objectType, object instance) + { + var defaultDescriptor = base.GetTypeDescriptor(objectType, instance); + return new PcaDescriptor(defaultDescriptor, instance); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs b/src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs new file mode 100644 index 00000000..307c4d8f --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs @@ -0,0 +1,44 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Collections.Generic; +using System.Linq; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Provides a custom type descriptor for PCA model creation. +/// +/// +/// Initializes a new instance of the class. +/// +/// +/// +public class PcaDescriptor(ICustomTypeDescriptor parent, object instance) : CustomTypeDescriptor(parent) +{ + private readonly object _instance = instance; + + /// + public override PropertyDescriptorCollection GetProperties(Attribute[] attributes) + { + var allProperties = base.GetProperties(attributes); + + if (_instance is CreatePca createPca) + { + var modelProperties = new HashSet(createPca.GetModelProperties()); + var filtered = allProperties.Cast() + .Where(p => modelProperties.Contains(p.Name)) + .ToArray(); + return new PropertyDescriptorCollection(filtered); + } + + return allProperties; + } + + /// + public override PropertyDescriptorCollection GetProperties() + => GetProperties([]); + + /// + public override object GetPropertyOwner(PropertyDescriptor pd) => _instance; +} diff --git a/src/Bonsai.ML.Pca.Torch/PcaModelBuilder.cs b/src/Bonsai.ML.Pca.Torch/PcaModelBuilder.cs new file mode 100644 index 00000000..8ef69b25 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/PcaModelBuilder.cs @@ -0,0 +1,94 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Xml.Serialization; +using Bonsai.Expressions; +using static TorchSharp.torch; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Represents an abstract PCA model builder. +/// +public abstract class PcaModelBuilder(IPcaModelProvider operatorInstance) : SingleArgumentExpressionBuilder, ICustomTypeDescriptor +{ + private readonly IPcaModelProvider _operator = operatorInstance; + + /// + /// The PCA model. + /// + [XmlIgnore] + [Description("The PCA model.")] + public IPcaBaseModel? Model + { + get => _operator.Model; + set => _operator.Model = value; + } + + /// + /// Determines whether the model is available in the input expression (and thus Model property should not be exposed) or if it should be explicitly set as a property of the operator. + /// + [Browsable(false)] + [XmlIgnore] + [RefreshProperties(RefreshProperties.All)] + public bool HasModel { get; private set; } = true; + + /// + public override Expression Build(IEnumerable arguments) + { + var input = arguments?.FirstOrDefault(); + MethodInfo processMethod; + if (input is null) + { + HasModel = true; + processMethod = typeof(T).GetMethod("Process", [typeof(IObservable)]); + return Expression.Call(Expression.Constant(_operator), processMethod!, [input]); + } + + var obsType = input.Type; + var t = obsType.GetGenericArguments().Single(); + + HasModel = !(t.IsGenericType && t.FullName?.StartsWith("System.Tuple`") == true); + + if (!HasModel) + { + var args = t.GetGenericArguments(); + if (args.Length != 2 || (!typeof(IPcaBaseModel).IsAssignableFrom(args[0]) && !typeof(IPcaBaseModel).IsAssignableFrom(args[1]))) + throw new InvalidOperationException("The input type is not valid. Expected an observable sequence of tuples containing a PCA model and a tensor."); + } + + processMethod = HasModel + ? typeof(T).GetMethod("Process", [typeof(IObservable)]) + : typeof(T).GetMethod("Process", [typeof(IObservable<>).MakeGenericType(t)]); + + return Expression.Call(Expression.Constant(_operator), processMethod!, [input]); + } + + + PropertyDescriptorCollection ICustomTypeDescriptor.GetProperties(Attribute[]? attributes) + { + var props = TypeDescriptor.GetProperties(this, attributes, true); + if (HasModel) return props; + + var filtered = props.Cast() + .Where(p => p.Name != nameof(Model)) + .ToArray(); + + return new PropertyDescriptorCollection(filtered); + } + + AttributeCollection ICustomTypeDescriptor.GetAttributes() => TypeDescriptor.GetAttributes(this, true); + string? ICustomTypeDescriptor.GetClassName() => TypeDescriptor.GetClassName(this, true); + string? ICustomTypeDescriptor.GetComponentName() => TypeDescriptor.GetComponentName(this, true); + TypeConverter ICustomTypeDescriptor.GetConverter() => TypeDescriptor.GetConverter(this, true); + EventDescriptor? ICustomTypeDescriptor.GetDefaultEvent() => TypeDescriptor.GetDefaultEvent(this, true); + PropertyDescriptor? ICustomTypeDescriptor.GetDefaultProperty() => TypeDescriptor.GetDefaultProperty(this, true); + object? ICustomTypeDescriptor.GetEditor(Type editorBaseType) => TypeDescriptor.GetEditor(this, editorBaseType, true); + EventDescriptorCollection ICustomTypeDescriptor.GetEvents(Attribute[]? attributes) => TypeDescriptor.GetEvents(this, attributes, true); + EventDescriptorCollection ICustomTypeDescriptor.GetEvents() => TypeDescriptor.GetEvents(this, true); + PropertyDescriptorCollection ICustomTypeDescriptor.GetProperties() => ((ICustomTypeDescriptor)this).GetProperties(Array.Empty()); + object ICustomTypeDescriptor.GetPropertyOwner(PropertyDescriptor? pd) => this; +} diff --git a/src/Bonsai.ML.Pca.Torch/PcaModelType.cs b/src/Bonsai.ML.Pca.Torch/PcaModelType.cs new file mode 100644 index 00000000..d8194720 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/PcaModelType.cs @@ -0,0 +1,27 @@ +namespace Bonsai.ML.Pca.Torch; + +/// +/// Specifies the type of PCA model. +/// +public enum PcaModelType +{ + /// + /// Standard PCA model. + /// + Pca, + + /// + /// Probabilistic PCA model. + /// + ProbabilisticPca, + + /// + /// Online Probabilistic PCA model. + /// + OnlineProbabilisticPca, + + /// + /// Online PCA model using the Generalized Hebbian Algorithm. + /// + OnlinePcaGha +} diff --git a/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs b/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs new file mode 100644 index 00000000..843298f2 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs @@ -0,0 +1,181 @@ +using System; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Represents a probabilistic PCA model. +/// +public class ProbabilisticPca : PcaBaseModel +{ + private readonly int _iterations; + private readonly double _tolerance; + + /// + /// Gets the mean of the fitted data. + /// + public Tensor Mean { get; private set; } = empty(0); + + /// + /// Gets the variance of the isotropic Gaussian noise model. + /// + public double Variance { get; private set; } + + /// + /// Gets the log likelihood of the fitted model. + /// + public Tensor LogLikelihood { get; private set; } = empty(0); + + /// + public override Tensor Components { get; protected set; } = empty(0); + + /// + /// Gets the random number generator used for initializing the model. + /// + public Generator? Generator { get; private set; } + + /// + /// Initializes a new instance of the class. + /// + /// + /// + /// + /// + /// + /// + /// + /// + public ProbabilisticPca(int numComponents, + Device? device = null, + ScalarType? scalarType = null, + double initialVariance = 1.0, + Generator? generator = null, + int iterations = 100, + double tolerance = 1e-5 + ) : base(numComponents, + device, + scalarType) + { + if (initialVariance < 0) + { + throw new ArgumentException("Starting variance must be greater than or equal to zero.", nameof(initialVariance)); + } + + if (iterations <= 0) + { + throw new ArgumentException("Number of iterations must be greater than zero.", nameof(iterations)); + } + + if (tolerance <= 0) + { + throw new ArgumentException("Tolerance must be greater than zero.", nameof(tolerance)); + } + + Variance = initialVariance; + Generator = generator; + _iterations = iterations; + _tolerance = tolerance; + } + + /// + public override void Fit(Tensor data) + { + base.Fit(data); + + using (no_grad()) + using (NewDisposeScope()) + { + var numSamples = data.size(0); + + // Initialize log likelihood + LogLikelihood = ones(_iterations, device: Device, dtype: ScalarType) * double.NegativeInfinity; + + var weights = randn(NumFeatures, NumComponents, generator: Generator, device: Device, dtype: ScalarType); + var identityComponents = eye(NumComponents, device: Device, dtype: ScalarType); + var identityFeatures = eye(NumFeatures, device: Device, dtype: ScalarType); + + var mean = data.mean([0], keepdim: true); + var dataCentered = data - mean; + + // Calculate the sample covariance + var covarianceTerm = dataCentered.T.matmul(dataCentered); + var sampleCov = covarianceTerm / numSamples; + + // Calculate term 1 for variance update + var term1 = trace(covarianceTerm); + + // Compute log likelihood constant + var logLikelihoodConst = NumFeatures * log(2 * Math.PI).to(Device); + + double diffWeights; + double diffVariance; + + // Repeat until convergence + for (int i = 0; i < _iterations; i++) + { + // E-step: Compute the posterior distribution of the latent variables + var M = weights.T.matmul(weights) + identityComponents * Variance; + var MInv = inv(M); + var mu = MInv.matmul(weights.T).matmul(dataCentered.T).T; + var SSum = numSamples * MInv * Variance; + var cov = mu.T.matmul(mu) + SSum; + + // M-step: Compute new weights and new variance + var dataMu = dataCentered.T.matmul(mu); + var weightsNew = dataMu.matmul(inv(cov)); + + var term2 = 2 * dataMu.mul(weightsNew).sum(); + var mu2 = mu.T.matmul(mu); + var weightsNew2 = weightsNew.T.matmul(weightsNew); + var term3 = trace(weightsNew2.matmul(mu2 + SSum)); + var varianceNew = (term1 - term2 + term3) / (numSamples * NumFeatures); + + // Compute the log likelihood + var logLikelihoodTerm = weightsNew.matmul(weightsNew.T) + eye(NumFeatures) * varianceNew; + var logLikelihoodTermInv = inv(logLikelihoodTerm); + var logLikelihood = -0.5 * numSamples * (logLikelihoodConst + logdet(logLikelihoodTerm) + trace(logLikelihoodTermInv.matmul(sampleCov))); + + // Compare previous and new parameters for convergence + diffWeights = linalg.norm(weightsNew - weights).to_type(TorchSharp.torch.ScalarType.Float64).item(); + diffVariance = abs(varianceNew - Variance).to_type(TorchSharp.torch.ScalarType.Float64).item(); + + // Update loglikelihood, weights and variance + LogLikelihood[i] = logLikelihood; + weights = weightsNew; + Variance = varianceNew.to_type(TorchSharp.torch.ScalarType.Float64).item(); + + // Check for convergence + if (diffWeights < _tolerance && diffVariance < _tolerance) + { + LogLikelihood = LogLikelihood.slice(0, 0, i + 1, 1); + break; + } + } + + // Finalize model parameters + LogLikelihood = LogLikelihood.MoveToOuterDisposeScope(); + Components = weights.MoveToOuterDisposeScope(); + Mean = mean.MoveToOuterDisposeScope(); + } + + IsFitted = true; + } + + /// + public override Tensor Transform(Tensor data) + { + base.Transform(data); + var dataCentered = data - Mean; + var M = Components.T.matmul(Components) + eye(NumComponents) * Variance; + var MInv = Utils.InvertSPD(M, eye(NumComponents)); + return dataCentered.matmul(Components).matmul(MInv); + } + + /// + public override Tensor Reconstruct(Tensor data) + { + base.Reconstruct(data); + return data.matmul(Components.T) + Mean; + } +} diff --git a/src/Bonsai.ML.Pca.Torch/Properties/launchSettings.json b/src/Bonsai.ML.Pca.Torch/Properties/launchSettings.json new file mode 100644 index 00000000..4af4f468 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/Properties/launchSettings.json @@ -0,0 +1,10 @@ +{ + "profiles": { + "Bonsai": { + "commandName": "Executable", + "executablePath": "$(BonsaiExecutablePath)", + "commandLineArgs": "--lib:\"$(TargetDir).\"", + "nativeDebugging": true + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Pca.Torch/Reconstruct.cs b/src/Bonsai.ML.Pca.Torch/Reconstruct.cs new file mode 100644 index 00000000..2701da28 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/Reconstruct.cs @@ -0,0 +1,142 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Reconstructs the input data using a PCA model. +/// +public class Reconstruct : IPcaModelProvider +{ + /// + public IPcaBaseModel? Model { get; set; } + + private static Tensor ReconstructData(IPcaBaseModel model, Tensor data) + { + return model.Reconstruct(data); + } + + /// + /// Reconstructs the input data using the specified PCA model. + /// + /// + /// + public IObservable Process(IObservable source) + { + if (Model == null) + { + throw new InvalidOperationException("The PCA model has not been specified."); + } + return source.Select(value => + { + return ReconstructData(Model, value); + }); + } + + /// + /// Reconstructs the input data using a standard PCA model. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item1, value.Item2); + }); + } + + /// + /// Reconstructs the input data using a standard PCA model. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item2, value.Item1); + }); + } + + /// + /// Reconstructs the input data using a probabilistic PCA model. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item1, value.Item2); + }); + } + + /// + /// Reconstructs the input data using a probabilistic PCA model. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item2, value.Item1); + }); + } + + /// + /// Reconstructs the input data using an online probabilistic PCA model based on stochastic online EM. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item1, value.Item2); + }); + } + + /// + /// Reconstructs the input data using an online probabilistic PCA model based on stochastic online EM. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item2, value.Item1); + }); + } + + /// + /// Reconstructs the input data using an online PCA model based on the Generalized Hebbian Algorithm. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item1, value.Item2); + }); + } + + /// + /// Reconstructs the input data using an online PCA model based on the Generalized Hebbian Algorithm + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return ReconstructData(value.Item2, value.Item1); + }); + } +} diff --git a/src/Bonsai.ML.Pca.Torch/ReconstructBuilder.cs b/src/Bonsai.ML.Pca.Torch/ReconstructBuilder.cs new file mode 100644 index 00000000..4938864f --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/ReconstructBuilder.cs @@ -0,0 +1,12 @@ +using System.ComponentModel; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Represents an operator that reconstructs the input data using a PCA model. +/// +[ResetCombinator] +[Combinator] +[Description("Reconstructs the input data using a PCA model.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class ReconstructBuilder() : PcaModelBuilder(new Reconstruct()) { } diff --git a/src/Bonsai.ML.Pca.Torch/Transform.cs b/src/Bonsai.ML.Pca.Torch/Transform.cs new file mode 100644 index 00000000..016b1346 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/Transform.cs @@ -0,0 +1,140 @@ +using System; +using System.Reactive.Linq; +using static TorchSharp.torch; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Transforms the input data using a PCA model. +/// +public class Transform : IPcaModelProvider +{ + /// + public IPcaBaseModel? Model { get; set; } = null; + + private static Tensor TransformData(IPcaBaseModel model, Tensor data) + { + return model.Transform(data); + } + + /// + /// Transforms the input data. + /// + /// + /// + public IObservable Process(IObservable source) + { + if (Model == null) + { + throw new InvalidOperationException("The PCA model has not been specified."); + } + return source.Select(value => + { + return TransformData(Model, value); + }); + } + + /// + /// Transforms the input data using a standard PCA model. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item1, value.Item2); + }); + } + + /// + /// Transforms the input data using a standard PCA model. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item2, value.Item1); + }); + } + + /// + /// Transforms the input data using a standard PCA model. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item1, value.Item2); + }); + } + + /// + /// Transforms the input data using a probabilistic PCA model. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item2, value.Item1); + }); + } + + /// + /// Transforms the input data using an online probabilistic PCA model. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item1, value.Item2); + }); + } + + /// + /// Transforms the input data using an online probabilistic PCA model. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item2, value.Item1); + }); + } + + /// + /// Transforms the input data using an online PCA model based on the Generalized Hebbian Algorithm. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item1, value.Item2); + }); + } + + /// + /// Transforms the input data using an online PCA model based on the Generalized Hebbian Algorithm. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return TransformData(value.Item2, value.Item1); + }); + } +} diff --git a/src/Bonsai.ML.Pca.Torch/TransformBuilder.cs b/src/Bonsai.ML.Pca.Torch/TransformBuilder.cs new file mode 100644 index 00000000..ee4ca8ed --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/TransformBuilder.cs @@ -0,0 +1,12 @@ +using System.ComponentModel; + +namespace Bonsai.ML.Pca.Torch; + +/// +/// Represents an operator that transforms the input data using a PCA model. +/// +[ResetCombinator] +[Combinator] +[Description("Transforms the input data using a PCA model.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class TransformBuilder() : PcaModelBuilder(new Transform()) { } diff --git a/src/Bonsai.ML.Pca.Torch/Utils.cs b/src/Bonsai.ML.Pca.Torch/Utils.cs new file mode 100644 index 00000000..b89b73e5 --- /dev/null +++ b/src/Bonsai.ML.Pca.Torch/Utils.cs @@ -0,0 +1,29 @@ +using System; +using static TorchSharp.torch; + +namespace Bonsai.ML.Pca.Torch; + +internal static class Utils +{ + internal static Tensor InvertSPD( + Tensor spdMatrix, + Tensor rhs, + double regularization = 1e-6, + Device? device = null, + ScalarType? scalarType = null + ) + { + var diagShape = spdMatrix.size(-1); + Tensor L; + try + { + L = linalg.cholesky(spdMatrix); + } + catch (Exception) + { + var regularizer = eye(diagShape, device: device, dtype: scalarType) * regularization; + L = linalg.cholesky(spdMatrix + regularizer); + } + return cholesky_solve(rhs, L); + } +} diff --git a/tests/Bonsai.ML.Pca.Torch.Tests/Bonsai.ML.Pca.Torch.Tests.csproj b/tests/Bonsai.ML.Pca.Torch.Tests/Bonsai.ML.Pca.Torch.Tests.csproj new file mode 100644 index 00000000..e480f9b3 --- /dev/null +++ b/tests/Bonsai.ML.Pca.Torch.Tests/Bonsai.ML.Pca.Torch.Tests.csproj @@ -0,0 +1,22 @@ + + + net8.0 + enable + enable + false + true + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/Bonsai.ML.Pca.Torch.Tests/StandardPcaTests.cs b/tests/Bonsai.ML.Pca.Torch.Tests/StandardPcaTests.cs new file mode 100644 index 00000000..3a323bc1 --- /dev/null +++ b/tests/Bonsai.ML.Pca.Torch.Tests/StandardPcaTests.cs @@ -0,0 +1,130 @@ +using System.Diagnostics; +using Bonsai.ML.Pca.Torch; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Pca.Torch.Tests; + +/// +/// Tests for the Bonsai.ML.Pca.Torch package. +/// +[TestClass] +public class StandardPcaTests +{ + public StandardPcaTests() + { + manual_seed(0); + set_printoptions(style: TorchSharp.TensorStringStyle.Numpy); + } + + private static readonly Tensor _expectedFirstComponent = tensor(new float[] { 1f, 0f }); + + private static Tensor Generate2dRotationMatrix(double angleDegrees) + { + var angleRad = angleDegrees * Math.PI / 180.0; + var cosA = Math.Cos(angleRad); + var sinA = Math.Sin(angleRad); + return tensor(new float[,] + { + { (float)cosA, (float)-sinA }, + { (float)sinA, (float)cosA } + }); + } + + private static Tensor Generate2dDataset(int numSamples, double scaleX = 3.0, double scaleY = 0.1, double offsetX = 0.0, double offsetY = 0.0, double rotationAngle = 0.0) + { + var x = randn(numSamples) * scaleX + offsetX; + var y = randn(numSamples) * scaleY + offsetY; + var data = stack([x, y], 1); + if (rotationAngle == 0) + return data; + var rotationMatrix = Generate2dRotationMatrix(rotationAngle); + return data.matmul(rotationMatrix); + } + + private static float Similarity(Tensor a, Tensor b) + { + if (a.dim() != 1 || b.dim() != 1) + throw new ArgumentException("Input tensors must be 1-dimensional. Instead got dimension: " + a.dim()); + if (a.size(0) != b.size(0)) + throw new ArgumentException($"Input tensors must have the same shape. Instead got {string.Join(", ", a.shape)} and {string.Join(", ", b.shape)}."); + var aNorm = a.norm(-1, keepdim: true); + var bNorm = b.norm(-1, keepdim: true); + var dotProduct = a.dot(b); + return abs(dotProduct / (aNorm * bNorm)).item(); + } + + private static void TestBasic(PcaBaseModel model) + { + // Generate a simple dataset. + var data = Generate2dDataset(1000); + Debug.WriteLine($"Data shape: {string.Join(", ", data.shape)}"); + + // Fit the model. + model.Fit(data); + + // Verify components. + Debug.WriteLine($"Components: {model.Components.str()}"); + Assert.IsTrue(model.Components.shape[0] == 2 && model.Components.shape[1] == 2); + + // Verify similarity with expected first component. + var similarity = Similarity(model.Components[0], _expectedFirstComponent); + Debug.WriteLine($"Similarity with expected first component: {similarity}"); + Assert.IsTrue(similarity > 0.99); + + // Compare reconstructed data. + var transformed = model.Transform(data); + var reconstructed = model.Reconstruct(transformed); + + var reconstructionError = mean((data - reconstructed).pow(2)).item(); + Debug.WriteLine($"Reconstruction error: {reconstructionError}"); + Assert.IsTrue(reconstructionError < 1e-10); + } + + private static void TestRotation(PcaBaseModel model) + { + // Compare rotated dataset + var data = Generate2dDataset(1000, rotationAngle: 30.0); + model.Fit(data); + + var rotationMatrix = Generate2dRotationMatrix(30.0); + var rotatedExpectedFirstComponent = _expectedFirstComponent.matmul(rotationMatrix); + Debug.WriteLine($"Rotated expected first component: {rotatedExpectedFirstComponent.str()}"); + var rotatedSimilarity = Similarity(model.Components[0], rotatedExpectedFirstComponent); + Debug.WriteLine($"Similarity with expected first component (rotated data): {rotatedSimilarity}"); + Assert.IsTrue(rotatedSimilarity > 0.99); + } + + private static void TestOffset(PcaBaseModel model) + { + // Test offset centering + var dataOffset = Generate2dDataset(1000, offsetX: 5.0, offsetY: -3.0); + model.Fit(dataOffset); + + var offsetExpectedFirstComponent = _expectedFirstComponent; + var offsetSimilarity = Similarity(model.Components[0], offsetExpectedFirstComponent); + Debug.WriteLine($"Similarity with expected first component (offset data): {offsetSimilarity}"); + Assert.IsTrue(offsetSimilarity > 0.99); + + var transformed = model.Transform(dataOffset); + var reconstructed = model.Reconstruct(transformed); + + var reconstructedMeans = reconstructed.mean([0]); + Debug.WriteLine($"Reconstructed means (offset data): {reconstructedMeans.str()}"); + Assert.IsTrue(abs(reconstructedMeans[0] - 5.0).item() < 0.1); + Assert.IsTrue(abs(reconstructedMeans[1] + 3.0).item() < 0.1); + } + + [TestMethod] + public void TestStandardPca() + { + var pca = new Pca(numComponents: 2); + TestBasic(pca); + + pca = new Pca(numComponents: 2); + TestRotation(pca); + + pca = new Pca(numComponents: 2); + TestOffset(pca); + } +} diff --git a/tests/Bonsai.ML.Pca.Torch.Tests/TransformPCATest.bonsai b/tests/Bonsai.ML.Pca.Torch.Tests/TransformPCATest.bonsai new file mode 100644 index 00000000..e6fa55b9 --- /dev/null +++ b/tests/Bonsai.ML.Pca.Torch.Tests/TransformPCATest.bonsai @@ -0,0 +1,889 @@ + + + + + + LoadTrainingData + + + + + + + + Z:/home/nicholas/Downloads/xfads-2/workflows/data/y_train.bin + 0 + 0 + 50 + 1 + F64 + RowMajor + + + + + + 1000 + + + + + + + + + + + + + + Float32 + + + + + + + + + + + + + + + + + + + + 1 + + + + it.squeeze(null).T + + + + + 0 + + true + + + + + + + + + it.T + + + + + + + + + + + + + + + + + YTrain + + + LoadValidationData + + + + + + + + Z:/home/nicholas/Downloads/xfads-2/workflows/data/y_valid.bin + 0 + 0 + 50 + 1 + F64 + RowMajor + + + + + + 1000 + + + + + + + + + + + + + + Float32 + + + + + + + + + + + + + + + + + + + + 1 + + + + it.squeeze(null).T + + + + + 0 + + true + + + + + + + + + it.T + + + + + + + + + + + + + + + + + YValid + + + + 0 + 50 + 1 + + + + + Float32 + + + YValid + + + YTrain + + + + 2 + + PCA + 1 + 100 + 1E-05 + 0.1 + 0.9 + + + + + + + + + + + + Item2 + + + + + + + + + + + 333 + 50 + 2 + + + + + + 0,:,0 + + + + Float32 + + + + + + DataStruct + + + + Source1 + + + Item1 + + + + + + Item2 + + + + + + + + + new (Item1 as X, Item2 as Y) + + + + + + + + + + + + + + + + + Circle + 1 + + + + + + + + + + 0 + 50 + 1 + + + + + Float32 + + + YValid + + + YTrain + + + + 2 + + ProbabilisticPCA + 1 + 100 + 1E-05 + 0.1 + 0.9 + + + + + + + + + + + + Item2 + + + + + + + + + + + 333 + 50 + 2 + + + + + + 0,:,0 + + + + Float32 + + + + + + DataStruct + + + + Source1 + + + Item1 + + + + + + Item2 + + + + + + + + + new (Item1 as X, Item2 as Y) + + + + + + + + + + + + + + + + + Circle + 1 + + + + + + + + + YValid + + + + 1 + + + + it.shape[1] + + + + + + + + + 0 + 16650 + 1 + + + + + Int32 + + + + + + Source1 + + + + + + + + + YValid + + + + 1 + + + + + + + Source1 + + + + + + + + + + + + + + + + it.unsqueeze(1) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + YValidStream + + + + YTrain + + + + 1 + + + + it.shape[1] + + + + + + + + + 0 + 33350 + 1 + + + + + Int32 + + + + + + Source1 + + + + + + + + + YTrain + + + + 1 + + + + + + + Source1 + + + + + + + + + + + + + + + + it.unsqueeze(1) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 2 + + OnlinePPCA + 0.1 + 100 + 1E-05 + + 0.7 + 3000 + 500 + + + + + + + + + + + 1 + + + + Item2 + + + FittedOnlinePPCA + + + Components + + + + 0 + + + + Float32 + + + + + + + + + + 0 + 50 + 1 + + + + + Float32 + + + FittedOnlinePPCA + + + + + + YValidStream + + + FittedOnlinePPCA + + + + + + + + + + 50 + + + + + + + + 0 + + + + + + + + + + + + + + + + + + + 1 + 50 + 2 + + + + + + 0,:,0 + + + + Float32 + + + + + + DataStruct + + + + Source1 + + + Item1 + + + + + + Item2 + + + + + + + + + new (Item1 as X, Item2 as Y) + + + + + + + + + + + + + + + + + Circle + 1 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file