diff --git a/Bonsai.ML.sln b/Bonsai.ML.sln
index fc4e963f..56e66f13 100644
--- a/Bonsai.ML.sln
+++ b/Bonsai.ML.sln
@@ -46,6 +46,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Lds.Torch.Tests",
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Lds.Torch.Design", "src\Bonsai.ML.Lds.Torch.Design\Bonsai.ML.Lds.Torch.Design.csproj", "{1F52DECD-1B2C-4F6C-996C-14C715283B80}"
EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Pca.Torch", "src\Bonsai.ML.Pca.Torch\Bonsai.ML.Pca.Torch.csproj", "{1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Pca.Torch.Tests", "tests\Bonsai.ML.Pca.Torch.Tests\Bonsai.ML.Pca.Torch.Tests.csproj", "{4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -120,6 +124,14 @@ Global
{1F52DECD-1B2C-4F6C-996C-14C715283B80}.Debug|Any CPU.Build.0 = Debug|Any CPU
{1F52DECD-1B2C-4F6C-996C-14C715283B80}.Release|Any CPU.ActiveCfg = Release|Any CPU
{1F52DECD-1B2C-4F6C-996C-14C715283B80}.Release|Any CPU.Build.0 = Release|Any CPU
+ {1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}.Release|Any CPU.Build.0 = Release|Any CPU
+ {4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
diff --git a/src/Bonsai.ML.Pca.Torch/Bonsai.ML.Pca.Torch.csproj b/src/Bonsai.ML.Pca.Torch/Bonsai.ML.Pca.Torch.csproj
new file mode 100644
index 00000000..b97c06fd
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/Bonsai.ML.Pca.Torch.csproj
@@ -0,0 +1,15 @@
+
+
+
+ Bonsai.ML.Pca.Torch Bonsai library.
+ $(PackageTags) PCA Principal Component Analysis
+ net472;netstandard2.0
+ enable
+
+
+
+
+
+
+
+
diff --git a/src/Bonsai.ML.Pca.Torch/CreatePca.cs b/src/Bonsai.ML.Pca.Torch/CreatePca.cs
new file mode 100644
index 00000000..bf7f6567
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/CreatePca.cs
@@ -0,0 +1,198 @@
+using System;
+using System.ComponentModel;
+using System.Collections.Generic;
+using System.Reactive.Linq;
+using System.Linq.Expressions;
+using Bonsai.Expressions;
+using System.Reflection;
+using static TorchSharp.torch;
+using System.Xml.Serialization;
+
+namespace Bonsai.ML.Pca.Torch;
+
+///
+/// Creates a PCA model.
+///
+[Combinator]
+[ResetCombinator]
+[WorkflowElementCategory(ElementCategory.Source)]
+[TypeDescriptionProvider(typeof(PcaDescriptionProvider))]
+[Description("Creates a PCA model.")]
+public class CreatePca : ZeroArgumentExpressionBuilder, INamedElement
+{
+ ///
+ public string Name => $"CreatePca.{ModelType}";
+
+ ///
+ /// The number of principal components to compute.
+ ///
+ public int NumComponents { get; set; } = 2;
+
+ ///
+ /// The device on which to create the PCA model.
+ ///
+ [XmlIgnore]
+ [Description("The device on which to create the PCA model.")]
+ public Device? Device { get; set; }
+
+ ///
+ /// The scalar type of the PCA model.
+ ///
+ [Description("The scalar type of the PCA model.")]
+ public ScalarType? ScalarType { get; set; }
+
+ ///
+ /// The type of PCA model to create.
+ ///
+ [RefreshProperties(RefreshProperties.All)]
+ [Description("The type of PCA model to create.")]
+ public PcaModelType ModelType { get; set; } = PcaModelType.Pca;
+
+ ///
+ /// The initial variance for probabilistic PCA models.
+ ///
+ [Description("The initial variance for probabilistic PCA models.")]
+ public double InitialVariance { get; set; } = 1.0;
+
+ ///
+ /// The number of iterations for fitting probabilistic PCA models.
+ ///
+ [Description("The number of iterations for fitting probabilistic PCA models.")]
+ public int Iterations { get; set; } = 100;
+
+ ///
+ /// The tolerance for convergence in probabilistic PCA models.
+ ///
+ [Description("The tolerance for convergence in probabilistic PCA models.")]
+ public double Tolerance { get; set; } = 1e-5;
+
+ ///
+ /// The constant learning rate parameter for the online probabilistic PCA model.
+ ///
+ [Description("The constant learning rate parameter for the online probabilistic PCA model. Only one of Rho or Kappa must be specified.")]
+ public double? Rho { get; set; } = 0.1;
+
+ ///
+ /// The forgetting factor for the online probabilistic PCA model.
+ ///
+ [Description("The forgetting factor for the online probabilistic PCA model. Only one of Rho or Kappa must be specified.")]
+ public double? Kappa { get; set; } = 0.9;
+
+ ///
+ /// The sample offset for the online probabilistic PCA model.
+ ///
+ [Description("The sample offset for the online probabilistic PCA model. If null, decaying learning rate starts from the first sample.")]
+ public int? SampleOffset { get; set; } = null;
+
+ ///
+ /// The period for reorthogonalizing the components in the online probabilistic PCA model.
+ ///
+ [Description("The period for reorthogonalizing the components in the online probabilistic PCA model. If null, reorthogonalization is not performed.")]
+ public int? ReorthogonalizePeriod { get; set; } = null;
+
+ ///
+ /// The random number generator used for initializing probabilistic PCA models.
+ ///
+ [XmlIgnore]
+ public Generator? Generator { get; set; } = null;
+
+ ///
+ /// The learning rate for the Online PCA GHA model.
+ ///
+ public double LearningRate { get; set; } = 0.1;
+
+ internal IEnumerable GetModelProperties()
+ {
+ yield return nameof(NumComponents);
+ yield return nameof(Device);
+ yield return nameof(ScalarType);
+ yield return nameof(ModelType);
+
+ if (ModelType == PcaModelType.ProbabilisticPca)
+ {
+ yield return nameof(InitialVariance);
+ yield return nameof(Iterations);
+ yield return nameof(Tolerance);
+ yield return nameof(Generator);
+ }
+
+ if (ModelType == PcaModelType.OnlineProbabilisticPca)
+ {
+ yield return nameof(InitialVariance);
+ yield return nameof(Rho);
+ yield return nameof(Kappa);
+ yield return nameof(SampleOffset);
+ yield return nameof(ReorthogonalizePeriod);
+ yield return nameof(Generator);
+ }
+
+ if (ModelType == PcaModelType.OnlinePcaGha)
+ {
+ yield return nameof(LearningRate);
+ yield return nameof(Generator);
+ }
+ }
+
+ private static PcaBaseModel CreateModel(CreatePca pcaBuilder)
+ {
+ return pcaBuilder.ModelType switch
+ {
+ PcaModelType.Pca => new Pca(
+ numComponents: pcaBuilder.NumComponents,
+ device: pcaBuilder.Device,
+ scalarType: pcaBuilder.ScalarType),
+ PcaModelType.ProbabilisticPca => new ProbabilisticPca(
+ numComponents: pcaBuilder.NumComponents,
+ device: pcaBuilder.Device,
+ scalarType: pcaBuilder.ScalarType,
+ initialVariance: pcaBuilder.InitialVariance,
+ generator: pcaBuilder.Generator,
+ iterations: pcaBuilder.Iterations,
+ tolerance: pcaBuilder.Tolerance),
+ PcaModelType.OnlineProbabilisticPca => new OnlineProbabilisticPca(
+ numComponents: pcaBuilder.NumComponents,
+ device: pcaBuilder.Device,
+ scalarType: pcaBuilder.ScalarType,
+ initialVariance: pcaBuilder.InitialVariance,
+ generator: pcaBuilder.Generator,
+ rho: pcaBuilder.Rho,
+ kappa: pcaBuilder.Kappa,
+ sampleOffset: pcaBuilder.SampleOffset,
+ reorthogonalizePeriod: pcaBuilder.ReorthogonalizePeriod),
+ PcaModelType.OnlinePcaGha => new OnlinePcaGha(
+ numComponents: pcaBuilder.NumComponents,
+ learningRate: pcaBuilder.LearningRate,
+ device: pcaBuilder.Device,
+ scalarType: pcaBuilder.ScalarType,
+ generator: pcaBuilder.Generator),
+ _ => throw new NotSupportedException($"Model type {pcaBuilder.ModelType} is not supported."),
+ };
+ }
+
+ private static Type GetModelType(PcaModelType modelType)
+ {
+ return modelType switch
+ {
+ PcaModelType.Pca => typeof(Pca),
+ PcaModelType.ProbabilisticPca => typeof(ProbabilisticPca),
+ PcaModelType.OnlineProbabilisticPca => typeof(OnlineProbabilisticPca),
+ PcaModelType.OnlinePcaGha => typeof(OnlinePcaGha),
+ _ => throw new NotSupportedException($"Model type {modelType} is not supported."),
+ };
+ }
+
+ ///
+ public override Expression Build(IEnumerable arguments)
+ {
+ var processMethod = GetType().GetMethod(
+ nameof(Process),
+ BindingFlags.NonPublic | BindingFlags.Static);
+ processMethod = processMethod.MakeGenericMethod(GetModelType(ModelType));
+ return Expression.Call(processMethod, Expression.Constant(this));
+ }
+
+ private static IObservable Process(CreatePca instance) where T : PcaBaseModel
+ {
+ return Observable.Return((T)CreateModel(instance));
+ }
+}
diff --git a/src/Bonsai.ML.Pca.Torch/Fit.cs b/src/Bonsai.ML.Pca.Torch/Fit.cs
new file mode 100644
index 00000000..eb762f93
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/Fit.cs
@@ -0,0 +1,143 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using System.Xml.Serialization;
+using static TorchSharp.torch;
+
+namespace Bonsai.ML.Pca.Torch;
+
+///
+/// Fits the PCA model to the input data.
+///
+public class Fit : IPcaModelProvider
+{
+ ///
+ public IPcaBaseModel? Model { get; set; } = null;
+
+ private void FitModel(IPcaBaseModel model, Tensor data)
+ {
+ model.Fit(data);
+ }
+
+ ///
+ /// Fits the PCA model to the input data.
+ ///
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ if (Model is null)
+ {
+ throw new InvalidOperationException("The PCA model has not been specified.");
+ }
+ return source.Do(value =>
+ {
+ FitModel(Model, value);
+ });
+ }
+
+ ///
+ /// Fits a standard PCA model to the input data.
+ ///
+ ///
+ ///
+ public IObservable> Process(IObservable> source)
+ {
+ return source.Do((value) =>
+ {
+ FitModel(value.Item1, value.Item2);
+ });
+ }
+
+ ///
+ /// Fits a standard PCA model to the input data.
+ ///
+ ///
+ ///
+ public IObservable> Process(IObservable> source)
+ {
+ return source.Do((value) =>
+ {
+ FitModel(value.Item2, value.Item1);
+ });
+ }
+
+ ///
+ /// Fits a probabilistic PCA model to the input data.
+ ///
+ ///
+ ///
+ public IObservable> Process(IObservable> source)
+ {
+ return source.Do((value) =>
+ {
+ FitModel(value.Item1, value.Item2);
+ });
+ }
+
+ ///
+ /// Fits a probabilistic PCA model to the input data.
+ ///
+ ///
+ ///
+ public IObservable> Process(IObservable> source)
+ {
+ return source.Do((value) =>
+ {
+ FitModel(value.Item2, value.Item1);
+ });
+ }
+
+ ///
+ /// Fits an online probabilistic PCA model to the input data.
+ ///
+ ///
+ ///
+ public IObservable> Process(IObservable> source)
+ {
+ return source.Do((value) =>
+ {
+ FitModel(value.Item1, value.Item2);
+ });
+ }
+
+ ///
+ /// Fits an online probabilistic PCA model to the input data.
+ ///
+ ///
+ ///
+ public IObservable> Process(IObservable> source)
+ {
+ return source.Do((value) =>
+ {
+ FitModel(value.Item2, value.Item1);
+ });
+ }
+
+ ///
+ /// Fits an online PCA model using the Generalized Hebbian Algorithm to the input data.
+ ///
+ ///
+ ///
+ public IObservable> Process(IObservable> source)
+ {
+ return source.Do((value) =>
+ {
+ FitModel(value.Item1, value.Item2);
+ });
+ }
+
+ ///
+ /// Fits an online PCA model using the Generalized Hebbian Algorithm to the input data.
+ ///
+ ///
+ ///
+ public IObservable> Process(IObservable> source)
+ {
+ return source.Do((value) =>
+ {
+ FitModel(value.Item2, value.Item1);
+ });
+ }
+}
diff --git a/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs b/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs
new file mode 100644
index 00000000..b7634711
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/FitAndTransform.cs
@@ -0,0 +1,143 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using System.Xml.Serialization;
+using static TorchSharp.torch;
+
+namespace Bonsai.ML.Pca.Torch;
+
+///
+/// Fits the PCA model to the input data and transforms it.
+///
+public class FitAndTransform : IPcaModelProvider
+{
+ ///
+ public IPcaBaseModel? Model { get; set; }
+
+ private static void FitModelAndTransformData(IPcaBaseModel model, Tensor data)
+ {
+ model.FitAndTransform(data);
+ }
+
+
+ ///
+ /// Fits the PCA model to the input data and transforms it.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ if (Model == null)
+ {
+ throw new InvalidOperationException("The PCA model has not been specified.");
+ }
+ return source.Do(value =>
+ {
+ FitModelAndTransformData(Model, value);
+ });
+ }
+
+ ///
+ /// Fits a standard PCA model to the input data and transforms it.
+ ///
+ ///
+ ///
+ public IObservable> Process(IObservable> source)
+ {
+ return source.Do((value) =>
+ {
+ FitModelAndTransformData(value.Item1, value.Item2);
+ });
+ }
+
+ ///
+ /// Fits a standard PCA model to the input data and transforms it.
+ ///
+ ///
+ ///
+ public IObservable> Process(IObservable> source)
+ {
+ return source.Do((value) =>
+ {
+ FitModelAndTransformData(value.Item2, value.Item1);
+ });
+ }
+
+ ///
+ /// Fits a probabilistic PCA model to the input data and transforms it.
+ ///
+ ///
+ ///
+ public IObservable> Process(IObservable> source)
+ {
+ return source.Do((value) =>
+ {
+ FitModelAndTransformData(value.Item1, value.Item2);
+ });
+ }
+
+ ///
+ /// Fits a probabilistic PCA model to the input data and transforms it.
+ ///
+ ///
+ ///
+ public IObservable> Process(IObservable> source)
+ {
+ return source.Do((value) =>
+ {
+ FitModelAndTransformData(value.Item2, value.Item1);
+ });
+ }
+
+ ///
+ /// Fits an online probabilistic PCA model to the input data and transforms it.
+ ///
+ ///
+ ///
+ public IObservable> Process(IObservable> source)
+ {
+ return source.Do((value) =>
+ {
+ FitModelAndTransformData(value.Item1, value.Item2);
+ });
+ }
+
+ ///
+ /// Fits an online probabilistic PCA model to the input data and transforms it.
+ ///
+ ///
+ ///
+ public IObservable> Process(IObservable> source)
+ {
+ return source.Do((value) =>
+ {
+ FitModelAndTransformData(value.Item2, value.Item1);
+ });
+ }
+
+ ///
+ /// Fits an online PCA model using the Generalized Hebbian Algorithm to the input data and transforms it.
+ ///
+ ///
+ ///
+ public IObservable> Process(IObservable> source)
+ {
+ return source.Do((value) =>
+ {
+ FitModelAndTransformData(value.Item1, value.Item2);
+ });
+ }
+
+ ///
+ /// Fits an online PCA model using the Generalized Hebbian Algorithm to the input data and transforms it.
+ ///
+ ///
+ ///
+ public IObservable> Process(IObservable> source)
+ {
+ return source.Do((value) =>
+ {
+ FitModelAndTransformData(value.Item2, value.Item1);
+ });
+ }
+}
diff --git a/src/Bonsai.ML.Pca.Torch/FitAndTransformBuilder.cs b/src/Bonsai.ML.Pca.Torch/FitAndTransformBuilder.cs
new file mode 100644
index 00000000..2cb35f18
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/FitAndTransformBuilder.cs
@@ -0,0 +1,12 @@
+using System.ComponentModel;
+
+namespace Bonsai.ML.Pca.Torch;
+
+///
+/// Represents an operator that fits a PCA model and transforms the input data.
+///
+[ResetCombinator]
+[Combinator]
+[Description("Fits a PCA model and transforms the input data.")]
+[WorkflowElementCategory(ElementCategory.Transform)]
+public class FitAndTransformBuilder() : PcaModelBuilder(new FitAndTransform()) { }
diff --git a/src/Bonsai.ML.Pca.Torch/FitBuilder.cs b/src/Bonsai.ML.Pca.Torch/FitBuilder.cs
new file mode 100644
index 00000000..5955fb2a
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/FitBuilder.cs
@@ -0,0 +1,12 @@
+using System.ComponentModel;
+
+namespace Bonsai.ML.Pca.Torch;
+
+///
+/// Represents an operator that fits a PCA model to the input data.
+///
+[ResetCombinator]
+[Combinator]
+[Description("Fits a PCA model to the input data.")]
+[WorkflowElementCategory(ElementCategory.Sink)]
+public class FitBuilder() : PcaModelBuilder(new Fit()) { }
diff --git a/src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs b/src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs
new file mode 100644
index 00000000..7759f109
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/IPcaBaseModel.cs
@@ -0,0 +1,78 @@
+using static TorchSharp.torch;
+
+namespace Bonsai.ML.Pca.Torch;
+
+///
+/// Defines the interface for PCA models.
+///
+public interface IPcaBaseModel
+{
+ ///
+ /// Gets a value indicating whether the model has been fitted to data.
+ ///
+ public bool IsFitted { get; }
+
+ ///
+ /// Gets the number of features in the fitted data.
+ ///
+ public int NumFeatures { get; }
+
+ ///
+ /// Gets the principal components of the model.
+ ///
+ public Tensor Components { get; }
+
+ ///
+ /// Gets the number of principal components kept by the model.
+ ///
+ public int NumComponents { get; }
+
+ ///
+ /// Gets the device on which the model operates.
+ ///
+ public Device Device { get; }
+
+ ///
+ /// Gets the data type used by the model.
+ ///
+ public ScalarType? ScalarType { get; }
+
+ ///
+ /// Fits the PCA model to the given data.
+ ///
+ ///
+ /// The input data should be a 2D tensor with shape (samples x features).
+ ///
+ ///
+ public void Fit(Tensor data);
+
+ ///
+ /// Transforms the input data using the PCA model.
+ ///
+ ///
+ /// The input data should be a 2D tensor with shape (samples x features).
+ ///
+ ///
+ ///
+ public Tensor Transform(Tensor data);
+
+ ///
+ /// Fits the PCA model to the given data and transforms it.
+ ///
+ ///
+ /// The input data should be a 2D tensor with shape (samples x features).
+ ///
+ ///
+ ///
+ public Tensor FitAndTransform(Tensor data);
+
+ ///
+ /// Reconstructs the input data using the PCA model.
+ ///
+ ///
+ /// The input data should be a 2D tensor with shape (samples x features).
+ ///
+ ///
+ ///
+ public Tensor Reconstruct(Tensor data);
+}
diff --git a/src/Bonsai.ML.Pca.Torch/IPcaModelProvider.cs b/src/Bonsai.ML.Pca.Torch/IPcaModelProvider.cs
new file mode 100644
index 00000000..512e15da
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/IPcaModelProvider.cs
@@ -0,0 +1,12 @@
+namespace Bonsai.ML.Pca.Torch;
+
+///
+/// Defines an interface for PCA model providers.
+///
+public interface IPcaModelProvider
+{
+ ///
+ /// Gets or sets the PCA model.
+ ///
+ public IPcaBaseModel? Model { get; set; }
+}
diff --git a/src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs b/src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs
new file mode 100644
index 00000000..2240701a
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/OnlinePcaGha.cs
@@ -0,0 +1,99 @@
+using static TorchSharp.torch;
+
+namespace Bonsai.ML.Pca.Torch;
+
+///
+/// Implements streaming/online PCA using the Generalized Hebbian Algorithm (GHA).
+///
+///
+///
+///
+///
+///
+public class OnlinePcaGha(
+ int numComponents,
+ double learningRate = 0.1,
+ Device? device = null,
+ ScalarType? scalarType = null,
+ Generator? generator = null
+) : PcaBaseModel(numComponents, device, scalarType)
+{
+ ///
+ /// Gets the number of samples that have been used to fit the model.
+ ///
+ public int SampleCount { get; private set; } = 0;
+
+ ///
+ /// Gets the mean of the fitted data.
+ ///
+ public Tensor Mean { get; private set; } = empty(0);
+
+ ///
+ /// Gets or sets the learning rate.
+ ///
+ public double LearningRate { get; set; } = learningRate;
+
+ ///
+ public override Tensor Components { get; protected set; } = empty(0);
+
+ ///
+ /// Gets the random number generator used for initializing the model.
+ ///
+ public Generator? Generator { get; private set; } = generator;
+
+ ///
+ public override void Fit(Tensor data)
+ {
+ base.Fit(data);
+
+ var numSamples = data.size(0);
+
+ using (no_grad())
+ using (NewDisposeScope())
+ {
+ // Initialize components randomly
+ if (Components.numel() == 0)
+ Components = randn([NumFeatures, NumComponents], dtype: ScalarType, device: Device, generator: Generator);
+
+ if (Mean.numel() == 0)
+ Mean = data.mean([0], keepdim: true);
+ else
+ {
+ Mean *= SampleCount / (SampleCount + numSamples);
+ Mean += data.mean([0], keepdim: true) * numSamples / (SampleCount + numSamples);
+ }
+
+ SampleCount += (int)numSamples;
+ var dataCentered = data - Mean;
+
+ var projection = dataCentered.matmul(Components);
+ var hebbianTerm = dataCentered.T.matmul(projection);
+ var crossTerm = projection.T.matmul(projection);
+ var lowerTriangular = crossTerm.tril(0);
+ var correlation = Components.matmul(lowerTriangular);
+ var componentsUpdate = (hebbianTerm - correlation) * (LearningRate / numSamples);
+ var weights = Components + componentsUpdate;
+ var norms = weights.norm(dim: 0, keepdim: true, p: 2).clamp_min(1e-12);
+
+ Components = linalg.qr(weights / norms, mode: linalg.QRMode.Reduced).Q.MoveToOuterDisposeScope();
+ Mean = Mean.MoveToOuterDisposeScope();
+ }
+
+ IsFitted = true;
+ }
+
+ ///
+ public override Tensor Transform(Tensor data)
+ {
+ base.Transform(data);
+ var dataCentered = data - Mean;
+ return dataCentered.matmul(Components);
+ }
+
+ ///
+ public override Tensor Reconstruct(Tensor data)
+ {
+ base.Reconstruct(data);
+ return data.matmul(Components.T) + Mean;
+ }
+}
diff --git a/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs b/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs
new file mode 100644
index 00000000..b5e251bb
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/OnlineProbabilisticPca.cs
@@ -0,0 +1,271 @@
+using System;
+using static TorchSharp.torch;
+using static TorchSharp.torch.linalg;
+
+namespace Bonsai.ML.Pca.Torch;
+
+///
+/// Implements an online probabilistic PCA model using the stochastic online EM algorithm.
+///
+public class OnlineProbabilisticPca : PcaBaseModel
+{
+ private Tensor _identityComponents = empty(0);
+ private Tensor _mx = empty(0); // E[x]
+ private Tensor _Cxz = empty(0); // E[xz^T]
+ private Tensor _mz = empty(0); // E[z]
+ private Tensor _Czz = empty(0); // E[zz^T]
+ private Tensor _sxx = empty(0); // E[||x||^2]
+ private readonly Func UpdateSchedule;
+ private int _stepCount = 0;
+ private readonly bool _reorthogonalize = false;
+
+ ///
+ /// Rho is a constant learning rate parameter.
+ ///
+ ///
+ /// Rho must be in the range (0, 1). Only one of Rho or Kappa should be specified.
+ ///
+ public double? Rho { get; private set; }
+
+ ///
+ /// Kappa is the exponent in the learning rate schedule.
+ ///
+ ///
+ /// Kappa must be in the range (0.5, 1]. Only one
+ /// of Rho or Kappa should be specified.
+ ///
+ public double? Kappa { get; private set; }
+
+ ///
+ /// Gets the mean of the fitted data.
+ ///
+ public Tensor Means { get; private set; } = empty(0);
+
+ ///
+ /// Gets the variance of the isotropic Gaussian noise model.
+ ///
+ public double Variance { get; private set; }
+
+ ///
+ /// Gets the period for reorthogonalizing the principal components.
+ ///
+ ///
+ /// Represented as the number of update steps between reorthogonalization operations.
+ /// If not specified, reorthogonalization is not performed.
+ ///
+ public int ReorthogonalizePeriod { get; private set; }
+
+ ///
+ /// Gets the sample offset used in the learning rate schedule when Kappa is specified.
+ ///
+ public int SampleOffset { get; private set; }
+
+ ///
+ public override Tensor Components { get; protected set; } = empty(0);
+
+ ///
+ /// Gets the random number generator used for initializing the model.
+ ///
+ public Generator? Generator { get; private set; }
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public OnlineProbabilisticPca(int numComponents,
+ Device? device = null,
+ ScalarType? scalarType = null,
+ double initialVariance = 1.0,
+ Generator? generator = null,
+ double? rho = 0.1,
+ double? kappa = null,
+ int? sampleOffset = null,
+ int? reorthogonalizePeriod = null
+ ) : base(numComponents,
+ device,
+ scalarType)
+ {
+ if (initialVariance <= 0)
+ {
+ throw new ArgumentException("Starting variance must be greater than or equal to zero.", nameof(initialVariance));
+ }
+
+ if (kappa.HasValue && rho.HasValue)
+ {
+ throw new ArgumentException("Only one of rho or kappa should be specified, not both.", nameof(rho));
+ }
+
+ if (!kappa.HasValue && !rho.HasValue)
+ {
+ throw new ArgumentException("Either rho or kappa must be specified.", nameof(rho));
+ }
+
+ if (rho.HasValue)
+ {
+ if (rho.Value <= 0 || rho.Value >= 1)
+ {
+ throw new ArgumentException("Rho must be in the range (0, 1).", nameof(rho));
+ }
+
+ UpdateSchedule = () => rho.Value;
+ }
+ else
+ {
+ sampleOffset ??= 0;
+ if (sampleOffset < 0)
+ {
+ throw new ArgumentException("Sample offset must be a positive integer.", nameof(sampleOffset));
+ }
+
+ if (!kappa.HasValue)
+ {
+ throw new ArgumentException("Kappa must be specified when using a learning rate schedule.", nameof(kappa));
+ }
+
+ if (kappa <= 0.5 || kappa > 1)
+ {
+ throw new ArgumentException("Kappa must be in the range (0.5, 1].", nameof(kappa));
+ }
+
+ UpdateSchedule = () => Math.Pow(_stepCount + SampleOffset, -kappa.Value);
+ }
+
+ if (reorthogonalizePeriod.HasValue)
+ {
+ _reorthogonalize = true;
+ ReorthogonalizePeriod = reorthogonalizePeriod.Value;
+ }
+
+ Generator = generator;
+ Rho = rho;
+ Kappa = kappa;
+ SampleOffset = sampleOffset ?? 0;
+ Variance = initialVariance;
+ }
+
+ ///
+ public override void Fit(Tensor data)
+ {
+ base.Fit(data);
+
+ using (no_grad())
+ using (NewDisposeScope())
+ {
+
+ _stepCount++;
+ var rho = UpdateSchedule();
+
+ // Initialize dimensions
+ var numSamples = data.size(0);
+
+ // Initialize parameters
+ if (Means.numel() == 0)
+ {
+ Means = zeros(NumFeatures, device: Device, dtype: ScalarType).MoveToOuterDisposeScope();
+ var weights = qr(randn(NumFeatures, NumComponents, generator: Generator, device: Device, dtype: ScalarType), mode: QRMode.Reduced).Q;
+ Components = (weights * Variance).MoveToOuterDisposeScope();
+ _identityComponents = eye(NumComponents, device: Device, dtype: ScalarType).MoveToOuterDisposeScope();
+ _mx = zeros(NumFeatures, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // d
+ _Cxz = zeros(NumFeatures, NumComponents, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // d x q
+ _mz = zeros(NumComponents, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q
+ _Czz = zeros(NumComponents, NumComponents, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // q x q
+ _sxx = zeros(1, device: Device, dtype: ScalarType).MoveToOuterDisposeScope(); // scalar
+ }
+
+ // Covariance matrix
+ var cov = _identityComponents * Variance;
+
+ // Center data using current mean
+ var dataCentered = data - Means;
+
+ // E-step
+ var M = Components.T.matmul(Components) + cov;
+ var MInv = Utils.InvertSPD(M, _identityComponents);
+ var projection = dataCentered.matmul(Components);
+ var EzT = Utils.InvertSPD(M, projection.T);
+ var Ez = EzT.T;
+
+ // Update statistics
+ var mx = data.mean([0]);
+ var sxx = data.pow(2).sum(dim: 1).mean();
+ var Cxz = data.T.matmul(Ez) / numSamples;
+ var mz = Ez.mean([0]);
+ var Czz = EzT.matmul(Ez) / numSamples + Variance * MInv;
+
+ // Update parameters
+ var rhoFactor = 1 - rho;
+ _mx = (rhoFactor * _mx + rho * mx).MoveToOuterDisposeScope();
+ _Cxz = (rhoFactor * _Cxz + rho * Cxz).MoveToOuterDisposeScope();
+ _mz = (rhoFactor * _mz + rho * mz).MoveToOuterDisposeScope();
+ _sxx = (rhoFactor * _sxx + rho * sxx).MoveToOuterDisposeScope();
+ _Czz = (rhoFactor * _Czz + rho * Czz).MoveToOuterDisposeScope();
+
+ // Update mean
+ Means = _mx.MoveToOuterDisposeScope();
+
+ // Centered statistics
+ var Sxz = _Cxz - outer(Means, _mz);
+ var Szz = _Czz;
+ var Sxx = _sxx - Means.dot(Means);
+
+ // M-step
+ var weightsUpdated = Utils.InvertSPD(Szz, Sxz.T).T;
+
+ if (_reorthogonalize &&
+ _stepCount % ReorthogonalizePeriod == 0)
+ {
+ var (U, S, Vh) = svd(weightsUpdated, fullMatrices: false);
+ var R = Vh.T;
+ weightsUpdated = U.matmul(diag(S));
+ _Cxz = _Cxz.matmul(R.T);
+ _Czz = R.matmul(_Czz).matmul(R.T);
+ _mz = R.matmul(_mz);
+ }
+
+ // Reorder components based on the strength of the components
+ var strength = sum(weightsUpdated * weightsUpdated, dim: 0);
+ var indices = argsort(strength, descending: true);
+ Components = weightsUpdated.index_select(1, indices).MoveToOuterDisposeScope();
+ _Cxz = _Cxz.index_select(1, indices).MoveToOuterDisposeScope();
+ _mz = _mz.index_select(0, indices).MoveToOuterDisposeScope();
+ _Czz = _Czz.index_select(0, indices).index_select(1, indices).MoveToOuterDisposeScope();
+
+ Sxz = _Cxz - outer(Means, _mz);
+ Szz = _Czz;
+
+ // Update variance
+ Variance = ((Sxx - 2 * trace(Components.T.matmul(Sxz)) + trace(Components.T.matmul(Components).matmul(Szz))) / (double)NumFeatures)
+ .clamp_min(0.0)
+ .to_type(TorchSharp.torch.ScalarType.Float64)
+ .item();
+ }
+
+ IsFitted = true;
+ }
+
+ ///
+ public override Tensor Transform(Tensor data)
+ {
+ base.Transform(data);
+ var dataCentered = data - Means;
+ var M = Components.T.matmul(Components) + _identityComponents * Variance;
+ var projection = dataCentered.matmul(Components);
+ return Utils.InvertSPD(M, projection.T).T;
+ }
+
+ ///
+ public override Tensor Reconstruct(Tensor data)
+ {
+ base.Reconstruct(data);
+ return data.matmul(Components.T) + Means;
+ }
+}
diff --git a/src/Bonsai.ML.Pca.Torch/Pca.cs b/src/Bonsai.ML.Pca.Torch/Pca.cs
new file mode 100644
index 00000000..a33c8ce4
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/Pca.cs
@@ -0,0 +1,73 @@
+using static TorchSharp.torch;
+using static TorchSharp.torch.linalg;
+
+namespace Bonsai.ML.Pca.Torch;
+
+///
+/// Represents a standard Principal Component Analysis (PCA) model.
+///
+public class Pca(int numComponents,
+ Device? device = null,
+ ScalarType? scalarType = null
+) : PcaBaseModel(numComponents,
+ device,
+ scalarType)
+{
+ ///
+ /// Gets the mean of the fitted data.
+ ///
+ public Tensor Mean { get; private set; } = empty(0);
+
+ ///
+ public override Tensor Components { get; protected set; } = empty(0);
+
+ ///
+ /// The singular values of the fitted data.
+ ///
+ public Tensor SingularValues { get; private set; } = empty(0);
+
+ ///
+ public override void Fit(Tensor data)
+ {
+ base.Fit(data);
+
+ using (no_grad())
+ using (NewDisposeScope())
+ {
+ var mean = data.mean([0], keepdim: true);
+ var dataCentered = data - mean;
+ var (U, S, Vh) = svd(dataCentered, fullMatrices: false);
+ var components = Vh.slice(0, 0, NumComponents, 1).T;
+ var singularValues = S.slice(0, 0, NumComponents, 1);
+
+ if (ScalarType is not null && ScalarType != data.dtype)
+ {
+ var scalarType = ScalarType.Value;
+ mean = mean.to(scalarType);
+ components = components.to(scalarType);
+ singularValues = singularValues.to(scalarType);
+ }
+
+ Mean = mean.MoveToOuterDisposeScope();
+ Components = components.MoveToOuterDisposeScope();
+ SingularValues = singularValues.MoveToOuterDisposeScope();
+ }
+
+ IsFitted = true;
+ }
+
+ ///
+ public override Tensor Transform(Tensor data)
+ {
+ base.Transform(data);
+ var dataCentered = data - Mean;
+ return dataCentered.matmul(Components);
+ }
+
+ ///
+ public override Tensor Reconstruct(Tensor data)
+ {
+ base.Reconstruct(data);
+ return data.matmul(Components.T) + Mean;
+ }
+}
diff --git a/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs b/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs
new file mode 100644
index 00000000..99ba852a
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/PcaBaseModel.cs
@@ -0,0 +1,114 @@
+using System;
+using System.Linq;
+using static TorchSharp.torch;
+
+namespace Bonsai.ML.Pca.Torch;
+
+///
+/// Provides an abstract base class for PCA models.
+///
+public abstract class PcaBaseModel : IPcaBaseModel
+{
+ ///
+ public bool IsFitted { get; protected set; }
+
+ ///
+ public int NumFeatures { get; protected set; } = -1;
+
+ ///
+ public abstract Tensor Components { get; protected set; }
+
+ ///
+ public int NumComponents { get; private set; }
+
+ ///
+ public Device Device { get; private set; }
+
+ ///
+ public ScalarType? ScalarType { get; private set; }
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public PcaBaseModel(int numComponents,
+ Device? device = null,
+ ScalarType? scalarType = null)
+ {
+ if (numComponents <= 0)
+ {
+ throw new ArgumentException("Number of components must be greater than zero.", nameof(numComponents));
+ }
+
+ NumComponents = numComponents;
+ Device = device ?? CPU;
+ ScalarType = scalarType;
+ }
+
+ ///
+ public virtual void Fit(Tensor data)
+ {
+ CheckDataCompatibility(data);
+
+ var d = data.size(1);
+
+ if (NumComponents > d)
+ throw new ArgumentException($"Number of components cannot be greater than the number of features. Number of components: {NumComponents}, number of features: {d}.", nameof(data));
+
+ NumFeatures = (int)d;
+ }
+
+ ///
+ public virtual Tensor Transform(Tensor data)
+ {
+ CheckFitted();
+ CheckDataCompatibility(data);
+ CheckDataFeatures(data);
+ return data;
+ }
+
+ ///
+ public virtual Tensor FitAndTransform(Tensor data)
+ {
+ Fit(data);
+ return Transform(data);
+ }
+
+ ///
+ public virtual Tensor Reconstruct(Tensor data)
+ {
+ CheckFitted();
+ CheckDataCompatibility(data);
+ CheckDataFeatures(data);
+ return data;
+ }
+
+ private void CheckFitted()
+ {
+ if (!IsFitted)
+ throw new InvalidOperationException("Model has not yet been fitted. You should call one of the Fit() or the FitAndTransform() methods first.");
+ }
+
+ private void CheckDataCompatibility(Tensor data)
+ {
+ if (data.NumberOfElements == 0)
+ throw new ArgumentException("Data must be a non-empty 2D tensor with shape (samples x features).", nameof(data));
+
+ if (data.dim() != 2)
+ {
+ var shapeStr = string.Join(",", data.shape.Select(x => x.ToString()).ToArray());
+ throw new ArgumentException($"Data must be a 2D tensor with shape (samples x features). Data shape: {shapeStr}.", nameof(data));
+ }
+ }
+
+ private void CheckDataFeatures(Tensor data)
+ {
+ var d = data.size(1);
+
+ if (d != NumFeatures)
+ throw new ArgumentException("The number of features in the data does not match the number of features in the fitted model.", nameof(data));
+ }
+}
diff --git a/src/Bonsai.ML.Pca.Torch/PcaDesciptionProvider.cs b/src/Bonsai.ML.Pca.Torch/PcaDesciptionProvider.cs
new file mode 100644
index 00000000..39b50eb0
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/PcaDesciptionProvider.cs
@@ -0,0 +1,26 @@
+using System;
+using System.ComponentModel;
+
+namespace Bonsai.ML.Pca.Torch;
+
+///
+/// Provides a custom type description provider for PCA models.
+///
+///
+/// Initializes a new instance of the class.
+///
+///
+public class PcaDescriptionProvider(TypeDescriptionProvider baseProvider) : TypeDescriptionProvider(baseProvider)
+{
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public PcaDescriptionProvider() : this(TypeDescriptor.GetProvider(typeof(object))) { }
+
+ ///
+ public override ICustomTypeDescriptor GetTypeDescriptor(Type objectType, object instance)
+ {
+ var defaultDescriptor = base.GetTypeDescriptor(objectType, instance);
+ return new PcaDescriptor(defaultDescriptor, instance);
+ }
+}
\ No newline at end of file
diff --git a/src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs b/src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs
new file mode 100644
index 00000000..307c4d8f
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/PcaDescriptor.cs
@@ -0,0 +1,44 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using System.Collections.Generic;
+using System.Linq;
+
+namespace Bonsai.ML.Pca.Torch;
+
+///
+/// Provides a custom type descriptor for PCA model creation.
+///
+///
+/// Initializes a new instance of the class.
+///
+///
+///
+public class PcaDescriptor(ICustomTypeDescriptor parent, object instance) : CustomTypeDescriptor(parent)
+{
+ private readonly object _instance = instance;
+
+ ///
+ public override PropertyDescriptorCollection GetProperties(Attribute[] attributes)
+ {
+ var allProperties = base.GetProperties(attributes);
+
+ if (_instance is CreatePca createPca)
+ {
+ var modelProperties = new HashSet(createPca.GetModelProperties());
+ var filtered = allProperties.Cast()
+ .Where(p => modelProperties.Contains(p.Name))
+ .ToArray();
+ return new PropertyDescriptorCollection(filtered);
+ }
+
+ return allProperties;
+ }
+
+ ///
+ public override PropertyDescriptorCollection GetProperties()
+ => GetProperties([]);
+
+ ///
+ public override object GetPropertyOwner(PropertyDescriptor pd) => _instance;
+}
diff --git a/src/Bonsai.ML.Pca.Torch/PcaModelBuilder.cs b/src/Bonsai.ML.Pca.Torch/PcaModelBuilder.cs
new file mode 100644
index 00000000..8ef69b25
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/PcaModelBuilder.cs
@@ -0,0 +1,94 @@
+using System;
+using System.Collections.Generic;
+using System.ComponentModel;
+using System.Linq;
+using System.Linq.Expressions;
+using System.Reflection;
+using System.Xml.Serialization;
+using Bonsai.Expressions;
+using static TorchSharp.torch;
+
+namespace Bonsai.ML.Pca.Torch;
+
+///
+/// Represents an abstract PCA model builder.
+///
+public abstract class PcaModelBuilder(IPcaModelProvider operatorInstance) : SingleArgumentExpressionBuilder, ICustomTypeDescriptor
+{
+ private readonly IPcaModelProvider _operator = operatorInstance;
+
+ ///
+ /// The PCA model.
+ ///
+ [XmlIgnore]
+ [Description("The PCA model.")]
+ public IPcaBaseModel? Model
+ {
+ get => _operator.Model;
+ set => _operator.Model = value;
+ }
+
+ ///
+ /// Determines whether the model is available in the input expression (and thus Model property should not be exposed) or if it should be explicitly set as a property of the operator.
+ ///
+ [Browsable(false)]
+ [XmlIgnore]
+ [RefreshProperties(RefreshProperties.All)]
+ public bool HasModel { get; private set; } = true;
+
+ ///
+ public override Expression Build(IEnumerable arguments)
+ {
+ var input = arguments?.FirstOrDefault();
+ MethodInfo processMethod;
+ if (input is null)
+ {
+ HasModel = true;
+ processMethod = typeof(T).GetMethod("Process", [typeof(IObservable)]);
+ return Expression.Call(Expression.Constant(_operator), processMethod!, [input]);
+ }
+
+ var obsType = input.Type;
+ var t = obsType.GetGenericArguments().Single();
+
+ HasModel = !(t.IsGenericType && t.FullName?.StartsWith("System.Tuple`") == true);
+
+ if (!HasModel)
+ {
+ var args = t.GetGenericArguments();
+ if (args.Length != 2 || (!typeof(IPcaBaseModel).IsAssignableFrom(args[0]) && !typeof(IPcaBaseModel).IsAssignableFrom(args[1])))
+ throw new InvalidOperationException("The input type is not valid. Expected an observable sequence of tuples containing a PCA model and a tensor.");
+ }
+
+ processMethod = HasModel
+ ? typeof(T).GetMethod("Process", [typeof(IObservable)])
+ : typeof(T).GetMethod("Process", [typeof(IObservable<>).MakeGenericType(t)]);
+
+ return Expression.Call(Expression.Constant(_operator), processMethod!, [input]);
+ }
+
+
+ PropertyDescriptorCollection ICustomTypeDescriptor.GetProperties(Attribute[]? attributes)
+ {
+ var props = TypeDescriptor.GetProperties(this, attributes, true);
+ if (HasModel) return props;
+
+ var filtered = props.Cast()
+ .Where(p => p.Name != nameof(Model))
+ .ToArray();
+
+ return new PropertyDescriptorCollection(filtered);
+ }
+
+ AttributeCollection ICustomTypeDescriptor.GetAttributes() => TypeDescriptor.GetAttributes(this, true);
+ string? ICustomTypeDescriptor.GetClassName() => TypeDescriptor.GetClassName(this, true);
+ string? ICustomTypeDescriptor.GetComponentName() => TypeDescriptor.GetComponentName(this, true);
+ TypeConverter ICustomTypeDescriptor.GetConverter() => TypeDescriptor.GetConverter(this, true);
+ EventDescriptor? ICustomTypeDescriptor.GetDefaultEvent() => TypeDescriptor.GetDefaultEvent(this, true);
+ PropertyDescriptor? ICustomTypeDescriptor.GetDefaultProperty() => TypeDescriptor.GetDefaultProperty(this, true);
+ object? ICustomTypeDescriptor.GetEditor(Type editorBaseType) => TypeDescriptor.GetEditor(this, editorBaseType, true);
+ EventDescriptorCollection ICustomTypeDescriptor.GetEvents(Attribute[]? attributes) => TypeDescriptor.GetEvents(this, attributes, true);
+ EventDescriptorCollection ICustomTypeDescriptor.GetEvents() => TypeDescriptor.GetEvents(this, true);
+ PropertyDescriptorCollection ICustomTypeDescriptor.GetProperties() => ((ICustomTypeDescriptor)this).GetProperties(Array.Empty());
+ object ICustomTypeDescriptor.GetPropertyOwner(PropertyDescriptor? pd) => this;
+}
diff --git a/src/Bonsai.ML.Pca.Torch/PcaModelType.cs b/src/Bonsai.ML.Pca.Torch/PcaModelType.cs
new file mode 100644
index 00000000..d8194720
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/PcaModelType.cs
@@ -0,0 +1,27 @@
+namespace Bonsai.ML.Pca.Torch;
+
+///
+/// Specifies the type of PCA model.
+///
+public enum PcaModelType
+{
+ ///
+ /// Standard PCA model.
+ ///
+ Pca,
+
+ ///
+ /// Probabilistic PCA model.
+ ///
+ ProbabilisticPca,
+
+ ///
+ /// Online Probabilistic PCA model.
+ ///
+ OnlineProbabilisticPca,
+
+ ///
+ /// Online PCA model using the Generalized Hebbian Algorithm.
+ ///
+ OnlinePcaGha
+}
diff --git a/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs b/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs
new file mode 100644
index 00000000..843298f2
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/ProbabalisticPca.cs
@@ -0,0 +1,181 @@
+using System;
+using static TorchSharp.torch;
+using static TorchSharp.torch.linalg;
+
+namespace Bonsai.ML.Pca.Torch;
+
+///
+/// Represents a probabilistic PCA model.
+///
+public class ProbabilisticPca : PcaBaseModel
+{
+ private readonly int _iterations;
+ private readonly double _tolerance;
+
+ ///
+ /// Gets the mean of the fitted data.
+ ///
+ public Tensor Mean { get; private set; } = empty(0);
+
+ ///
+ /// Gets the variance of the isotropic Gaussian noise model.
+ ///
+ public double Variance { get; private set; }
+
+ ///
+ /// Gets the log likelihood of the fitted model.
+ ///
+ public Tensor LogLikelihood { get; private set; } = empty(0);
+
+ ///
+ public override Tensor Components { get; protected set; } = empty(0);
+
+ ///
+ /// Gets the random number generator used for initializing the model.
+ ///
+ public Generator? Generator { get; private set; }
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public ProbabilisticPca(int numComponents,
+ Device? device = null,
+ ScalarType? scalarType = null,
+ double initialVariance = 1.0,
+ Generator? generator = null,
+ int iterations = 100,
+ double tolerance = 1e-5
+ ) : base(numComponents,
+ device,
+ scalarType)
+ {
+ if (initialVariance < 0)
+ {
+ throw new ArgumentException("Starting variance must be greater than or equal to zero.", nameof(initialVariance));
+ }
+
+ if (iterations <= 0)
+ {
+ throw new ArgumentException("Number of iterations must be greater than zero.", nameof(iterations));
+ }
+
+ if (tolerance <= 0)
+ {
+ throw new ArgumentException("Tolerance must be greater than zero.", nameof(tolerance));
+ }
+
+ Variance = initialVariance;
+ Generator = generator;
+ _iterations = iterations;
+ _tolerance = tolerance;
+ }
+
+ ///
+ public override void Fit(Tensor data)
+ {
+ base.Fit(data);
+
+ using (no_grad())
+ using (NewDisposeScope())
+ {
+ var numSamples = data.size(0);
+
+ // Initialize log likelihood
+ LogLikelihood = ones(_iterations, device: Device, dtype: ScalarType) * double.NegativeInfinity;
+
+ var weights = randn(NumFeatures, NumComponents, generator: Generator, device: Device, dtype: ScalarType);
+ var identityComponents = eye(NumComponents, device: Device, dtype: ScalarType);
+ var identityFeatures = eye(NumFeatures, device: Device, dtype: ScalarType);
+
+ var mean = data.mean([0], keepdim: true);
+ var dataCentered = data - mean;
+
+ // Calculate the sample covariance
+ var covarianceTerm = dataCentered.T.matmul(dataCentered);
+ var sampleCov = covarianceTerm / numSamples;
+
+ // Calculate term 1 for variance update
+ var term1 = trace(covarianceTerm);
+
+ // Compute log likelihood constant
+ var logLikelihoodConst = NumFeatures * log(2 * Math.PI).to(Device);
+
+ double diffWeights;
+ double diffVariance;
+
+ // Repeat until convergence
+ for (int i = 0; i < _iterations; i++)
+ {
+ // E-step: Compute the posterior distribution of the latent variables
+ var M = weights.T.matmul(weights) + identityComponents * Variance;
+ var MInv = inv(M);
+ var mu = MInv.matmul(weights.T).matmul(dataCentered.T).T;
+ var SSum = numSamples * MInv * Variance;
+ var cov = mu.T.matmul(mu) + SSum;
+
+ // M-step: Compute new weights and new variance
+ var dataMu = dataCentered.T.matmul(mu);
+ var weightsNew = dataMu.matmul(inv(cov));
+
+ var term2 = 2 * dataMu.mul(weightsNew).sum();
+ var mu2 = mu.T.matmul(mu);
+ var weightsNew2 = weightsNew.T.matmul(weightsNew);
+ var term3 = trace(weightsNew2.matmul(mu2 + SSum));
+ var varianceNew = (term1 - term2 + term3) / (numSamples * NumFeatures);
+
+ // Compute the log likelihood
+ var logLikelihoodTerm = weightsNew.matmul(weightsNew.T) + eye(NumFeatures) * varianceNew;
+ var logLikelihoodTermInv = inv(logLikelihoodTerm);
+ var logLikelihood = -0.5 * numSamples * (logLikelihoodConst + logdet(logLikelihoodTerm) + trace(logLikelihoodTermInv.matmul(sampleCov)));
+
+ // Compare previous and new parameters for convergence
+ diffWeights = linalg.norm(weightsNew - weights).to_type(TorchSharp.torch.ScalarType.Float64).item();
+ diffVariance = abs(varianceNew - Variance).to_type(TorchSharp.torch.ScalarType.Float64).item();
+
+ // Update loglikelihood, weights and variance
+ LogLikelihood[i] = logLikelihood;
+ weights = weightsNew;
+ Variance = varianceNew.to_type(TorchSharp.torch.ScalarType.Float64).item();
+
+ // Check for convergence
+ if (diffWeights < _tolerance && diffVariance < _tolerance)
+ {
+ LogLikelihood = LogLikelihood.slice(0, 0, i + 1, 1);
+ break;
+ }
+ }
+
+ // Finalize model parameters
+ LogLikelihood = LogLikelihood.MoveToOuterDisposeScope();
+ Components = weights.MoveToOuterDisposeScope();
+ Mean = mean.MoveToOuterDisposeScope();
+ }
+
+ IsFitted = true;
+ }
+
+ ///
+ public override Tensor Transform(Tensor data)
+ {
+ base.Transform(data);
+ var dataCentered = data - Mean;
+ var M = Components.T.matmul(Components) + eye(NumComponents) * Variance;
+ var MInv = Utils.InvertSPD(M, eye(NumComponents));
+ return dataCentered.matmul(Components).matmul(MInv);
+ }
+
+ ///
+ public override Tensor Reconstruct(Tensor data)
+ {
+ base.Reconstruct(data);
+ return data.matmul(Components.T) + Mean;
+ }
+}
diff --git a/src/Bonsai.ML.Pca.Torch/Properties/launchSettings.json b/src/Bonsai.ML.Pca.Torch/Properties/launchSettings.json
new file mode 100644
index 00000000..4af4f468
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/Properties/launchSettings.json
@@ -0,0 +1,10 @@
+{
+ "profiles": {
+ "Bonsai": {
+ "commandName": "Executable",
+ "executablePath": "$(BonsaiExecutablePath)",
+ "commandLineArgs": "--lib:\"$(TargetDir).\"",
+ "nativeDebugging": true
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/Bonsai.ML.Pca.Torch/Reconstruct.cs b/src/Bonsai.ML.Pca.Torch/Reconstruct.cs
new file mode 100644
index 00000000..2701da28
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/Reconstruct.cs
@@ -0,0 +1,142 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using System.Xml.Serialization;
+using static TorchSharp.torch;
+
+namespace Bonsai.ML.Pca.Torch;
+
+///
+/// Reconstructs the input data using a PCA model.
+///
+public class Reconstruct : IPcaModelProvider
+{
+ ///
+ public IPcaBaseModel? Model { get; set; }
+
+ private static Tensor ReconstructData(IPcaBaseModel model, Tensor data)
+ {
+ return model.Reconstruct(data);
+ }
+
+ ///
+ /// Reconstructs the input data using the specified PCA model.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ if (Model == null)
+ {
+ throw new InvalidOperationException("The PCA model has not been specified.");
+ }
+ return source.Select(value =>
+ {
+ return ReconstructData(Model, value);
+ });
+ }
+
+ ///
+ /// Reconstructs the input data using a standard PCA model.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(value =>
+ {
+ return ReconstructData(value.Item1, value.Item2);
+ });
+ }
+
+ ///
+ /// Reconstructs the input data using a standard PCA model.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(value =>
+ {
+ return ReconstructData(value.Item2, value.Item1);
+ });
+ }
+
+ ///
+ /// Reconstructs the input data using a probabilistic PCA model.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(value =>
+ {
+ return ReconstructData(value.Item1, value.Item2);
+ });
+ }
+
+ ///
+ /// Reconstructs the input data using a probabilistic PCA model.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(value =>
+ {
+ return ReconstructData(value.Item2, value.Item1);
+ });
+ }
+
+ ///
+ /// Reconstructs the input data using an online probabilistic PCA model based on stochastic online EM.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(value =>
+ {
+ return ReconstructData(value.Item1, value.Item2);
+ });
+ }
+
+ ///
+ /// Reconstructs the input data using an online probabilistic PCA model based on stochastic online EM.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(value =>
+ {
+ return ReconstructData(value.Item2, value.Item1);
+ });
+ }
+
+ ///
+ /// Reconstructs the input data using an online PCA model based on the Generalized Hebbian Algorithm.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(value =>
+ {
+ return ReconstructData(value.Item1, value.Item2);
+ });
+ }
+
+ ///
+ /// Reconstructs the input data using an online PCA model based on the Generalized Hebbian Algorithm
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(value =>
+ {
+ return ReconstructData(value.Item2, value.Item1);
+ });
+ }
+}
diff --git a/src/Bonsai.ML.Pca.Torch/ReconstructBuilder.cs b/src/Bonsai.ML.Pca.Torch/ReconstructBuilder.cs
new file mode 100644
index 00000000..4938864f
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/ReconstructBuilder.cs
@@ -0,0 +1,12 @@
+using System.ComponentModel;
+
+namespace Bonsai.ML.Pca.Torch;
+
+///
+/// Represents an operator that reconstructs the input data using a PCA model.
+///
+[ResetCombinator]
+[Combinator]
+[Description("Reconstructs the input data using a PCA model.")]
+[WorkflowElementCategory(ElementCategory.Transform)]
+public class ReconstructBuilder() : PcaModelBuilder(new Reconstruct()) { }
diff --git a/src/Bonsai.ML.Pca.Torch/Transform.cs b/src/Bonsai.ML.Pca.Torch/Transform.cs
new file mode 100644
index 00000000..016b1346
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/Transform.cs
@@ -0,0 +1,140 @@
+using System;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+
+namespace Bonsai.ML.Pca.Torch;
+
+///
+/// Transforms the input data using a PCA model.
+///
+public class Transform : IPcaModelProvider
+{
+ ///
+ public IPcaBaseModel? Model { get; set; } = null;
+
+ private static Tensor TransformData(IPcaBaseModel model, Tensor data)
+ {
+ return model.Transform(data);
+ }
+
+ ///
+ /// Transforms the input data.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ if (Model == null)
+ {
+ throw new InvalidOperationException("The PCA model has not been specified.");
+ }
+ return source.Select(value =>
+ {
+ return TransformData(Model, value);
+ });
+ }
+
+ ///
+ /// Transforms the input data using a standard PCA model.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(value =>
+ {
+ return TransformData(value.Item1, value.Item2);
+ });
+ }
+
+ ///
+ /// Transforms the input data using a standard PCA model.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(value =>
+ {
+ return TransformData(value.Item2, value.Item1);
+ });
+ }
+
+ ///
+ /// Transforms the input data using a standard PCA model.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(value =>
+ {
+ return TransformData(value.Item1, value.Item2);
+ });
+ }
+
+ ///
+ /// Transforms the input data using a probabilistic PCA model.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(value =>
+ {
+ return TransformData(value.Item2, value.Item1);
+ });
+ }
+
+ ///
+ /// Transforms the input data using an online probabilistic PCA model.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(value =>
+ {
+ return TransformData(value.Item1, value.Item2);
+ });
+ }
+
+ ///
+ /// Transforms the input data using an online probabilistic PCA model.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(value =>
+ {
+ return TransformData(value.Item2, value.Item1);
+ });
+ }
+
+ ///
+ /// Transforms the input data using an online PCA model based on the Generalized Hebbian Algorithm.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(value =>
+ {
+ return TransformData(value.Item1, value.Item2);
+ });
+ }
+
+ ///
+ /// Transforms the input data using an online PCA model based on the Generalized Hebbian Algorithm.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(value =>
+ {
+ return TransformData(value.Item2, value.Item1);
+ });
+ }
+}
diff --git a/src/Bonsai.ML.Pca.Torch/TransformBuilder.cs b/src/Bonsai.ML.Pca.Torch/TransformBuilder.cs
new file mode 100644
index 00000000..ee4ca8ed
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/TransformBuilder.cs
@@ -0,0 +1,12 @@
+using System.ComponentModel;
+
+namespace Bonsai.ML.Pca.Torch;
+
+///
+/// Represents an operator that transforms the input data using a PCA model.
+///
+[ResetCombinator]
+[Combinator]
+[Description("Transforms the input data using a PCA model.")]
+[WorkflowElementCategory(ElementCategory.Transform)]
+public class TransformBuilder() : PcaModelBuilder(new Transform()) { }
diff --git a/src/Bonsai.ML.Pca.Torch/Utils.cs b/src/Bonsai.ML.Pca.Torch/Utils.cs
new file mode 100644
index 00000000..b89b73e5
--- /dev/null
+++ b/src/Bonsai.ML.Pca.Torch/Utils.cs
@@ -0,0 +1,29 @@
+using System;
+using static TorchSharp.torch;
+
+namespace Bonsai.ML.Pca.Torch;
+
+internal static class Utils
+{
+ internal static Tensor InvertSPD(
+ Tensor spdMatrix,
+ Tensor rhs,
+ double regularization = 1e-6,
+ Device? device = null,
+ ScalarType? scalarType = null
+ )
+ {
+ var diagShape = spdMatrix.size(-1);
+ Tensor L;
+ try
+ {
+ L = linalg.cholesky(spdMatrix);
+ }
+ catch (Exception)
+ {
+ var regularizer = eye(diagShape, device: device, dtype: scalarType) * regularization;
+ L = linalg.cholesky(spdMatrix + regularizer);
+ }
+ return cholesky_solve(rhs, L);
+ }
+}
diff --git a/tests/Bonsai.ML.Pca.Torch.Tests/Bonsai.ML.Pca.Torch.Tests.csproj b/tests/Bonsai.ML.Pca.Torch.Tests/Bonsai.ML.Pca.Torch.Tests.csproj
new file mode 100644
index 00000000..e480f9b3
--- /dev/null
+++ b/tests/Bonsai.ML.Pca.Torch.Tests/Bonsai.ML.Pca.Torch.Tests.csproj
@@ -0,0 +1,22 @@
+
+
+ net8.0
+ enable
+ enable
+ false
+ true
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/tests/Bonsai.ML.Pca.Torch.Tests/StandardPcaTests.cs b/tests/Bonsai.ML.Pca.Torch.Tests/StandardPcaTests.cs
new file mode 100644
index 00000000..3a323bc1
--- /dev/null
+++ b/tests/Bonsai.ML.Pca.Torch.Tests/StandardPcaTests.cs
@@ -0,0 +1,130 @@
+using System.Diagnostics;
+using Bonsai.ML.Pca.Torch;
+using TorchSharp;
+using static TorchSharp.torch;
+
+namespace Bonsai.ML.Pca.Torch.Tests;
+
+///
+/// Tests for the Bonsai.ML.Pca.Torch package.
+///
+[TestClass]
+public class StandardPcaTests
+{
+ public StandardPcaTests()
+ {
+ manual_seed(0);
+ set_printoptions(style: TorchSharp.TensorStringStyle.Numpy);
+ }
+
+ private static readonly Tensor _expectedFirstComponent = tensor(new float[] { 1f, 0f });
+
+ private static Tensor Generate2dRotationMatrix(double angleDegrees)
+ {
+ var angleRad = angleDegrees * Math.PI / 180.0;
+ var cosA = Math.Cos(angleRad);
+ var sinA = Math.Sin(angleRad);
+ return tensor(new float[,]
+ {
+ { (float)cosA, (float)-sinA },
+ { (float)sinA, (float)cosA }
+ });
+ }
+
+ private static Tensor Generate2dDataset(int numSamples, double scaleX = 3.0, double scaleY = 0.1, double offsetX = 0.0, double offsetY = 0.0, double rotationAngle = 0.0)
+ {
+ var x = randn(numSamples) * scaleX + offsetX;
+ var y = randn(numSamples) * scaleY + offsetY;
+ var data = stack([x, y], 1);
+ if (rotationAngle == 0)
+ return data;
+ var rotationMatrix = Generate2dRotationMatrix(rotationAngle);
+ return data.matmul(rotationMatrix);
+ }
+
+ private static float Similarity(Tensor a, Tensor b)
+ {
+ if (a.dim() != 1 || b.dim() != 1)
+ throw new ArgumentException("Input tensors must be 1-dimensional. Instead got dimension: " + a.dim());
+ if (a.size(0) != b.size(0))
+ throw new ArgumentException($"Input tensors must have the same shape. Instead got {string.Join(", ", a.shape)} and {string.Join(", ", b.shape)}.");
+ var aNorm = a.norm(-1, keepdim: true);
+ var bNorm = b.norm(-1, keepdim: true);
+ var dotProduct = a.dot(b);
+ return abs(dotProduct / (aNorm * bNorm)).item();
+ }
+
+ private static void TestBasic(PcaBaseModel model)
+ {
+ // Generate a simple dataset.
+ var data = Generate2dDataset(1000);
+ Debug.WriteLine($"Data shape: {string.Join(", ", data.shape)}");
+
+ // Fit the model.
+ model.Fit(data);
+
+ // Verify components.
+ Debug.WriteLine($"Components: {model.Components.str()}");
+ Assert.IsTrue(model.Components.shape[0] == 2 && model.Components.shape[1] == 2);
+
+ // Verify similarity with expected first component.
+ var similarity = Similarity(model.Components[0], _expectedFirstComponent);
+ Debug.WriteLine($"Similarity with expected first component: {similarity}");
+ Assert.IsTrue(similarity > 0.99);
+
+ // Compare reconstructed data.
+ var transformed = model.Transform(data);
+ var reconstructed = model.Reconstruct(transformed);
+
+ var reconstructionError = mean((data - reconstructed).pow(2)).item();
+ Debug.WriteLine($"Reconstruction error: {reconstructionError}");
+ Assert.IsTrue(reconstructionError < 1e-10);
+ }
+
+ private static void TestRotation(PcaBaseModel model)
+ {
+ // Compare rotated dataset
+ var data = Generate2dDataset(1000, rotationAngle: 30.0);
+ model.Fit(data);
+
+ var rotationMatrix = Generate2dRotationMatrix(30.0);
+ var rotatedExpectedFirstComponent = _expectedFirstComponent.matmul(rotationMatrix);
+ Debug.WriteLine($"Rotated expected first component: {rotatedExpectedFirstComponent.str()}");
+ var rotatedSimilarity = Similarity(model.Components[0], rotatedExpectedFirstComponent);
+ Debug.WriteLine($"Similarity with expected first component (rotated data): {rotatedSimilarity}");
+ Assert.IsTrue(rotatedSimilarity > 0.99);
+ }
+
+ private static void TestOffset(PcaBaseModel model)
+ {
+ // Test offset centering
+ var dataOffset = Generate2dDataset(1000, offsetX: 5.0, offsetY: -3.0);
+ model.Fit(dataOffset);
+
+ var offsetExpectedFirstComponent = _expectedFirstComponent;
+ var offsetSimilarity = Similarity(model.Components[0], offsetExpectedFirstComponent);
+ Debug.WriteLine($"Similarity with expected first component (offset data): {offsetSimilarity}");
+ Assert.IsTrue(offsetSimilarity > 0.99);
+
+ var transformed = model.Transform(dataOffset);
+ var reconstructed = model.Reconstruct(transformed);
+
+ var reconstructedMeans = reconstructed.mean([0]);
+ Debug.WriteLine($"Reconstructed means (offset data): {reconstructedMeans.str()}");
+ Assert.IsTrue(abs(reconstructedMeans[0] - 5.0).item() < 0.1);
+ Assert.IsTrue(abs(reconstructedMeans[1] + 3.0).item() < 0.1);
+ }
+
+ [TestMethod]
+ public void TestStandardPca()
+ {
+ var pca = new Pca(numComponents: 2);
+ TestBasic(pca);
+
+ pca = new Pca(numComponents: 2);
+ TestRotation(pca);
+
+ pca = new Pca(numComponents: 2);
+ TestOffset(pca);
+ }
+}
diff --git a/tests/Bonsai.ML.Pca.Torch.Tests/TransformPCATest.bonsai b/tests/Bonsai.ML.Pca.Torch.Tests/TransformPCATest.bonsai
new file mode 100644
index 00000000..e6fa55b9
--- /dev/null
+++ b/tests/Bonsai.ML.Pca.Torch.Tests/TransformPCATest.bonsai
@@ -0,0 +1,889 @@
+
+
+
+
+
+ LoadTrainingData
+
+
+
+
+
+
+
+ Z:/home/nicholas/Downloads/xfads-2/workflows/data/y_train.bin
+ 0
+ 0
+ 50
+ 1
+ F64
+ RowMajor
+
+
+
+
+
+ 1000
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Float32
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 1
+
+
+
+ it.squeeze(null).T
+
+
+
+
+ 0
+
+ true
+
+
+
+
+
+
+
+
+ it.T
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ YTrain
+
+
+ LoadValidationData
+
+
+
+
+
+
+
+ Z:/home/nicholas/Downloads/xfads-2/workflows/data/y_valid.bin
+ 0
+ 0
+ 50
+ 1
+ F64
+ RowMajor
+
+
+
+
+
+ 1000
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Float32
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 1
+
+
+
+ it.squeeze(null).T
+
+
+
+
+ 0
+
+ true
+
+
+
+
+
+
+
+
+ it.T
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ YValid
+
+
+
+ 0
+ 50
+ 1
+
+
+
+
+ Float32
+
+
+ YValid
+
+
+ YTrain
+
+
+
+ 2
+
+ PCA
+ 1
+ 100
+ 1E-05
+ 0.1
+ 0.9
+
+
+
+
+
+
+
+
+
+
+
+ Item2
+
+
+
+
+
+
+
+
+
+
+ 333
+ 50
+ 2
+
+
+
+
+
+ 0,:,0
+
+
+
+ Float32
+
+
+
+
+
+ DataStruct
+
+
+
+ Source1
+
+
+ Item1
+
+
+
+
+
+ Item2
+
+
+
+
+
+
+
+
+ new (Item1 as X, Item2 as Y)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Circle
+ 1
+
+
+
+
+
+
+
+
+
+ 0
+ 50
+ 1
+
+
+
+
+ Float32
+
+
+ YValid
+
+
+ YTrain
+
+
+
+ 2
+
+ ProbabilisticPCA
+ 1
+ 100
+ 1E-05
+ 0.1
+ 0.9
+
+
+
+
+
+
+
+
+
+
+
+ Item2
+
+
+
+
+
+
+
+
+
+
+ 333
+ 50
+ 2
+
+
+
+
+
+ 0,:,0
+
+
+
+ Float32
+
+
+
+
+
+ DataStruct
+
+
+
+ Source1
+
+
+ Item1
+
+
+
+
+
+ Item2
+
+
+
+
+
+
+
+
+ new (Item1 as X, Item2 as Y)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Circle
+ 1
+
+
+
+
+
+
+
+
+ YValid
+
+
+
+ 1
+
+
+
+ it.shape[1]
+
+
+
+
+
+
+
+
+ 0
+ 16650
+ 1
+
+
+
+
+ Int32
+
+
+
+
+
+ Source1
+
+
+
+
+
+
+
+
+ YValid
+
+
+
+ 1
+
+
+
+
+
+
+ Source1
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ it.unsqueeze(1)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ YValidStream
+
+
+
+ YTrain
+
+
+
+ 1
+
+
+
+ it.shape[1]
+
+
+
+
+
+
+
+
+ 0
+ 33350
+ 1
+
+
+
+
+ Int32
+
+
+
+
+
+ Source1
+
+
+
+
+
+
+
+
+ YTrain
+
+
+
+ 1
+
+
+
+
+
+
+ Source1
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ it.unsqueeze(1)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 2
+
+ OnlinePPCA
+ 0.1
+ 100
+ 1E-05
+
+ 0.7
+ 3000
+ 500
+
+
+
+
+
+
+
+
+
+
+ 1
+
+
+
+ Item2
+
+
+ FittedOnlinePPCA
+
+
+ Components
+
+
+
+ 0
+
+
+
+ Float32
+
+
+
+
+
+
+
+
+
+ 0
+ 50
+ 1
+
+
+
+
+ Float32
+
+
+ FittedOnlinePPCA
+
+
+
+
+
+ YValidStream
+
+
+ FittedOnlinePPCA
+
+
+
+
+
+
+
+
+
+ 50
+
+
+
+
+
+
+
+ 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 1
+ 50
+ 2
+
+
+
+
+
+ 0,:,0
+
+
+
+ Float32
+
+
+
+
+
+ DataStruct
+
+
+
+ Source1
+
+
+ Item1
+
+
+
+
+
+ Item2
+
+
+
+
+
+
+
+
+ new (Item1 as X, Item2 as Y)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Circle
+ 1
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file