Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
a26b727
Added new PCA project
ncguilbeault Aug 6, 2025
4bfcb4a
Added main components of PCA package
ncguilbeault Aug 7, 2025
2a941e8
Added probabalistic PCA method to package
ncguilbeault Aug 11, 2025
5be1999
Added online probabalistic PCA method
ncguilbeault Aug 11, 2025
1afd121
Added synthetic data tests for new PCA package
ncguilbeault Aug 11, 2025
e9b7587
Updated online ppca method to return closely what is expected from ba…
ncguilbeault Aug 11, 2025
84d9a94
Refactored `OnlinePPCA` to enforce non-nullable ReorthogonalizePeriod…
ncguilbeault Aug 11, 2025
bfd8857
Added PCA data reconstruction functionality
ncguilbeault Aug 11, 2025
3d4756f
Added method to fit model and transform data
ncguilbeault Aug 11, 2025
dd03327
Moved `InvertSPD` method to seperate static `Utils` class and updated…
ncguilbeault Aug 11, 2025
fb6c741
Added implementation of device and data type properties of base class…
ncguilbeault Aug 11, 2025
885c7a0
Refactored to use improved naming conventions
ncguilbeault Nov 3, 2025
7268686
Added XML documentation to code
ncguilbeault Nov 10, 2025
4d3d5dc
Added method for online PCA using the generalized hebbian rule
ncguilbeault Jan 13, 2026
d7fe674
Refactored PCA model to use more convenient shape (samples x features…
ncguilbeault Jan 14, 2026
16b8f05
Updated online PCA models
ncguilbeault Jan 15, 2026
5621afd
Updated to support "hidding" model property if provided by the input
ncguilbeault Jan 26, 2026
5b60a4f
Changed time offset to sample offset for better consistency
ncguilbeault Jan 27, 2026
7aa2efd
Added standard PCA unit tests
ncguilbeault Jan 29, 2026
d76c2c3
Removed `static` from test method to fix ms test warning
ncguilbeault Mar 25, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Bonsai.ML.sln
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Lds.Torch.Tests",
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Lds.Torch.Design", "src\Bonsai.ML.Lds.Torch.Design\Bonsai.ML.Lds.Torch.Design.csproj", "{1F52DECD-1B2C-4F6C-996C-14C715283B80}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Pca.Torch", "src\Bonsai.ML.Pca.Torch\Bonsai.ML.Pca.Torch.csproj", "{1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Pca.Torch.Tests", "tests\Bonsai.ML.Pca.Torch.Tests\Bonsai.ML.Pca.Torch.Tests.csproj", "{4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -120,6 +124,14 @@ Global
{1F52DECD-1B2C-4F6C-996C-14C715283B80}.Debug|Any CPU.Build.0 = Debug|Any CPU
{1F52DECD-1B2C-4F6C-996C-14C715283B80}.Release|Any CPU.ActiveCfg = Release|Any CPU
{1F52DECD-1B2C-4F6C-996C-14C715283B80}.Release|Any CPU.Build.0 = Release|Any CPU
{1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}.Debug|Any CPU.Build.0 = Debug|Any CPU
{1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}.Release|Any CPU.ActiveCfg = Release|Any CPU
{1A2DEED3-795E-4C28-9C5E-BA3D76B2A485}.Release|Any CPU.Build.0 = Release|Any CPU
{4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}.Debug|Any CPU.Build.0 = Debug|Any CPU
{4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}.Release|Any CPU.ActiveCfg = Release|Any CPU
{4ABCC6B2-024A-450F-85CB-2A9B2D2D2A10}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down
15 changes: 15 additions & 0 deletions src/Bonsai.ML.Pca.Torch/Bonsai.ML.Pca.Torch.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<Description>Bonsai.ML.Pca.Torch Bonsai library.</Description>
<PackageTags>$(PackageTags) PCA Principal Component Analysis</PackageTags>
<TargetFrameworks>net472;netstandard2.0</TargetFrameworks>
<Nullable>enable</Nullable>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\Bonsai.ML\Bonsai.ML.csproj" />
<ProjectReference Include="..\Bonsai.ML.Torch\Bonsai.ML.Torch.csproj" />
</ItemGroup>

</Project>
198 changes: 198 additions & 0 deletions src/Bonsai.ML.Pca.Torch/CreatePca.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
using System;
using System.ComponentModel;
using System.Collections.Generic;
using System.Reactive.Linq;
using System.Linq.Expressions;
using Bonsai.Expressions;
using System.Reflection;
using static TorchSharp.torch;
using System.Xml.Serialization;

namespace Bonsai.ML.Pca.Torch;

/// <summary>
/// Creates a PCA model.
/// </summary>
[Combinator]
[ResetCombinator]
[WorkflowElementCategory(ElementCategory.Source)]
[TypeDescriptionProvider(typeof(PcaDescriptionProvider))]
[Description("Creates a PCA model.")]
public class CreatePca : ZeroArgumentExpressionBuilder, INamedElement
{
/// <inheritdoc/>
public string Name => $"CreatePca.{ModelType}";

/// <summary>
/// The number of principal components to compute.
/// </summary>
public int NumComponents { get; set; } = 2;

/// <summary>
/// The device on which to create the PCA model.
/// </summary>
[XmlIgnore]
[Description("The device on which to create the PCA model.")]
public Device? Device { get; set; }

/// <summary>
/// The scalar type of the PCA model.
/// </summary>
[Description("The scalar type of the PCA model.")]
public ScalarType? ScalarType { get; set; }

/// <summary>
/// The type of PCA model to create.
/// </summary>
[RefreshProperties(RefreshProperties.All)]
[Description("The type of PCA model to create.")]
public PcaModelType ModelType { get; set; } = PcaModelType.Pca;

/// <summary>
/// The initial variance for probabilistic PCA models.
/// </summary>
[Description("The initial variance for probabilistic PCA models.")]
public double InitialVariance { get; set; } = 1.0;

/// <summary>
/// The number of iterations for fitting probabilistic PCA models.
/// </summary>
[Description("The number of iterations for fitting probabilistic PCA models.")]
public int Iterations { get; set; } = 100;

/// <summary>
/// The tolerance for convergence in probabilistic PCA models.
/// </summary>
[Description("The tolerance for convergence in probabilistic PCA models.")]
public double Tolerance { get; set; } = 1e-5;

/// <summary>
/// The constant learning rate parameter for the online probabilistic PCA model.
/// </summary>
[Description("The constant learning rate parameter for the online probabilistic PCA model. Only one of Rho or Kappa must be specified.")]
public double? Rho { get; set; } = 0.1;

/// <summary>
/// The forgetting factor for the online probabilistic PCA model.
/// </summary>
[Description("The forgetting factor for the online probabilistic PCA model. Only one of Rho or Kappa must be specified.")]
public double? Kappa { get; set; } = 0.9;

/// <summary>
/// The sample offset for the online probabilistic PCA model.
/// </summary>
[Description("The sample offset for the online probabilistic PCA model. If null, decaying learning rate starts from the first sample.")]
public int? SampleOffset { get; set; } = null;

/// <summary>
/// The period for reorthogonalizing the components in the online probabilistic PCA model.
/// </summary>
[Description("The period for reorthogonalizing the components in the online probabilistic PCA model. If null, reorthogonalization is not performed.")]
public int? ReorthogonalizePeriod { get; set; } = null;

/// <summary>
/// The random number generator used for initializing probabilistic PCA models.
/// </summary>
[XmlIgnore]
public Generator? Generator { get; set; } = null;

/// <summary>
/// The learning rate for the Online PCA GHA model.
/// </summary>
public double LearningRate { get; set; } = 0.1;

internal IEnumerable<string> GetModelProperties()
{
yield return nameof(NumComponents);
yield return nameof(Device);
yield return nameof(ScalarType);
yield return nameof(ModelType);

if (ModelType == PcaModelType.ProbabilisticPca)
{
yield return nameof(InitialVariance);
yield return nameof(Iterations);
yield return nameof(Tolerance);
yield return nameof(Generator);
}

if (ModelType == PcaModelType.OnlineProbabilisticPca)
{
yield return nameof(InitialVariance);
yield return nameof(Rho);
yield return nameof(Kappa);
yield return nameof(SampleOffset);
yield return nameof(ReorthogonalizePeriod);
yield return nameof(Generator);
}

if (ModelType == PcaModelType.OnlinePcaGha)
{
yield return nameof(LearningRate);
yield return nameof(Generator);
}
}

private static PcaBaseModel CreateModel(CreatePca pcaBuilder)
{
return pcaBuilder.ModelType switch
{
PcaModelType.Pca => new Pca(
numComponents: pcaBuilder.NumComponents,
device: pcaBuilder.Device,
scalarType: pcaBuilder.ScalarType),
PcaModelType.ProbabilisticPca => new ProbabilisticPca(
numComponents: pcaBuilder.NumComponents,
device: pcaBuilder.Device,
scalarType: pcaBuilder.ScalarType,
initialVariance: pcaBuilder.InitialVariance,
generator: pcaBuilder.Generator,
iterations: pcaBuilder.Iterations,
tolerance: pcaBuilder.Tolerance),
PcaModelType.OnlineProbabilisticPca => new OnlineProbabilisticPca(
numComponents: pcaBuilder.NumComponents,
device: pcaBuilder.Device,
scalarType: pcaBuilder.ScalarType,
initialVariance: pcaBuilder.InitialVariance,
generator: pcaBuilder.Generator,
rho: pcaBuilder.Rho,
kappa: pcaBuilder.Kappa,
sampleOffset: pcaBuilder.SampleOffset,
reorthogonalizePeriod: pcaBuilder.ReorthogonalizePeriod),
PcaModelType.OnlinePcaGha => new OnlinePcaGha(
numComponents: pcaBuilder.NumComponents,
learningRate: pcaBuilder.LearningRate,
device: pcaBuilder.Device,
scalarType: pcaBuilder.ScalarType,
generator: pcaBuilder.Generator),
_ => throw new NotSupportedException($"Model type {pcaBuilder.ModelType} is not supported."),
};
}

private static Type GetModelType(PcaModelType modelType)
{
return modelType switch
{
PcaModelType.Pca => typeof(Pca),
PcaModelType.ProbabilisticPca => typeof(ProbabilisticPca),
PcaModelType.OnlineProbabilisticPca => typeof(OnlineProbabilisticPca),
PcaModelType.OnlinePcaGha => typeof(OnlinePcaGha),
_ => throw new NotSupportedException($"Model type {modelType} is not supported."),
};
}

/// <inheritdoc/>
public override Expression Build(IEnumerable<Expression> arguments)
{
var processMethod = GetType().GetMethod(
nameof(Process),
BindingFlags.NonPublic | BindingFlags.Static);
processMethod = processMethod.MakeGenericMethod(GetModelType(ModelType));
return Expression.Call(processMethod, Expression.Constant(this));
}

private static IObservable<T> Process<T>(CreatePca instance) where T : PcaBaseModel
{
return Observable.Return((T)CreateModel(instance));
}
}
143 changes: 143 additions & 0 deletions src/Bonsai.ML.Pca.Torch/Fit.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
using System;
using System.ComponentModel;
using System.Reactive.Linq;
using System.Xml.Serialization;
using static TorchSharp.torch;

namespace Bonsai.ML.Pca.Torch;

/// <summary>
/// Fits the PCA model to the input data.
/// </summary>
public class Fit : IPcaModelProvider
{
/// <inheritdoc/>
public IPcaBaseModel? Model { get; set; } = null;

private void FitModel(IPcaBaseModel model, Tensor data)
{
model.Fit(data);
}

/// <summary>
/// Fits the PCA model to the input data.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public IObservable<Tensor> Process(IObservable<Tensor> source)
{
if (Model is null)
{
throw new InvalidOperationException("The PCA model has not been specified.");
}
return source.Do(value =>
{
FitModel(Model, value);
});
}

/// <summary>
/// Fits a standard PCA model to the input data.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tuple<Pca, Tensor>> Process(IObservable<Tuple<Pca, Tensor>> source)
{
return source.Do((value) =>
{
FitModel(value.Item1, value.Item2);
});
}

/// <summary>
/// Fits a standard PCA model to the input data.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tuple<Tensor, Pca>> Process(IObservable<Tuple<Tensor, Pca>> source)
{
return source.Do((value) =>
{
FitModel(value.Item2, value.Item1);
});
}

/// <summary>
/// Fits a probabilistic PCA model to the input data.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tuple<ProbabilisticPca, Tensor>> Process(IObservable<Tuple<ProbabilisticPca, Tensor>> source)
{
return source.Do((value) =>
{
FitModel(value.Item1, value.Item2);
});
}

/// <summary>
/// Fits a probabilistic PCA model to the input data.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tuple<Tensor, ProbabilisticPca>> Process(IObservable<Tuple<Tensor, ProbabilisticPca>> source)
{
return source.Do((value) =>
{
FitModel(value.Item2, value.Item1);
});
}

/// <summary>
/// Fits an online probabilistic PCA model to the input data.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tuple<OnlineProbabilisticPca, Tensor>> Process(IObservable<Tuple<OnlineProbabilisticPca, Tensor>> source)
{
return source.Do((value) =>
{
FitModel(value.Item1, value.Item2);
});
}

/// <summary>
/// Fits an online probabilistic PCA model to the input data.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tuple<Tensor, OnlineProbabilisticPca>> Process(IObservable<Tuple<Tensor, OnlineProbabilisticPca>> source)
{
return source.Do((value) =>
{
FitModel(value.Item2, value.Item1);
});
}

/// <summary>
/// Fits an online PCA model using the Generalized Hebbian Algorithm to the input data.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tuple<OnlinePcaGha, Tensor>> Process(IObservable<Tuple<OnlinePcaGha, Tensor>> source)
{
return source.Do((value) =>
{
FitModel(value.Item1, value.Item2);
});
}

/// <summary>
/// Fits an online PCA model using the Generalized Hebbian Algorithm to the input data.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tuple<Tensor, OnlinePcaGha>> Process(IObservable<Tuple<Tensor, OnlinePcaGha>> source)
{
return source.Do((value) =>
{
FitModel(value.Item2, value.Item1);
});
}
}
Loading
Loading