diff --git a/Bonsai.ML.sln b/Bonsai.ML.sln index 62e44269..fc4e963f 100644 --- a/Bonsai.ML.sln +++ b/Bonsai.ML.sln @@ -1,4 +1,4 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 +Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio Version 17 VisualStudioVersion = 17.0.31903.59 MinimumVisualStudioVersion = 10.0.40219.1 @@ -40,6 +40,12 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "build", "build", "{DEE5DD87 EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Tests.Utilities", "tests\Bonsai.ML.Tests.Utilities\Bonsai.ML.Tests.Utilities.csproj", "{DB1090D3-38DD-404B-96DF-66BF3C39E508}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Lds.Torch", "src\Bonsai.ML.Lds.Torch\Bonsai.ML.Lds.Torch.csproj", "{41D4BEC7-AB1F-41E4-95FE-4DB23970FF4B}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Lds.Torch.Tests", "tests\Bonsai.ML.Lds.Torch.Tests\Bonsai.ML.Lds.Torch.Tests.csproj", "{0B258929-0B07-4CE7-BE8D-A86BBC46AAD4}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Lds.Torch.Design", "src\Bonsai.ML.Lds.Torch.Design\Bonsai.ML.Lds.Torch.Design.csproj", "{1F52DECD-1B2C-4F6C-996C-14C715283B80}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -102,6 +108,18 @@ Global {DB1090D3-38DD-404B-96DF-66BF3C39E508}.Debug|Any CPU.Build.0 = Debug|Any CPU {DB1090D3-38DD-404B-96DF-66BF3C39E508}.Release|Any CPU.ActiveCfg = Release|Any CPU {DB1090D3-38DD-404B-96DF-66BF3C39E508}.Release|Any CPU.Build.0 = Release|Any CPU + {41D4BEC7-AB1F-41E4-95FE-4DB23970FF4B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {41D4BEC7-AB1F-41E4-95FE-4DB23970FF4B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {41D4BEC7-AB1F-41E4-95FE-4DB23970FF4B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {41D4BEC7-AB1F-41E4-95FE-4DB23970FF4B}.Release|Any CPU.Build.0 = Release|Any CPU + {0B258929-0B07-4CE7-BE8D-A86BBC46AAD4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {0B258929-0B07-4CE7-BE8D-A86BBC46AAD4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {0B258929-0B07-4CE7-BE8D-A86BBC46AAD4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {0B258929-0B07-4CE7-BE8D-A86BBC46AAD4}.Release|Any CPU.Build.0 = Release|Any CPU + {1F52DECD-1B2C-4F6C-996C-14C715283B80}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {1F52DECD-1B2C-4F6C-996C-14C715283B80}.Debug|Any CPU.Build.0 = Debug|Any CPU + {1F52DECD-1B2C-4F6C-996C-14C715283B80}.Release|Any CPU.ActiveCfg = Release|Any CPU + {1F52DECD-1B2C-4F6C-996C-14C715283B80}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/docs/README.md b/docs/README.md index 60864b6d..9dbc6ac9 100644 --- a/docs/README.md +++ b/docs/README.md @@ -51,6 +51,9 @@ Interfaces with the [TorchSharp](https://github.com/dotnet/TorchSharp) package, > [!NOTE] > Bonsai.ML packages can be installed through Bonsai's integrated package manager and are generally ready for immediate use. However, some packages may require additional installation steps. Refer to the specific package section for detailed installation guides and documentation. +### Bonsai.ML.Lds.Torch +Linear dynamical systems implemented using the `Bonsai.ML.Torch` package for online filtering and smoothing of latent states and parameter estimation. + ## Development Roadmap The ultimate goal of the `Bonsai.ML` project is to bring powerful machine learning tools into Bonsai to enable intelligent experimental control. To achieve this, our plan is to incorporate several different packages, models, frameworks, and language integrations. You can follow our progress by going to the [Bonsai ML development roadmap](https://github.com/orgs/bonsai-rx/projects/7). diff --git a/docs/articles/Lds.Torch/lds-torch-overview.md b/docs/articles/Lds.Torch/lds-torch-overview.md new file mode 100644 index 00000000..fc140755 --- /dev/null +++ b/docs/articles/Lds.Torch/lds-torch-overview.md @@ -0,0 +1,7 @@ +# Bonsai.ML.Lds.Torch - Overview + +This package provides an implementation of the Kalman filter, Rauch-Tung-Striebel (RTS) smoother, expectation maximization (EM) algorithm, and stochastic subspace identification, developed for online filtering, smoothing, and parameter estimation from data streams in Bonsai using the TorchSharp package. + +## Installation Guide + +Install the `Bonsai.ML.Lds.Torch` package from the Bonsai package manager. You will also need to follow the [instructions for setting up the Bonsai.ML.Torch package](../Torch/torch-overview.md) for running on the CPU or GPU. \ No newline at end of file diff --git a/docs/articles/toc.yml b/docs/articles/toc.yml index 162c5226..59c5a874 100644 --- a/docs/articles/toc.yml +++ b/docs/articles/toc.yml @@ -19,4 +19,6 @@ href: PointProcessDecoder/ppd-overview.md - name: Torch - name: Overview - href: Torch/torch-overview.md \ No newline at end of file + href: Torch/torch-overview.md +- name: Lds.Torch + href: Lds.Torch/lds-torch-overview.md \ No newline at end of file diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/OxyColorPresetCycle.cs b/src/Bonsai.ML.Design/OxyColorPresetCycle.cs similarity index 95% rename from src/Bonsai.ML.PointProcessDecoder.Design/OxyColorPresetCycle.cs rename to src/Bonsai.ML.Design/OxyColorPresetCycle.cs index e2976cf4..800744cf 100644 --- a/src/Bonsai.ML.PointProcessDecoder.Design/OxyColorPresetCycle.cs +++ b/src/Bonsai.ML.Design/OxyColorPresetCycle.cs @@ -1,7 +1,7 @@ using OxyPlot; -namespace Bonsai.ML.PointProcessDecoder.Design +namespace Bonsai.ML.Design { /// /// Enumerates the colors and provides a preset collection of colors to cycle through. diff --git a/src/Bonsai.ML.Lds.Torch.Design/Bonsai.ML.Lds.Torch.Design.csproj b/src/Bonsai.ML.Lds.Torch.Design/Bonsai.ML.Lds.Torch.Design.csproj new file mode 100644 index 00000000..2185e90e --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch.Design/Bonsai.ML.Lds.Torch.Design.csproj @@ -0,0 +1,19 @@ + + + Visualizers for the Bonsai.ML.Lds.Torch library. + $(PackageTags) Torch LDS Design + net472 + true + + + + + + + + + + + false + + \ No newline at end of file diff --git a/src/Bonsai.ML.Lds.Torch.Design/ExpectationMaximizationVisualizer.cs b/src/Bonsai.ML.Lds.Torch.Design/ExpectationMaximizationVisualizer.cs new file mode 100644 index 00000000..0648685e --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch.Design/ExpectationMaximizationVisualizer.cs @@ -0,0 +1,90 @@ +using System; +using System.Reactive; +using System.Linq; +using System.Windows.Forms; +using System.Collections.Generic; + +using Bonsai; +using Bonsai.Design; +using Bonsai.ML.Design; + +using OxyPlot; +using OxyPlot.Series; + +using static TorchSharp.torch; + +[assembly: TypeVisualizer(typeof(Bonsai.ML.Lds.Torch.Design.ExpectationMaximizationVisualizer), + Target = typeof(Bonsai.ML.Lds.Torch.ExpectationMaximizationResult))] + +namespace Bonsai.ML.Lds.Torch.Design; + +/// +/// Provides a visualizer for the state means and covariances from a Kalman filter or smoother. +/// +public class ExpectationMaximizationVisualizer : BufferedVisualizer +{ + private TimeSeriesOxyPlotBase _plot; + private LineSeries _lineSeries; + + /// + /// Gets the underlying plot control. + /// + public TimeSeriesOxyPlotBase Plot => _plot; + + /// + public override void Load(IServiceProvider provider) + { + _plot = new TimeSeriesOxyPlotBase() + { + Dock = DockStyle.Fill, + StartTime = DateTime.Now, + BufferData = true, + ValueLabel = "Log Likelihood" + }; + + _lineSeries = _plot.AddNewLineSeries("Log Likelihood", OxyColors.Blue); + + var visualizerService = (IDialogTypeVisualizerService)provider.GetService(typeof(IDialogTypeVisualizerService)); + visualizerService?.AddControl(_plot); + } + + /// + public override void Show(object value) + { + } + + /// + protected override void Show(DateTime time, object value) + { + if (value is null) return; + + if (value is not ExpectationMaximizationResult result) return; + + var logLikelihood = result.LogLikelihood; + if (logLikelihood is null) return; + + var ll = logLikelihood[-1].to_type(ScalarType.Float64).item(); + + _plot.AddToLineSeries( + lineSeries: _lineSeries, + time: time, + value: ll + ); + } + + /// + protected override void ShowBuffer(IList> values) + { + base.ShowBuffer(values); + if (values.Count > 0) + { + _plot.UpdatePlot(); + } + } + + /// + public override void Unload() + { + if (!_plot.IsDisposed) _plot.Dispose(); + } +} diff --git a/src/Bonsai.ML.Lds.Torch.Design/Properties/AssemblyInfo.cs b/src/Bonsai.ML.Lds.Torch.Design/Properties/AssemblyInfo.cs new file mode 100644 index 00000000..7a732600 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch.Design/Properties/AssemblyInfo.cs @@ -0,0 +1,6 @@ +using Bonsai; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: XmlNamespacePrefix("clr-namespace:Bonsai.ML.Lds.Torch.Design", null)] diff --git a/src/Bonsai.ML.Lds.Torch.Design/Properties/launchSettings.json b/src/Bonsai.ML.Lds.Torch.Design/Properties/launchSettings.json new file mode 100644 index 00000000..4af4f468 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch.Design/Properties/launchSettings.json @@ -0,0 +1,10 @@ +{ + "profiles": { + "Bonsai": { + "commandName": "Executable", + "executablePath": "$(BonsaiExecutablePath)", + "commandLineArgs": "--lib:\"$(TargetDir).\"", + "nativeDebugging": true + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Lds.Torch.Design/StateVisualizer.cs b/src/Bonsai.ML.Lds.Torch.Design/StateVisualizer.cs new file mode 100644 index 00000000..d1aea0ca --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch.Design/StateVisualizer.cs @@ -0,0 +1,202 @@ +using System; +using System.Reactive; +using System.Linq; +using System.Windows.Forms; +using System.Collections.Generic; + +using Bonsai; +using Bonsai.Design; +using Bonsai.ML.Design; + +using OxyPlot; +using OxyPlot.Series; + +using static TorchSharp.torch; + +[assembly: TypeVisualizer(typeof(Bonsai.ML.Lds.Torch.Design.StateVisualizer), + Target = typeof(Bonsai.ML.Lds.Torch.FilteredState))] +[assembly: TypeVisualizer(typeof(Bonsai.ML.Lds.Torch.Design.StateVisualizer), + Target = typeof(Bonsai.ML.Lds.Torch.LinearDynamicalSystemState))] + +namespace Bonsai.ML.Lds.Torch.Design; + +/// +/// Provides a visualizer for the state means and covariances from a Kalman filter or smoother. +/// +public class StateVisualizer : BufferedVisualizer +{ + private TimeSeriesOxyPlotBase _plot; + private LineSeries[] _lineSeries; + private AreaSeries[] _areaSeries; + + /// + /// Gets or sets the amount of time in seconds that should be shown along the x axis. + /// + public int Capacity { get; set; } = 10; + + /// + /// Gets or sets a boolean value that determines whether to buffer the data beyond the capacity. + /// + public bool BufferData { get; set; } = false; + + /// + /// Gets the underlying plot control. + /// + public TimeSeriesOxyPlotBase Plot => _plot; + + /// + public override void Load(IServiceProvider provider) + { + _plot = new TimeSeriesOxyPlotBase() + { + Capacity = Capacity, + Dock = DockStyle.Fill, + StartTime = DateTime.Now, + BufferData = BufferData + }; + + var capacityStatusLabel = new ToolStripStatusLabel + { + Text = "Capacity: ", + AutoSize = true + }; + + var capacityStatusControl = new ToolStripTextBox + { + Text = Capacity.ToString(), + AutoSize = true + }; + + capacityStatusControl.TextChanged += (sender, e) => + { + if (int.TryParse(capacityStatusControl.Text, out int capacity)) + { + Capacity = capacity; + _plot.Capacity = Capacity; + } + }; + + var bufferDataStatusLabel = new ToolStripStatusLabel + { + Text = "Buffer Data: ", + AutoSize = true + }; + + var bufferDataStatusControl = new ToolStripComboBox + { + AutoSize = true + }; + + bufferDataStatusControl.Items.AddRange(["True", "False"]); + bufferDataStatusControl.SelectedIndex = BufferData ? 0 : 1; + + bufferDataStatusControl.SelectedIndexChanged += (sender, e) => + { + BufferData = bufferDataStatusControl.SelectedIndex == 0; + _plot.BufferData = BufferData; + }; + + _plot.StatusStrip.Items.Add(capacityStatusLabel); + _plot.StatusStrip.Items.Add(capacityStatusControl); + _plot.StatusStrip.Items.Add(bufferDataStatusLabel); + _plot.StatusStrip.Items.Add(bufferDataStatusControl); + + var visualizerService = (IDialogTypeVisualizerService)provider.GetService(typeof(IDialogTypeVisualizerService)); + visualizerService?.AddControl(_plot); + } + + /// + public override void Show(object value) + { + } + + /// + protected override void Show(DateTime time, object value) + { + if (value is null) return; + + if (value is not ILinearDynamicalSystemState state) + throw new ArgumentException($"Expected value to be a type of {nameof(ILinearDynamicalSystemState)}.", nameof(value)); + + var mean = state.Mean; + var covariance = state.Covariance; + + if (mean is null || covariance is null) return; + + if (mean.Dimensions == 1) + { + mean = mean.unsqueeze(0); + covariance = covariance.unsqueeze(0); + } + + var numTimesteps = mean.shape[0]; + var numStates = mean.shape[1]; + + if (_lineSeries is null || _areaSeries is null) + { + var colors = new OxyColorPresetCycle(); + + _lineSeries = new LineSeries[numStates]; + _areaSeries = new AreaSeries[numStates]; + + for (int i = 0; i < numStates; i++) + { + var currentColor = colors.Next(); + + _lineSeries[i] = _plot.AddNewLineSeries( + lineSeriesName: $"Mean: {i}", + color: currentColor + ); + + _areaSeries[i] = _plot.AddNewAreaSeries( + areaSeriesName: $"Variance: {i}", + color: currentColor + ); + } + } + + var covarianceDiagonal = covariance.diagonal(0, 1, 2); + + for (int i = 0; i < numTimesteps; i++) + { + for (int j = 0; j < numStates; j++) + { + + var meanVal = mean[i, j].to_type(ScalarType.Float64).item(); + + _plot.AddToLineSeries( + lineSeries: _lineSeries[j], + time: time, + value: meanVal + ); + + var sigmaVal = covarianceDiagonal[i, j].sqrt().to_type(ScalarType.Float64).item(); + + _plot.AddToAreaSeries( + areaSeries: _areaSeries[j], + time: time, + value1: meanVal + sigmaVal, + value2: meanVal - sigmaVal + ); + } + } + } + + /// + protected override void ShowBuffer(IList> values) + { + base.ShowBuffer(values); + if (values.Count > 0) + { + var time = values.LastOrDefault().Timestamp.DateTime; + _plot.SetAxes(minTime: time.AddSeconds(-Capacity), maxTime: time); + _plot.UpdatePlot(); + } + } + + /// + public override void Unload() + { + if (!_plot.IsDisposed) _plot.Dispose(); + } +} diff --git a/src/Bonsai.ML.Lds.Torch/Bonsai.ML.Lds.Torch.csproj b/src/Bonsai.ML.Lds.Torch/Bonsai.ML.Lds.Torch.csproj new file mode 100644 index 00000000..4b6a4bb6 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/Bonsai.ML.Lds.Torch.csproj @@ -0,0 +1,13 @@ + + + A Bonsai package building on the Bonsai.ML.Torch library that implements Linear Dynamical Systems. + $(PackageTags) Torch TorchSharp LDS LinearDynamicalSystems + net472;netstandard2.0 + + + + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs new file mode 100644 index 00000000..4b9dda07 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs @@ -0,0 +1,269 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; +using Bonsai.ML.Torch; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Creates a Kalman filter model. +/// +[Combinator] +[ResetCombinator] +[Description("Creates a Kalman filter model.")] +[WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] +public class CreateKalmanFilter : IScalarTypeProvider +{ + private Tensor _transitionMatrix; + private Tensor _measurementFunction; + private Tensor _processNoiseVariance; + private Tensor _measurementNoiseVariance; + private Tensor _initialMean; + private Tensor _initialCovariance; + private Tensor _stateOffset; + private Tensor _observationOffset; + + /// + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// The device on which to create the tensor. + /// + [Description("The device on which to create the tensor.")] + [XmlIgnore] + public Device Device { get; set; } + + /// + /// The number of states in the Kalman filter model. + /// + public int? NumStates { get; set; } = null; + + /// + /// The number of observations in the Kalman filter model. + /// + public int? NumObservations { get; set; } = null; + + /// + /// The state transition matrix. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor TransitionMatrix + { + get => _transitionMatrix; + set => _transitionMatrix = value; + } + + /// + /// The XML string representation of the transition matrix for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(TransitionMatrix))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string TransitionMatrixXml + { + get => TensorConverter.ConvertToString(_transitionMatrix, Type); + set => _transitionMatrix = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// The measurement function. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor MeasurementFunction + { + get => _measurementFunction; + set => _measurementFunction = value; + } + + /// + /// The XML string representation of the measurement function for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(MeasurementFunction))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string MeasurementFunctionXml + { + get => TensorConverter.ConvertToString(_measurementFunction, Type); + set => _measurementFunction = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// The process noise variance. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor ProcessNoiseVariance + { + get => _processNoiseVariance; + set => _processNoiseVariance = value; + } + + /// + /// The XML string representation of the process noise variance for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(ProcessNoiseVariance))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ProcessNoiseVarianceXml + { + get => TensorConverter.ConvertToString(_processNoiseVariance, Type); + set => _processNoiseVariance = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// The measurement noise variance. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor MeasurementNoiseVariance + { + get => _measurementNoiseVariance; + set => _measurementNoiseVariance = value; + } + + /// + /// The XML string representation of the measurement noise variance for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(MeasurementNoiseVariance))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string MeasurementNoiseVarianceXml + { + get => TensorConverter.ConvertToString(_measurementNoiseVariance, Type); + set => _measurementNoiseVariance = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// The initial mean. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor InitialMean + { + get => _initialMean; + set => _initialMean = value; + } + + /// + /// The XML string representation of the initial mean for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(InitialMean))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string InitialMeanXml + { + get => TensorConverter.ConvertToString(_initialMean, Type); + set => _initialMean = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// The initial covariance. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor InitialCovariance + { + get => _initialCovariance; + set => _initialCovariance = value; + } + + /// + /// The XML string representation of the initial covariance for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(InitialCovariance))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string InitialCovarianceXml + { + get => TensorConverter.ConvertToString(_initialCovariance, Type); + set => _initialCovariance = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// The state offset. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor StateOffset + { + get => _stateOffset; + set => _stateOffset = value; + } + + /// + /// The XML string representation of the state offset for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(StateOffset))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string StateOffsetXml + { + get => TensorConverter.ConvertToString(_stateOffset, Type); + set => _stateOffset = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// The observation offset. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor ObservationOffset + { + get => _observationOffset; + set => _observationOffset = value; + } + + /// + /// The XML string representation of the observation offset for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(ObservationOffset))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ObservationOffsetXml + { + get => TensorConverter.ConvertToString(_observationOffset, Type); + set => _observationOffset = TensorConverter.ConvertFromString(value, Type); + } + + + /// + /// Creates a Kalman filter model using the properties of this class. + /// + public IObservable Process() + { + return Observable.Return(new KalmanFilter( + numStates: NumStates, + numObservations: NumObservations, + transitionMatrix: TransitionMatrix, + measurementFunction: MeasurementFunction, + processNoiseVariance: ProcessNoiseVariance, + measurementNoiseVariance: MeasurementNoiseVariance, + initialMean: InitialMean, + initialCovariance: InitialCovariance, + stateOffset: StateOffset, + observationOffset: ObservationOffset, + device: Device, + scalarType: Type + )); + } + + /// + /// Creates a Kalman filter model using the parameters provided in the input sequence. + /// + public IObservable Process(IObservable source) + { + return source.Select(parameters => + { + return new KalmanFilter( + parameters: parameters + ); + }); + } +} diff --git a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs new file mode 100644 index 00000000..b3aa2a09 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs @@ -0,0 +1,280 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; +using Bonsai.ML.Torch; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Initializes the parameters for a new Kalman filter model. +/// +[Combinator] +[ResetCombinator] +[Description("Initializes the parameters for a new Kalman filter model.")] +[WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] +public class CreateKalmanFilterParameters : IScalarTypeProvider +{ + private Tensor _transitionMatrix = null; + private Tensor _measurementFunction = null; + private Tensor _processNoiseCovariance = null; + private Tensor _measurementNoiseCovariance = null; + private Tensor _initialMean = null; + private Tensor _initialCovariance = null; + private Tensor _stateOffset = null; + private Tensor _observationOffset = null; + + /// + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// The device on which to create the tensor. + /// + [XmlIgnore] + [Description("The device on which to create the tensor.")] + public Device Device { get; set; } + + /// + /// The number of states in the Kalman filter model. + /// + public int? NumStates { get; set; } = null; + + /// + /// The number of observations in the Kalman filter model. + /// + public int? NumObservations { get; set; } = null; + + /// + /// The state transition matrix. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor TransitionMatrix + { + get => _transitionMatrix; + set => _transitionMatrix = value; + } + + /// + /// The XML string representation of the transition matrix for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(TransitionMatrix))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string TransitionMatrixXml + { + get => TensorConverter.ConvertToString(_transitionMatrix, Type); + set => _transitionMatrix = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// The measurement function. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor MeasurementFunction + { + get => _measurementFunction; + set => _measurementFunction = value; + } + + /// + /// The XML string representation of the measurement function for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(MeasurementFunction))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string MeasurementFunctionXml + { + get => TensorConverter.ConvertToString(_measurementFunction, Type); + set => _measurementFunction = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// The process noise variance. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor ProcessNoiseCovariance + { + get => _processNoiseCovariance; + set => _processNoiseCovariance = value; + } + + /// + /// The XML string representation of the process noise variance for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(ProcessNoiseCovariance))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ProcessNoiseCovarianceXml + { + get => TensorConverter.ConvertToString(_processNoiseCovariance, Type); + set => _processNoiseCovariance = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// The measurement noise covariance matrix. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor MeasurementNoiseCovariance + { + get => _measurementNoiseCovariance; + set => _measurementNoiseCovariance = value; + } + + /// + /// The XML string representation of the measurement noise covariance matrix for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(MeasurementNoiseCovariance))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string MeasurementNoiseCovarianceXml + { + get => TensorConverter.ConvertToString(_measurementNoiseCovariance, Type); + set => _measurementNoiseCovariance = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// The initial mean. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor InitialMean + { + get => _initialMean; + set => _initialMean = value; + } + + /// + /// The XML string representation of the initial state for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(InitialMean))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string InitialMeanXml + { + get => TensorConverter.ConvertToString(_initialMean, Type); + set => _initialMean = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// The initial covariance. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor InitialCovariance + { + get => _initialCovariance; + set => _initialCovariance = value; + } + + /// + /// The XML string representation of the initial covariance for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(InitialCovariance))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string InitialCovarianceXml + { + get => TensorConverter.ConvertToString(_initialCovariance, Type); + set => _initialCovariance = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// The state offset. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor StateOffset + { + get => _stateOffset; + set => _stateOffset = value; + } + + /// + /// The XML string representation of the state offset for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(StateOffset))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string StateOffsetXml + { + get => TensorConverter.ConvertToString(_stateOffset, Type); + set => _stateOffset = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// The observation offset. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor ObservationOffset + { + get => _observationOffset; + set => _observationOffset = value; + } + + /// + /// The XML string representation of the observation offset for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(ObservationOffset))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ObservationOffsetXml + { + get => TensorConverter.ConvertToString(_observationOffset, Type); + set => _observationOffset = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// Creates parameters for a Kalman filter model using the properties of this class. + /// + public IObservable Process() + { + return Observable.Return(new KalmanFilterParameters( + numStates: NumStates, + numObservations: NumObservations, + transitionMatrix: _transitionMatrix, + measurementFunction: _measurementFunction, + processNoiseCovariance: _processNoiseCovariance, + measurementNoiseCovariance: _measurementNoiseCovariance, + initialMean: _initialMean, + initialCovariance: _initialCovariance, + stateOffset: _stateOffset, + observationOffset: _observationOffset, + scalarType: Type, + device: Device + )); + } + + /// + /// Creates parameters for a Kalman filter model for each element in the input sequence. + /// + public IObservable Process(IObservable source) + { + return source.Select(_ => + { + return new KalmanFilterParameters( + numStates: NumStates, + numObservations: NumObservations, + transitionMatrix: _transitionMatrix, + measurementFunction: _measurementFunction, + processNoiseCovariance: _processNoiseCovariance, + measurementNoiseCovariance: _measurementNoiseCovariance, + initialMean: _initialMean, + initialCovariance: _initialCovariance, + stateOffset: _stateOffset, + observationOffset: _observationOffset, + scalarType: Type, + device: Device + ); + }); + } +} diff --git a/src/Bonsai.ML.Lds.Torch/CreateLinearDynamicalSystemState.cs b/src/Bonsai.ML.Lds.Torch/CreateLinearDynamicalSystemState.cs new file mode 100644 index 00000000..a1760724 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/CreateLinearDynamicalSystemState.cs @@ -0,0 +1,126 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; +using Bonsai.ML.Torch; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Creates a generic state object for a linear gaussian dynamical system. +/// +[Combinator] +[ResetCombinator] +[Description("Creates a new state for a linear gaussian dynamical system.")] +[WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] +public class CreateLinearDynamicalSystemState : IScalarTypeProvider +{ + private Tensor _mean = null; + private Tensor _covariance = null; + + /// + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// The device on which to create the tensor. + /// + [Description("The device on which to create the tensor.")] + [XmlIgnore] + public Device Device { get; set; } + + /// + /// The mean of the state. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor Mean + { + get => _mean; + set => _mean = value; + } + + /// + /// The XML string representation of the mean for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(Mean))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string MeanXml + { + get => TensorConverter.ConvertToString(_mean, Type); + set => _mean = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// The covariance of the state. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor Covariance + { + get => _covariance; + set => _covariance = value; + } + + /// + /// The XML string representation of the covariance for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(Covariance))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string CovarianceXml + { + get => TensorConverter.ConvertToString(_covariance, Type); + set => _covariance = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// Creates an observable sequence and emits the state for a linear gaussian dynamical system. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => + { + var device = Device ?? CPU; + var mean = _mean?.to(device) ?? throw new InvalidOperationException("The mean of the state must be specified."); + var covariance = _covariance?.to(device) ?? throw new InvalidOperationException("The covariance of the state must be specified."); + return Observable.Return(new LinearDynamicalSystemState(mean, covariance)); + }); + } + + /// + /// Processes an observable sequence and emits new states for a linear gaussian dynamical system. + /// + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(_ => + { + var device = Device ?? CPU; + var mean = _mean?.to(device) ?? throw new InvalidOperationException("The mean of the state must be specified."); + var covariance = _covariance?.to(device) ?? throw new InvalidOperationException("The covariance of the state must be specified."); + return new LinearDynamicalSystemState(mean, covariance); + }); + } + + /// + /// Processes an observable sequence of a tuple of tensors (mean and covariance) and emits a state for a linear gaussian dynamical system. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(input => + { + return new LinearDynamicalSystemState(input.Item1, input.Item2); + }); + } +} diff --git a/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs new file mode 100644 index 00000000..f54d24c8 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs @@ -0,0 +1,215 @@ +using System; +using System.ComponentModel; +using System.Reactive; +using System.Xml.Serialization; +using System.Linq; +using System.Reactive.Linq; +using TorchSharp; +using static TorchSharp.torch; +using System.Threading.Tasks; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Learn the parameters of a kalman filter using the batch EM update algorithm. +/// +[Combinator] +[ResetCombinator] +[Description("Learn the parameters of a kalman filter using the batch EM update algorithm.")] +[WorkflowElementCategory(ElementCategory.Combinator)] +public class ExpectationMaximization +{ + private int _maxIterations = 10; + private double _tolerance = 1e-4; + private bool _verbose = true; + + /// + /// The number of states in the Kalman filter model. + /// + [Description("The number of states in the Kalman filter model.")] + public int? NumStates { get; set; } = null; + + /// + /// The Kalman filter parameters used to initialize the model. + /// + [Description("The Kalman filter parameters used to initialize the model.")] + [XmlIgnore] + public KalmanFilterParameters? ModelParameters { get; set; } = null; + + /// + /// The maximum number of EM iterations to perform. + /// + [Description("The maximum number of EM iterations to perform.")] + public int MaxIterations + { + get => _maxIterations; + set => _maxIterations = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(MaxIterations), "Must be greater than zero."); + } + + /// + /// The convergence tolerance for the EM algorithm. + /// + [Description("The convergence tolerance for the EM algorithm.")] + public double Tolerance + { + get => _tolerance; + set => _tolerance = value >= 0 ? value : throw new ArgumentOutOfRangeException(nameof(Tolerance), "Must be greater than or equal to zero."); + } + + /// + /// If true, prints progress messages to the console. + /// + [Description("If true, prints progress messages to the console.")] + public bool Verbose + { + get => _verbose; + set => _verbose = value; + } + + /// + /// If true, the transition matrix will be estimated during the EM algorithm. + /// + [Description("If true, the transition matrix will be estimated during the EM algorithm.")] + public bool EstimateTransitionMatrix { get; set; } = true; + + /// + /// If true, the measurement function will be estimated during the EM algorithm. + /// + [Description("If true, the measurement function will be estimated during the EM algorithm.")] + public bool EstimateMeasurementFunction { get; set; } = true; + + /// + /// If true, the process noise covariance will be estimated during the EM algorithm. + /// + [Description("If true, the process noise covariance will be estimated during the EM algorithm.")] + public bool EstimateProcessNoiseCovariance { get; set; } = true; + + /// + /// If true, the measurement noise covariance will be estimated during the EM algorithm. + /// + [Description("If true, the measurement noise covariance will be estimated during the EM algorithm.")] + public bool EstimateMeasurementNoiseCovariance { get; set; } = true; + + /// + /// If true, the initial mean will be estimated during the EM algorithm. + /// + [Description("If true, the initial mean will be estimated during the EM algorithm.")] + public bool EstimateInitialMean { get; set; } = true; + + /// + /// If true, the initial covariance will be estimated during the EM algorithm. + /// + [Description("If true, the initial covariance will be estimated during the EM algorithm.")] + public bool EstimateInitialCovariance { get; set; } = true; + + /// + /// If true, the state offset will be estimated during the EM algorithm. + /// + [Description("If true, the state offset will be estimated during the EM algorithm.")] + public bool EstimateStateOffset { get; set; } = false; + + /// + /// If true, the observation offset will be estimated during the EM algorithm. + /// + [Description("If true, the observation offset will be estimated during the EM algorithm.")] + public bool EstimateObservationOffset { get; set; } = false; + + /// + /// Processes an observable sequence of input tensors, applying the Expectation-Maximization algorithm to learn the parameters of a Kalman filter model. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => Observable.Create((observer, cancellationToken) => + { + return Task.Run(() => + { + var numObservations = (int)input.size(1); + var previousLogLikelihood = double.NegativeInfinity; + var logLikelihood = zeros([MaxIterations], device: input.device); + var maxIterationsReached = false; + + var parametersToEstimate = new ParametersToEstimate( + transitionMatrix: EstimateTransitionMatrix, + measurementFunction: EstimateMeasurementFunction, + processNoiseCovariance: EstimateProcessNoiseCovariance, + measurementNoiseCovariance: EstimateMeasurementNoiseCovariance, + initialMean: EstimateInitialMean, + initialCovariance: EstimateInitialCovariance, + stateOffset: EstimateStateOffset, + observationOffset: EstimateObservationOffset); + + var parameters = ModelParameters?.Copy() ?? new KalmanFilterParameters( + numStates: NumStates, + numObservations: numObservations, + scalarType: input.dtype, + device: input.device); + + for (int i = 0; i < MaxIterations; i++) + { + // Check for cancellation before each iteration + if (cancellationToken.IsCancellationRequested) + { + observer.OnCompleted(); + return System.Reactive.Disposables.Disposable.Empty; + } + + var result = KalmanFilter.ExpectationMaximization( + observation: input, + parameters: parameters, + maxIterations: 1, + tolerance: Tolerance, + parametersToEstimate: parametersToEstimate); + + var logLikelihoodSum = result.LogLikelihood + .to_type(ScalarType.Float32) + .item(); + + logLikelihood[i] = logLikelihoodSum; + + if (Verbose) + { + Console.WriteLine("Iteration " + (i + 1) + ", Log Likelihood: " + logLikelihoodSum); + if (i == MaxIterations - 1) + { + Console.WriteLine("EM reached the maximum number of iterations."); + maxIterationsReached = true; + } + } + + if (logLikelihoodSum - previousLogLikelihood < Tolerance) + { + if (Verbose) + { + Console.WriteLine("EM converged after " + (i + 1) + " iterations."); + } + logLikelihood = logLikelihood[TensorIndex.Slice(0, i + 1)]; + break; + } + + parameters = result.Parameters; + + if (!maxIterationsReached) + { + previousLogLikelihood = logLikelihoodSum; + + observer.OnNext(new ExpectationMaximizationResult( + logLikelihood: logLikelihood[TensorIndex.Slice(0, i + 1)], + parameters: parameters, + finished: false)); + } + } + + observer.OnNext(new ExpectationMaximizationResult( + logLikelihood: logLikelihood, + parameters: parameters, + finished: true)); + + observer.OnCompleted(); + return System.Reactive.Disposables.Disposable.Empty; + }, + cancellationToken); + })).Concat(); + } +} diff --git a/src/Bonsai.ML.Lds.Torch/ExpectationMaximizationResult.cs b/src/Bonsai.ML.Lds.Torch/ExpectationMaximizationResult.cs new file mode 100644 index 00000000..67dd8f22 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/ExpectationMaximizationResult.cs @@ -0,0 +1,30 @@ +using static TorchSharp.torch; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Represents the result of an expectation-maximization step for a Kalman filter model. +/// +/// +/// +/// +public struct ExpectationMaximizationResult( + Tensor logLikelihood, + KalmanFilterParameters parameters, + bool finished = false) +{ + /// + /// The log likelihood of the observed data given the model parameters after each iteration. + /// + public Tensor LogLikelihood = logLikelihood; + + /// + /// The final updated Kalman filter parameters after the last expectation-maximization step. + /// + public KalmanFilterParameters Parameters = parameters; + + /// + /// Indicates whether the EM algorithm has finished. + /// + public bool Finished = finished; +} \ No newline at end of file diff --git a/src/Bonsai.ML.Lds.Torch/Filter.cs b/src/Bonsai.ML.Lds.Torch/Filter.cs new file mode 100644 index 00000000..699d6dcc --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/Filter.cs @@ -0,0 +1,32 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Applies a Kalman filter to the input tensor sequence. +/// +[Combinator] +[ResetCombinator] +[Description("Applies a Kalman filter to the input tensor sequence.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Filter +{ + /// + /// The Kalman filter model. + /// + [Description("The Kalman filter model.")] + [XmlIgnore] + public KalmanFilter Model { get; set; } + + /// + /// Processes an observable sequence of input tensors, applying the Kalman filter to each tensor. + /// + public IObservable Process(IObservable source) + { + return source.Select(Model.Filter); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Lds.Torch/FilteredState.cs b/src/Bonsai.ML.Lds.Torch/FilteredState.cs new file mode 100644 index 00000000..9b1d53a1 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/FilteredState.cs @@ -0,0 +1,57 @@ +using static TorchSharp.torch; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Represents the state of a Kalman filter. +/// +/// +/// +/// +/// +/// +/// +public readonly struct FilteredState( + LinearDynamicalSystemState predictedState, + LinearDynamicalSystemState updatedState, + Tensor innovation = null, + Tensor innovationCovariance = null, + Tensor kalmanGain = null, + Tensor logLikelihood = null) : ILinearDynamicalSystemState +{ + /// + /// The predicted state following the prediction step. + /// + public readonly LinearDynamicalSystemState PredictedState => predictedState; + + /// + /// The updated state following the update step. + /// + public readonly LinearDynamicalSystemState UpdatedState => updatedState; + + /// + /// The innovation (residual) between the observation and the prediction. + /// + public readonly Tensor Innovation => innovation; + + /// + /// The innovation (residual) covariance. + /// + public readonly Tensor InnovationCovariance => innovationCovariance; + + /// + /// The Kalman gain. + /// + public readonly Tensor KalmanGain => kalmanGain; + + /// + /// The log likelihood of the observation given the updated state. + /// + public readonly Tensor LogLikelihood => logLikelihood; + + /// + public readonly Tensor Mean => updatedState.Mean; + + /// + public readonly Tensor Covariance => updatedState.Covariance; +} diff --git a/src/Bonsai.ML.Lds.Torch/ILinearDynamicalSystemState.cs b/src/Bonsai.ML.Lds.Torch/ILinearDynamicalSystemState.cs new file mode 100644 index 00000000..ee90a811 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/ILinearDynamicalSystemState.cs @@ -0,0 +1,19 @@ +using static TorchSharp.torch; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Represents the state of a linear gaussian dynamical system. +/// +public interface ILinearDynamicalSystemState +{ + /// + /// The mean of the state. + /// + Tensor Mean { get; } + + /// + /// The covariance of the state. + /// + Tensor Covariance { get; } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs new file mode 100644 index 00000000..f6b9ba49 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs @@ -0,0 +1,773 @@ +using System; +using TorchSharp.Modules; +using static TorchSharp.torch; + +namespace Bonsai.ML.Lds.Torch; + +// disable missing XML comment warnings +# pragma warning disable CS1591 + +public class KalmanFilter : nn.Module +{ + private LinearDynamicalSystemState _state; + public readonly KalmanFilterParameters Parameters; + + public KalmanFilter( + KalmanFilterParameters parameters) : base("KalmanFilter") + { + Parameters = parameters; + + _state = new LinearDynamicalSystemState( + Parameters.InitialMean, + Parameters.InitialCovariance); + + RegisterComponents(); + } + + public KalmanFilter( + int? numStates = null, + int? numObservations = null, + Tensor transitionMatrix = null, + Tensor measurementFunction = null, + Tensor initialMean = null, + Tensor initialCovariance = null, + Tensor processNoiseVariance = null, + Tensor measurementNoiseVariance = null, + Tensor stateOffset = null, + Tensor observationOffset = null, + Device device = null, + ScalarType? scalarType = null, + bool requiresGrad = false) : base("KalmanFilter") + { + Parameters = new KalmanFilterParameters( + numStates, + numObservations, + transitionMatrix, + measurementFunction, + processNoiseVariance, + measurementNoiseVariance, + initialMean, + initialCovariance, + stateOffset, + observationOffset, + device, + scalarType, + requiresGrad + ); + + _state = new LinearDynamicalSystemState( + Parameters.InitialMean, + Parameters.InitialCovariance); + + RegisterComponents(); + } + + private static LinearDynamicalSystemState FilterPredict( + LinearDynamicalSystemState state, + KalmanFilterParameters parameters) => + new(parameters.TransitionMatrix.matmul(state.Mean) + (parameters.OffsetsProvided ? parameters.StateOffset : 0), + parameters.TransitionMatrix.matmul(state.Covariance) + .matmul(parameters.TransitionMatrix.mT) + parameters.ProcessNoiseCovariance); + + private static FilteredState FilterUpdate( + Tensor observation, + LinearDynamicalSystemState state, + KalmanFilterParameters parameters) + { + if (observation is null) + return new FilteredState( + predictedState: state, + updatedState: state + ); + + // Innovation step + var innovation = observation - (parameters.MeasurementFunction.matmul(state.Mean) + (parameters.OffsetsProvided ? parameters.ObservationOffset : 0)); + var innovationCovariance = WrappedTensorDisposeScope(() => EnsureSymmetric( + parameters.MeasurementFunction.matmul(state.Covariance) + .matmul(parameters.MeasurementFunction.mT) + parameters.MeasurementNoiseCovariance)); + + // Kalman gain + var kalmanGain = WrappedTensorDisposeScope(() => InverseCholesky( + state.Covariance.matmul(parameters.MeasurementFunction.mT), + innovationCovariance)); + + // Update step + var updatedMean = state.Mean + kalmanGain.matmul(innovation); + var updatedCovariance = WrappedTensorDisposeScope(() => state.Covariance + - kalmanGain.matmul(parameters.MeasurementFunction).matmul(state.Covariance)); + + var updatedState = new LinearDynamicalSystemState(updatedMean, updatedCovariance); + + return new FilteredState( + predictedState: state, + updatedState: updatedState, + innovation: innovation, + innovationCovariance: innovationCovariance, + kalmanGain: kalmanGain + ); + } + + public FilteredState Filter(Tensor observation) + { + using var g = no_grad(); + + var obs = observation.atleast_2d(); + var timeBins = obs.size(0); + + // We get the scalar type and device from the parameters in the event that the observations are null (e.g. during forecasting) + var scalarType = Parameters.ScalarType; + var device = Parameters.Device; + + var predictedState = new LinearDynamicalSystemState( + empty([timeBins, Parameters.NumStates], dtype: scalarType, device: device), + empty([timeBins, Parameters.NumStates, Parameters.NumStates], dtype: scalarType, device: device)); + + var updatedState = new LinearDynamicalSystemState( + empty([timeBins, Parameters.NumStates], dtype: scalarType, device: device), + empty([timeBins, Parameters.NumStates, Parameters.NumStates], dtype: scalarType, device: device)); + + for (long time = 0; time < timeBins; time++) + { + // Predict + var prediction = FilterPredict(_state, Parameters); + + // Update + var update = FilterUpdate(obs[time], prediction, Parameters); + + predictedState.Mean[time] = prediction.Mean; + predictedState.Covariance[time] = prediction.Covariance; + updatedState.Mean[time] = update.UpdatedState.Mean; + updatedState.Covariance[time] = update.UpdatedState.Covariance; + + _state = update.UpdatedState; + } + + return new FilteredState( + predictedState: predictedState, + updatedState: updatedState + ); + } + + public override LinearDynamicalSystemState forward(Tensor input) + { + var filteredState = Filter(input); + return filteredState.UpdatedState; + } + + private static FilteredState Filter( + long timeBins, + Tensor observation, + KalmanFilterParameters parameters) + { + using var g = no_grad(); + + var timeBinsObservations = observation.size(0); + timeBins = Math.Min(timeBins, timeBinsObservations); + + var filteredState = new FilteredState( + predictedState: new LinearDynamicalSystemState( + empty([timeBins, parameters.NumStates], dtype: parameters.ScalarType, device: parameters.Device), + empty([timeBins, parameters.NumStates, parameters.NumStates], dtype: parameters.ScalarType, device: parameters.Device)), + updatedState: new LinearDynamicalSystemState( + empty([timeBins, parameters.NumStates], dtype: parameters.ScalarType, device: parameters.Device), + empty([timeBins, parameters.NumStates, parameters.NumStates], dtype: parameters.ScalarType, device: parameters.Device)), + innovation: empty([timeBins, parameters.NumObservations], dtype: parameters.ScalarType, device: parameters.Device), + innovationCovariance: empty([timeBins, parameters.NumObservations, parameters.NumObservations], dtype: parameters.ScalarType, device: parameters.Device), + kalmanGain: empty([timeBins, parameters.NumStates, parameters.NumObservations], dtype: parameters.ScalarType, device: parameters.Device), + logLikelihood: empty([timeBins], dtype: parameters.ScalarType, device: parameters.Device) + ); + + var state = new LinearDynamicalSystemState( + mean: parameters.InitialMean, + covariance: parameters.InitialCovariance); + + for (long time = 0; time < timeBins; time++) + { + // Predict + var prediction = FilterPredict( + state: state, + parameters: parameters); + + // Update + var update = FilterUpdate( + observation: observation[time], + state: prediction, + parameters: parameters); + + // Log Likelihood + var logLikelihoodData = -(slogdet(update.InnovationCovariance).logabsdet + + InverseCholesky(update.Innovation.T, update.InnovationCovariance) + .matmul(update.Innovation)).squeeze(); + + // Detach and assign + filteredState.LogLikelihood[time] = logLikelihoodData; + filteredState.PredictedState.Mean[time] = prediction.Mean; + filteredState.PredictedState.Covariance[time] = prediction.Covariance; + filteredState.UpdatedState.Mean[time] = update.UpdatedState.Mean; + filteredState.UpdatedState.Covariance[time] = update.UpdatedState.Covariance; + filteredState.Innovation[time] = update.Innovation; + filteredState.InnovationCovariance[time] = update.InnovationCovariance; + filteredState.KalmanGain[time] = update.KalmanGain; + + state = update.UpdatedState; + } + + return filteredState; + } + + public LinearDynamicalSystemState Smooth(FilteredState filteredState) + { + using var g = no_grad(); + + var predictedMean = filteredState.PredictedState.Mean; + var predictedCovariance = filteredState.PredictedState.Covariance; + var updatedMean = filteredState.UpdatedState.Mean; + var updatedCovariance = filteredState.UpdatedState.Covariance; + + var timeBins = predictedMean.size(0); + var smoothedMean = empty_like(updatedMean); + var smoothedCovariance = empty_like(updatedCovariance); + + // Fix the last time point + smoothedMean[-1] = updatedMean[-1]; + smoothedCovariance[-1] = updatedCovariance[-1]; + + var smoothingGain = empty([Parameters.NumStates, Parameters.NumStates], dtype: Parameters.ScalarType, device: Parameters.Device); + + // Backward pass + for (long time = timeBins - 2; time >= 0; time--) + { + // Smoothing gain + smoothingGain = WrappedTensorDisposeScope(() => updatedCovariance[time].matmul( + InverseCholesky(Parameters.TransitionMatrix.mT, predictedCovariance[time + 1]) + )); + + // Smoothed mean + smoothedMean[time] = WrappedTensorDisposeScope(() => updatedMean[time] + + smoothingGain.matmul( + (smoothedMean[time + 1] - predictedMean[time + 1]).unsqueeze(-1) + ).squeeze(-1)); + + // Smoothed covariance + smoothedCovariance[time] = WrappedTensorDisposeScope(() => updatedCovariance[time] + smoothingGain + .matmul(smoothedCovariance[time + 1] - predictedCovariance[time + 1]) + .matmul(smoothingGain.mT) + ); + } + + return new LinearDynamicalSystemState( + smoothedMean, + smoothedCovariance + ); + } + + private readonly struct SummaryStatisticsEM( + Tensor sxx11, + Tensor sxx10, + Tensor sxx00, + Tensor tx1 = null, + Tensor tx0 = null, + Tensor ty1 = null, + Tensor tyx11 = null, + Tensor tyy11 = null + ) + { + public readonly Tensor Sxx11 => sxx11; + public readonly Tensor Sxx10 => sxx10; + public readonly Tensor Sxx00 => sxx00; + public readonly Tensor Tx1 => tx1; + public readonly Tensor Tx0 => tx0; + public readonly Tensor Ty1 => ty1; + public readonly Tensor Tyx11 => tyx11; + public readonly Tensor Tyy11 => tyy11; + } + + private readonly struct SmoothedStateWithAuxiliaryVariables( + LinearDynamicalSystemState smoothedState, + Tensor smoothedInitialMean, + Tensor smoothedInitialCovariance) + { + public readonly LinearDynamicalSystemState SmoothedState => smoothedState; + public readonly Tensor SmoothedInitialMean => smoothedInitialMean; + public readonly Tensor SmoothedInitialCovariance => smoothedInitialCovariance; + } + + private static (SmoothedStateWithAuxiliaryVariables state, SummaryStatisticsEM statistics) Smooth( + FilteredState filteredState, + Tensor observations, + KalmanFilterParameters parameters + ) + { + var timeBins = filteredState.PredictedState.Mean.size(0); + + if (timeBins < 2) + throw new ArgumentException("Smoothing requires at least two time bins."); + + var predictedMean = filteredState.PredictedState.Mean; + var predictedCovariance = filteredState.PredictedState.Covariance; + var updatedMean = filteredState.UpdatedState.Mean; + var updatedCovariance = filteredState.UpdatedState.Covariance; + var kalmanGain = filteredState.KalmanGain; + + var smoothedState = new LinearDynamicalSystemState( + mean: empty_like(updatedMean), + covariance: empty_like(updatedCovariance) + ); + + var sxx00 = zeros_like(smoothedState.Covariance, dtype: parameters.ScalarType, device: parameters.Device); + var sxx10 = zeros_like(smoothedState.Covariance, dtype: parameters.ScalarType, device: parameters.Device); + var sxx11 = zeros_like(smoothedState.Covariance, dtype: parameters.ScalarType, device: parameters.Device); + + var tx1 = zeros([parameters.NumStates], dtype: parameters.ScalarType, device: parameters.Device); + var tx0 = zeros([parameters.NumStates], dtype: parameters.ScalarType, device: parameters.Device); + var ty1 = zeros([parameters.NumObservations], dtype: parameters.ScalarType, device: parameters.Device); + var tyx11 = zeros([parameters.NumObservations, parameters.NumStates], dtype: parameters.ScalarType, device: parameters.Device); + var tyy11 = zeros([parameters.NumObservations, parameters.NumObservations], dtype: parameters.ScalarType, device: parameters.Device); + + var identityStates = eye(parameters.NumStates, dtype: parameters.ScalarType, device: parameters.Device); + + // Fix the last time point + smoothedState.Mean[-1] = updatedMean[-1]; + smoothedState.Covariance[-1] = updatedCovariance[-1]; + var smoothedLagOneCovariance = WrappedTensorDisposeScope(() => + (identityStates - kalmanGain[-1] + .matmul(parameters.MeasurementFunction)) + .matmul(parameters.TransitionMatrix) + .matmul(updatedCovariance[-2])); + + sxx11[-1] = outer(smoothedState.Mean[-1], smoothedState.Mean[-1]) + smoothedState.Covariance[-1]; + + if (parameters.OffsetsProvided) + { + tx1 += smoothedState.Mean[-1]; + ty1 += observations[-1]; + tyx11 += outer(observations[-1], smoothedState.Mean[-1]); + tyy11 += outer(observations[-1], observations[-1]); + } + + var smoothingGain = empty([parameters.NumStates, parameters.NumStates], dtype: parameters.ScalarType, device: parameters.Device); + var smoothingGainNext = null as Tensor; + + // Backward pass + for (long time = timeBins - 2; time >= 0; time--) + { + // Smoothing gain + smoothingGain = smoothingGainNext ?? WrappedTensorDisposeScope(() => updatedCovariance[time].matmul( + InverseCholesky(parameters.TransitionMatrix.mT, predictedCovariance[time + 1]) + )); + + // Smoothed mean + smoothedState.Mean[time] = WrappedTensorDisposeScope(() => updatedMean[time] + + smoothingGain.matmul( + (smoothedState.Mean[time + 1] - predictedMean[time + 1]).unsqueeze(-1) + ).squeeze(-1)); + + // Smoothed covariance + smoothedState.Covariance[time] = WrappedTensorDisposeScope(() => updatedCovariance[time] + smoothingGain + .matmul(smoothedState.Covariance[time + 1] - predictedCovariance[time + 1]) + .matmul(smoothingGain.mT) + ); + + var expectationUpdate = outer(smoothedState.Mean[time], smoothedState.Mean[time]) + smoothedState.Covariance[time]; + sxx11[time] = expectationUpdate; + sxx00[time + 1] = expectationUpdate; + sxx10[time + 1] = outer(smoothedState.Mean[time + 1], smoothedState.Mean[time]) + smoothedLagOneCovariance; + + if (parameters.OffsetsProvided) + { + tx1 += smoothedState.Mean[time]; + tx0 += smoothedState.Mean[time + 1]; + ty1 += observations[time]; + tyx11 += outer(observations[time], smoothedState.Mean[time]); + tyy11 += outer(observations[time], observations[time]); + } + + // Compute next smoothing gain for lag one covariance + if (time > 0) + { + smoothingGainNext = WrappedTensorDisposeScope(() => updatedCovariance[time - 1] + .matmul(InverseCholesky(parameters.TransitionMatrix.mT, predictedCovariance[time]))); + + // Smoothed lag one covariance + smoothedLagOneCovariance = WrappedTensorDisposeScope(() => updatedCovariance[time] + .matmul(smoothingGainNext.mT) + + smoothingGain.matmul(smoothedLagOneCovariance + - parameters.TransitionMatrix.matmul(updatedCovariance[time])) + .matmul(smoothingGainNext.mT)); + } + } + + var smoothingGain0 = WrappedTensorDisposeScope(() => parameters.InitialCovariance.matmul( + InverseCholesky(parameters.TransitionMatrix.mT, predictedCovariance[0]) + )); + + // Smoothed initial mean + var smoothedInitialMean = WrappedTensorDisposeScope(() => parameters.InitialMean + smoothingGain0.matmul( + (smoothedState.Mean[0] - predictedMean[0]).unsqueeze(-1) + ).squeeze(-1)); + + // Smoothed initial covariance + var smoothedInitialCovariance = WrappedTensorDisposeScope(() => parameters.InitialCovariance + smoothingGain0 + .matmul(smoothedState.Covariance[0] - predictedCovariance[0]) + .matmul(smoothingGain0.mT)); + + // Smoothed lag one covariance at time 0 + smoothedLagOneCovariance = WrappedTensorDisposeScope(() => updatedCovariance[0] + .matmul(smoothingGain0.mT) + + smoothingGain.matmul(smoothedLagOneCovariance + - parameters.TransitionMatrix.matmul(updatedCovariance[0])) + .matmul(smoothingGain0.mT)); + + sxx10[0] = outer(smoothedState.Mean[0], smoothedInitialMean) + smoothedLagOneCovariance; + sxx00[0] = outer(smoothedInitialMean, smoothedInitialMean) + smoothedInitialCovariance; + + if (parameters.OffsetsProvided) + tx0 += smoothedInitialMean; + + var state = new SmoothedStateWithAuxiliaryVariables( + smoothedState: smoothedState, + smoothedInitialMean: smoothedInitialMean, + smoothedInitialCovariance: smoothedInitialCovariance + ); + + var stats = parameters.OffsetsProvided ? new SummaryStatisticsEM( + sxx11: sxx11, + sxx10: sxx10, + sxx00: sxx00, + tx1: tx1, + tx0: tx0, + ty1: ty1, + tyx11: tyx11, + tyy11: tyy11 + ) : new SummaryStatisticsEM( + sxx11: sxx11, + sxx10: sxx10, + sxx00: sxx00 + ); + + return (state, stats); + } + + public static ExpectationMaximizationResult ExpectationMaximization( + Tensor observation, + KalmanFilterParameters parameters, + int maxIterations = 100, + double tolerance = 1e-4, + ParametersToEstimate parametersToEstimate = new()) + { + var timeBins = observation.size(0); + var numObservations = (int)observation.size(1); + var logLikelihood = empty(maxIterations, dtype: ScalarType.Float32, device: parameters.Device); + var previousLogLikelihood = double.NegativeInfinity; + var logLikelihoodConst = -0.5 * timeBins * numObservations * Math.Log(2.0 * Math.PI); + + if (parameters.NumObservations != numObservations) + throw new ArgumentException($"The number of observation dimensions in the parameters ({parameters.NumObservations}) does not match the observations ({numObservations})."); + + var identityStates = eye(parameters.NumStates, dtype: parameters.ScalarType, device: parameters.Device); + + // Precompute constant observation terms reused across EM iterations + var observationT = observation.mT; + var autoCorrelationObservations = observationT.matmul(observation); + + using var g = no_grad(); + + for (int iteration = 0; iteration < maxIterations; iteration++) + { + // Filter observations + var filteredState = Filter( + timeBins: timeBins, + observation: observation, + parameters: parameters); + + // Compute log likelihood (avoid creating intermediate tensors) + var llSumDouble = filteredState.LogLikelihood.sum() + .to_type(ScalarType.Float64).item(); + var filteredLogLikelihoodSum = logLikelihoodConst + 0.5 * llSumDouble; + + logLikelihood[iteration] = filteredLogLikelihoodSum; + + // Check for convergence + if (filteredLogLikelihoodSum <= previousLogLikelihood) + { + Console.WriteLine($"Warning: Log likelihood decreased! New: {filteredLogLikelihoodSum}, Previous: {previousLogLikelihood}"); + break; + } + + if (filteredLogLikelihoodSum - previousLogLikelihood < tolerance) + break; + + previousLogLikelihood = filteredLogLikelihoodSum; + + // Smooth the filtered results + var (state, statistics) = Smooth( + filteredState: filteredState, + observations: observation, + parameters: parameters); + + // Sufficient statistics + var sxx00 = statistics.Sxx00.sum([0]); + var sxx11 = statistics.Sxx11.sum([0]); + var sxx10 = statistics.Sxx10.sum([0]); + + // Replace einsum with faster matmul + var crossCorrelationObservations = observationT.matmul(state.SmoothedState.Mean); + + // Update parameters + if (parametersToEstimate.TransitionMatrix) + { + if (parameters.OffsetsProvided) + { + parameters.TransitionMatrix.set_( + InverseCholesky( + sxx10 - outer(statistics.Tx1, statistics.Tx0) / timeBins, + sxx00 - outer(statistics.Tx0, statistics.Tx0) / timeBins + ) + ); + } + else + { + parameters.TransitionMatrix.set_( + InverseCholesky( + sxx10, + sxx00 + ) + ); + } + } + + if (parametersToEstimate.StateOffset) + { + if (parameters.OffsetsProvided) + { + parameters.StateOffset.set_( + (statistics.Tx1 - parameters.TransitionMatrix.matmul(statistics.Tx0)) / timeBins + ); + } + } + + if (parametersToEstimate.MeasurementFunction) + { + if (parameters.OffsetsProvided) + { + parameters.MeasurementFunction.set_( + InverseCholesky( + statistics.Tyx11 - outer(statistics.Ty1, statistics.Tx1) / timeBins, + sxx11 - outer(statistics.Tx0, statistics.Tx0) / timeBins + ) + ); + } + else + { + parameters.MeasurementFunction.set_( + InverseCholesky( + crossCorrelationObservations, + sxx11 + ) + ); + } + } + + if (parametersToEstimate.ObservationOffset) + { + if (parameters.OffsetsProvided) + { + parameters.ObservationOffset.set_( + (statistics.Ty1 - parameters.MeasurementFunction.matmul(statistics.Tx1)) / timeBins + ); + } + } + + if (parametersToEstimate.ProcessNoiseCovariance) + { + if (parameters.OffsetsProvided) + { + parameters.ProcessNoiseCovariance.set_(WrappedTensorDisposeScope(() => + EnsureSymmetric( + (sxx11 - outer(statistics.Tx1, parameters.StateOffset) - outer(parameters.StateOffset, statistics.Tx1) + timeBins * outer(parameters.StateOffset, parameters.StateOffset) - parameters.TransitionMatrix.matmul(sxx10.mT) - sxx10.matmul(parameters.TransitionMatrix.mT) + linalg.multi_dot([parameters.TransitionMatrix, sxx00, parameters.TransitionMatrix.mT]) + parameters.TransitionMatrix.matmul(outer(statistics.Tx0, parameters.StateOffset)) + outer(parameters.StateOffset, statistics.Tx0).matmul(parameters.TransitionMatrix.mT)) / timeBins + ) + )); + } + else + { + parameters.ProcessNoiseCovariance.set_(WrappedTensorDisposeScope(() => + EnsureSymmetric((sxx11 - parameters.TransitionMatrix.matmul(sxx10.mT)) / timeBins))); + } + } + + if (parametersToEstimate.MeasurementNoiseCovariance) + { + if (parameters.OffsetsProvided) + { + parameters.MeasurementNoiseCovariance.set_(WrappedTensorDisposeScope(() => + EnsureSymmetric( + (statistics.Tyy11 - outer(statistics.Ty1, parameters.ObservationOffset) - outer(parameters.ObservationOffset, statistics.Ty1) + timeBins * outer(parameters.ObservationOffset, parameters.ObservationOffset) - parameters.MeasurementFunction.matmul(statistics.Tyx11.mT) - statistics.Tyx11.matmul(parameters.MeasurementFunction.mT) + parameters.MeasurementFunction.matmul(outer(statistics.Tx1, parameters.ObservationOffset)) + linalg.multi_dot([parameters.MeasurementFunction, sxx11, parameters.MeasurementFunction.mT]) + outer(parameters.ObservationOffset, statistics.Tx1).matmul(parameters.MeasurementFunction.mT)) / timeBins + ) + )); + } + else + { + var explainedObservationCovariance = parameters.MeasurementFunction.matmul(crossCorrelationObservations.mT); + + if (parametersToEstimate.MeasurementNoiseCovariance) + parameters.MeasurementNoiseCovariance.set_(WrappedTensorDisposeScope(() => + EnsureSymmetric((autoCorrelationObservations - explainedObservationCovariance - explainedObservationCovariance.mT + + parameters.MeasurementFunction.matmul(sxx11).matmul(parameters.MeasurementFunction.mT)) / timeBins))); + } + } + + if (parametersToEstimate.InitialMean) + parameters.InitialMean.set_(state.SmoothedInitialMean); + + if (parametersToEstimate.InitialCovariance) + parameters.InitialCovariance.set_(state.SmoothedInitialCovariance); + } + + return new ExpectationMaximizationResult(logLikelihood, parameters); + } + + public static StochasticSubspaceIdentificationResult StochasticSubspaceIdentification( + Tensor observations, + int? targetNumStates = null, + int maxLag = 20, + double threshold = 0.01, + ParametersToEstimate parametersToEstimate = new()) + { + using var g = no_grad(); + + var timeBins = observations.size(0); + var numObservations = observations.size(1); + var centered = observations - observations.mean([0], keepdim: true); + + // Build Hankel matrices from observations + var numCols = (int)(timeBins - 2 * maxLag + 1); + + if (numCols <= 0) + throw new ArgumentException($"Number of time bins ({timeBins}) must be greater than 2*maxLag ({2 * maxLag}) for subspace identification."); + + var stride = centered.stride(); + var pastView = centered.as_strided([maxLag, numCols, numObservations], [stride[0], stride[0], stride[1]]); + var past = pastView.permute(0, 2, 1).reshape(maxLag * numObservations, numCols); + + var futureView = centered.narrow(0, maxLag, timeBins - maxLag) + .as_strided([maxLag, numCols, numObservations], [stride[0], stride[0], stride[1]]); + var future = futureView.permute(0, 2, 1).reshape(maxLag * numObservations, numCols); + + // Compute the projection + var Pp = past.matmul(past.mT); + var projection = InverseCholesky(future.matmul(past.mT), Pp).matmul(past); + + // Compute SVD of the past observations + var (U, S, Vt) = linalg.svd(projection, fullMatrices: false); + + // Compute the effective rank + var effectiveRank = (S > (threshold * S[0])).to_type(ScalarType.Int64).sum().item(); + var effectiveStates = Math.Max(Math.Min(effectiveRank, targetNumStates ?? effectiveRank), 1); + + var Ur = U[TensorIndex.Colon, TensorIndex.Slice(0, effectiveStates)]; + var SrSqrt = S[TensorIndex.Slice(0, effectiveStates)].diag().sqrt(); + var Vrt = Vt[TensorIndex.Slice(0, effectiveStates)]; + + // Estimate observability matrix + var observability = Ur.matmul(SrSqrt); + + // Extract measurement function from first block of observability matrix + var measurementFunction = observability[TensorIndex.Slice(0, numObservations)]; + + // Estimate state sequence + var states = SrSqrt.matmul(Vrt); + + // Estimate transition matrix using shifted states + var statesShifted = states[TensorIndex.Colon, TensorIndex.Slice(0, numCols - 1)]; + var statesNext = states[TensorIndex.Colon, TensorIndex.Slice(1, numCols)]; + + var transitionMatrix = WrappedTensorDisposeScope(() => InverseCholesky( + statesNext.matmul(statesShifted.mT), + statesShifted.matmul(statesShifted.mT))); + + // Estimate noise covariances using residuals + var stateResiduals = statesNext - transitionMatrix.matmul(statesShifted); + var processNoiseCovariance = WrappedTensorDisposeScope(() => EnsureSymmetric(stateResiduals.matmul(stateResiduals.mT) / (numCols - 1))); + + // Compute the observation residuals + var observationPredictions = measurementFunction.matmul(states); + var observationWindow = centered[TensorIndex.Slice(maxLag, maxLag + numCols)].mT; + var observationResiduals = observationWindow - observationPredictions; + var measurementNoiseCovariance = WrappedTensorDisposeScope(() => EnsureSymmetric(observationResiduals.matmul(observationResiduals.mT) / numCols)); + + // Initial state estimates + var initialMean = states[TensorIndex.Colon, 0]; + var initialCovariance = WrappedTensorDisposeScope(() => EnsureSymmetric( + states.matmul(states.mT) / numCols)); + + var parameters = new KalmanFilterParameters( + numStates: (int)effectiveStates, + numObservations: (int)numObservations, + transitionMatrix: parametersToEstimate.TransitionMatrix ? transitionMatrix : null, + measurementFunction: parametersToEstimate.MeasurementFunction ? measurementFunction : null, + processNoiseCovariance: parametersToEstimate.ProcessNoiseCovariance ? processNoiseCovariance : null, + measurementNoiseCovariance: parametersToEstimate.MeasurementNoiseCovariance ? measurementNoiseCovariance : null, + initialMean: parametersToEstimate.InitialMean ? initialMean : null, + initialCovariance: parametersToEstimate.InitialCovariance ? initialCovariance : null + ); + + return new StochasticSubspaceIdentificationResult( + parameters: parameters, + effectiveStates: effectiveStates, + singularValues: S + ); + } + + public LinearDynamicalSystemState OrthogonalizeMeanAndCovariance(LinearDynamicalSystemState state) + { + var (_, S, Vt) = linalg.svd(Parameters.MeasurementFunction); + var SVt = diag(S).matmul(Vt); + + Tensor orthogonalizedMean = null; + if (state.Mean is not null) + orthogonalizedMean = matmul(state.Mean, SVt.mT); + + Tensor orthogonalizedCovariance = null; + if (state.Covariance is not null) + { + var auxilary = matmul(SVt, state.Covariance); + orthogonalizedCovariance = matmul(auxilary, SVt.mT); + } + + return new LinearDynamicalSystemState( + orthogonalizedMean, + orthogonalizedCovariance + ); + } + + public void UpdateParameters(KalmanFilterParameters updatedParameters) + { + if (updatedParameters.TransitionMatrix is not null) + Parameters.TransitionMatrix.set_(updatedParameters.TransitionMatrix); + if (updatedParameters.MeasurementFunction is not null) + Parameters.MeasurementFunction.set_(updatedParameters.MeasurementFunction); + if (updatedParameters.ProcessNoiseCovariance is not null) + Parameters.ProcessNoiseCovariance.set_(updatedParameters.ProcessNoiseCovariance); + if (updatedParameters.MeasurementNoiseCovariance is not null) + Parameters.MeasurementNoiseCovariance.set_(updatedParameters.MeasurementNoiseCovariance); + if (updatedParameters.InitialMean is not null) + Parameters.InitialMean.set_(updatedParameters.InitialMean); + if (updatedParameters.InitialCovariance is not null) + Parameters.InitialCovariance.set_(updatedParameters.InitialCovariance); + if (updatedParameters.StateOffset is not null) + Parameters.StateOffset.set_(updatedParameters.StateOffset); + if (updatedParameters.ObservationOffset is not null) + Parameters.ObservationOffset.set_(updatedParameters.ObservationOffset); + } + + private static Tensor EnsureSymmetric(Tensor M) => 0.5 * (M + M.mT); + + private static Tensor Ensure2D(Tensor M) => M.atleast_2d(); + + private static Tensor InverseCholesky(Tensor B, Tensor A) + { + var L = linalg.cholesky(Ensure2D(A)); + var solT = cholesky_solve(Ensure2D(B).mT, L); + return solT.mT; + } +} diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs new file mode 100644 index 00000000..521d9336 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs @@ -0,0 +1,338 @@ +using System; +using System.Text; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Represents the parameters of a Kalman filter model. +/// +public class KalmanFilterParameters : nn.Module +{ + private readonly static StringBuilder _sb = new(); + private readonly ScalarType _scalarType; + private readonly Device _device; + + /// + /// The number of states in the system. + /// + public int NumStates { get; private set; } + + /// + /// The number of observations in the system. + /// + public int NumObservations { get; private set; } + + /// + /// The state transition matrix. + /// + public Tensor TransitionMatrix { get; private set; } + + /// + /// The measurement function. + /// + public Tensor MeasurementFunction { get; private set; } + + /// + /// The process noise covariance. + /// + public Tensor ProcessNoiseCovariance { get; private set; } + + /// + /// The measurement noise covariance. + /// + public Tensor MeasurementNoiseCovariance { get; private set; } + + /// + /// The initial mean. + /// + public Tensor InitialMean { get; private set; } + + /// + /// The initial covariance. + /// + public Tensor InitialCovariance { get; private set; } + + /// + /// The optional state offset. + /// + public Tensor StateOffset { get; private set; } + + /// + /// The optional observation offset. + /// + public Tensor ObservationOffset { get; private set; } + + /// + /// Indicates whether any offsets have been provided. + /// + public bool OffsetsProvided => StateOffset is not null || ObservationOffset is not null; + + /// + /// The data type of the tensors. + /// + public ScalarType ScalarType => _scalarType; + + /// + /// The device on which the tensors are allocated. + /// + public Device Device => _device; + + /// + /// Initializes a new instance of the class with the specified parameters. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public KalmanFilterParameters( + int? numStates = null, + int? numObservations = null, + Tensor transitionMatrix = null, + Tensor measurementFunction = null, + Tensor processNoiseCovariance = null, + Tensor measurementNoiseCovariance = null, + Tensor initialMean = null, + Tensor initialCovariance = null, + Tensor stateOffset = null, + Tensor observationOffset = null, + Device device = null, + ScalarType? scalarType = null, + bool requiresGrad = false) : base("KalmanFilterParameters") + { + numStates ??= InferNumStates(transitionMatrix, measurementFunction, initialMean, initialCovariance, processNoiseCovariance, stateOffset); + numObservations ??= InferNumObservations(measurementFunction, measurementNoiseCovariance, observationOffset); + + if (numStates is null) + throw new ArgumentOutOfRangeException(nameof(numStates), "Number of states must be specified or inferred from the parameters."); + if (numObservations is null) + throw new ArgumentOutOfRangeException(nameof(numObservations), "Number of observations must be specified or inferred from the parameters."); + + transitionMatrix = transitionMatrix?.clone() ?? eye(numStates.Value); + measurementFunction = measurementFunction?.clone() ?? eye(numObservations.Value, numStates.Value); + initialMean = initialMean?.clone() ?? zeros(numStates.Value); + initialCovariance = initialCovariance?.clone() ?? eye(numStates.Value); + + processNoiseCovariance = processNoiseCovariance?.NumberOfElements == 1 + ? CreateCovarianceMatrixFromScalar(processNoiseCovariance, numStates.Value, "Process noise variance") + : processNoiseCovariance?.clone() + ?? CreateCovarianceMatrixFromScalar(1.0, numStates.Value, "Process noise variance"); + + measurementNoiseCovariance = measurementNoiseCovariance?.NumberOfElements == 1 + ? CreateCovarianceMatrixFromScalar(measurementNoiseCovariance, numObservations.Value, "Measurement noise variance") + : measurementNoiseCovariance?.clone() + ?? CreateCovarianceMatrixFromScalar(1.0, numObservations.Value, "Measurement noise variance"); + + NumStates = numStates.Value; + NumObservations = numObservations.Value; + TransitionMatrix = transitionMatrix; + MeasurementFunction = measurementFunction; + ProcessNoiseCovariance = processNoiseCovariance; + MeasurementNoiseCovariance = measurementNoiseCovariance; + InitialMean = initialMean; + InitialCovariance = initialCovariance; + StateOffset = stateOffset; + ObservationOffset = observationOffset; + + Validate(); + + if (device is not null) + this.to(device); + if (scalarType is not null) + this.to(scalarType.Value); + + _device = TransitionMatrix.device; + _scalarType = TransitionMatrix.dtype; + + SetGrad(requiresGrad); + } + + /// + /// Validates the Kalman filter parameters. + /// + private void Validate() + { + int numStates = InferNumStates(TransitionMatrix, MeasurementFunction, InitialMean, InitialCovariance, ProcessNoiseCovariance, StateOffset); + + int numObservations = InferNumObservations(MeasurementFunction, MeasurementNoiseCovariance, ObservationOffset); + + ValidateMatrix(TransitionMatrix, "Transition matrix", isSquare: true, expectedDimension1: numStates); + + ValidateMatrix(MeasurementFunction, "Measurement function", expectedDimension1: numObservations, expectedDimension2: numStates); + + ValidateMatrix(ProcessNoiseCovariance, "Process noise covariance", isSquare: true, expectedDimension1: numStates); + + ValidateMatrix(MeasurementNoiseCovariance, "Measurement noise covariance", isSquare: true, expectedDimension1: numObservations); + + ValidateVector(InitialMean, "Initial mean", numStates); + ValidateMatrix(InitialCovariance, "Initial covariance", isSquare: true, expectedDimension1: numStates); + + if (StateOffset is not null) + ValidateVector(StateOffset, "State offset", numStates); + + if (ObservationOffset is not null) + ValidateVector(ObservationOffset, "Observation offset", numObservations); + + NumStates = numStates; + NumObservations = numObservations; + } + + /// + /// Creates a copy of the current Kalman filter parameters. + /// + /// + public KalmanFilterParameters Copy() => new( + NumStates, + NumObservations, + TransitionMatrix?.clone(), + MeasurementFunction?.clone(), + ProcessNoiseCovariance?.clone(), + MeasurementNoiseCovariance?.clone(), + InitialMean?.clone(), + InitialCovariance?.clone(), + StateOffset?.clone(), + ObservationOffset?.clone() + ); + + /// + /// Sets the requires_grad flag for all tensors in the Kalman filter parameters. + /// + /// + public void SetGrad(bool requiresGrad) + { + TransitionMatrix = TransitionMatrix?.requires_grad_(requiresGrad); + MeasurementFunction = MeasurementFunction?.requires_grad_(requiresGrad); + ProcessNoiseCovariance = ProcessNoiseCovariance?.requires_grad_(requiresGrad); + MeasurementNoiseCovariance = MeasurementNoiseCovariance?.requires_grad_(requiresGrad); + InitialMean = InitialMean?.requires_grad_(requiresGrad); + InitialCovariance = InitialCovariance?.requires_grad_(requiresGrad); + StateOffset = StateOffset?.requires_grad_(requiresGrad); + ObservationOffset = ObservationOffset?.requires_grad_(requiresGrad); + } + + private static int InferNumStates(Tensor transitionMatrix, Tensor measurementFunction, Tensor initialMean, Tensor initialCovariance, Tensor processNoiseCovariance, Tensor stateOffset) + { + if (transitionMatrix is not null) + { + ValidateMatrix(transitionMatrix, "Transition matrix", isSquare: true); + return (int)transitionMatrix.size(0); + } + else if (measurementFunction is not null) + { + ValidateMatrix(measurementFunction, "Measurement function"); + return (int)measurementFunction.size(1); + } + else if (initialMean is not null) + { + ValidateVector(initialMean, "Initial mean"); + return (int)initialMean.size(0); + } + else if (initialCovariance is not null) + { + ValidateMatrix(initialCovariance, "Initial covariance", isSquare: true); + return (int)initialCovariance.size(0); + } + else if (processNoiseCovariance is not null) + { + ValidateMatrix(processNoiseCovariance, "Process noise covariance", isSquare: true); + return (int)processNoiseCovariance.size(0); + } + else if (stateOffset is not null) + { + ValidateVector(stateOffset, "State offset"); + return (int)stateOffset.size(0); + } + else + { + throw new ArgumentException("At least one of the parameters must be provided to infer the number of states."); + } + } + + private static int InferNumObservations(Tensor measurementFunction, Tensor measurementNoiseCovariance, Tensor observationOffset) + { + if (measurementFunction is not null) + { + ValidateMatrix(measurementFunction, "Measurement function"); + return (int)measurementFunction.size(0); + } + else if (measurementNoiseCovariance is not null) + { + ValidateMatrix(measurementNoiseCovariance, "Measurement noise covariance", isSquare: true); + return (int)measurementNoiseCovariance.size(0); + } + else if (observationOffset is not null) + { + ValidateVector(observationOffset, "Observation offset"); + return (int)observationOffset.size(0); + } + else + { + throw new ArgumentException("At least one of the measurement function or measurement noise covariance must be provided to infer the number of observations."); + } + } + + private static void ValidateMatrix(Tensor matrix, string name, bool isSquare = false, int? expectedDimension1 = null, int? expectedDimension2 = null) + { + if (matrix.NumberOfElements == 0) + throw new ArgumentException($"{name} must be a non-empty matrix."); + + if (matrix.Dimensions != 2) + throw new ArgumentException($"{name} must be 2-dimensional."); + + if (isSquare && matrix.size(0) != matrix.size(1)) + throw new ArgumentException($"{name} must be square."); + + if (expectedDimension1.HasValue && matrix.size(0) != expectedDimension1.Value) + throw new ArgumentException($"{name} must have {expectedDimension1.Value} rows."); + + if (expectedDimension2.HasValue && matrix.size(1) != expectedDimension2.Value) + throw new ArgumentException($"{name} must have {expectedDimension2.Value} columns."); + } + + private static void ValidateVector(Tensor vector, string name, int? expectedLength = null) + { + if (vector.NumberOfElements == 0) + throw new ArgumentException($"{name} must be a non-empty vector."); + + if (vector.Dimensions != 1) + throw new ArgumentException($"{name} must be a vector."); + + if (expectedLength.HasValue && vector.NumberOfElements != expectedLength.Value) + throw new ArgumentException($"{name} must be a vector with length equal to {expectedLength.Value}."); + } + + private static void ValidateScalar(Tensor scalar, string name) + { + if (scalar.NumberOfElements != 1) + throw new ArgumentException($"{name} must be a scalar."); + } + + private static Tensor CreateCovarianceMatrixFromScalar(Tensor variance, int dimension, string name) + { + ValidateScalar(variance, name); + var scalar = variance.clone().squeeze(); + return scalar * eye(dimension); + } + + /// + public override string ToString() => _sb.Length == 0 ? _sb.Append( + $"KalmanFilterParameters(NumStates={NumStates}, NumObservations={NumObservations}, TransitionMatrix={TransitionMatrix}, MeasurementFunction={MeasurementFunction}, ProcessNoiseCovariance={ProcessNoiseCovariance}, MeasurementNoiseCovariance={MeasurementNoiseCovariance}, InitialMean={InitialMean}, InitialCovariance={InitialCovariance})" + (StateOffset is not null ? $", StateOffset={StateOffset}" : "") + (ObservationOffset is not null ? $", ObservationOffset={ObservationOffset}" : "")).ToString() : _sb.ToString(); + + /// + /// Returns a string representation of the Kalman filter parameters with the specified tensor string style. + /// + /// + /// + public string ToString(TensorStringStyle tensorStringStyle) => _sb.Length == 0 ? _sb.Append( + $"KalmanFilterParameters(NumStates={NumStates}, NumObservations={NumObservations}, TransitionMatrix={TransitionMatrix.ToString(tensorStringStyle)}, MeasurementFunction={MeasurementFunction.ToString(tensorStringStyle)}, ProcessNoiseCovariance={ProcessNoiseCovariance.ToString(tensorStringStyle)}, MeasurementNoiseCovariance={MeasurementNoiseCovariance.ToString(tensorStringStyle)}, InitialMean={InitialMean.ToString(tensorStringStyle)}, InitialCovariance={InitialCovariance.ToString(tensorStringStyle)})" + (StateOffset is not null ? $", StateOffset={StateOffset.ToString(tensorStringStyle)}" : "") + (ObservationOffset is not null ? $", ObservationOffset={ObservationOffset.ToString(tensorStringStyle)}" : "")).ToString() : _sb.ToString(); +} diff --git a/src/Bonsai.ML.Lds.Torch/LinearDynamicalSystemState.cs b/src/Bonsai.ML.Lds.Torch/LinearDynamicalSystemState.cs new file mode 100644 index 00000000..60323c6c --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/LinearDynamicalSystemState.cs @@ -0,0 +1,17 @@ +using static TorchSharp.torch; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Represents the state of a linear gaussian dynamical system. +/// +/// +/// +public readonly struct LinearDynamicalSystemState(Tensor mean, Tensor covariance) : ILinearDynamicalSystemState +{ + /// + public readonly Tensor Mean => mean; + + /// + public readonly Tensor Covariance => covariance; +} diff --git a/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs new file mode 100644 index 00000000..bb6435b6 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs @@ -0,0 +1,90 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using System.IO; +using static TorchSharp.torch; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Loads the parameters of a Kalman filter model. +/// +[Combinator] +[Description("Loads the parameters of a Kalman filter model.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class LoadKalmanFilterParameters +{ + /// + /// The path to the folder where the Kalman filter model parameters were saved. + /// + [Editor("Bonsai.Design.FolderNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + [Description("The path to the folder where the Kalman filter model parameters were saved.")] + public string Path { get; set; } = string.Empty; + + /// + /// Gets or sets the data type of the tensors. + /// + [Description("Gets or sets the data type of the tensors.")] + public ScalarType? Type { get; set; } = null; + + /// + /// Gets or sets the device to use for tensor operations. + /// + [XmlIgnore] + [Description("Gets or sets the device to use for tensor operations.")] + public Device Device { get; set; } = null; + + private static Tensor LoadTensorFromFile(string basePath, string filePath) + { + if (filePath == null) return null; + + filePath = System.IO.Path.Combine(basePath, filePath); + + if (File.Exists(filePath)) + { + return Tensor.Load(filePath); + } + + return null; + } + + /// + /// Creates parameters for a Kalman filter model using the properties of this class. + /// + public IObservable Process() + { + if (string.IsNullOrEmpty(Path)) + { + throw new InvalidOperationException("The save path is not specified."); + } + + if (!Directory.Exists(Path)) + { + throw new InvalidOperationException("The save path does not exist."); + } + + var transitionMatrix = LoadTensorFromFile(Path, "TransitionMatrix.bin"); + var measurementFunction = LoadTensorFromFile(Path, "MeasurementFunction.bin"); + var processNoiseCovariance = LoadTensorFromFile(Path, "ProcessNoiseCovariance.bin"); + var measurementNoiseCovariance = LoadTensorFromFile(Path, "MeasurementNoiseCovariance.bin"); + var initialMean = LoadTensorFromFile(Path, "InitialMean.bin"); + var initialCovariance = LoadTensorFromFile(Path, "InitialCovariance.bin"); + var stateOffset = LoadTensorFromFile(Path, "StateOffset.bin"); + var observationOffset = LoadTensorFromFile(Path, "ObservationOffset.bin"); + + var parameters = new KalmanFilterParameters( + transitionMatrix: transitionMatrix, + measurementFunction: measurementFunction, + processNoiseCovariance: processNoiseCovariance, + measurementNoiseCovariance: measurementNoiseCovariance, + initialMean: initialMean, + initialCovariance: initialCovariance, + stateOffset: stateOffset, + observationOffset: observationOffset, + device: Device, + scalarType: Type); + + return Observable.Return(parameters); + } +} diff --git a/src/Bonsai.ML.Lds.Torch/Orthogonalize.cs b/src/Bonsai.ML.Lds.Torch/Orthogonalize.cs new file mode 100644 index 00000000..061832ff --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/Orthogonalize.cs @@ -0,0 +1,61 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Orthogonalizes the state and covariance estimates from a Kalman filter or smoother. +/// +[Combinator] +[ResetCombinator] +[Description("Orthogonalizes the state and covariance estimates from a Kalman filter or smoother.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Orthogonalize +{ + /// + /// The Kalman filter model. + /// + [Description("The Kalman filter model.")] + [XmlIgnore] + public KalmanFilter Model { get; set; } + + /// + /// Processes an observable sequence of smoothed results, orthogonalizing the mean and covariance estimates. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Model.OrthogonalizeMeanAndCovariance); + } + + /// + /// Processes an observable sequence of filtered results, orthogonalizing the mean and covariance estimates. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => + { + return Model.OrthogonalizeMeanAndCovariance(input.UpdatedState); + }); + } + + /// + /// Processes an observable sequence of mean and covariance tuples, orthogonalizing the mean and covariance estimates. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(input => + { + var state = new LinearDynamicalSystemState(input.Item1, input.Item2); + return Model.OrthogonalizeMeanAndCovariance(state); + }); + } +} diff --git a/src/Bonsai.ML.Lds.Torch/ParametersToEstimate.cs b/src/Bonsai.ML.Lds.Torch/ParametersToEstimate.cs new file mode 100644 index 00000000..2127ef6e --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/ParametersToEstimate.cs @@ -0,0 +1,66 @@ +namespace Bonsai.ML.Lds.Torch; + +/// +/// Represents the parameters to estimate for a Kalman filter model. +/// +/// +/// Initializes a new instance of the struct with the specified parameters. +/// +/// +/// +/// +/// +/// +/// +/// +/// +public struct ParametersToEstimate( + bool transitionMatrix = true, + bool measurementFunction = true, + bool processNoiseCovariance = true, + bool measurementNoiseCovariance = true, + bool initialMean = true, + bool initialCovariance = true, + bool stateOffset = false, + bool observationOffset = false) +{ + /// + /// The state transition matrix. + /// + public bool TransitionMatrix = transitionMatrix; + + /// + /// The measurement function. + /// + public bool MeasurementFunction = measurementFunction; + + /// + /// The process noise covariance. + /// + public bool ProcessNoiseCovariance = processNoiseCovariance; + + /// + /// The measurement noise covariance. + /// + public bool MeasurementNoiseCovariance = measurementNoiseCovariance; + + /// + /// The initial mean. + /// + public bool InitialMean = initialMean; + + /// + /// The initial covariance. + /// + public bool InitialCovariance = initialCovariance; + + /// + /// The state offset. + /// + public bool StateOffset = stateOffset; + + /// + /// The observation offset. + /// + public bool ObservationOffset = observationOffset; +} diff --git a/src/Bonsai.ML.Lds.Torch/Properties/launchSettings.json b/src/Bonsai.ML.Lds.Torch/Properties/launchSettings.json new file mode 100644 index 00000000..4af4f468 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/Properties/launchSettings.json @@ -0,0 +1,10 @@ +{ + "profiles": { + "Bonsai": { + "commandName": "Executable", + "executablePath": "$(BonsaiExecutablePath)", + "commandLineArgs": "--lib:\"$(TargetDir).\"", + "nativeDebugging": true + } + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Lds.Torch/SaveKalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/SaveKalmanFilterParameters.cs new file mode 100644 index 00000000..83dff315 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/SaveKalmanFilterParameters.cs @@ -0,0 +1,146 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.IO; +using TorchSharp; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Saves the parameters of a Kalman filter model. +/// +[Combinator] +[Description("Saves the parameters of a Kalman filter model.")] +[WorkflowElementCategory(ElementCategory.Sink)] +public class SaveKalmanFilterParameters +{ + /// + /// The path to the folder where the Kalman filter model parameters will be saved. + /// + [Editor("Bonsai.Design.FolderNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + [Description("The path to the folder where the Kalman filter model parameters will be saved.")] + public string Path { get; set; } = string.Empty; + + /// + /// If true, the contents of the folder will be overwritten if it already exists. + /// + [Description("If true, the contents of the folder will be overwritten if it already exists.")] + public bool Overwrite { get; set; } = false; + + /// + /// Specifies the type of suffix to add to the save path. + /// The suffix is added as a subfolder to the save path. + /// If DateTime is used, a suffix with the current date and time is added to the save path in the format, '[Path]/yyyyMMddHHmmss/...'. + /// If Guid is used, a suffix with a unique 128-bit identifier is added to the save path in the format, '[Path]/[128-bit identifier]/...'. + /// + [Description("Specifies the type of suffix to add to the save path.")] + public SuffixType AddSuffix { get; set; } = SuffixType.None; + + private void SaveKalmanFilterParametersToDisk(KalmanFilterParameters parameters) + { + if (string.IsNullOrEmpty(Path)) + { + Path = Directory.GetCurrentDirectory(); + } + + var path = AddSuffix switch + { + SuffixType.DateTime => System.IO.Path.Combine(Path, $"{HighResolutionScheduler.Now:yyyyMMddHHmmss}"), + SuffixType.Guid => System.IO.Path.Combine(Path, $"{Guid.NewGuid()}"), + _ => Path + }; + + var transitionMatrixPath = System.IO.Path.Combine(path, "TransitionMatrix.bin"); + var measurementFunctionPath = System.IO.Path.Combine(path, "MeasurementFunction.bin"); + var processNoiseCovariancePath = System.IO.Path.Combine(path, "ProcessNoiseCovariance.bin"); + var measurementNoiseCovariancePath = System.IO.Path.Combine(path, "MeasurementNoiseCovariance.bin"); + var initialMeanPath = System.IO.Path.Combine(path, "InitialMean.bin"); + var initialCovariancePath = System.IO.Path.Combine(path, "InitialCovariance.bin"); + var stateOffsetPath = System.IO.Path.Combine(path, "StateOffset.bin"); + var observationOffsetPath = System.IO.Path.Combine(path, "ObservationOffset.bin"); + + if (Directory.Exists(path)) + { + if (!Overwrite && ( + File.Exists(transitionMatrixPath) || + File.Exists(measurementFunctionPath) || + File.Exists(processNoiseCovariancePath) || + File.Exists(measurementNoiseCovariancePath) || + File.Exists(initialMeanPath) || + File.Exists(initialCovariancePath) || + File.Exists(stateOffsetPath) || + File.Exists(observationOffsetPath)) + ) + { + throw new InvalidOperationException("The save path already exists."); + } + else + { + if (File.Exists(transitionMatrixPath)) + File.Delete(transitionMatrixPath); + if (File.Exists(measurementFunctionPath)) + File.Delete(measurementFunctionPath); + if (File.Exists(processNoiseCovariancePath)) + File.Delete(processNoiseCovariancePath); + if (File.Exists(measurementNoiseCovariancePath)) + File.Delete(measurementNoiseCovariancePath); + if (File.Exists(initialMeanPath)) + File.Delete(initialMeanPath); + if (File.Exists(initialCovariancePath)) + File.Delete(initialCovariancePath); + if (File.Exists(stateOffsetPath)) + File.Delete(stateOffsetPath); + if (File.Exists(observationOffsetPath)) + File.Delete(observationOffsetPath); + } + } + + Directory.CreateDirectory(path); + + parameters.TransitionMatrix?.Save(transitionMatrixPath); + parameters.MeasurementFunction?.Save(measurementFunctionPath); + parameters.ProcessNoiseCovariance?.Save(processNoiseCovariancePath); + parameters.MeasurementNoiseCovariance?.Save(measurementNoiseCovariancePath); + parameters.InitialMean?.Save(initialMeanPath); + parameters.InitialCovariance?.Save(initialCovariancePath); + parameters.StateOffset?.Save(stateOffsetPath); + parameters.ObservationOffset?.Save(observationOffsetPath); + } + + /// + /// Processes an observable sequence of Kalman filter parameters, saving to files. + /// + public IObservable Process(IObservable source) + { + return source.Do(SaveKalmanFilterParametersToDisk); + } + + /// + /// Processes an observable sequence of Kalman filter models, saving their parameters to files. + /// + public IObservable Process(IObservable source) + { + return source.Do(model => SaveKalmanFilterParametersToDisk(model.Parameters)); + } + + /// + /// Specifies the type of suffix to add to the save path. + /// + public enum SuffixType + { + /// + /// No suffix is added to the save path. + /// + None, + + /// + /// A suffix with the current date and time is added to the save path. + /// + DateTime, + + /// + /// A suffix with a unique identifier is added to the save path. + /// + Guid + } +} diff --git a/src/Bonsai.ML.Lds.Torch/Smooth.cs b/src/Bonsai.ML.Lds.Torch/Smooth.cs new file mode 100644 index 00000000..69985696 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/Smooth.cs @@ -0,0 +1,51 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Applies a Kalman smoother to the input filtered result sequence. +/// +[Combinator] +[ResetCombinator] +[Description("Applies a Kalman smoother to the input filtered result sequence.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Smooth +{ + /// + /// The Kalman filter model. + /// + [Description("The Kalman filter model.")] + [XmlIgnore] + public KalmanFilter Model { get; set; } + + /// + /// Processes an observable sequence of filtered results, applying the Kalman smoother to each result. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(Model.Smooth); + } + + /// + /// Processes an observable sequence of tuples containing the components of a filtered result (predictedMean, predictedCovariance, updatedMean, updatedCovariance), applying the Kalman smoother to each result. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select((input) => + { + var filteredState = new FilteredState( + predictedState: new LinearDynamicalSystemState(input.Item1, input.Item2), + updatedState: new LinearDynamicalSystemState(input.Item3, input.Item4) + ); + return Model.Smooth(filteredState); + }); + } +} diff --git a/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs b/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs new file mode 100644 index 00000000..be041086 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs @@ -0,0 +1,134 @@ +using System; +using System.Linq; +using System.ComponentModel; +using System.Reactive.Linq; +using TorchSharp; +using static TorchSharp.torch; +using System.Threading.Tasks; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Learn the parameters of a kalman filter using the stochastic subspace identification method. +/// +[Combinator] +[Description("Learn the parameters of a kalman filter using the stochastic subspace identification method.")] +[WorkflowElementCategory(ElementCategory.Combinator)] +public class StochasticSubspaceIdentification +{ + private int? _targetNumStates = 2; + private int _maxLag = 20; + private double _threshold = 1e-4; + + /// + /// The target number of states in the Kalman filter model. + /// + [Description("The target number of states in the Kalman filter model.")] + public int? TargetNumStates + { + get => _targetNumStates; + set => _targetNumStates = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(value), "Number of states must be greater than zero."); + } + + /// + /// The maximum lag to consider for the subspace identification. + /// + [Description("The maximum lag to consider for the subspace identification.")] + public int MaxLag + { + get => _maxLag; + set => _maxLag = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(MaxLag), "Must be greater than zero."); + } + + /// + /// The threshold for the singular values to determine the effective number of states. + /// + [Description("The threshold for the singular values to determine the effective number of states.")] + public double Threshold + { + get => _threshold; + set => _threshold = value >= 0 && value < 1 ? value : throw new ArgumentOutOfRangeException(nameof(Threshold), "Must be greater than or equal to zero and less than one."); + } + + /// + /// If true, the transition matrix will be estimated during the SSI algorithm. + /// + [Description("If true, the transition matrix will be estimated during the SSI algorithm.")] + public bool EstimateTransitionMatrix { get; set; } = true; + + /// + /// If true, the measurement function will be estimated during the SSI algorithm. + /// + [Description("If true, the measurement function will be estimated during the SSI algorithm.")] + public bool EstimateMeasurementFunction { get; set; } = true; + + /// + /// If true, the process noise covariance will be estimated during the SSI algorithm. + /// + [Description("If true, the process noise covariance will be estimated during the SSI algorithm.")] + public bool EstimateProcessNoiseCovariance { get; set; } = true; + + /// + /// If true, the measurement noise covariance will be estimated during the SSI algorithm. + /// + [Description("If true, the measurement noise covariance will be estimated during the SSI algorithm.")] + public bool EstimateMeasurementNoiseCovariance { get; set; } = true; + + /// + /// If true, the initial mean will be estimated during the SSI algorithm. + /// + [Description("If true, the initial mean will be estimated during the SSI algorithm.")] + public bool EstimateInitialMean { get; set; } = true; + + /// + /// If true, the initial covariance will be estimated during the SSI algorithm. + /// + [Description("If true, the initial covariance will be estimated during the SSI algorithm.")] + public bool EstimateInitialCovariance { get; set; } = true; + + /// + /// If true, the state offset will be estimated during the SSI algorithm. + /// + [Description("If true, the state offset will be estimated during the SSI algorithm.")] + public bool EstimateStateOffset { get; set; } = false; + + /// + /// If true, the observation offset will be estimated during the SSI algorithm. + /// + [Description("If true, the observation offset will be estimated during the SSI algorithm.")] + public bool EstimateObservationOffset { get; set; } = false; + + /// + /// Processes an observable sequence of input tensors, applying the Stochastic Subspace Identification algorithm to learn the parameters of a Kalman filter model. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => Observable.Create(observer => + { + return Task.Run(() => + { + var parametersToEstimate = new ParametersToEstimate( + transitionMatrix: EstimateTransitionMatrix, + measurementFunction: EstimateMeasurementFunction, + processNoiseCovariance: EstimateProcessNoiseCovariance, + measurementNoiseCovariance: EstimateMeasurementNoiseCovariance, + initialMean: EstimateInitialMean, + initialCovariance: EstimateInitialCovariance, + stateOffset: EstimateStateOffset, + observationOffset: EstimateObservationOffset); + + observer.OnNext(KalmanFilter.StochasticSubspaceIdentification( + observations: input, + targetNumStates: TargetNumStates, + maxLag: MaxLag, + threshold: Threshold, + parametersToEstimate: parametersToEstimate)); + + observer.OnCompleted(); + return System.Reactive.Disposables.Disposable.Empty; + }); + })).Concat(); + } +} diff --git a/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentificationResult.cs b/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentificationResult.cs new file mode 100644 index 00000000..b46d16c4 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentificationResult.cs @@ -0,0 +1,31 @@ +using static TorchSharp.torch; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Represents the result of subspace identification using the Kalman-Ho method. +/// +/// The identified Kalman filter parameters. +/// The effective states of the system determined by SVD. +/// The singular values from the SVD decomposition. +public struct StochasticSubspaceIdentificationResult( + KalmanFilterParameters parameters, + long effectiveStates, + Tensor singularValues) +{ + /// + /// The identified Kalman filter parameters from subspace identification. + /// + public KalmanFilterParameters Parameters = parameters; + + /// + /// The effective states of the system determined by SVD truncation. + /// + public long EffectiveStates = effectiveStates; + + /// + /// The singular values from the SVD decomposition of the Hankel matrix. + /// These can be used to assess the quality of the identification and choose the model order. + /// + public Tensor SingularValues = singularValues; +} diff --git a/src/Bonsai.ML.Lds.Torch/UpdateKalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/UpdateKalmanFilterParameters.cs new file mode 100644 index 00000000..edcf14c3 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/UpdateKalmanFilterParameters.cs @@ -0,0 +1,33 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Updates the parameters of a Kalman filter model instance using the provided Kalman filter parameters. +/// +[Combinator] +[ResetCombinator] +[Description("Updates the parameters of a Kalman filter model instance using the provided Kalman filter parameters.")] +[WorkflowElementCategory(ElementCategory.Sink)] +public class UpdateKalmanFilterParameters +{ + /// + /// The Kalman filter model. + /// + [XmlIgnore] + [Description("The Kalman filter model.")] + public KalmanFilter Model { get; set; } + + /// + /// Updates the parameters of a Kalman filter model using elements from the input sequence. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Do(Model.UpdateParameters); + } +} \ No newline at end of file diff --git a/tests/Bonsai.ML.Lds.Python.Tests/ReceptiveFieldSimpleCellTest.cs b/tests/Bonsai.ML.Lds.Python.Tests/ReceptiveFieldSimpleCellTest.cs index 61de29a4..a2f159de 100644 --- a/tests/Bonsai.ML.Lds.Python.Tests/ReceptiveFieldSimpleCellTest.cs +++ b/tests/Bonsai.ML.Lds.Python.Tests/ReceptiveFieldSimpleCellTest.cs @@ -93,9 +93,9 @@ private static bool CompareJSONData(string basePath, double tolerance = 1e-9) { for (int j = 0; j < bonsaiOutput.X.GetLength(1); j++) { - if (Math.Abs(bonsaiOutput.X[i,j] - pythonOutput.X[i,j]) > tolerance || Math.Abs(originalOutput.X[i,j] - pythonOutput.X[i,j]) > tolerance) + if (Math.Abs(bonsaiOutput.X[i, j] - pythonOutput.X[i, j]) > tolerance || Math.Abs(originalOutput.X[i, j] - pythonOutput.X[i, j]) > tolerance) { - Console.WriteLine($"Discrepency found comparing X at index ({i},{j}) with tolerance {tolerance}: bonsaiOutput = {bonsaiOutput.X[i,j]}, pythonOutput = {pythonOutput.X[i,j]}, originalOutput = {originalOutput.X[i,j]}."); + Console.WriteLine($"Discrepency found comparing X at index ({i},{j}) with tolerance {tolerance}: bonsaiOutput = {bonsaiOutput.X[i, j]}, pythonOutput = {pythonOutput.X[i, j]}, originalOutput = {originalOutput.X[i, j]}."); return false; } } @@ -104,9 +104,9 @@ private static bool CompareJSONData(string basePath, double tolerance = 1e-9) { for (int j = 0; j < bonsaiOutput.P.GetLength(1); j++) { - if (Math.Abs(bonsaiOutput.P[i,j] - pythonOutput.P[i,j]) > tolerance || Math.Abs(originalOutput.P[i,j] - pythonOutput.P[i,j]) > tolerance) + if (Math.Abs(bonsaiOutput.P[i, j] - pythonOutput.P[i, j]) > tolerance || Math.Abs(originalOutput.P[i, j] - pythonOutput.P[i, j]) > tolerance) { - Console.WriteLine($"Discrepency found comparing P at index ({i},{j}) with tolerance {tolerance}: bonsaiOutput = {bonsaiOutput.P[i,j]}, pythonOutput = {pythonOutput.P[i,j]}, originalOutput = {originalOutput.P[i,j]}."); + Console.WriteLine($"Discrepency found comparing P at index ({i},{j}) with tolerance {tolerance}: bonsaiOutput = {bonsaiOutput.P[i, j]}, pythonOutput = {pythonOutput.P[i, j]}, originalOutput = {originalOutput.P[i, j]}."); return false; } } diff --git a/tests/Bonsai.ML.Lds.Torch.Tests/Bonsai.ML.Lds.Torch.Tests.csproj b/tests/Bonsai.ML.Lds.Torch.Tests/Bonsai.ML.Lds.Torch.Tests.csproj new file mode 100644 index 00000000..9893cf36 --- /dev/null +++ b/tests/Bonsai.ML.Lds.Torch.Tests/Bonsai.ML.Lds.Torch.Tests.csproj @@ -0,0 +1,34 @@ + + + + net8.0 + enable + enable + false + true + + + + + + + + + + + + + + + + Always + + + + + + + + + + diff --git a/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai new file mode 100644 index 00000000..af3f1d26 --- /dev/null +++ b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai @@ -0,0 +1,412 @@ + + + + + + LoadData + + + + + transformed_binned_spikes.pt + true + + + + + + + + + 1 + 0 + + + + + ObservationT + + + + + python_V0_0.pt + true + + + + + + + Covariance + + + + + python_m0_0.pt + true + + + + + + + Mean + + + + + python_Z0.pt + true + + + + + + + MeasurementFunction + + + + + python_R0.pt + true + + + + + + + MeasurementNoiseCovariance + + + + + python_B0.pt + true + + + + + + + TransitionMatrix + + + + + python_Q0.pt + true + + + + + + + ProcessNoiseCovariance + + + + + + + + + + + + + + + + + + + + + + + + LoadModel + + + + TransitionMatrix + + + MeasurementFunction + + + ProcessNoiseCovariance + + + MeasurementNoiseCovariance + + + Mean + + + Covariance + + + + + + + + + + + + + + + + + Float64 + + + + + + + + + + + + + Float64 + 2 + 2 + + + + + + + + + + KalmanFilterModel + + + + + + + + + + + + + + + + + + + + LearnParameters + + + + ObservationT + + + KalmanFilterModel + + + Parameters + + + + + + + + + + 1 + 0.0001 + true + true + true + true + true + true + true + + + + Parameters + + + KalmanFilterModel + + + + + + + + + + + ExpectationMaximizationResult + + + KalmanFilterModel + + + + + + + + + + + + + + + + + + + + + + Smoother + + + + ObservationT + + + KalmanFilterModel + + + + + + + + + + + UpdatedFilteredResult + + + ExpectationMaximizationResult + + + + + + UpdatedFilteredResult + + + KalmanFilterModel + + + + + + + + + + + UpdatedSmoothedResult + + + UpdatedSmoothedResult + + + KalmanFilterModel + + + + + + + + + + + OrthogonalizedResult + + + OrthogonalizedResult + + + Mean + + + + bonsai_means.pt + true + + + + OrthogonalizedResult + + + Covariance + + + + bonsai_covs.pt + true + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 1 + + + + + + + + + + \ No newline at end of file diff --git a/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs new file mode 100644 index 00000000..c641fe93 --- /dev/null +++ b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs @@ -0,0 +1,110 @@ +using Newtonsoft.Json; +using System; +using System.Diagnostics; +using System.IO; +using System.IO.Compression; +using System.Net.Http; +using System.Reactive.Linq; +using System.Runtime.InteropServices; +using System.Threading.Tasks; +using Bonsai.ML.Tests.Utilities; +using static TorchSharp.torch; +using TorchSharp; +using System.Text; + +namespace Bonsai.ML.Lds.Torch.Tests; + +/// +/// Tests for the neural latents workflow. +/// +[TestClass] +public class NeuralLatentsTest +{ + private readonly string basePath = Path.Combine(AppDomain.CurrentDomain.BaseDirectory); + + private static void DownloadData(string basePath) + { + string zipFileUrl = "https://zenodo.org/records/17427805/files/Bonsai.ML.Lds.Torch.Tests.zip"; + + try + { + byte[] responseBytes; + using (var httpClient = new HttpClient()) + { + httpClient.DefaultRequestHeaders.Add("User-Agent", "Bonsai.ML.Tests"); + responseBytes = httpClient.GetByteArrayAsync(zipFileUrl).Result; + Console.WriteLine("File downloaded successfully."); + } + + using MemoryStream zipStream = new(responseBytes); + using ZipArchive zip = new(zipStream, ZipArchiveMode.Read); + zip.ExtractToDirectory(basePath); + Console.WriteLine("File extracted successfully."); + } + catch (Exception ex) + { + Console.WriteLine($"An error occurred: {ex.Message}"); + } + } + + private static async Task RunBonsaiWorkflow(string basePath) + { + Console.WriteLine($"Running Bonsai workflow..."); + var currentDirectory = Environment.CurrentDirectory; + Environment.CurrentDirectory = basePath; + + try + { + var workflowPath = Path.Combine(basePath, "NeuralLatentsTest.bonsai"); + await WorkflowHelper.RunWorkflow( + workflowPath); + Console.WriteLine("Run bonsai workflow finished."); + } + finally { Environment.CurrentDirectory = currentDirectory; } + } + + /// + /// Setup for the test. + /// + [TestInitialize] + [DeploymentItem("NeuralLatentsTest.bonsai")] + public async Task TestSetup() + { + Directory.CreateDirectory(basePath); + DownloadData(basePath); + await RunBonsaiWorkflow(basePath); + } + + /// + /// Cleanup files generated for test. + /// + [TestCleanup] + public void TestCleanup() + { + var ptFiles = Directory.GetFiles(basePath, "*.pt"); + var zipFiles = Directory.GetFiles(basePath, "*.zip"); + foreach (var file in ptFiles) File.Delete(file); + foreach (var file in zipFiles) File.Delete(file); + } + + /// + /// Compares the results from the Python script and the Bonsai workflow. + /// + [TestMethod] + public void CompareTensorData() + { + var bonsaiMeansFileName = Path.Combine(basePath, "bonsai_means.pt"); + var bonsaiCovariancesFileName = Path.Combine(basePath, "bonsai_covs.pt"); + + var pythonMeansFileName = Path.Combine(basePath, "python_means.pt"); + var pythonCovariancesFileName = Path.Combine(basePath, "python_covs.pt"); + + var bonsaiMeans = Tensor.Load(bonsaiMeansFileName); + var bonsaiCovariances = Tensor.Load(bonsaiCovariancesFileName); + var pythonMeans = Tensor.Load(pythonMeansFileName).permute(1, 0); + var pythonCovariances = Tensor.Load(pythonCovariancesFileName).permute(2, 0, 1); + + Assert.IsTrue(allclose(bonsaiMeans, pythonMeans)); + Assert.IsTrue(allclose(bonsaiCovariances, pythonCovariances)); + } +} diff --git a/tests/Bonsai.ML.Lds.Torch.Tests/bootstrap_test_environment.py b/tests/Bonsai.ML.Lds.Torch.Tests/bootstrap_test_environment.py new file mode 100644 index 00000000..08beff7c --- /dev/null +++ b/tests/Bonsai.ML.Lds.Torch.Tests/bootstrap_test_environment.py @@ -0,0 +1,94 @@ +import sys +import os +import subprocess +import argparse + +def get_venv_path(): + return os.path.join(os.path.dirname(os.path.realpath(__file__)), '.venv') + +def get_pip_path(venv_path: str = None): + if venv_path is None: + venv_path = get_venv_path() + if sys.platform.startswith('linux'): + return os.path.join(venv_path, 'bin', 'pip') + else: + return os.path.join(venv_path, 'Scripts', 'pip.exe') + +def get_python_path(venv_path: str = None): + if venv_path is None: + venv_path = get_venv_path() + if sys.platform.startswith('linux'): + return os.path.join(venv_path, 'bin', 'python') + else: + return os.path.join(venv_path, 'Scripts', 'python.exe') + +def get_base_dir(base_dir: str = None): + # function to get the base directory + if base_dir is not None: + return base_dir + try: + return os.path.dirname(os.path.realpath(__file__)) + except: + return os.getcwd() + +def create_venv(parent_dir: str = None): + # function to create a virtual environment + if parent_dir is None: + parent_dir = os.path.dirname(os.path.realpath(__file__)) + venv_path = os.path.join(parent_dir, ".venv") + subprocess.check_call([sys.executable, "-m", "venv", venv_path]) + return venv_path + +def activate_venv(venv_path: str = None): + # function to activate the virtual environment + if venv_path is None: + venv_path = get_venv_path() + if sys.platform.startswith('linux'): + bin_path = os.path.join(venv_path, 'bin') + os.environ["PATH"] = os.pathsep.join([bin_path, *os.environ.get("PATH", "").split(os.pathsep)]) + sys.path.insert(0, os.path.join(venv_path, 'lib', f'python{sys.version_info.major}.{sys.version_info.minor}', 'site-packages')) + else: + bin_path = os.path.join(venv_path, 'Scripts') + os.environ["PATH"] = os.pathsep.join([bin_path, *os.environ.get("PATH", "").split(os.pathsep)]) + sys.path.insert(0, os.path.join(venv_path, 'Lib', 'site-packages')) + +def install(venv_path: str = None, pip_args: list[str] = None): + # function to install pip packages into a virtual environment + if venv_path is None: + venv_path = get_venv_path() + pip_path = get_pip_path(venv_path) + if pip_args is None: + raise ValueError("pip_args must be provided") + subprocess.check_call([pip_path, "install", *pip_args]) + +def install_requirements(requirements_file: str, venv_path: str = None): + # function to install pip packages from a requirements file into a virtual environment + if venv_path is None: + venv_path = get_venv_path() + pip_path = get_pip_path(venv_path) + subprocess.check_call([pip_path, "install", "-r", requirements_file]) + +parser = argparse.ArgumentParser() +parser.add_argument("base_dir", type=str, default=None) +args = parser.parse_args() + +base_dir = get_base_dir(args.base_dir) +venv_path = create_venv(base_dir) +activate_venv(venv_path) + +install(venv_path, ["--no-cache-dir", "torch"]) +install(venv_path, ["--no-cache-dir", "plotly"]) +install(venv_path, ["--no-cache-dir", "remfile"]) +install(venv_path, ["--no-cache-dir", "dandi"]) +install(venv_path, ["--no-cache-dir", "ssm@git+https://github.com/ncguilbeault/lds_python@75e3e5e92ce6344009b62a5034db49b238db63ef"]) + +python_path = get_python_path(venv_path) + +script_path = os.path.join(base_dir, "estimate_neural_latents.py") +process = subprocess.Popen([python_path, script_path, base_dir]) +return_code = process.wait() + +if return_code == 0: + print("Script completed successfully.") +else: + print(f"Script exited with errors. Return code: {return_code}") \ No newline at end of file diff --git a/tests/Bonsai.ML.Lds.Torch.Tests/estimate_neural_latents.py b/tests/Bonsai.ML.Lds.Torch.Tests/estimate_neural_latents.py new file mode 100644 index 00000000..9dc3666d --- /dev/null +++ b/tests/Bonsai.ML.Lds.Torch.Tests/estimate_neural_latents.py @@ -0,0 +1,112 @@ +import numpy as np +import remfile, h5py + +from dandi.dandiapi import DandiAPIClient +from pynwb import NWBHDF5IO + +import ssm.inference +import ssm.learning +import ssm.neural_latents.utils +import ssm.neural_latents.plotting + +import argparse +import os + +# Parse arguments +try: + parser = argparse.ArgumentParser() + parser.add_argument("base_dir", type=str, default=None) + args = parser.parse_args() + + base_dir = args.base_dir +except: + base_dir = os.path.realpath(os.path.dirname(__file__)) + +# data +dandiset_ID = "000140" +dandi_filepath = "sub-Jenkins/sub-Jenkins_ses-small_desc-train_behavior+ecephys.nwb" +bin_size = 0.02 + +# model +n_latents = 10 + +# estimation initial conditions +sigma_B = 0.1 +sigma_Z = 0.1 +sigma_Q = 0.1 +sigma_R = 0.1 +sigma_m0 = 0.1 +sigma_V0 = 0.1 + +# estimation parameters +max_iter = 1 +tol = 0.1 +vars_to_estimate = {"B": True, "Q": True, "Z": True, "R": True, + "m0": True, "V0": True, } + +with DandiAPIClient() as client: + asset = client.get_dandiset(dandiset_ID, + "draft").get_asset_by_path(dandi_filepath) + s3_path = asset.get_content_url(follow_redirects=1, strip_query=True) + cache = remfile.DiskCache("./remfile_cache") + rf = remfile.File(s3_path, disk_cache=cache) + with h5py.File(rf, "r") as h: + with NWBHDF5IO(file=h, mode="r") as io: + nwbfile = io.read() + units_df = nwbfile.units.to_dataframe() + trials_df = nwbfile.intervals["trials"].to_dataframe() + +# n_clusters +n_clusters = units_df.shape[0] + +# continuous spikes times +continuous_spikes_times = [None for n in range(n_clusters)] +for n in range(n_clusters): + continuous_spikes_times[n] = units_df.iloc[n]['spike_times'] + +binned_spikes, bin_edges = ssm.neural_latents.utils.bin_spike_times( + spike_times=continuous_spikes_times, bin_size=bin_size) +bin_centers = (bin_edges[1:] + bin_edges[:-1])/2 +transformed_binned_spikes = np.sqrt(binned_spikes + 0.5) + +transformed_binned_spikes.astype(float).tofile(os.path.join(base_dir, "transformed_binned_spikes.bin")) + +np.random.seed(0) + +B0 = np.diag(np.random.normal(loc=0, scale=sigma_B, size=n_latents)) +Z0 = np.random.normal(loc=0, scale=sigma_Z, size=(n_clusters, n_latents)) +Q0 = np.diag(np.abs(np.random.normal(loc=0, scale=sigma_Q, size=n_latents))) +R0 = np.diag(np.abs(np.random.normal(loc=0, scale=sigma_R, size=n_clusters))) +m0_0 = np.random.normal(loc=0, scale=sigma_m0, size=n_latents) +V0_0 = np.diag(np.abs(np.random.normal(loc=0, scale=sigma_V0, size=n_latents))) + +# Save initial parameters to binary file +B0.astype(float).tofile(os.path.join(base_dir, "python_B0.bin")) +Z0.astype(float).tofile(os.path.join(base_dir, "python_Z0.bin")) +Q0.astype(float).tofile(os.path.join(base_dir, "python_Q0.bin")) +R0.astype(float).tofile(os.path.join(base_dir, "python_R0.bin")) +m0_0.astype(float).tofile(os.path.join(base_dir, "python_m0_0.bin")) +V0_0.astype(float).tofile(os.path.join(base_dir, "python_V0_0.bin")) + +optim_res = ssm.learning.em_SS_LDS( + y=transformed_binned_spikes, B0=B0, Q0=Q0, Z0=Z0, R0=R0, + m0_0=m0_0, V0_0=V0_0, max_iter=max_iter, tol=tol, + vars_to_estimate=vars_to_estimate, +) + +filter_res = ssm.inference.filterLDS_SS_withMissingValues_np( + y=transformed_binned_spikes, B=optim_res["B"], Q=optim_res["Q"], + m0=optim_res["m0"], V0=optim_res["V0"], Z=optim_res["Z"], R=optim_res["R"]) + +smoothing_res = ssm.inference.smoothLDS_SS( + B=optim_res["B"], xnn=filter_res["xnn"], Pnn=filter_res["Pnn"], + xnn1=filter_res["xnn1"], Pnn1=filter_res["Pnn1"], + m0=optim_res["m0"], V0=optim_res["V0"]) + +o_means, o_covs = ssm.neural_latents.utils.ortogonalizeMeansAndCovs( + means=smoothing_res["xnN"], + covs=smoothing_res["PnN"], Z=optim_res["Z"]) + +# save outputs to binary file +o_means.astype(float).tofile(os.path.join(base_dir, "python_means.bin")) +o_covs.astype(float).tofile(os.path.join(base_dir, "python_covs.bin"))