From 12f5a9019f11721f006a09379b30faafda73068f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 12 Sep 2025 11:17:11 +0100 Subject: [PATCH 01/92] Added kalman filter neural latents package to the Bonsai.ML project --- Bonsai.ML.sln | 6 ++++++ src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj | 10 ++++++++++ src/Bonsai.ML.Torch.LDS/Properties/launchSettings.json | 10 ++++++++++ 3 files changed, 26 insertions(+) create mode 100644 src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj create mode 100644 src/Bonsai.ML.Torch.LDS/Properties/launchSettings.json diff --git a/Bonsai.ML.sln b/Bonsai.ML.sln index 62e44269..e2b989d4 100644 --- a/Bonsai.ML.sln +++ b/Bonsai.ML.sln @@ -40,6 +40,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "build", "build", "{DEE5DD87 EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Tests.Utilities", "tests\Bonsai.ML.Tests.Utilities\Bonsai.ML.Tests.Utilities.csproj", "{DB1090D3-38DD-404B-96DF-66BF3C39E508}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Torch.LDS", "src\Bonsai.ML.Torch.LDS\Bonsai.ML.Torch.LDS.csproj", "{41D4BEC7-AB1F-41E4-95FE-4DB23970FF4B}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -102,6 +104,10 @@ Global {DB1090D3-38DD-404B-96DF-66BF3C39E508}.Debug|Any CPU.Build.0 = Debug|Any CPU {DB1090D3-38DD-404B-96DF-66BF3C39E508}.Release|Any CPU.ActiveCfg = Release|Any CPU {DB1090D3-38DD-404B-96DF-66BF3C39E508}.Release|Any CPU.Build.0 = Release|Any CPU + {41D4BEC7-AB1F-41E4-95FE-4DB23970FF4B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {41D4BEC7-AB1F-41E4-95FE-4DB23970FF4B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {41D4BEC7-AB1F-41E4-95FE-4DB23970FF4B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {41D4BEC7-AB1F-41E4-95FE-4DB23970FF4B}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj b/src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj new file mode 100644 index 00000000..ec1fb9cb --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj @@ -0,0 +1,10 @@ + + + Bonsai.ML.Torch.LDS Bonsai library. + Bonsai Rx Bonsai ML Torch TorchSharp LDS LinearDynamicalSystems + net472 + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/Properties/launchSettings.json b/src/Bonsai.ML.Torch.LDS/Properties/launchSettings.json new file mode 100644 index 00000000..4af4f468 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/Properties/launchSettings.json @@ -0,0 +1,10 @@ +{ + "profiles": { + "Bonsai": { + "commandName": "Executable", + "executablePath": "$(BonsaiExecutablePath)", + "commandLineArgs": "--lib:\"$(TargetDir).\"", + "nativeDebugging": true + } + } +} \ No newline at end of file From 79d48e3feb50ab3d8bf40795b071191424e0fb5c Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 17 Sep 2025 10:39:32 +0100 Subject: [PATCH 02/92] Added `Bonsai.ML.Torch` project to references --- src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj b/src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj index ec1fb9cb..7f755850 100644 --- a/src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj +++ b/src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj @@ -7,4 +7,7 @@ + + + \ No newline at end of file From dd273a8357844a67ac8897f51c43c2b6bc4081f7 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 17 Sep 2025 10:45:40 +0100 Subject: [PATCH 03/92] Added KalmanFilter class --- src/Bonsai.ML.Torch.LDS/KalmanFilter.cs | 647 ++++++++++++++++++++++++ 1 file changed, 647 insertions(+) create mode 100644 src/Bonsai.ML.Torch.LDS/KalmanFilter.cs diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs new file mode 100644 index 00000000..e783f196 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs @@ -0,0 +1,647 @@ +using System; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LDS; + +internal class KalmanFilter : nn.Module +{ + private readonly Tensor _transitionMatrix; + private readonly Tensor _measurementFunction; + private readonly Tensor _initialState; + private readonly Tensor _initialCovariance; + private readonly Tensor _processNoiseCovariance; + private readonly Tensor _measurementNoiseCovariance; + private readonly Tensor _identityStates; + private readonly Tensor _identityObservations; + private readonly Tensor _state; + private readonly Tensor _covariance; + private readonly int _numStates; + private readonly int _numObservations; + private readonly Device _device; + private readonly ScalarType _scalarType; + + public KalmanFilterParameters Parameters + { + get + { + return new KalmanFilterParameters( + _transitionMatrix, + _measurementFunction, + _processNoiseCovariance, + _measurementNoiseCovariance, + _initialState, + _initialCovariance + ); + } + } + + public KalmanFilter( + KalmanFilterParameters parameters, + Device device = null, + ScalarType scalarType = ScalarType.Float32) : base("KalmanFilter") + { + device ??= CPU; + + _device = device; + _scalarType = scalarType; + + var transitionMatrix = parameters.TransitionMatrix; + var measurementFunction = parameters.MeasurementFunction; + var initialState = parameters.InitialState; + var initialCovariance = parameters.InitialCovariance; + var processNoiseCovariance = parameters.ProcessNoiseCovariance; + var measurementNoiseCovariance = parameters.MeasurementNoiseCovariance; + + if (transitionMatrix is null) + { + throw new ArgumentException("Transition matrix cannot be null."); + } + else + { + if (transitionMatrix.Dimensions != 2 || + transitionMatrix.size(0) != transitionMatrix.size(1)) + { + throw new ArgumentException("Transition matrix must be square."); + } + _transitionMatrix = transitionMatrix.clone().to_type(_scalarType).requires_grad_(false); + _numStates = (int)transitionMatrix.size(0); + _identityStates = eye(_numStates, dtype: _scalarType, device: _device); + } + + if (measurementFunction is null) + { + throw new ArgumentException("Measurement function cannot be null."); + } + else + { + if (measurementFunction.Dimensions != 2 || + measurementFunction.size(1) != _numStates) + { + throw new ArgumentException("Observation matrix must have dimensions [numObservations, numStates]."); + } + _measurementFunction = measurementFunction.clone().to_type(_scalarType).requires_grad_(false); + _numObservations = (int)measurementFunction.size(0); + _identityObservations = eye(_numObservations, dtype: _scalarType, device: _device); + } + + if (initialState is null) + { + throw new ArgumentException("Initial state cannot be null."); + } + else + { + if (initialState.NumberOfElements != _numStates) + { + throw new ArgumentException("Initial state must be a vector with length equal to the number of states."); + } + _initialState = initialState.clone().to_type(_scalarType).requires_grad_(false); + } + + if (initialCovariance is null) + { + throw new ArgumentException("Initial covariance cannot be null."); + } + else + { + if (initialCovariance.Dimensions != 2 || + initialCovariance.size(0) != _numStates || + initialCovariance.size(1) != _numStates) + { + throw new ArgumentException("Initial covariance must be square with dimensions equal to the number of states."); + } + _initialCovariance = initialCovariance.clone().to_type(_scalarType).requires_grad_(false); + } + + if (processNoiseCovariance is null) + { + throw new ArgumentException("Process noise covariance cannot be null."); + } + else + { + if (processNoiseCovariance.Dimensions != 2 || + processNoiseCovariance.size(0) != _numStates || + processNoiseCovariance.size(0) != _numStates) + { + throw new ArgumentException("Process noise covariance must be square with dimensions equal to the number of states."); + } + _processNoiseCovariance = processNoiseCovariance.clone().to_type(_scalarType).requires_grad_(false); + } + + if (measurementNoiseCovariance is null) + { + throw new ArgumentException("Measurement noise covariance cannot be null."); + } + else + { + if (measurementNoiseCovariance.Dimensions != 2 || + measurementNoiseCovariance.size(0) != _numObservations || + measurementNoiseCovariance.size(1) != _numObservations) + { + throw new ArgumentException("Measurement noise variance must be a scalar."); + } + _measurementNoiseCovariance = measurementNoiseCovariance.clone().to_type(_scalarType).requires_grad_(false); + } + + _state = _initialState.clone(); + _covariance = _initialCovariance.clone(); + } + + public KalmanFilter( + int numStates, + int numObservations, + Tensor transitionMatrix = null, + Tensor measurementFunction = null, + Tensor initialState = null, + Tensor initialCovariance = null, + Tensor processNoiseVariance = null, + Tensor measurementNoiseVariance = null, + Device device = null, + ScalarType scalarType = ScalarType.Float32) : base("KalmanFilter") + { + device ??= CPU; + + _device = device; + _scalarType = scalarType; + _numStates = numStates; + _identityStates = eye(numStates, dtype: _scalarType, device: _device); + _identityObservations = eye(numObservations, dtype: _scalarType, device: _device); + + if (transitionMatrix is null) + { + _transitionMatrix = eye(numStates, dtype: _scalarType, device: _device); + } + else + { + if (transitionMatrix.Dimensions != 2 || + transitionMatrix.shape[0] != numStates || + transitionMatrix.shape[1] != numStates) + { + throw new ArgumentException("Transition matrix must be square with dimensions equal to the number of states."); + } + _transitionMatrix = transitionMatrix.clone().to_type(_scalarType).requires_grad_(false); + } + + if (measurementFunction is null) + { + _measurementFunction = eye(numObservations, numStates, dtype: _scalarType, device: _device); + } + else + { + if (measurementFunction.Dimensions != 2 || + measurementFunction.shape[0] != numObservations || + measurementFunction.shape[1] != numStates) + { + throw new ArgumentException("Observation matrix must have dimensions [numObservations, numStates]."); + } + _measurementFunction = measurementFunction.clone().to_type(_scalarType).requires_grad_(false); + } + + if (initialState is null) + { + _initialState = zeros(numStates, dtype: _scalarType, device: _device); + } + else + { + if (initialState.NumberOfElements != numStates) + { + throw new ArgumentException("Initial state must be a vector with length equal to the number of states."); + } + _initialState = initialState.clone().to_type(_scalarType).requires_grad_(false); + } + + if (initialCovariance is null) + { + _initialCovariance = eye(numStates, dtype: _scalarType, device: _device); + } + else + { + if (initialCovariance.Dimensions != 2 || + initialCovariance.shape[0] != numStates || + initialCovariance.shape[1] != numStates) + { + throw new ArgumentException("Initial covariance must be square with dimensions equal to the number of states."); + } + _initialCovariance = initialCovariance.clone().to_type(_scalarType).requires_grad_(false); + } + + if (processNoiseVariance is null) + { + _processNoiseCovariance = eye(numStates, dtype: _scalarType, device: _device); + } + else + { + if (processNoiseVariance.NumberOfElements != 1) + { + throw new ArgumentException("Process noise variance must be a scalar."); + } + _processNoiseCovariance = (processNoiseVariance * eye(numStates, dtype: _scalarType, device: _device)).requires_grad_(false); + } + + if (measurementNoiseVariance is null) + { + _measurementNoiseCovariance = eye(numObservations, dtype: _scalarType, device: _device); + } + else + { + if (measurementNoiseVariance.NumberOfElements != 1) + { + throw new ArgumentException("Measurement noise variance must be a scalar."); + } + _measurementNoiseCovariance = (measurementNoiseVariance * eye(numObservations, dtype: _scalarType, device: _device)).requires_grad_(false); + } + + _state = _initialState.clone(); + _covariance = _initialCovariance.clone(); + + RegisterComponents(); + } + + private struct PredictedResult(Tensor predictedState, Tensor predictedCovariance) + { + public Tensor PredictedState = predictedState; + public Tensor PredictedCovariance = predictedCovariance; + } + + private PredictedResult FilterPredict( + Tensor state, + Tensor covariance + ) + { + var predictedState = _transitionMatrix.matmul(state); + var predictedCovariance = EnsureSymmetric(_transitionMatrix.matmul(covariance) + .matmul(_transitionMatrix.mT) + + _processNoiseCovariance); + + return new PredictedResult(predictedState, predictedCovariance); + } + + private struct UpdatedResult( + Tensor updatedState, + Tensor updatedCovariance, + Tensor innovation, + Tensor innovationCovariance, + Tensor kalmanGain + ) + { + public Tensor UpdatedState = updatedState; + public Tensor UpdatedCovariance = updatedCovariance; + public Tensor Innovation = innovation; + public Tensor InnovationCovariance = innovationCovariance; + public Tensor KalmanGain = kalmanGain; + } + + private UpdatedResult FilterUpdate( + Tensor predictedState, + Tensor predictedCovariance, + Tensor observation + ) + { + // Innovation step + var innovation = observation - _measurementFunction.matmul(predictedState); + var innovationCovariance = EnsureSymmetric( + _measurementFunction.matmul(predictedCovariance) + .matmul(_measurementFunction.mT) + + _measurementNoiseCovariance); + + // Kalman gain + var kalmanGain = InverseCholesky( + predictedCovariance.matmul(_measurementFunction.mT), + innovationCovariance); + + // Update step + var updatedState = predictedState + kalmanGain.matmul(innovation); + var updatedCovariance = EnsureSymmetric(predictedCovariance + - kalmanGain.matmul(_measurementFunction) + .matmul(predictedCovariance)); + + return new UpdatedResult( + updatedState, + updatedCovariance, + innovation, + innovationCovariance, + kalmanGain); + } + + public FilteredResult Filter(Tensor observation) + { + var obs = observation.atleast_2d(); + + var timeBins = obs.size(0); + var logLikelihood = empty(timeBins, dtype: _scalarType, device: _device); + var predictedState = empty(new long[] { timeBins, _numStates }, dtype: _scalarType, device: _device); + var predictedCovariance = empty(new long[] { timeBins, _numStates, _numStates }, dtype: _scalarType, device: _device); + var updatedState = empty(new long[] { timeBins, _numStates }, dtype: _scalarType, device: _device); + var updatedCovariance = empty(new long[] { timeBins, _numStates, _numStates }, dtype: _scalarType, device: _device); + var kalmanGain = empty(new long[] { timeBins, _numStates, _numObservations }, dtype: _scalarType, device: _device); + + for (long time = 0; time < timeBins; time++) + { + using (var d = NewDisposeScope()) + { + // Predict + var prediction = FilterPredict(_state, _covariance); + + // Update + var update = FilterUpdate( + prediction.PredictedState, + prediction.PredictedCovariance, + obs[time] + ); + + // Log Likelihood + var invInnovationCov = InverseCholesky(_identityObservations, update.InnovationCovariance); + var logLikelihoodData = -1.0 * (slogdet(update.InnovationCovariance).Item2 + + update.Innovation.T.matmul(invInnovationCov) + .matmul(update.Innovation)); + + logLikelihoodData.DetachFromDisposeScope(); + prediction.PredictedState.DetachFromDisposeScope(); + prediction.PredictedCovariance.DetachFromDisposeScope(); + update.UpdatedState.DetachFromDisposeScope(); + update.UpdatedCovariance.DetachFromDisposeScope(); + update.KalmanGain.DetachFromDisposeScope(); + + logLikelihood[time] = logLikelihoodData; + predictedState[time] = prediction.PredictedState; + predictedCovariance[time] = prediction.PredictedCovariance; + updatedState[time] = update.UpdatedState; + updatedCovariance[time] = update.UpdatedCovariance; + kalmanGain[time] = update.KalmanGain; + + _state.set_(update.UpdatedState); + _covariance.set_(update.UpdatedCovariance); + } + } + + var filteredResult = new FilteredResult( + predictedState, + predictedCovariance, + updatedState, + updatedCovariance, + logLikelihood, + kalmanGain + ); + + return filteredResult; + } + + public SmoothedResult Smooth(FilteredResult filteredResult) + { + var predictedState = filteredResult.PredictedState; + var predictedCovariance = filteredResult.PredictedCovariance; + var updatedState = filteredResult.UpdatedState; + var updatedCovariance = filteredResult.UpdatedCovariance; + var kalmanGain = filteredResult.KalmanGain; + + var timeBins = predictedState.size(0); + var smoothedState = empty_like(updatedState); + var smoothedCovariance = empty_like(updatedCovariance); + var smoothedLagOneCovariance = empty(new long[] { timeBins, _numStates, _numStates }, dtype: _scalarType, device: _device); + + // Fix the last time point + smoothedState[-1] = updatedState[-1]; + smoothedCovariance[-1] = updatedCovariance[-1]; + smoothedLagOneCovariance[-1] = (_identityStates - kalmanGain[-1] + .matmul(_measurementFunction)) + .matmul(_transitionMatrix) + .matmul(updatedCovariance[-2]); + + var smoothingGain = empty(new long[] { _numStates, _numStates }, dtype: _scalarType, device: _device); + + // Backward pass + for (long time = timeBins - 2; time >= 0; time--) + { + using (var d = NewDisposeScope()) + { + // Smoothing gain + smoothingGain = updatedCovariance[time].matmul( + InverseCholesky(_transitionMatrix.mT, predictedCovariance[time + 1]) + ).DetachFromDisposeScope(); + + // Smoothed state + smoothedState[time] = updatedState[time] + + smoothingGain.matmul( + (smoothedState[time + 1] - predictedState[time + 1]).unsqueeze(-1) + ).squeeze(-1) + .DetachFromDisposeScope(); + + // Smoothed covariance + smoothedCovariance[time] = EnsureSymmetric( + updatedCovariance[time] + smoothingGain + .matmul(smoothedCovariance[time + 1] - predictedCovariance[time + 1]) + .matmul(smoothingGain.mT) + ).DetachFromDisposeScope(); + + // Compute next smoothing gain for lag one covariance + if (time > 0) + { + var smoothingGainNext = updatedCovariance[time - 1] + .matmul(InverseCholesky(_transitionMatrix.mT, predictedCovariance[time])); + + // Smoothed lag one covariance + + smoothedLagOneCovariance[time] = smoothedCovariance[time] + .matmul(smoothingGainNext.mT) + + smoothingGain.matmul(smoothedLagOneCovariance[time + 1] + - _transitionMatrix.matmul(updatedCovariance[time])) + .matmul(smoothingGainNext.mT) + .DetachFromDisposeScope(); + } + } + } + + // Smoothed initial state + var smoothedInitialState = _initialState + smoothingGain.matmul( + (smoothedState[0] - predictedState[0]).unsqueeze(-1) + ).squeeze(-1); + + // Smoothed initial covariance + var smoothedInitialCovariance = EnsureSymmetric( + _initialCovariance[0] + smoothingGain + .matmul(smoothedCovariance[0] - predictedCovariance[0]) + .matmul(smoothingGain.mT) + ); + + // Smoothing gain at time 0 + var smoothingGain0 = _initialCovariance.matmul( + InverseCholesky(_transitionMatrix.mT, predictedCovariance[0]) + ); + + // Smoothed lag one covariance at time 0 + smoothedLagOneCovariance[0] = smoothedCovariance[0] + .matmul(smoothingGain0.mT) + + smoothingGain.matmul(smoothedLagOneCovariance[1] + - _transitionMatrix.matmul(updatedCovariance[0])) + .matmul(smoothingGain0.mT) + .DetachFromDisposeScope(); + + return new SmoothedResult( + smoothedState, + smoothedCovariance, + smoothedLagOneCovariance, + smoothedInitialState, + smoothedInitialCovariance + ); + } + + public ExpectationMaximizationResult ExpectationMaximization( + Tensor observation, + int maxIterations = 100, + double tolerance = 1e-4, + bool updateParameters = true + ) + { + var timeBins = observation.size(0); + var logLikelihood = empty(maxIterations, dtype: ScalarType.Float32, device: _device); + var previousLogLikelihood = double.NegativeInfinity; + var logLikelihoodConst = -0.5 * timeBins * _numObservations * Math.Log(2 * Math.PI); + var updatedParameters = Parameters; + + for (int iteration = 0; iteration < maxIterations; iteration++) + { + using (var d = NewDisposeScope()) + { + // Filter observations + var filterResult = Filter(observation); + + // Compute log likelihood + var filteredLogLikelihood = logLikelihoodConst + 0.5 * filterResult.LogLikelihood.sum(); + var filteredLogLikelihoodSum = filteredLogLikelihood + .cpu() + .to_type(ScalarType.Float32) + .ReadCpuSingle(0); + + logLikelihood[iteration] = filteredLogLikelihoodSum; + + // Check for convergence + if (filteredLogLikelihoodSum <= previousLogLikelihood) + { + // throw new ArgumentException("Log likelihood decreased, something is wrong! New likelihood: " + filteredLogLikelihoodSum + ", Previous likelihood: " + previousLogLikelihood); + Console.WriteLine("Warning: Log likelihood decreased! New likelihood: " + filteredLogLikelihoodSum + ", Previous likelihood: " + previousLogLikelihood); + break; + } + + if (filteredLogLikelihoodSum - previousLogLikelihood < tolerance) + { + break; + } + + previousLogLikelihood = filteredLogLikelihoodSum; + + // Smooth the filtered results + var smoothedResult = Smooth(filterResult); + + var smoothedState = smoothedResult.SmoothedState; + var smoothedCovariance = smoothedResult.SmoothedCovariance; + var smoothedLagOneCovariance = smoothedResult.SmoothedLagOneCovariance; + + var smoothedInitialState = smoothedResult.SmoothedInitialState; + var smoothedInitialCovariance = smoothedResult.SmoothedInitialCovariance; + + // Sufficient statistics + var Ezzt = smoothedCovariance + einsum("tn,tm->tnm", smoothedState, smoothedState); + var Ezztm1 = smoothedLagOneCovariance[torch.TensorIndex.Slice(1)] + + einsum("tn,tm->tnm", + smoothedState[torch.TensorIndex.Slice(1)], + smoothedState[torch.TensorIndex.Slice(0, -1)]); + + var S00 = Ezzt[torch.TensorIndex.Slice(0, -1)].sum(new long[] { 0 }); + var S10 = Ezztm1.sum(new long[] { 0 }); + var S11 = Ezzt[torch.TensorIndex.Slice(1)].sum(new long[] { 0 }); + + var Syz = einsum("tp,tn->pn", observation, smoothedState); + var Eyy = einsum("tp,tq->pq", observation, observation); + + // Update transition matrix + var updatedTransitionMatrix = InverseCholesky(S10, S00); + + // Update measurement function + var updatedMeasurementFunction = InverseCholesky(Syz, S11); + + // Update process noise covariance + var updatedProcessNoiseCovariance = EnsureSymmetric( + (S11 - InverseCholesky(S10, S00).matmul(S10.T)) + / timeBins + ); + + // Update measurement noise covariance + var CSyzT = updatedMeasurementFunction.matmul(Syz.mT); + var updatedMeasurementNoiseCovariance = EnsureSymmetric( + (Eyy - CSyzT - CSyzT.mT + + updatedMeasurementFunction.matmul(S11) + .matmul(updatedMeasurementFunction.mT)) + / timeBins + ); + + // Update initial state + var updatedInitialState = smoothedInitialState; + + // Update initial covariance + var updatedInitialCovariance = smoothedInitialCovariance; + + updatedTransitionMatrix.DetachFromDisposeScope(); + updatedMeasurementFunction.DetachFromDisposeScope(); + updatedProcessNoiseCovariance.DetachFromDisposeScope(); + updatedMeasurementNoiseCovariance.DetachFromDisposeScope(); + updatedInitialState.DetachFromDisposeScope(); + updatedInitialCovariance.DetachFromDisposeScope(); + + updatedParameters = new KalmanFilterParameters( + updatedTransitionMatrix, + updatedMeasurementFunction, + updatedProcessNoiseCovariance, + updatedMeasurementNoiseCovariance, + updatedInitialState, + updatedInitialCovariance + ); + + if (updateParameters) + { + UpdateParameters(updatedParameters); + } + } + + } + + return new ExpectationMaximizationResult( + logLikelihood.DetachFromDisposeScope(), + updatedParameters + ); + } + + public OrthogonalizedResult OrthogonalizeStateAndCovariance(Tensor state, Tensor covariance) + { + var (U, S, Vt) = linalg.svd(_measurementFunction); + var SVt = diag(S).matmul(Vt); + + var orthogonalizedState = einsum("tk,kj->tj", state, SVt.mT); + + var auxilary = einsum("ik,tkj->tij", SVt, covariance); + var orthogonalizedCovariance = einsum("tij,jk->tik", auxilary, SVt.mT); + + return new OrthogonalizedResult(orthogonalizedState, orthogonalizedCovariance); + } + + public void UpdateParameters(KalmanFilterParameters updatedParameters) + { + _transitionMatrix.set_(updatedParameters.TransitionMatrix); + _measurementFunction.set_(updatedParameters.MeasurementFunction); + _processNoiseCovariance.set_(updatedParameters.ProcessNoiseCovariance); + _measurementNoiseCovariance.set_(updatedParameters.MeasurementNoiseCovariance); + _initialState.set_(updatedParameters.InitialState); + _initialCovariance.set_(updatedParameters.InitialCovariance); + } + + private static Tensor EnsureSymmetric(Tensor M) + { + return 0.5f * (M + M.transpose(0, 1)); + } + + private static Tensor InverseCholesky(Tensor B, Tensor A) + { + var L = linalg.cholesky(A); + var solT = cholesky_solve(B.transpose(0, 1), L); + return solT.transpose(0, 1); + } +} From 18bf1626506fd6ef9b704f3f252098bcbbb034e4 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 17 Sep 2025 12:04:04 +0100 Subject: [PATCH 04/92] Updated KalmanFilter class for better matrix, vector, and scalar validation during initialization --- src/Bonsai.ML.Torch.LDS/KalmanFilter.cs | 654 +++++++++--------------- 1 file changed, 235 insertions(+), 419 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs index e783f196..21173eef 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs @@ -21,127 +21,32 @@ internal class KalmanFilter : nn.Module private readonly Device _device; private readonly ScalarType _scalarType; - public KalmanFilterParameters Parameters - { - get - { - return new KalmanFilterParameters( - _transitionMatrix, - _measurementFunction, - _processNoiseCovariance, - _measurementNoiseCovariance, - _initialState, - _initialCovariance - ); - } - } + public KalmanFilterParameters Parameters => new( + _transitionMatrix, + _measurementFunction, + _processNoiseCovariance, + _measurementNoiseCovariance, + _initialState, + _initialCovariance + ); public KalmanFilter( KalmanFilterParameters parameters, Device device = null, ScalarType scalarType = ScalarType.Float32) : base("KalmanFilter") { - device ??= CPU; - - _device = device; + _device = device ?? CPU; _scalarType = scalarType; - var transitionMatrix = parameters.TransitionMatrix; - var measurementFunction = parameters.MeasurementFunction; - var initialState = parameters.InitialState; - var initialCovariance = parameters.InitialCovariance; - var processNoiseCovariance = parameters.ProcessNoiseCovariance; - var measurementNoiseCovariance = parameters.MeasurementNoiseCovariance; - - if (transitionMatrix is null) - { - throw new ArgumentException("Transition matrix cannot be null."); - } - else - { - if (transitionMatrix.Dimensions != 2 || - transitionMatrix.size(0) != transitionMatrix.size(1)) - { - throw new ArgumentException("Transition matrix must be square."); - } - _transitionMatrix = transitionMatrix.clone().to_type(_scalarType).requires_grad_(false); - _numStates = (int)transitionMatrix.size(0); - _identityStates = eye(_numStates, dtype: _scalarType, device: _device); - } - - if (measurementFunction is null) - { - throw new ArgumentException("Measurement function cannot be null."); - } - else - { - if (measurementFunction.Dimensions != 2 || - measurementFunction.size(1) != _numStates) - { - throw new ArgumentException("Observation matrix must have dimensions [numObservations, numStates]."); - } - _measurementFunction = measurementFunction.clone().to_type(_scalarType).requires_grad_(false); - _numObservations = (int)measurementFunction.size(0); - _identityObservations = eye(_numObservations, dtype: _scalarType, device: _device); - } + ValidateAndSetMatrix(parameters.TransitionMatrix, "Transition matrix", _scalarType, _device, out _transitionMatrix, out _numStates, out _, isSquare: true); + ValidateAndSetMatrix(parameters.MeasurementFunction, "Measurement function", _scalarType, _device, out _measurementFunction, out _numObservations, out _); + ValidateAndSetVector(parameters.InitialState, "Initial state", _scalarType, _device, out _initialState, out _, expectedLength: _numStates); + ValidateAndSetMatrix(parameters.InitialCovariance, "Initial covariance", _scalarType, _device, out _initialCovariance, out _, out _, isSquare: true, expectedDimension1: _numStates); + ValidateAndSetMatrix(parameters.ProcessNoiseCovariance, "Process noise covariance", _scalarType, _device, out _processNoiseCovariance, out _, out _, isSquare: true, expectedDimension1: _numStates); + ValidateAndSetMatrix(parameters.MeasurementNoiseCovariance, "Measurement noise covariance", _scalarType, _device, out _measurementNoiseCovariance, out _, out _, isSquare: true, expectedDimension1: _numObservations); - if (initialState is null) - { - throw new ArgumentException("Initial state cannot be null."); - } - else - { - if (initialState.NumberOfElements != _numStates) - { - throw new ArgumentException("Initial state must be a vector with length equal to the number of states."); - } - _initialState = initialState.clone().to_type(_scalarType).requires_grad_(false); - } - - if (initialCovariance is null) - { - throw new ArgumentException("Initial covariance cannot be null."); - } - else - { - if (initialCovariance.Dimensions != 2 || - initialCovariance.size(0) != _numStates || - initialCovariance.size(1) != _numStates) - { - throw new ArgumentException("Initial covariance must be square with dimensions equal to the number of states."); - } - _initialCovariance = initialCovariance.clone().to_type(_scalarType).requires_grad_(false); - } - - if (processNoiseCovariance is null) - { - throw new ArgumentException("Process noise covariance cannot be null."); - } - else - { - if (processNoiseCovariance.Dimensions != 2 || - processNoiseCovariance.size(0) != _numStates || - processNoiseCovariance.size(0) != _numStates) - { - throw new ArgumentException("Process noise covariance must be square with dimensions equal to the number of states."); - } - _processNoiseCovariance = processNoiseCovariance.clone().to_type(_scalarType).requires_grad_(false); - } - - if (measurementNoiseCovariance is null) - { - throw new ArgumentException("Measurement noise covariance cannot be null."); - } - else - { - if (measurementNoiseCovariance.Dimensions != 2 || - measurementNoiseCovariance.size(0) != _numObservations || - measurementNoiseCovariance.size(1) != _numObservations) - { - throw new ArgumentException("Measurement noise variance must be a scalar."); - } - _measurementNoiseCovariance = measurementNoiseCovariance.clone().to_type(_scalarType).requires_grad_(false); - } + _identityStates = eye(_numStates, dtype: _scalarType, device: _device); + _identityObservations = eye(_numObservations, dtype: _scalarType, device: _device); _state = _initialState.clone(); _covariance = _initialCovariance.clone(); @@ -159,97 +64,32 @@ public KalmanFilter( Device device = null, ScalarType scalarType = ScalarType.Float32) : base("KalmanFilter") { - device ??= CPU; - - _device = device; + _device = device ?? CPU; _scalarType = scalarType; _numStates = numStates; - _identityStates = eye(numStates, dtype: _scalarType, device: _device); - _identityObservations = eye(numObservations, dtype: _scalarType, device: _device); + _numObservations = numObservations; - if (transitionMatrix is null) - { - _transitionMatrix = eye(numStates, dtype: _scalarType, device: _device); - } - else - { - if (transitionMatrix.Dimensions != 2 || - transitionMatrix.shape[0] != numStates || - transitionMatrix.shape[1] != numStates) - { - throw new ArgumentException("Transition matrix must be square with dimensions equal to the number of states."); - } - _transitionMatrix = transitionMatrix.clone().to_type(_scalarType).requires_grad_(false); - } + _identityStates = eye(_numStates, dtype: _scalarType, device: _device); + _identityObservations = eye(_numObservations, dtype: _scalarType, device: _device); - if (measurementFunction is null) - { - _measurementFunction = eye(numObservations, numStates, dtype: _scalarType, device: _device); - } - else - { - if (measurementFunction.Dimensions != 2 || - measurementFunction.shape[0] != numObservations || - measurementFunction.shape[1] != numStates) - { - throw new ArgumentException("Observation matrix must have dimensions [numObservations, numStates]."); - } - _measurementFunction = measurementFunction.clone().to_type(_scalarType).requires_grad_(false); - } + _transitionMatrix = transitionMatrix?.clone().to_type(_scalarType).to(_device).requires_grad_(false) + ?? eye(_numStates, dtype: _scalarType, device: _device); + ValidateMatrix(_transitionMatrix, "Transition matrix", isSquare: true, expectedDimension1: _numStates); - if (initialState is null) - { - _initialState = zeros(numStates, dtype: _scalarType, device: _device); - } - else - { - if (initialState.NumberOfElements != numStates) - { - throw new ArgumentException("Initial state must be a vector with length equal to the number of states."); - } - _initialState = initialState.clone().to_type(_scalarType).requires_grad_(false); - } + _measurementFunction = measurementFunction?.clone().to_type(_scalarType).to(_device).requires_grad_(false) + ?? eye(_numObservations, _numStates, dtype: _scalarType, device: _device); + ValidateMatrix(_measurementFunction, "Measurement function", expectedDimension1: _numObservations, expectedDimension2: _numStates); - if (initialCovariance is null) - { - _initialCovariance = eye(numStates, dtype: _scalarType, device: _device); - } - else - { - if (initialCovariance.Dimensions != 2 || - initialCovariance.shape[0] != numStates || - initialCovariance.shape[1] != numStates) - { - throw new ArgumentException("Initial covariance must be square with dimensions equal to the number of states."); - } - _initialCovariance = initialCovariance.clone().to_type(_scalarType).requires_grad_(false); - } + _initialState = initialState?.clone().to_type(_scalarType).to(_device).requires_grad_(false) + ?? zeros(_numStates, dtype: _scalarType, device: _device); + ValidateVector(_initialState, "Initial state", _numStates); - if (processNoiseVariance is null) - { - _processNoiseCovariance = eye(numStates, dtype: _scalarType, device: _device); - } - else - { - if (processNoiseVariance.NumberOfElements != 1) - { - throw new ArgumentException("Process noise variance must be a scalar."); - } - _processNoiseCovariance = (processNoiseVariance * eye(numStates, dtype: _scalarType, device: _device)).requires_grad_(false); - } + _initialCovariance = initialCovariance?.clone().to_type(_scalarType).to(_device).requires_grad_(false) + ?? eye(_numStates, dtype: _scalarType, device: _device); + ValidateMatrix(_initialCovariance, "Initial covariance", isSquare: true, expectedDimension1: _numStates); - if (measurementNoiseVariance is null) - { - _measurementNoiseCovariance = eye(numObservations, dtype: _scalarType, device: _device); - } - else - { - if (measurementNoiseVariance.NumberOfElements != 1) - { - throw new ArgumentException("Measurement noise variance must be a scalar."); - } - _measurementNoiseCovariance = (measurementNoiseVariance * eye(numObservations, dtype: _scalarType, device: _device)).requires_grad_(false); - } + _processNoiseCovariance = CreateCovarianceMatrix(processNoiseVariance, _scalarType, _device, numStates, "Process noise variance"); + _measurementNoiseCovariance = CreateCovarianceMatrix(measurementNoiseVariance, _scalarType, _device, numObservations, "Measurement noise variance"); _state = _initialState.clone(); _covariance = _initialCovariance.clone(); @@ -257,52 +97,104 @@ public KalmanFilter( RegisterComponents(); } - private struct PredictedResult(Tensor predictedState, Tensor predictedCovariance) + private static void ValidateAndSetMatrix(Tensor matrix, string name, ScalarType scalarType, Device device, out Tensor result, out int rows, out int columns, bool isSquare = false, int? expectedDimension1 = null, int? expectedDimension2 = null) + { + ValidateMatrix(matrix, name, isSquare, expectedDimension1, expectedDimension2); + result = matrix.clone().to_type(scalarType).to(device).requires_grad_(false); + rows = (int)matrix.size(0); + columns = (int)matrix.size(1); + } + + private static void ValidateAndSetVector(Tensor vector, string name, ScalarType scalarType, Device device, out Tensor result, out int length, int? expectedLength = null) + { + ValidateVector(vector, name, expectedLength); + result = vector.clone().to_type(scalarType).to(device).requires_grad_(false); + length = (int)vector.size(0); + } + + private static void ValidateAndSetScalar(Tensor scalar, string name, ScalarType scalarType, Device device, out Tensor result) + { + ValidateScalar(scalar, name); + result = scalar.clone().squeeze().to_type(scalarType).to(device).requires_grad_(false); + } + + private static void ValidateMatrix(Tensor matrix, string name, bool isSquare = false, int? expectedDimension1 = null, int? expectedDimension2 = null) + { + if (matrix is null) + throw new ArgumentException($"{name} cannot be null."); + + if (matrix.Dimensions != 2) + throw new ArgumentException($"{name} must be 2-dimensional."); + + if (isSquare && matrix.size(0) != matrix.size(1)) + throw new ArgumentException($"{name} must be square."); + + if (expectedDimension1.HasValue && matrix.size(0) != expectedDimension1.Value) + throw new ArgumentException($"{name} must have {expectedDimension1.Value} rows."); + + if (expectedDimension2.HasValue && matrix.size(1) != expectedDimension2.Value) + throw new ArgumentException($"{name} must have {expectedDimension2.Value} columns."); + } + + private static void ValidateVector(Tensor vector, string name, int? expectedLength = null) + { + if (vector is null) + throw new ArgumentException($"{name} cannot be null."); + + if (vector.Dimensions != 1) + throw new ArgumentException($"{name} must be a vector."); + + if (expectedLength.HasValue && vector.NumberOfElements != expectedLength.Value) + throw new ArgumentException($"{name} must be a vector with length equal to {expectedLength.Value}."); + } + + private static void ValidateScalar(Tensor scalar, string name) { - public Tensor PredictedState = predictedState; - public Tensor PredictedCovariance = predictedCovariance; + if (scalar is null) + throw new ArgumentException($"{name} cannot be null."); + + if (scalar.NumberOfElements != 1) + throw new ArgumentException($"{name} must be a scalar."); } - private PredictedResult FilterPredict( - Tensor state, - Tensor covariance - ) + private Tensor CreateCovarianceMatrix(Tensor variance, ScalarType scalarType, Device device, int dimension, string name) { - var predictedState = _transitionMatrix.matmul(state); - var predictedCovariance = EnsureSymmetric(_transitionMatrix.matmul(covariance) - .matmul(_transitionMatrix.mT) - + _processNoiseCovariance); + ValidateAndSetScalar(variance, name, scalarType, device, out var scalar); + return (scalar * eye(dimension, dtype: scalarType, device: device)).requires_grad_(false); + } - return new PredictedResult(predictedState, predictedCovariance); + private readonly struct PredictedResult(Tensor predictedState, Tensor predictedCovariance) + { + public readonly Tensor PredictedState = predictedState; + public readonly Tensor PredictedCovariance = predictedCovariance; } - private struct UpdatedResult( + private PredictedResult FilterPredict(Tensor state, Tensor covariance) => + new(_transitionMatrix.matmul(state), + EnsureSymmetric(_transitionMatrix.matmul(covariance) + .matmul(_transitionMatrix.mT) + _processNoiseCovariance)); + + private readonly struct UpdatedResult( Tensor updatedState, Tensor updatedCovariance, Tensor innovation, Tensor innovationCovariance, - Tensor kalmanGain - ) + Tensor kalmanGain) { - public Tensor UpdatedState = updatedState; - public Tensor UpdatedCovariance = updatedCovariance; - public Tensor Innovation = innovation; - public Tensor InnovationCovariance = innovationCovariance; - public Tensor KalmanGain = kalmanGain; + public readonly Tensor UpdatedState = updatedState; + public readonly Tensor UpdatedCovariance = updatedCovariance; + public readonly Tensor Innovation = innovation; + public readonly Tensor InnovationCovariance = innovationCovariance; + public readonly Tensor KalmanGain = kalmanGain; } - private UpdatedResult FilterUpdate( - Tensor predictedState, - Tensor predictedCovariance, - Tensor observation - ) + private UpdatedResult FilterUpdate(Tensor predictedState, Tensor predictedCovariance, Tensor observation) { // Innovation step var innovation = observation - _measurementFunction.matmul(predictedState); var innovationCovariance = EnsureSymmetric( _measurementFunction.matmul(predictedCovariance) - .matmul(_measurementFunction.mT) - + _measurementNoiseCovariance); + .matmul(_measurementFunction.mT) + _measurementNoiseCovariance); // Kalman gain var kalmanGain = InverseCholesky( @@ -312,22 +204,16 @@ Tensor observation // Update step var updatedState = predictedState + kalmanGain.matmul(innovation); var updatedCovariance = EnsureSymmetric(predictedCovariance - - kalmanGain.matmul(_measurementFunction) - .matmul(predictedCovariance)); - - return new UpdatedResult( - updatedState, - updatedCovariance, - innovation, - innovationCovariance, - kalmanGain); + - kalmanGain.matmul(_measurementFunction).matmul(predictedCovariance)); + + return new UpdatedResult(updatedState, updatedCovariance, innovation, innovationCovariance, kalmanGain); } public FilteredResult Filter(Tensor observation) { var obs = observation.atleast_2d(); - var timeBins = obs.size(0); + var logLikelihood = empty(timeBins, dtype: _scalarType, device: _device); var predictedState = empty(new long[] { timeBins, _numStates }, dtype: _scalarType, device: _device); var predictedCovariance = empty(new long[] { timeBins, _numStates, _numStates }, dtype: _scalarType, device: _device); @@ -337,53 +223,32 @@ public FilteredResult Filter(Tensor observation) for (long time = 0; time < timeBins; time++) { - using (var d = NewDisposeScope()) - { - // Predict - var prediction = FilterPredict(_state, _covariance); - - // Update - var update = FilterUpdate( - prediction.PredictedState, - prediction.PredictedCovariance, - obs[time] - ); - - // Log Likelihood - var invInnovationCov = InverseCholesky(_identityObservations, update.InnovationCovariance); - var logLikelihoodData = -1.0 * (slogdet(update.InnovationCovariance).Item2 - + update.Innovation.T.matmul(invInnovationCov) - .matmul(update.Innovation)); - - logLikelihoodData.DetachFromDisposeScope(); - prediction.PredictedState.DetachFromDisposeScope(); - prediction.PredictedCovariance.DetachFromDisposeScope(); - update.UpdatedState.DetachFromDisposeScope(); - update.UpdatedCovariance.DetachFromDisposeScope(); - update.KalmanGain.DetachFromDisposeScope(); - - logLikelihood[time] = logLikelihoodData; - predictedState[time] = prediction.PredictedState; - predictedCovariance[time] = prediction.PredictedCovariance; - updatedState[time] = update.UpdatedState; - updatedCovariance[time] = update.UpdatedCovariance; - kalmanGain[time] = update.KalmanGain; - - _state.set_(update.UpdatedState); - _covariance.set_(update.UpdatedCovariance); - } + using var d = NewDisposeScope(); + + // Predict + var prediction = FilterPredict(_state, _covariance); + + // Update + var update = FilterUpdate(prediction.PredictedState, prediction.PredictedCovariance, obs[time]); + + // Log Likelihood + var invInnovationCov = InverseCholesky(_identityObservations, update.InnovationCovariance); + var logLikelihoodData = -1.0 * (slogdet(update.InnovationCovariance).logabsdet + + update.Innovation.T.matmul(invInnovationCov).matmul(update.Innovation)); + + // Detach and assign + logLikelihood[time] = logLikelihoodData.DetachFromDisposeScope(); + predictedState[time] = prediction.PredictedState.DetachFromDisposeScope(); + predictedCovariance[time] = prediction.PredictedCovariance.DetachFromDisposeScope(); + updatedState[time] = update.UpdatedState.DetachFromDisposeScope(); + updatedCovariance[time] = update.UpdatedCovariance.DetachFromDisposeScope(); + kalmanGain[time] = update.KalmanGain.DetachFromDisposeScope(); + + _state.set_(update.UpdatedState); + _covariance.set_(update.UpdatedCovariance); } - var filteredResult = new FilteredResult( - predictedState, - predictedCovariance, - updatedState, - updatedCovariance, - logLikelihood, - kalmanGain - ); - - return filteredResult; + return new FilteredResult(predictedState, predictedCovariance, updatedState, updatedCovariance, logLikelihood, kalmanGain); } public SmoothedResult Smooth(FilteredResult filteredResult) @@ -412,42 +277,40 @@ public SmoothedResult Smooth(FilteredResult filteredResult) // Backward pass for (long time = timeBins - 2; time >= 0; time--) { - using (var d = NewDisposeScope()) + using var d = NewDisposeScope(); + // Smoothing gain + smoothingGain = updatedCovariance[time].matmul( + InverseCholesky(_transitionMatrix.mT, predictedCovariance[time + 1]) + ).DetachFromDisposeScope(); + + // Smoothed state + smoothedState[time] = updatedState[time] + + smoothingGain.matmul( + (smoothedState[time + 1] - predictedState[time + 1]).unsqueeze(-1) + ).squeeze(-1) + .DetachFromDisposeScope(); + + // Smoothed covariance + smoothedCovariance[time] = EnsureSymmetric( + updatedCovariance[time] + smoothingGain + .matmul(smoothedCovariance[time + 1] - predictedCovariance[time + 1]) + .matmul(smoothingGain.mT) + ).DetachFromDisposeScope(); + + // Compute next smoothing gain for lag one covariance + if (time > 0) { - // Smoothing gain - smoothingGain = updatedCovariance[time].matmul( - InverseCholesky(_transitionMatrix.mT, predictedCovariance[time + 1]) - ).DetachFromDisposeScope(); - - // Smoothed state - smoothedState[time] = updatedState[time] - + smoothingGain.matmul( - (smoothedState[time + 1] - predictedState[time + 1]).unsqueeze(-1) - ).squeeze(-1) - .DetachFromDisposeScope(); + var smoothingGainNext = updatedCovariance[time - 1] + .matmul(InverseCholesky(_transitionMatrix.mT, predictedCovariance[time])); - // Smoothed covariance - smoothedCovariance[time] = EnsureSymmetric( - updatedCovariance[time] + smoothingGain - .matmul(smoothedCovariance[time + 1] - predictedCovariance[time + 1]) - .matmul(smoothingGain.mT) - ).DetachFromDisposeScope(); + // Smoothed lag one covariance - // Compute next smoothing gain for lag one covariance - if (time > 0) - { - var smoothingGainNext = updatedCovariance[time - 1] - .matmul(InverseCholesky(_transitionMatrix.mT, predictedCovariance[time])); - - // Smoothed lag one covariance - - smoothedLagOneCovariance[time] = smoothedCovariance[time] + smoothedLagOneCovariance[time] = smoothedCovariance[time] + .matmul(smoothingGainNext.mT) + + smoothingGain.matmul(smoothedLagOneCovariance[time + 1] + - _transitionMatrix.matmul(updatedCovariance[time])) .matmul(smoothingGainNext.mT) - + smoothingGain.matmul(smoothedLagOneCovariance[time + 1] - - _transitionMatrix.matmul(updatedCovariance[time])) - .matmul(smoothingGainNext.mT) - .DetachFromDisposeScope(); - } + .DetachFromDisposeScope(); } } @@ -489,8 +352,7 @@ public ExpectationMaximizationResult ExpectationMaximization( Tensor observation, int maxIterations = 100, double tolerance = 1e-4, - bool updateParameters = true - ) + bool updateParameters = true) { var timeBins = observation.size(0); var logLikelihood = empty(maxIterations, dtype: ScalarType.Float32, device: _device); @@ -500,114 +362,70 @@ public ExpectationMaximizationResult ExpectationMaximization( for (int iteration = 0; iteration < maxIterations; iteration++) { - using (var d = NewDisposeScope()) + using var d = NewDisposeScope(); + + // Filter observations + var filterResult = Filter(observation); + + // Compute log likelihood + var filteredLogLikelihood = logLikelihoodConst + 0.5 * filterResult.LogLikelihood.sum(); + var filteredLogLikelihoodSum = filteredLogLikelihood.to_type(ScalarType.Float64).item(); + + logLikelihood[iteration] = filteredLogLikelihoodSum; + + // Check for convergence + if (filteredLogLikelihoodSum <= previousLogLikelihood) { - // Filter observations - var filterResult = Filter(observation); - - // Compute log likelihood - var filteredLogLikelihood = logLikelihoodConst + 0.5 * filterResult.LogLikelihood.sum(); - var filteredLogLikelihoodSum = filteredLogLikelihood - .cpu() - .to_type(ScalarType.Float32) - .ReadCpuSingle(0); - - logLikelihood[iteration] = filteredLogLikelihoodSum; - - // Check for convergence - if (filteredLogLikelihoodSum <= previousLogLikelihood) - { - // throw new ArgumentException("Log likelihood decreased, something is wrong! New likelihood: " + filteredLogLikelihoodSum + ", Previous likelihood: " + previousLogLikelihood); - Console.WriteLine("Warning: Log likelihood decreased! New likelihood: " + filteredLogLikelihoodSum + ", Previous likelihood: " + previousLogLikelihood); - break; - } - - if (filteredLogLikelihoodSum - previousLogLikelihood < tolerance) - { - break; - } - - previousLogLikelihood = filteredLogLikelihoodSum; - - // Smooth the filtered results - var smoothedResult = Smooth(filterResult); - - var smoothedState = smoothedResult.SmoothedState; - var smoothedCovariance = smoothedResult.SmoothedCovariance; - var smoothedLagOneCovariance = smoothedResult.SmoothedLagOneCovariance; - - var smoothedInitialState = smoothedResult.SmoothedInitialState; - var smoothedInitialCovariance = smoothedResult.SmoothedInitialCovariance; - - // Sufficient statistics - var Ezzt = smoothedCovariance + einsum("tn,tm->tnm", smoothedState, smoothedState); - var Ezztm1 = smoothedLagOneCovariance[torch.TensorIndex.Slice(1)] - + einsum("tn,tm->tnm", - smoothedState[torch.TensorIndex.Slice(1)], - smoothedState[torch.TensorIndex.Slice(0, -1)]); - - var S00 = Ezzt[torch.TensorIndex.Slice(0, -1)].sum(new long[] { 0 }); - var S10 = Ezztm1.sum(new long[] { 0 }); - var S11 = Ezzt[torch.TensorIndex.Slice(1)].sum(new long[] { 0 }); - - var Syz = einsum("tp,tn->pn", observation, smoothedState); - var Eyy = einsum("tp,tq->pq", observation, observation); - - // Update transition matrix - var updatedTransitionMatrix = InverseCholesky(S10, S00); - - // Update measurement function - var updatedMeasurementFunction = InverseCholesky(Syz, S11); - - // Update process noise covariance - var updatedProcessNoiseCovariance = EnsureSymmetric( - (S11 - InverseCholesky(S10, S00).matmul(S10.T)) - / timeBins - ); - - // Update measurement noise covariance - var CSyzT = updatedMeasurementFunction.matmul(Syz.mT); - var updatedMeasurementNoiseCovariance = EnsureSymmetric( - (Eyy - CSyzT - CSyzT.mT - + updatedMeasurementFunction.matmul(S11) - .matmul(updatedMeasurementFunction.mT)) - / timeBins - ); - - // Update initial state - var updatedInitialState = smoothedInitialState; - - // Update initial covariance - var updatedInitialCovariance = smoothedInitialCovariance; - - updatedTransitionMatrix.DetachFromDisposeScope(); - updatedMeasurementFunction.DetachFromDisposeScope(); - updatedProcessNoiseCovariance.DetachFromDisposeScope(); - updatedMeasurementNoiseCovariance.DetachFromDisposeScope(); - updatedInitialState.DetachFromDisposeScope(); - updatedInitialCovariance.DetachFromDisposeScope(); - - updatedParameters = new KalmanFilterParameters( - updatedTransitionMatrix, - updatedMeasurementFunction, - updatedProcessNoiseCovariance, - updatedMeasurementNoiseCovariance, - updatedInitialState, - updatedInitialCovariance - ); - - if (updateParameters) - { - UpdateParameters(updatedParameters); - } + Console.WriteLine($"Warning: Log likelihood decreased! New: {filteredLogLikelihoodSum}, Previous: {previousLogLikelihood}"); + break; } + if (filteredLogLikelihoodSum - previousLogLikelihood < tolerance) + break; + + previousLogLikelihood = filteredLogLikelihoodSum; + + // Smooth the filtered results + var smoothedResult = Smooth(filterResult); + + // Sufficient statistics + var Ezzt = smoothedResult.SmoothedCovariance + einsum("tn,tm->tnm", smoothedResult.SmoothedState, smoothedResult.SmoothedState); + var Ezztm1 = smoothedResult.SmoothedLagOneCovariance[torch.TensorIndex.Slice(1)] + + einsum("tn,tm->tnm", + smoothedResult.SmoothedState[torch.TensorIndex.Slice(1)], + smoothedResult.SmoothedState[torch.TensorIndex.Slice(0, -1)]); + + var S00 = Ezzt[torch.TensorIndex.Slice(0, -1)].sum(new long[] { 0 }); + var S10 = Ezztm1.sum(new long[] { 0 }); + var S11 = Ezzt[torch.TensorIndex.Slice(1)].sum(new long[] { 0 }); + + var Syz = einsum("tp,tn->pn", observation, smoothedResult.SmoothedState); + var Eyy = einsum("tp,tq->pq", observation, observation); + + // Update parameters + var updatedTransitionMatrix = InverseCholesky(S10, S00).DetachFromDisposeScope(); + var updatedMeasurementFunction = InverseCholesky(Syz, S11).DetachFromDisposeScope(); + var updatedProcessNoiseCovariance = EnsureSymmetric((S11 - InverseCholesky(S10, S00).matmul(S10.T)) / timeBins).DetachFromDisposeScope(); + + var CSyzT = updatedMeasurementFunction.matmul(Syz.mT); + var updatedMeasurementNoiseCovariance = EnsureSymmetric( + (Eyy - CSyzT - CSyzT.mT + updatedMeasurementFunction.matmul(S11).matmul(updatedMeasurementFunction.mT)) / timeBins + ).DetachFromDisposeScope(); + + updatedParameters = new KalmanFilterParameters( + updatedTransitionMatrix, + updatedMeasurementFunction, + updatedProcessNoiseCovariance, + updatedMeasurementNoiseCovariance, + smoothedResult.SmoothedInitialState.DetachFromDisposeScope(), + smoothedResult.SmoothedInitialCovariance.DetachFromDisposeScope() + ); + + if (updateParameters) + UpdateParameters(updatedParameters); } - return new ExpectationMaximizationResult( - logLikelihood.DetachFromDisposeScope(), - updatedParameters - ); + return new ExpectationMaximizationResult(logLikelihood.DetachFromDisposeScope(), updatedParameters); } public OrthogonalizedResult OrthogonalizeStateAndCovariance(Tensor state, Tensor covariance) @@ -633,15 +451,13 @@ public void UpdateParameters(KalmanFilterParameters updatedParameters) _initialCovariance.set_(updatedParameters.InitialCovariance); } - private static Tensor EnsureSymmetric(Tensor M) - { - return 0.5f * (M + M.transpose(0, 1)); - } + private static Tensor EnsureSymmetric(Tensor M) => 0.5f * (M + M.transpose(0, 1)); private static Tensor InverseCholesky(Tensor B, Tensor A) { + using var d = NewDisposeScope(); var L = linalg.cholesky(A); var solT = cholesky_solve(B.transpose(0, 1), L); - return solT.transpose(0, 1); + return solT.transpose(0, 1).MoveToOuterDisposeScope(); } } From e17b252cb049e2427db7a81bcb794e93d2dd1f58 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 17 Sep 2025 13:13:56 +0100 Subject: [PATCH 05/92] Adding classes from original scripted demo --- src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs | 247 ++++++++++++++++++ .../CreateKalmanFilterParameters.cs | 217 +++++++++++++++ .../ExpectationMaximization.cs | 128 +++++++++ .../ExpectationMaximizationResult.cs | 23 ++ src/Bonsai.ML.Torch.LDS/Filter.cs | 30 +++ src/Bonsai.ML.Torch.LDS/FilteredResult.cs | 51 ++++ .../KalmanFilterModelManager.cs | 123 +++++++++ .../KalmanFilterNameConverter.cs | 42 +++ .../KalmanFilterParameters.cs | 54 ++++ src/Bonsai.ML.Torch.LDS/Orthogonalize.cs | 45 ++++ .../OrthogonalizedResult.cs | 17 ++ src/Bonsai.ML.Torch.LDS/Smooth.cs | 33 +++ src/Bonsai.ML.Torch.LDS/SmoothedResult.cs | 44 ++++ src/Bonsai.ML.Torch.LDS/UpdateParameters.cs | 34 +++ 14 files changed, 1088 insertions(+) create mode 100644 src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs create mode 100644 src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs create mode 100644 src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs create mode 100644 src/Bonsai.ML.Torch.LDS/ExpectationMaximizationResult.cs create mode 100644 src/Bonsai.ML.Torch.LDS/Filter.cs create mode 100644 src/Bonsai.ML.Torch.LDS/FilteredResult.cs create mode 100644 src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs create mode 100644 src/Bonsai.ML.Torch.LDS/KalmanFilterNameConverter.cs create mode 100644 src/Bonsai.ML.Torch.LDS/KalmanFilterParameters.cs create mode 100644 src/Bonsai.ML.Torch.LDS/Orthogonalize.cs create mode 100644 src/Bonsai.ML.Torch.LDS/OrthogonalizedResult.cs create mode 100644 src/Bonsai.ML.Torch.LDS/Smooth.cs create mode 100644 src/Bonsai.ML.Torch.LDS/SmoothedResult.cs create mode 100644 src/Bonsai.ML.Torch.LDS/UpdateParameters.cs diff --git a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs new file mode 100644 index 00000000..82222b5c --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs @@ -0,0 +1,247 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LDS; + +/// +/// Creates a Kalman filter model. +/// +[Combinator] +[Description("Creates a Kalman filter model.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class CreateKalmanFilter : IScalarTypeProvider +{ + /// + /// A unique name for the Kalman filter model. + /// + public string ModelName { get; set; } = "KalmanFilter"; + + private int _numStates = 2; + /// + /// The number of states in the Kalman filter model. + /// + public int NumStates + { + get => _numStates; + set => _numStates = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(value), "Number of states must be greater than zero."); + } + + private int _numObservations = 2; + /// + /// The number of observations in the Kalman filter model. + /// + public int NumObservations + { + get => _numObservations; + set => _numObservations = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(value), "Number of observations must be greater than zero."); + } + + private ScalarType _scalarType = ScalarType.Float32; + /// + /// The data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type + { + get => _scalarType; + set + { + _scalarType = value; + ConvertTensorsScalarType(value); + } + } + + /// + /// The device on which to create the tensor. + /// + [Description("The device on which to create the tensor.")] + [XmlIgnore] + public Device Device { get; set; } + + private void ConvertTensorsScalarType(ScalarType scalarType) + { + _transitionMatrix = _transitionMatrix?.to_type(scalarType); + _measurementFunction = _measurementFunction?.to_type(scalarType); + _processNoiseVariance = _processNoiseVariance?.to_type(scalarType); + _measurementNoiseVariance = _measurementNoiseVariance?.to_type(scalarType); + _initialState = _initialState?.to_type(scalarType); + _initialCovariance = _initialCovariance?.to_type(scalarType); + } + + // Tensor properties with XML serialization support + private Tensor _transitionMatrix; + /// + /// The state transition matrix. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor TransitionMatrix + { + get => _transitionMatrix; + set => _transitionMatrix = value?.to_type(Type); + } + + /// + /// The XML string representation of the transition matrix for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(TransitionMatrix))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string TransitionMatrixXml + { + get => TensorConverter.ConvertToString(TransitionMatrix, _scalarType); + set => TransitionMatrix = TensorConverter.ConvertFromString(value, _scalarType); + } + + private Tensor _measurementFunction; + /// + /// The measurement function. + /// + [XmlIgnore] + public Tensor MeasurementFunction + { + get => _measurementFunction; + set => _measurementFunction = value?.to_type(Type); + } + + /// + /// The XML string representation of the measurement function for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(MeasurementFunction))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string MeasurementFunctionXml + { + get => TensorConverter.ConvertToString(MeasurementFunction, _scalarType); + set => MeasurementFunction = TensorConverter.ConvertFromString(value, _scalarType); + } + + private Tensor _processNoiseVariance; + /// + /// The process noise variance. + /// + [XmlIgnore] + public Tensor ProcessNoiseVariance + { + get => _processNoiseVariance; + set => _processNoiseVariance = value?.to_type(Type); + } + + /// + /// The XML string representation of the process noise variance for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(ProcessNoiseVariance))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ProcessNoiseVarianceXml + { + get => TensorConverter.ConvertToString(ProcessNoiseVariance, _scalarType); + set => ProcessNoiseVariance = TensorConverter.ConvertFromString(value, _scalarType); + } + + private Tensor _measurementNoiseVariance; + /// + /// The measurement noise variance. + /// + [XmlIgnore] + public Tensor MeasurementNoiseVariance + { + get => _measurementNoiseVariance; + set => _measurementNoiseVariance = value?.to_type(Type); + } + + /// + /// The XML string representation of the measurement noise variance for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(MeasurementNoiseVariance))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string MeasurementNoiseVarianceXml + { + get => TensorConverter.ConvertToString(MeasurementNoiseVariance, _scalarType); + set => MeasurementNoiseVariance = TensorConverter.ConvertFromString(value, _scalarType); + } + + private Tensor _initialState; + /// + /// The initial state. + /// + [XmlIgnore] + public Tensor InitialState + { + get => _initialState; + set => _initialState = value?.to_type(Type); + } + + /// + /// The XML string representation of the initial state for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(InitialState))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string InitialStateXml + { + get => TensorConverter.ConvertToString(InitialState, _scalarType); + set => InitialState = TensorConverter.ConvertFromString(value, _scalarType); + } + + private Tensor _initialCovariance; + /// + /// The initial covariance. + /// + [XmlIgnore] + public Tensor InitialCovariance + { + get => _initialCovariance; + set => _initialCovariance = value?.to_type(Type); + } + + /// + /// The XML string representation of the initial covariance for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(InitialCovariance))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string InitialCovarianceXml + { + get => TensorConverter.ConvertToString(InitialCovariance, _scalarType); + set => InitialCovariance = TensorConverter.ConvertFromString(value, _scalarType); + } + + /// + /// Creates a Kalman filter model using the properties of this class. + /// + public IObservable Process() + { + return Observable.Return( + new KalmanFilter( + numStates: NumStates, + numObservations: NumObservations, + transitionMatrix: _transitionMatrix, + measurementFunction: _measurementFunction, + initialState: _initialState, + initialCovariance: _initialCovariance, + processNoiseVariance: _processNoiseVariance, + measurementNoiseVariance: _measurementNoiseVariance, + device: Device, + scalarType: Type + )); + } + + /// + /// Creates a Kalman filter model using the parameters provided in the input sequence. + /// + public IObservable Process(IObservable source) + { + return source.Select(parameters => + new KalmanFilter( + parameters: parameters, + device: Device, + scalarType: Type + )); + } +} diff --git a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs new file mode 100644 index 00000000..77aa932a --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs @@ -0,0 +1,217 @@ +using System; +using System.ComponentModel; +using System.Collections.Generic; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LDS; + +/// +/// Initializes the parameters for a new Kalman filter model. +/// +[Combinator] +[Description("Initializes the parameters for a new Kalman filter model.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class CreateKalmanFilterParameters : IScalarTypeProvider +{ + /// + /// The data type of the tensor elements. + /// + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type + { + get => _scalarType; + set + { + _scalarType = value; + ConvertTensorsScalarType(value); + } + } + private ScalarType _scalarType = ScalarType.Float32; + + private Tensor _transitionMatrix = null; + /// + /// The state transition matrix. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor TransitionMatrix + { + get => _transitionMatrix; + set => _transitionMatrix = value.to_type(Type); + } + + /// + /// The XML string representation of the transition matrix for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(TransitionMatrix))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string TransitionMatrixXml + { + get => TensorConverter.ConvertToString(TransitionMatrix, _scalarType); + set => TransitionMatrix = TensorConverter.ConvertFromString(value, _scalarType); + } + + private Tensor _measurementFunction = null; + /// + /// The measurement function. + /// + [XmlIgnore] + public Tensor MeasurementFunction + { + get => _measurementFunction; + set => _measurementFunction = value.to_type(Type); + } + + /// + /// The XML string representation of the measurement function for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(MeasurementFunction))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string MeasurementFunctionXml + { + get => TensorConverter.ConvertToString(MeasurementFunction, _scalarType); + set => MeasurementFunction = TensorConverter.ConvertFromString(value, _scalarType); + } + + private Tensor _processNoiseCovariance = null; + /// + /// The process noise variance. + /// + [XmlIgnore] + public Tensor ProcessNoiseCovariance + { + get => _processNoiseCovariance; + set => _processNoiseCovariance = value.to_type(Type); + } + + /// + /// The XML string representation of the process noise variance for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(ProcessNoiseCovariance))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ProcessNoiseCovarianceXml + { + get => TensorConverter.ConvertToString(ProcessNoiseCovariance, _scalarType); + set => ProcessNoiseCovariance = TensorConverter.ConvertFromString(value, _scalarType); + } + + private Tensor _measurementNoiseCovariance = null; + /// + /// The measurement noise covariance matrix. + /// + [XmlIgnore] + public Tensor MeasurementNoiseCovariance + { + get => _measurementNoiseCovariance; + set => _measurementNoiseCovariance = value.to_type(Type); + } + + /// + /// The XML string representation of the measurement noise covariance matrix for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(MeasurementNoiseCovariance))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string MeasurementNoiseCovarianceXml + { + get => TensorConverter.ConvertToString(MeasurementNoiseCovariance, _scalarType); + set => MeasurementNoiseCovariance = TensorConverter.ConvertFromString(value, _scalarType); + } + + private Tensor _initialState; + /// + /// The initial state. + /// + [XmlIgnore] + public Tensor InitialState + { + get => _initialState; + set => _initialState = value.to_type(Type); + } + + /// + /// The XML string representation of the initial state for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(InitialState))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string InitialStateXml + { + get => TensorConverter.ConvertToString(InitialState, _scalarType); + set => InitialState = TensorConverter.ConvertFromString(value, _scalarType); + } + + private Tensor _initialCovariance; + /// + /// The initial covariance. + /// + [XmlIgnore] + public Tensor InitialCovariance + { + get => _initialCovariance; + set => _initialCovariance = value.to_type(Type); + } + + /// + /// The XML string representation of the initial covariance for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(InitialCovariance))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string InitialCovarianceXml + { + get => TensorConverter.ConvertToString(InitialCovariance, _scalarType); + set => InitialCovariance = TensorConverter.ConvertFromString(value, _scalarType); + } + + private void ConvertTensorsScalarType(ScalarType scalarType) + { + _transitionMatrix = _transitionMatrix.to_type(scalarType); + _measurementFunction = _measurementFunction.to_type(scalarType); + _processNoiseCovariance = _processNoiseCovariance.to_type(scalarType); + _measurementNoiseCovariance = _measurementNoiseCovariance.to_type(scalarType); + _initialState = _initialState.to_type(scalarType); + _initialCovariance = _initialCovariance.to_type(scalarType); + } + + /// + /// Creates parameters for a Kalman filter model using the properties of this class. + /// + public IObservable Process() + { + return Observable.Return( + new KalmanFilterParameters( + TransitionMatrix, + MeasurementFunction, + ProcessNoiseCovariance, + MeasurementNoiseCovariance, + InitialState, + InitialCovariance + ) + ); + } + + /// + /// Creates parameters for a Kalman filter model for each element in the input sequence. + /// + public IObservable Process(IObservable source) + { + return source.Select(_ => + new KalmanFilterParameters( + TransitionMatrix, + MeasurementFunction, + ProcessNoiseCovariance, + MeasurementNoiseCovariance, + InitialState, + InitialCovariance + ) + ); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs b/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs new file mode 100644 index 00000000..c27e4613 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs @@ -0,0 +1,128 @@ +using System; +using System.ComponentModel; +using System.Reactive; +using System.Reactive.Linq; +using System.Reflection; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; +using System.Xml.Serialization; +using Bonsai; +using Bonsai.ML.Torch; +using Bonsai.ML.Torch.NeuralNets; +using Bonsai.Reactive; +using TorchSharp; +using TorchSharp.Modules; + +namespace Bonsai.ML.Torch.LDS; + +/// +/// Learn the parameters kalman filter using the batch EM update algorithm. +/// +[Combinator] +[ResetCombinator] +[Description("Learn the parameters kalman filter using the batch EM update algorithm.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class ExpectationMaximization +{ + [TypeConverter(typeof(KalmanFilterNameConverter))] + public string ModelName { get; set; } = "KalmanFilter"; + + private int _maxIterations = 10; + public int MaxIterations + { + get => _maxIterations; + set + { + if (value < 1) throw new ArgumentOutOfRangeException("MaxIterations must be at least 1."); + _maxIterations = value; + } + } + + private double _tolerance = 1e-4; + public double Tolerance + { + get => _tolerance; + set + { + if (value < 0) throw new ArgumentOutOfRangeException("Tolerance must be non-negative."); + _tolerance = value; + } + } + + private bool _verbose = true; + public bool Verbose + { + get => _verbose; + set => _verbose = value; + } + + public IObservable Process(IObservable source) + { + return source.SelectMany(input => + { + var model = KalmanFilterModelManager.GetKalmanFilter(ModelName); + return Observable.FromAsync(cancellationToken => + { + return Task.Run(() => + { + var previousLogLikelihood = double.NegativeInfinity; + var logLikelihood = torch.zeros(new long[] { MaxIterations }, device: input.device); + + for (int i = 0; i < MaxIterations; i++) + { + if (cancellationToken.IsCancellationRequested) + { + break; + } + + ExpectationMaximizationResult result; + using (KalmanFilterModelManager.Read(model)) + { + result = model.ExpectationMaximization(input, 1, Tolerance, false); + } + + var logLikelihoodSum = result.LogLikelihood + .cpu() + .to_type(torch.ScalarType.Float32) + .ReadCpuSingle(0); + + logLikelihood[i] = logLikelihoodSum; + + if (Verbose) + { + Console.WriteLine("Iteration " + (i + 1) + ", Log Likelihood: " + logLikelihoodSum); + if (i == MaxIterations - 1) + { + Console.WriteLine("EM reached the maximum number of iterations."); + } + } + + if (logLikelihoodSum - previousLogLikelihood < Tolerance) + { + if (Verbose) + { + Console.WriteLine("EM converged after " + (i + 1) + " iterations."); + } + logLikelihood = logLikelihood[torch.TensorIndex.Slice(0, i + 1)]; + break; + } + previousLogLikelihood = logLikelihoodSum; + + using (KalmanFilterModelManager.Write(model)) + { + model.UpdateParameters(result.Parameters); + } + } + + var expectationMaximizationResult = new ExpectationMaximizationResult( + logLikelihood, + model.Parameters); + + return expectationMaximizationResult; + }, + cancellationToken); + }); + }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/ExpectationMaximizationResult.cs b/src/Bonsai.ML.Torch.LDS/ExpectationMaximizationResult.cs new file mode 100644 index 00000000..a51a54b0 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/ExpectationMaximizationResult.cs @@ -0,0 +1,23 @@ +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LDS; + +/// +/// Represents the result of an expectation-maximization step for a Kalman filter model. +/// +/// +/// +public struct ExpectationMaximizationResult( + Tensor logLikelihood, + KalmanFilterParameters parameters) +{ + /// + /// The log likelihood of the observed data given the model parameters after each iteration. + /// + public Tensor LogLikelihood = logLikelihood; + + /// + /// The final updated Kalman filter parameters after the last expectation-maximization step. + /// + public KalmanFilterParameters Parameters = parameters; +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/Filter.cs b/src/Bonsai.ML.Torch.LDS/Filter.cs new file mode 100644 index 00000000..2e8e3c37 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/Filter.cs @@ -0,0 +1,30 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using Bonsai; +using TorchSharp; + +namespace Bonsai.ML.Torch.LDS; + +[Combinator] +[ResetCombinator] +[Description("")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Filter +{ + [TypeConverter(typeof(KalmanFilterNameConverter))] + public string ModelName { get; set; } = "KalmanFilter"; + + public IObservable Process(IObservable source) + { + return source.Select((input) => + { + var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); + using (KalmanFilterModelManager.Read(kalmanFilter)) + { + return kalmanFilter.Filter(input); + } + }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/FilteredResult.cs b/src/Bonsai.ML.Torch.LDS/FilteredResult.cs new file mode 100644 index 00000000..51e4d9d7 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/FilteredResult.cs @@ -0,0 +1,51 @@ +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LDS; + +/// +/// Represents the result of a Kalman filter update step. +/// +/// +/// +/// +/// +/// +/// +public struct FilteredResult( + Tensor predictedState, + Tensor predictedCovariance, + Tensor updatedState, + Tensor updatedCovariance, + Tensor logLikelihood, + Tensor kalmanGain) +{ + /// + /// The predicted state after the prediction step. + /// + public Tensor PredictedState = predictedState; + + /// + /// The predicted covariance after the prediction step. + /// + public Tensor PredictedCovariance = predictedCovariance; + + /// + /// The updated state after the update step. + /// + public Tensor UpdatedState = updatedState; + + /// + /// The updated covariance after the update step. + /// + public Tensor UpdatedCovariance = updatedCovariance; + + /// + /// The log likelihood of the measurement given the predicted state. + /// + public Tensor LogLikelihood = logLikelihood; + + /// + /// The Kalman gain. + /// + public Tensor KalmanGain = kalmanGain; +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs new file mode 100644 index 00000000..76b844d0 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs @@ -0,0 +1,123 @@ +using System; +using System.Reactive.Disposables; +using System.Threading; +using System.Runtime.CompilerServices; +using System.Collections.Generic; +using static TorchSharp.torch; +using Bonsai.ML.Torch.LDS; +using TorchSharp; + +// +// Manages instances of the Kalman Filter in a thread-safe manner. +// +internal sealed class KalmanFilterModelManager +{ + private static readonly ConditionalWeakTable _moduleLocks = new(); + + public static ReaderWriterLockSlim GetLock(KalmanFilter instance) => + _moduleLocks.GetValue(instance, _ => new ReaderWriterLockSlim(LockRecursionPolicy.NoRecursion)); + + public static IDisposable Read(KalmanFilter instance) + { + var lockObject = GetLock(instance); + lockObject.EnterReadLock(); + return new ManagedLock(lockObject, Mode.Read); + } + + public static IDisposable Write(KalmanFilter instance) + { + var lockObject = GetLock(instance); + lockObject.EnterWriteLock(); + return new ManagedLock(lockObject, Mode.Write); + } + + private enum Mode + { + Read, + Write + } + + private static readonly Dictionary _models = new(); + + public static KalmanFilter GetKalmanFilter(string name) + { + return _models.TryGetValue(name, out var model) ? model : throw new InvalidOperationException($"Kalman filter with name {name} not found."); + } + + internal static KalmanFilterDisposable Reserve( + string name, + int numStates, + int numObservations, + Tensor transitionMatrix, + Tensor measurementFunction, + Tensor processNoiseVariance, + Tensor measurementNoiseVariance, + Tensor initialState, + Tensor initialCovariance, + Device? device = null, + ScalarType? scalarType = null + ) + { + if (_models.ContainsKey(name)) + { + throw new InvalidOperationException($"A Kalman filter with name {name} already exists."); + } + + var kalmanFilter = new KalmanFilter( + numStates: numStates, + numObservations: numObservations, + transitionMatrix: transitionMatrix, + measurementFunction: measurementFunction, + processNoiseVariance: processNoiseVariance, + measurementNoiseVariance: measurementNoiseVariance, + initialState: initialState, + initialCovariance: initialCovariance, + device: device, + scalarType: scalarType ?? ScalarType.Float32 + ); + + _models.Add(name, kalmanFilter); + + return new KalmanFilterDisposable(kalmanFilter, Disposable.Create(() => + { + _models.Remove(name); + kalmanFilter.Dispose(); + })); + } + + private readonly struct ManagedLock( + ReaderWriterLockSlim lockObject, + Mode mode) : IDisposable + { + private readonly ReaderWriterLockSlim _lockObject = lockObject; + private readonly Mode _mode = mode; + + public void Dispose() + { + // Exit in the reverse mode we entered. + switch (_mode) + { + case Mode.Read when _lockObject.IsReadLockHeld: _lockObject.ExitReadLock(); break; + case Mode.Write when _lockObject.IsWriteLockHeld: _lockObject.ExitWriteLock(); break; + } + } + } + + internal sealed class KalmanFilterDisposable(KalmanFilter model, IDisposable disposable) : IDisposable + { + private IDisposable? resource = disposable ?? throw new ArgumentNullException(nameof(disposable)); + + public bool IsDisposed => resource is null; + + private readonly KalmanFilter model = model ?? throw new ArgumentNullException(nameof(model)); + + public KalmanFilter Model => model; + + public void Dispose() + { + var disposable = Interlocked.Exchange(ref resource, null); + disposable?.Dispose(); + } + } +} + diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilterNameConverter.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilterNameConverter.cs new file mode 100644 index 00000000..af31bf37 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilterNameConverter.cs @@ -0,0 +1,42 @@ +using Bonsai; +using Bonsai.Expressions; +using System.Linq; +using System.ComponentModel; + +namespace Bonsai.ML.Torch.LDS; + +public class KalmanFilterNameConverter : StringConverter +{ + /// + public override bool GetStandardValuesSupported(ITypeDescriptorContext context) + { + return true; + } + + /// + public override StandardValuesCollection GetStandardValues(ITypeDescriptorContext context) + { + if (context != null) + { + var workflowBuilder = (WorkflowBuilder)context.GetService(typeof(WorkflowBuilder)); + if (workflowBuilder != null) + { + var models = (from builder in workflowBuilder.Workflow.Descendants() + where builder.GetType() != typeof(DisableBuilder) + let managedModelNode = ExpressionBuilder.GetWorkflowElement(builder) + where managedModelNode != null && managedModelNode is CreateKalmanFilter + let createKalmanFilter = (CreateKalmanFilter)managedModelNode + where createKalmanFilter != null && !string.IsNullOrEmpty(createKalmanFilter.ModelName) + select createKalmanFilter.ModelName) + .Distinct() + .ToList(); + if (models.Count > 0) + { + return new StandardValuesCollection(models); + } + } + } + + return new StandardValuesCollection(new string[] { }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilterParameters.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilterParameters.cs new file mode 100644 index 00000000..ad708605 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilterParameters.cs @@ -0,0 +1,54 @@ +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LDS; + +/// +/// Represents the parameters of a Kalman filter model. +/// +/// +/// Initializes a new instance of the struct with the specified parameters. +/// +/// +/// +/// +/// +/// +/// +public struct KalmanFilterParameters( + Tensor transitionMatrix, + Tensor measurementFunction, + Tensor processNoiseCovariance, + Tensor measurementNoiseCovariance, + Tensor initialState, + Tensor initialCovariance) +{ + /// + /// The state transition matrix. + /// + public Tensor TransitionMatrix = transitionMatrix; + + /// + /// The measurement function. + /// + public Tensor MeasurementFunction = measurementFunction; + + /// + /// The process noise covariance. + /// + public Tensor ProcessNoiseCovariance = processNoiseCovariance; + + /// + /// The measurement noise covariance. + /// + public Tensor MeasurementNoiseCovariance = measurementNoiseCovariance; + + /// + /// The initial state. + /// + public Tensor InitialState = initialState; + + /// + /// The initial covariance. + /// + public Tensor InitialCovariance = initialCovariance; +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs b/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs new file mode 100644 index 00000000..2081056d --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs @@ -0,0 +1,45 @@ +using TorchSharp; +using System; +using Bonsai; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; + +namespace Bonsai.ML.Torch.LDS; + +[Combinator] +[Description("")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Orthogonalize +{ + [TypeConverter(typeof(KalmanFilterNameConverter))] + public string ModelName { get; set; } = "KalmanFilter"; + + public IObservable Process(IObservable source) + { + return source.Select(input => + { + var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); + var smoothedState = input.SmoothedState; + var smoothedCovariance = input.SmoothedCovariance; + using (KalmanFilterModelManager.Read(kalmanFilter)) + { + return kalmanFilter.OrthogonalizeStateAndCovariance(smoothedState, smoothedCovariance); + } + }); + } + + public IObservable Process(IObservable source) + { + return source.Select(input => + { + var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); + var filteredState = input.UpdatedState; + var filteredCovariance = input.UpdatedCovariance; + using (KalmanFilterModelManager.Read(kalmanFilter)) + { + return kalmanFilter.OrthogonalizeStateAndCovariance(filteredState, filteredCovariance); + } + }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/OrthogonalizedResult.cs b/src/Bonsai.ML.Torch.LDS/OrthogonalizedResult.cs new file mode 100644 index 00000000..e2419f07 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/OrthogonalizedResult.cs @@ -0,0 +1,17 @@ +using TorchSharp; + +namespace Bonsai.ML.Torch.LDS; + +public struct OrthogonalizedResult +{ + public torch.Tensor OrthogonalizedState; + public torch.Tensor OrthogonalizedCovariance; + + public OrthogonalizedResult( + torch.Tensor orthogonalizedState, + torch.Tensor orthogonalizedCovariance) + { + OrthogonalizedState = orthogonalizedState; + OrthogonalizedCovariance = orthogonalizedCovariance; + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/Smooth.cs b/src/Bonsai.ML.Torch.LDS/Smooth.cs new file mode 100644 index 00000000..d06f6946 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/Smooth.cs @@ -0,0 +1,33 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using Bonsai; +using Bonsai.ML.Torch; +using Bonsai.ML.Torch.NeuralNets; +using TorchSharp; +using TorchSharp.Modules; + +namespace Bonsai.ML.Torch.LDS; + +[Combinator] +[ResetCombinator] +[Description("")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Smooth +{ + [TypeConverter(typeof(KalmanFilterNameConverter))] + public string ModelName { get; set; } = "KalmanFilter"; + + public IObservable Process(IObservable source) + { + return source.Select((input) => + { + var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); + using (KalmanFilterModelManager.Read(kalmanFilter)) + { + return kalmanFilter.Smooth(input); + } + }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/SmoothedResult.cs b/src/Bonsai.ML.Torch.LDS/SmoothedResult.cs new file mode 100644 index 00000000..1e45cd80 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/SmoothedResult.cs @@ -0,0 +1,44 @@ +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LDS; + +/// +/// Represents the result of a Kalman smoother step. +/// +/// +/// +/// +/// +/// +public struct SmoothedResult( + Tensor smoothedState, + Tensor smoothedCovariance, + Tensor smoothedLagOneCovariance, + Tensor smoothedInitialState = null, + Tensor smoothedInitialCovariance = null) +{ + /// + /// The smoothed state after the smoothing step. + /// + public Tensor SmoothedState = smoothedState; + + /// + /// The smoothed covariance after the smoothing step. + /// + public Tensor SmoothedCovariance = smoothedCovariance; + + /// + /// The smoothed lag-one covariance after the smoothing step. + /// + public Tensor SmoothedLagOneCovariance = smoothedLagOneCovariance; + + /// + /// The smoothed initial state after the smoothing step. + /// + public Tensor SmoothedInitialState = smoothedInitialState; + + /// + /// The smoothed initial covariance after the smoothing step. + /// + public Tensor SmoothedInitialCovariance = smoothedInitialCovariance; +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/UpdateParameters.cs b/src/Bonsai.ML.Torch.LDS/UpdateParameters.cs new file mode 100644 index 00000000..3286c2b9 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/UpdateParameters.cs @@ -0,0 +1,34 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Runtime.InteropServices; +using System.Xml.Serialization; +using Bonsai; +using Bonsai.ML.Torch; +using Bonsai.ML.Torch.NeuralNets; +using TorchSharp; +using TorchSharp.Modules; + +namespace Bonsai.ML.Torch.LDS; + +[Combinator] +[ResetCombinator] +[Description("Learn the parameters kalman filter using the batch EM update algorithm.")] +[WorkflowElementCategory(ElementCategory.Sink)] +public class UpdateParameters +{ + [TypeConverter(typeof(KalmanFilterNameConverter))] + public string ModelName { get; set; } = "KalmanFilter"; + + public IObservable Process(IObservable source) + { + return source.Do((input) => + { + var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); + using (KalmanFilterModelManager.Write(kalmanFilter)) + { + kalmanFilter.UpdateParameters(input); + } + }); + } +} \ No newline at end of file From 5b88d0ba29b2a1cb4c6e237d0a303e71c780b27d Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 17 Sep 2025 13:16:32 +0100 Subject: [PATCH 06/92] Added test repo for `Torch.LDS` to compare output with Python implementation --- Bonsai.ML.sln | 8 +- .../Bonsai.ML.Torch.LDS.Tests.csproj | 27 +++ tests/Bonsai.ML.Torch.LDS.Tests/UnitTest1.cs | 170 ++++++++++++++++++ .../bootstrap_test_environment.py | 87 +++++++++ .../requirements.txt | 152 ++++++++++++++++ 5 files changed, 443 insertions(+), 1 deletion(-) create mode 100644 tests/Bonsai.ML.Torch.LDS.Tests/Bonsai.ML.Torch.LDS.Tests.csproj create mode 100644 tests/Bonsai.ML.Torch.LDS.Tests/UnitTest1.cs create mode 100644 tests/Bonsai.ML.Torch.LDS.Tests/bootstrap_test_environment.py create mode 100644 tests/Bonsai.ML.Torch.LDS.Tests/requirements.txt diff --git a/Bonsai.ML.sln b/Bonsai.ML.sln index e2b989d4..445012ab 100644 --- a/Bonsai.ML.sln +++ b/Bonsai.ML.sln @@ -1,4 +1,4 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 +Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio Version 17 VisualStudioVersion = 17.0.31903.59 MinimumVisualStudioVersion = 10.0.40219.1 @@ -42,6 +42,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Tests.Utilities", EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Torch.LDS", "src\Bonsai.ML.Torch.LDS\Bonsai.ML.Torch.LDS.csproj", "{41D4BEC7-AB1F-41E4-95FE-4DB23970FF4B}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Torch.LDS.Tests", "tests\Bonsai.ML.Torch.LDS.Tests\Bonsai.ML.Torch.LDS.Tests.csproj", "{0B258929-0B07-4CE7-BE8D-A86BBC46AAD4}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -108,6 +110,10 @@ Global {41D4BEC7-AB1F-41E4-95FE-4DB23970FF4B}.Debug|Any CPU.Build.0 = Debug|Any CPU {41D4BEC7-AB1F-41E4-95FE-4DB23970FF4B}.Release|Any CPU.ActiveCfg = Release|Any CPU {41D4BEC7-AB1F-41E4-95FE-4DB23970FF4B}.Release|Any CPU.Build.0 = Release|Any CPU + {0B258929-0B07-4CE7-BE8D-A86BBC46AAD4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {0B258929-0B07-4CE7-BE8D-A86BBC46AAD4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {0B258929-0B07-4CE7-BE8D-A86BBC46AAD4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {0B258929-0B07-4CE7-BE8D-A86BBC46AAD4}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/Bonsai.ML.Torch.LDS.Tests.csproj b/tests/Bonsai.ML.Torch.LDS.Tests/Bonsai.ML.Torch.LDS.Tests.csproj new file mode 100644 index 00000000..182878b4 --- /dev/null +++ b/tests/Bonsai.ML.Torch.LDS.Tests/Bonsai.ML.Torch.LDS.Tests.csproj @@ -0,0 +1,27 @@ + + + + net8.0 + enable + enable + + false + true + + + + + + + + + + + + + + + + + + diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/UnitTest1.cs b/tests/Bonsai.ML.Torch.LDS.Tests/UnitTest1.cs new file mode 100644 index 00000000..f82b84c3 --- /dev/null +++ b/tests/Bonsai.ML.Torch.LDS.Tests/UnitTest1.cs @@ -0,0 +1,170 @@ +using Newtonsoft.Json; +using System; +using System.Diagnostics; +using System.IO; +using System.IO.Compression; +using System.Net.Http; +using System.Runtime.InteropServices; +using System.Threading.Tasks; +using Bonsai.ML.Tests; + +namespace Bonsai.ML.Torch.LDS.Tests; + +/// +/// Tests for the neural latents workflow. +/// +[TestClass] +public class NeuralLatentsWorkflowTest +{ + private readonly string basePath = Path.Combine(AppDomain.CurrentDomain.BaseDirectory); + + private static void RunProcess(string fileName, string fmtArg) + { + var start = new ProcessStartInfo + { + FileName = fileName, + Arguments = fmtArg, + RedirectStandardOutput = true, + RedirectStandardInput = true, + RedirectStandardError = true, + UseShellExecute = false, + CreateNoWindow = true, + }; + + using var process = new Process { StartInfo = start }; + process.Start(); + var output = process.StandardOutput.ReadToEnd(); + var error = process.StandardError.ReadToEnd(); + process.WaitForExit(); + + if (!string.IsNullOrEmpty(output)) + { + Console.WriteLine("Standard Output: "); + Console.WriteLine(output); + } + + if (!string.IsNullOrEmpty(error)) + { + Console.WriteLine("Standard Error: "); + Console.WriteLine(error); + } + } + + private static void DownloadData(string basePath) + { + string zipFileUrl = "https://zenodo.org/records/10879253/files/ReceptiveFieldSimpleCell.zip"; + string outputPath = Path.Combine(basePath, "data"); + + try + { + byte[] responseBytes; + using (var httpClient = new HttpClient()) + { + responseBytes = httpClient.GetByteArrayAsync(zipFileUrl).Result; + Console.WriteLine("File downloaded successfully."); + } + + using MemoryStream zipStream = new(responseBytes); + using ZipArchive zip = new(zipStream, ZipArchiveMode.Read); + zip.ExtractToDirectory(outputPath); + Console.WriteLine("File extracted successfully."); + } + catch (Exception ex) + { + Console.WriteLine($"An error occurred: {ex.Message}"); + } + } + + private static void RunPythonScript(string basePath) + { + var pythonExec = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) + ? "python" + : "python3"; + var scriptPath = Path.Combine(basePath, "bootstrap_test_environment.py"); + RunProcess(pythonExec, $"\"{scriptPath}\" {basePath}"); + + Console.WriteLine("Run python script finished."); + } + + private async Task RunBonsaiWorkflow(string basePath) + { + var currentDirectory = Environment.CurrentDirectory; + Environment.CurrentDirectory = basePath; + try + { + var workflowPath = Path.Combine(basePath, "receptive_field.bonsai"); + await WorkflowHelper.RunWorkflow( + workflowPath); + Console.WriteLine("Run bonsai workflow finished."); + } + finally { Environment.CurrentDirectory = currentDirectory; } + } + + private bool CompareJSONData(string basePath, double tolerance = 1e-9) + { + var originalFileName = Path.Combine(basePath, "original-receptivefield.json"); + var bonsaiFileName = Path.Combine(basePath, "bonsai-receptivefield.json"); + var pythonFileName = Path.Combine(basePath, "python-receptivefield.json"); + + var originalOutput = GetStateFromJson(originalFileName); + var bonsaiOutput = GetStateFromJson(bonsaiFileName); + var pythonOutput = GetStateFromJson(pythonFileName); + + try + { + for (int i = 0; i < bonsaiOutput.X.GetLength(0); i++) + { + for (int j = 0; j < bonsaiOutput.X.GetLength(1); j++) + { + if (Math.Abs(bonsaiOutput.X[i,j] - pythonOutput.X[i,j]) > tolerance || Math.Abs(originalOutput.X[i,j] - pythonOutput.X[i,j]) > tolerance) + { + Console.WriteLine($"Discrepency found comparing X at index ({i},{j}) with tolerance {tolerance}: bonsaiOutput = {bonsaiOutput.X[i,j]}, pythonOutput = {pythonOutput.X[i,j]}, originalOutput = {originalOutput.X[i,j]}."); + return false; + } + } + } + for (int i = 0; i < bonsaiOutput.P.GetLength(0); i++) + { + for (int j = 0; j < bonsaiOutput.P.GetLength(1); j++) + { + if (Math.Abs(bonsaiOutput.P[i,j] - pythonOutput.P[i,j]) > tolerance || Math.Abs(originalOutput.P[i,j] - pythonOutput.P[i,j]) > tolerance) + { + Console.WriteLine($"Discrepency found comparing P at index ({i},{j}) with tolerance {tolerance}: bonsaiOutput = {bonsaiOutput.P[i,j]}, pythonOutput = {pythonOutput.P[i,j]}, originalOutput = {originalOutput.P[i,j]}."); + return false; + } + } + } + } + catch + { + return false; + } + return true; + } + + /// + /// Setup for the test. + /// + [TestInitialize] + [DeploymentItem("bootstrap_test_environment.py")] + [DeploymentItem("receptive_field.py")] + [DeploymentItem("receptive_field.bonsai")] + [DeploymentItem("original-receptivefield.json")] + public async Task TestSetup() + { + Directory.CreateDirectory(basePath); + DownloadData(basePath); + RunPythonScript(basePath); + await RunBonsaiWorkflow(basePath); + } + + /// + /// Compares the results from the Python script and the Bonsai workflow. + /// + [TestMethod] + public void CompareResults() + { + var result = CompareJSONData(basePath); + Assert.IsTrue(result); + } +} diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/bootstrap_test_environment.py b/tests/Bonsai.ML.Torch.LDS.Tests/bootstrap_test_environment.py new file mode 100644 index 00000000..5ef73288 --- /dev/null +++ b/tests/Bonsai.ML.Torch.LDS.Tests/bootstrap_test_environment.py @@ -0,0 +1,87 @@ +import sys +import os +import subprocess +import argparse + +def get_venv_path(): + return os.path.join(os.path.dirname(os.path.realpath(__file__)), '.venv') + +def get_pip_path(venv_path: str = None): + if venv_path is None: + venv_path = get_venv_path() + if sys.platform.startswith('linux'): + return os.path.join(venv_path, 'bin', 'pip') + else: + return os.path.join(venv_path, 'Scripts', 'pip.exe') + +def get_python_path(venv_path: str = None): + if venv_path is None: + venv_path = get_venv_path() + if sys.platform.startswith('linux'): + return os.path.join(venv_path, 'bin', 'python') + else: + return os.path.join(venv_path, 'Scripts', 'python.exe') + +def get_base_dir(base_dir: str = None): + # function to get the base directory + if base_dir is not None: + return base_dir + try: + return os.path.dirname(os.path.realpath(__file__)) + except: + return os.getcwd() + +def create_venv(parent_dir: str = None): + # function to create a virtual environment + if parent_dir is None: + parent_dir = os.path.dirname(os.path.realpath(__file__)) + venv_path = os.path.join(parent_dir, ".venv") + subprocess.check_call([sys.executable, "-m", "venv", venv_path]) + return venv_path + +def activate_venv(venv_path: str = None): + # function to activate the virtual environment + if venv_path is None: + venv_path = get_venv_path() + if sys.platform.startswith('linux'): + bin_path = os.path.join(venv_path, 'bin') + os.environ["PATH"] = os.pathsep.join([bin_path, *os.environ.get("PATH", "").split(os.pathsep)]) + sys.path.insert(0, os.path.join(venv_path, 'lib', f'python{sys.version_info.major}.{sys.version_info.minor}', 'site-packages')) + else: + bin_path = os.path.join(venv_path, 'Scripts') + os.environ["PATH"] = os.pathsep.join([bin_path, *os.environ.get("PATH", "").split(os.pathsep)]) + sys.path.insert(0, os.path.join(venv_path, 'Lib', 'site-packages')) + +def install(package: str, venv_path: str = None): + # function to install pip packages into a virtual environment + if venv_path is None: + venv_path = get_venv_path() + pip_path = get_pip_path(venv_path) + subprocess.check_call([pip_path, "install", package]) + +def install_requirements(requirements_file: str, venv_path: str = None): + # function to install pip packages from a requirements file into a virtual environment + if venv_path is None: + venv_path = get_venv_path() + pip_path = get_pip_path(venv_path) + subprocess.check_call([pip_path, "install", "-r", requirements_file]) + +parser = argparse.ArgumentParser() +parser.add_argument("base_dir", type=str, default=None) +args = parser.parse_args() + +base_dir = get_base_dir(args.base_dir) +venv_path = create_venv(base_dir) +activate_venv(venv_path) +install_requirements(os.path.join(base_dir, "requirements.txt"), venv_path) + +python_path = get_python_path(venv_path) + +script_path = os.path.join(base_dir, "receptive_field.py") +process = subprocess.Popen([python_path, script_path, base_dir, str(args.n_samples)]) +return_code = process.wait() + +if return_code == 0: + print("Script completed successfully.") +else: + print(f"Script exited with errors. Return code: {return_code}") \ No newline at end of file diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/requirements.txt b/tests/Bonsai.ML.Torch.LDS.Tests/requirements.txt new file mode 100644 index 00000000..05d89b7e --- /dev/null +++ b/tests/Bonsai.ML.Torch.LDS.Tests/requirements.txt @@ -0,0 +1,152 @@ +acres==0.5.0 +aiobotocore==2.24.1 +aiohappyeyeballs==2.6.1 +aiohttp==3.12.15 +aioitertools==0.12.0 +aiosignal==1.4.0 +annotated-types==0.7.0 +arrow==1.3.0 +asciitree==0.3.3 +asttokens==3.0.0 +attrs==25.3.0 +bids-validator-deno==2.0.11 +bidsschematools==1.0.14 +blessed==1.21.0 +botocore==1.39.11 +certifi==2025.8.3 +cffi==1.17.1 +charset-normalizer==3.4.3 +ci-info==0.3.0 +click==8.1.8 +click-didyoumean==0.3.1 +comm==0.2.3 +cryptography==45.0.7 +dandi==0.71.3 +dandischema==0.11.1 +debugpy==1.8.16 +decorator==5.2.1 +deprecated==1.2.18 +dnspython==2.7.0 +email-validator==2.3.0 +etelemetry==0.3.1 +executing==2.2.1 +fasteners==0.20 +fastjsonschema==2.21.2 +filelock==3.19.1 +fqdn==1.5.1 +frozenlist==1.7.0 +fscacher==0.4.4 +fsspec==2025.9.0 +h5py==3.14.0 +hdmf==4.1.0 +hdmf-zarr==0.11.3 +humanize==4.13.0 +idna==3.10 +interleave==0.3.0 +ipykernel==6.30.1 +ipython==9.5.0 +ipython-pygments-lexers==1.1.1 +isodate==0.7.2 +isoduration==20.11.0 +jaraco-classes==3.4.0 +jaraco-context==6.0.1 +jaraco-functools==4.3.0 +jedi==0.19.2 +jeepney==0.9.0 +jinja2==3.1.6 +jmespath==1.0.1 +joblib==1.5.2 +jsonpointer==3.0.0 +jsonschema==4.25.1 +jsonschema-specifications==2025.4.1 +jupyter-client==8.6.3 +jupyter-core==5.8.1 +keyring==25.6.0 +keyrings-alt==5.0.2 +markupsafe==3.0.2 +matplotlib-inline==0.1.7 +ml-dtypes==0.5.3 +more-itertools==10.8.0 +mpmath==1.3.0 +multidict==6.6.4 +narwhals==2.3.0 +natsort==8.4.0 +nbformat==5.10.4 +nest-asyncio==1.6.0 +networkx==3.5 +numcodecs==0.15.1 +numpy==2.3.2 +nvidia-cublas-cu12==12.8.4.1 +nvidia-cuda-cupti-cu12==12.8.90 +nvidia-cuda-nvrtc-cu12==12.8.93 +nvidia-cuda-runtime-cu12==12.8.90 +nvidia-cudnn-cu12==9.10.2.21 +nvidia-cufft-cu12==11.3.3.83 +nvidia-cufile-cu12==1.13.1.3 +nvidia-curand-cu12==10.3.9.90 +nvidia-cusolver-cu12==11.7.3.90 +nvidia-cusparse-cu12==12.5.8.93 +nvidia-cusparselt-cu12==0.7.1 +nvidia-nccl-cu12==2.27.3 +nvidia-nvjitlink-cu12==12.8.93 +nvidia-nvtx-cu12==12.8.90 +nwbinspector==0.6.5 +packaging==25.0 +pandas==2.3.2 +parso==0.8.5 +pexpect==4.9.0 +platformdirs==4.4.0 +plotly==6.3.0 +prompt-toolkit==3.0.52 +propcache==0.3.2 +psutil==7.0.0 +ptyprocess==0.7.0 +pure-eval==0.2.3 +pycparser==2.22 +pycryptodomex==3.23.0 +pydantic==2.11.7 +pydantic-core==2.33.2 +pygments==2.19.2 +pynwb==3.1.2 +pyout==0.8.1 +python-dateutil==2.9.0.post0 +pytz==2025.2 +pyyaml==6.0.2 +pyzmq==27.0.2 +referencing==0.36.2 +remfile==0.1.13 +requests==2.32.5 +rfc3339-validator==0.1.4 +rfc3987==1.3.8 +rpds-py==0.27.1 +ruamel-yaml==0.18.15 +ruamel-yaml-clib==0.2.12 +s3fs==2025.9.0 +scipy==1.16.1 +secretstorage==3.3.3 +semantic-version==2.10.0 +setuptools==80.9.0 +six==1.17.0 +-e file:///home/nicholas/lds_python +stack-data==0.6.3 +sympy==1.14.0 +tenacity==9.1.2 +tensorstore==0.1.76 +threadpoolctl==3.6.0 +torch==2.8.0 +tornado==6.5.2 +tqdm==4.67.1 +traitlets==5.14.3 +triton==3.4.0 +types-python-dateutil==2.9.0.20250822 +typing-extensions==4.15.0 +typing-inspection==0.4.1 +tzdata==2025.2 +uri-template==1.3.0 +urllib3==2.5.0 +wcwidth==0.2.13 +webcolors==24.11.1 +wrapt==1.17.3 +yarl==1.20.1 +zarr==2.18.7 +zarr-checksum==0.4.7 From 9f740cb9e108b3e9f29bdc70b2ccc719f85f7c33 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 17 Sep 2025 14:31:11 +0100 Subject: [PATCH 07/92] Updated Torch.LDS test for running test against Python script for estimating neural latents --- .../Bonsai.ML.Torch.LDS.Tests.csproj | 9 +- .../NeuralLatentsTest.bonsai | 806 ++++++++++++++++++ .../NeuralLatentsTest.cs | 121 +++ tests/Bonsai.ML.Torch.LDS.Tests/UnitTest1.cs | 170 ---- .../bootstrap_test_environment.py | 4 +- .../estimate_neural_latents.py | 120 +++ 6 files changed, 1056 insertions(+), 174 deletions(-) create mode 100644 tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.bonsai create mode 100644 tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.cs delete mode 100644 tests/Bonsai.ML.Torch.LDS.Tests/UnitTest1.cs create mode 100644 tests/Bonsai.ML.Torch.LDS.Tests/estimate_neural_latents.py diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/Bonsai.ML.Torch.LDS.Tests.csproj b/tests/Bonsai.ML.Torch.LDS.Tests/Bonsai.ML.Torch.LDS.Tests.csproj index 182878b4..af262cc7 100644 --- a/tests/Bonsai.ML.Torch.LDS.Tests/Bonsai.ML.Torch.LDS.Tests.csproj +++ b/tests/Bonsai.ML.Torch.LDS.Tests/Bonsai.ML.Torch.LDS.Tests.csproj @@ -4,7 +4,6 @@ net8.0 enable enable - false true @@ -21,7 +20,13 @@ - + + Always + + + + + diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.bonsai b/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.bonsai new file mode 100644 index 00000000..43d1b506 --- /dev/null +++ b/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.bonsai @@ -0,0 +1,806 @@ + + + + + + + CUDA + -1 + + + + CUDA + + + LoadData + + + + + ../data/transformed_binned_spikes.bin + 0 + 0 + 142 + 0 + F64 + RowMajor + + + + + + + + + CUDA + + + + + + + + + + + + Float32 + + + + + + -1 + 142 + + + + + ObservationT + + + + + ../data/stop_times.bin + 0 + 0 + 1 + 0 + F64 + RowMajor + + + + + + + + + CUDA + + + + + + + + + + + + Float32 + + + + + + -1 + 1 + + + + + TrialEnd + + + + + ../data/JoaquinModelParameters/covariance.bin + 0 + 0 + 10 + 0 + F64 + RowMajor + + + + + + + + + CUDA + + + + + + + + + + + + Float32 + + + + + + 10 + 10 + + + + + Covariance + + + + + ../data/JoaquinModelParameters/state.bin + 0 + 0 + 10 + 0 + F64 + RowMajor + + + + + + + + + CUDA + + + + + + + + + + + + Float32 + + + + + + 10 + + + + + State + + + + + ../data/JoaquinModelParameters/measurementFunction.bin + 0 + 0 + 10 + 0 + F64 + RowMajor + + + + + + + + + CUDA + + + + + + + + + + + + Float32 + + + + + + 142 + 10 + + + + + MeasurementFunction + + + + + ../data/JoaquinModelParameters/measurementNoiseCovariance.bin + 0 + 0 + 142 + 0 + F64 + RowMajor + + + + + + + + + CUDA + + + + + + + + + + + + Float32 + + + + + + 142 + 142 + + + + + MeasurementNoiseCovariance + + + + + ../data/JoaquinModelParameters/transitionMatrix.bin + 0 + 0 + 10 + 0 + F64 + RowMajor + + + + + + + + + CUDA + + + + + + + + + + + + Float32 + + + + + + 10 + 10 + + + + + TransitionMatrix + + + + + ../data/JoaquinModelParameters/processNoiseCovariance.bin + 0 + 0 + 10 + 0 + F64 + RowMajor + + + + + + + + + CUDA + + + + + + + + + + + + Float32 + + + + + + 10 + 10 + + + + + ProcessNoiseCovariance + + + + + ../data/bin_centers.bin + 0 + 0 + 1 + 0 + F64 + RowMajor + + + + + + + + + CUDA + + + + + + + + + + + + Float32 + + + + + + -1 + 1 + + + + + Time + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + LoadModel + + + + TransitionMatrix + + + MeasurementFunction + + + ProcessNoiseCovariance + + + MeasurementNoiseCovariance + + + State + + + Covariance + + + + + + + + + + + + + + + + + + + CUDA + + + + + + + + + 10 + 142 + + + + KalmanFilterModel + + + + + + + + + + + + + + + + + + + + + + LearnParameters + + + + ObservationT + + + KalmanFilterModel + + + + + + + + + 5 + 0.1 + true + + + + ExpectationMaximizationResult + + + KalmanFilterModel + + + + + + ExpectationMaximizationResult + + + LogLikelihood + + + + 0 + + + + Float32 + + + + 0 + + + + + + + + + + + + + + + + + + + + + UpdateParameters + + + + ExpectationMaximizationResult + + + Parameters + + + KalmanFilterModel + + + + + + + + + + + ParametersUpdated + + + + + + + + + + + + + + + Filter + + + + ObservationT + + + KalmanFilterModel + + + + + + + + + + + UpdatedFilteredResult + + + ParametersUpdated + + + + + + UpdatedFilteredResult + + + KalmanFilterModel + + + + + + + + + + + UpdatedSmoothedResult + + + ParametersUpdated + + + + + + UpdatedSmoothedResult + + + KalmanFilterModel + + + + + + + + + + + OrthogonalizedResult + + + ParametersUpdated + + + + + + + PT0S + PT0.01S + + + + OrthogonalizedResult + + + + + + TimerSource + + + OrthogonalizedResult + + + OrthogonalizedState + + + + 5000:5500,0 + + + + + 0 + + + + Float32 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.cs b/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.cs new file mode 100644 index 00000000..426307b9 --- /dev/null +++ b/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.cs @@ -0,0 +1,121 @@ +using Newtonsoft.Json; +using System; +using System.Diagnostics; +using System.IO; +using System.IO.Compression; +using System.Net.Http; +using System.Runtime.InteropServices; +using System.Threading.Tasks; +using Bonsai.ML.Tests.Utilities; + +namespace Bonsai.ML.Torch.LDS.Tests; + +/// +/// Tests for the neural latents workflow. +/// +[TestClass] +public class NeuralLatentsTest +{ + private readonly string basePath = Path.Combine(AppDomain.CurrentDomain.BaseDirectory); + + private static void RunPythonScript(string basePath) + { + var pythonExec = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) + ? "python" + : "python3"; + var scriptPath = Path.Combine(basePath, "bootstrap_test_environment.py"); + ProcessHelper.RunProcess(pythonExec, $"\"{scriptPath}\" {basePath}"); + + Console.WriteLine("Run python script finished."); + } + + private static async Task RunBonsaiWorkflow(string basePath) + { + var currentDirectory = Environment.CurrentDirectory; + Environment.CurrentDirectory = basePath; + try + { + var workflowPath = Path.Combine(basePath, "NeuralLatentsTest.bonsai"); + await WorkflowHelper.RunWorkflow( + workflowPath); + Console.WriteLine("Run bonsai workflow finished."); + } + finally { Environment.CurrentDirectory = currentDirectory; } + } + + private static double[] ReadBinaryFile(string fileName) + { + using var fileStream = new FileStream(fileName, FileMode.Open, FileAccess.Read); + using var binaryReader = new BinaryReader(fileStream); + var fileLength = fileStream.Length; + var numDoubles = fileLength / sizeof(double); + var data = new double[numDoubles]; + for (int i = 0; i < numDoubles; i++) + { + data[i] = binaryReader.ReadDouble(); + } + return data; + } + + private static bool CompareBinaryData(string basePath, double tolerance = 1e-4) + { + var bonsaiMeansFileName = Path.Combine(basePath, "bonsai_means.bin"); + var bonsaiCovariancesFileName = Path.Combine(basePath, "bonsai_covs.bin"); + + var pythonMeansFileName = Path.Combine(basePath, "python_means.bin"); + var pythonCovariancesFileName = Path.Combine(basePath, "python_covs.bin"); + + var bonsaiMeans = ReadBinaryFile(bonsaiMeansFileName); + var bonsaiCovariances = ReadBinaryFile(bonsaiCovariancesFileName); + var pythonMeans = ReadBinaryFile(pythonMeansFileName); + var pythonCovariances = ReadBinaryFile(pythonCovariancesFileName); + + if (bonsaiMeans.Length != pythonMeans.Length || + bonsaiCovariances.Length != pythonCovariances.Length) + { + return false; + } + + for (int i = 0; i < bonsaiMeans.Length; i++) + { + if (Math.Abs(bonsaiMeans[i] - pythonMeans[i]) > tolerance) + { + return false; + } + } + + for (int i = 0; i < bonsaiCovariances.Length; i++) + { + if (Math.Abs(bonsaiCovariances[i] - pythonCovariances[i]) > tolerance) + { + return false; + } + } + + return true; + } + + /// + /// Setup for the test. + /// + [TestInitialize] + [DeploymentItem("bootstrap_test_environment.py")] + [DeploymentItem("estimate_neural_latents.py")] + [DeploymentItem("NeuralLatentsTest.bonsai")] + public async Task TestSetup() + { + Directory.CreateDirectory(basePath); + RunPythonScript(basePath); + await RunBonsaiWorkflow(basePath); + } + + /// + /// Compares the results from the Python script and the Bonsai workflow. + /// + [TestMethod] + public void CompareResults() + { + var result = CompareBinaryData(basePath); + Assert.IsTrue(result); + } +} diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/UnitTest1.cs b/tests/Bonsai.ML.Torch.LDS.Tests/UnitTest1.cs deleted file mode 100644 index f82b84c3..00000000 --- a/tests/Bonsai.ML.Torch.LDS.Tests/UnitTest1.cs +++ /dev/null @@ -1,170 +0,0 @@ -using Newtonsoft.Json; -using System; -using System.Diagnostics; -using System.IO; -using System.IO.Compression; -using System.Net.Http; -using System.Runtime.InteropServices; -using System.Threading.Tasks; -using Bonsai.ML.Tests; - -namespace Bonsai.ML.Torch.LDS.Tests; - -/// -/// Tests for the neural latents workflow. -/// -[TestClass] -public class NeuralLatentsWorkflowTest -{ - private readonly string basePath = Path.Combine(AppDomain.CurrentDomain.BaseDirectory); - - private static void RunProcess(string fileName, string fmtArg) - { - var start = new ProcessStartInfo - { - FileName = fileName, - Arguments = fmtArg, - RedirectStandardOutput = true, - RedirectStandardInput = true, - RedirectStandardError = true, - UseShellExecute = false, - CreateNoWindow = true, - }; - - using var process = new Process { StartInfo = start }; - process.Start(); - var output = process.StandardOutput.ReadToEnd(); - var error = process.StandardError.ReadToEnd(); - process.WaitForExit(); - - if (!string.IsNullOrEmpty(output)) - { - Console.WriteLine("Standard Output: "); - Console.WriteLine(output); - } - - if (!string.IsNullOrEmpty(error)) - { - Console.WriteLine("Standard Error: "); - Console.WriteLine(error); - } - } - - private static void DownloadData(string basePath) - { - string zipFileUrl = "https://zenodo.org/records/10879253/files/ReceptiveFieldSimpleCell.zip"; - string outputPath = Path.Combine(basePath, "data"); - - try - { - byte[] responseBytes; - using (var httpClient = new HttpClient()) - { - responseBytes = httpClient.GetByteArrayAsync(zipFileUrl).Result; - Console.WriteLine("File downloaded successfully."); - } - - using MemoryStream zipStream = new(responseBytes); - using ZipArchive zip = new(zipStream, ZipArchiveMode.Read); - zip.ExtractToDirectory(outputPath); - Console.WriteLine("File extracted successfully."); - } - catch (Exception ex) - { - Console.WriteLine($"An error occurred: {ex.Message}"); - } - } - - private static void RunPythonScript(string basePath) - { - var pythonExec = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) - ? "python" - : "python3"; - var scriptPath = Path.Combine(basePath, "bootstrap_test_environment.py"); - RunProcess(pythonExec, $"\"{scriptPath}\" {basePath}"); - - Console.WriteLine("Run python script finished."); - } - - private async Task RunBonsaiWorkflow(string basePath) - { - var currentDirectory = Environment.CurrentDirectory; - Environment.CurrentDirectory = basePath; - try - { - var workflowPath = Path.Combine(basePath, "receptive_field.bonsai"); - await WorkflowHelper.RunWorkflow( - workflowPath); - Console.WriteLine("Run bonsai workflow finished."); - } - finally { Environment.CurrentDirectory = currentDirectory; } - } - - private bool CompareJSONData(string basePath, double tolerance = 1e-9) - { - var originalFileName = Path.Combine(basePath, "original-receptivefield.json"); - var bonsaiFileName = Path.Combine(basePath, "bonsai-receptivefield.json"); - var pythonFileName = Path.Combine(basePath, "python-receptivefield.json"); - - var originalOutput = GetStateFromJson(originalFileName); - var bonsaiOutput = GetStateFromJson(bonsaiFileName); - var pythonOutput = GetStateFromJson(pythonFileName); - - try - { - for (int i = 0; i < bonsaiOutput.X.GetLength(0); i++) - { - for (int j = 0; j < bonsaiOutput.X.GetLength(1); j++) - { - if (Math.Abs(bonsaiOutput.X[i,j] - pythonOutput.X[i,j]) > tolerance || Math.Abs(originalOutput.X[i,j] - pythonOutput.X[i,j]) > tolerance) - { - Console.WriteLine($"Discrepency found comparing X at index ({i},{j}) with tolerance {tolerance}: bonsaiOutput = {bonsaiOutput.X[i,j]}, pythonOutput = {pythonOutput.X[i,j]}, originalOutput = {originalOutput.X[i,j]}."); - return false; - } - } - } - for (int i = 0; i < bonsaiOutput.P.GetLength(0); i++) - { - for (int j = 0; j < bonsaiOutput.P.GetLength(1); j++) - { - if (Math.Abs(bonsaiOutput.P[i,j] - pythonOutput.P[i,j]) > tolerance || Math.Abs(originalOutput.P[i,j] - pythonOutput.P[i,j]) > tolerance) - { - Console.WriteLine($"Discrepency found comparing P at index ({i},{j}) with tolerance {tolerance}: bonsaiOutput = {bonsaiOutput.P[i,j]}, pythonOutput = {pythonOutput.P[i,j]}, originalOutput = {originalOutput.P[i,j]}."); - return false; - } - } - } - } - catch - { - return false; - } - return true; - } - - /// - /// Setup for the test. - /// - [TestInitialize] - [DeploymentItem("bootstrap_test_environment.py")] - [DeploymentItem("receptive_field.py")] - [DeploymentItem("receptive_field.bonsai")] - [DeploymentItem("original-receptivefield.json")] - public async Task TestSetup() - { - Directory.CreateDirectory(basePath); - DownloadData(basePath); - RunPythonScript(basePath); - await RunBonsaiWorkflow(basePath); - } - - /// - /// Compares the results from the Python script and the Bonsai workflow. - /// - [TestMethod] - public void CompareResults() - { - var result = CompareJSONData(basePath); - Assert.IsTrue(result); - } -} diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/bootstrap_test_environment.py b/tests/Bonsai.ML.Torch.LDS.Tests/bootstrap_test_environment.py index 5ef73288..a842b565 100644 --- a/tests/Bonsai.ML.Torch.LDS.Tests/bootstrap_test_environment.py +++ b/tests/Bonsai.ML.Torch.LDS.Tests/bootstrap_test_environment.py @@ -77,8 +77,8 @@ def install_requirements(requirements_file: str, venv_path: str = None): python_path = get_python_path(venv_path) -script_path = os.path.join(base_dir, "receptive_field.py") -process = subprocess.Popen([python_path, script_path, base_dir, str(args.n_samples)]) +script_path = os.path.join(base_dir, "estimate_neural_latents.py") +process = subprocess.Popen([python_path, script_path, base_dir]) return_code = process.wait() if return_code == 0: diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/estimate_neural_latents.py b/tests/Bonsai.ML.Torch.LDS.Tests/estimate_neural_latents.py new file mode 100644 index 00000000..fad7f533 --- /dev/null +++ b/tests/Bonsai.ML.Torch.LDS.Tests/estimate_neural_latents.py @@ -0,0 +1,120 @@ +import numpy as np +import plotly.graph_objects as go +import remfile, h5py + +from dandi.dandiapi import DandiAPIClient +from pynwb import NWBHDF5IO + +import ssm.inference +import ssm.learning +import ssm.neural_latents.utils +import ssm.neural_latents.plotting + +import argparse +import os + +# Parse arguments +parser = argparse.ArgumentParser() +parser.add_argument("base_dir", type=str, default=None) +args = parser.parse_args() + +base_dir = args.base_dir + +# data +dandiset_ID = "000140" +dandi_filepath = "sub-Jenkins/sub-Jenkins_ses-small_desc-train_behavior+ecephys.nwb" +bin_size = 0.02 + +# plot +events_names = ["start_time", "target_on_time", "go_cue_time", + "move_onset_time", "stop_time"] +events_linetypes = ["dot", "dash", "dashdot", "longdash", "solid"] +events_colors_spikes = ["white", "white", "white", "white", "white"] +events_colors_latents = ["black", "black", "black", "black", "black"] +cb_alpha = 0.3 +from_time = 100.0 +to_time = 130.0 + +# model +n_latents = 10 + +# estimation initial conditions +sigma_B = 0.1 +sigma_Z = 0.1 +sigma_Q = 0.1 +sigma_R = 0.1 +sigma_m0 = 0.1 +sigma_V0 = 0.1 + +# estimation parameters +max_iter = 5 +tol = 1e-1 +vars_to_estimate = {"B": True, "Q": True, "Z": True, "R": True, + "m0": True, "V0": True, } + +with DandiAPIClient() as client: + asset = client.get_dandiset(dandiset_ID, + "draft").get_asset_by_path(dandi_filepath) + s3_path = asset.get_content_url(follow_redirects=1, strip_query=True) + cache = remfile.DiskCache("./remfile_cache") + rf = remfile.File(s3_path, disk_cache=cache) + with h5py.File(rf, "r") as h: + with NWBHDF5IO(file=h, mode="r") as io: + nwbfile = io.read() + units_df = nwbfile.units.to_dataframe() + trials_df = nwbfile.intervals["trials"].to_dataframe() + +# n_clusters +n_clusters = units_df.shape[0] + +# continuous spikes times +continuous_spikes_times = [None for n in range(n_clusters)] +for n in range(n_clusters): + continuous_spikes_times[n] = units_df.iloc[n]['spike_times'] + +binned_spikes, bin_edges = ssm.neural_latents.utils.bin_spike_times( + spike_times=continuous_spikes_times, bin_size=bin_size) +bin_centers = (bin_edges[1:] + bin_edges[:-1])/2 +transformed_binned_spikes = np.sqrt(binned_spikes + 0.5) + +transformed_binned_spikes.astype(float).tofile(os.path.join(base_dir, "transformed_binned_spikes.bin")) + +np.random.seed(0) + +B0 = np.diag(np.random.normal(loc=0, scale=sigma_B, size=n_latents)) +Z0 = np.random.normal(loc=0, scale=sigma_Z, size=(n_clusters, n_latents)) +Q0 = np.diag(np.abs(np.random.normal(loc=0, scale=sigma_Q, size=n_latents))) +R0 = np.diag(np.abs(np.random.normal(loc=0, scale=sigma_R, size=n_clusters))) +m0_0 = np.random.normal(loc=0, scale=sigma_m0, size=n_latents) +V0_0 = np.diag(np.abs(np.random.normal(loc=0, scale=sigma_V0, size=n_latents))) + +# Save initial parameters to binary file +B0.astype(float).tofile(os.path.join(base_dir, "python_B0.bin")) +Z0.astype(float).tofile(os.path.join(base_dir, "python_Z0.bin")) +Q0.astype(float).tofile(os.path.join(base_dir, "python_Q0.bin")) +R0.astype(float).tofile(os.path.join(base_dir, "python_R0.bin")) +m0_0.astype(float).tofile(os.path.join(base_dir, "python_m0_0.bin")) +V0_0.astype(float).tofile(os.path.join(base_dir, "python_V0_0.bin")) + +optim_res = ssm.learning.em_SS_LDS( + y=transformed_binned_spikes, B0=B0, Q0=Q0, Z0=Z0, R0=R0, + m0_0=m0_0, V0_0=V0_0, max_iter=max_iter, tol=tol, + vars_to_estimate=vars_to_estimate, +) + +filter_res = ssm.inference.filterLDS_SS_withMissingValues_np( + y=transformed_binned_spikes, B=optim_res["B"], Q=optim_res["Q"], + m0=optim_res["m0"], V0=optim_res["V0"], Z=optim_res["Z"], R=optim_res["R"]) + +smoothing_res = ssm.inference.smoothLDS_SS( + B=optim_res["B"], xnn=filter_res["xnn"], Pnn=filter_res["Pnn"], + xnn1=filter_res["xnn1"], Pnn1=filter_res["Pnn1"], + m0=optim_res["m0"], V0=optim_res["V0"]) + +o_means, o_covs = ssm.neural_latents.utils.ortogonalizeMeansAndCovs( + means=smoothing_res["xnN"], + covs=smoothing_res["PnN"], Z=optim_res["Z"]) + +# save outputs to binary file +o_means.astype(float).tofile(os.path.join(base_dir, "python_means.bin")) +o_covs.astype(float).tofile(os.path.join(base_dir, "python_covs.bin")) From 22f73763f747ea38a53751ead07cf6d01e0fb5f6 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 17 Sep 2025 14:32:41 +0100 Subject: [PATCH 08/92] Added null checks to properties --- .../CreateKalmanFilterParameters.cs | 33 +++++++++++-------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs index 77aa932a..e8d6ca92 100644 --- a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs +++ b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs @@ -41,7 +41,7 @@ public ScalarType Type public Tensor TransitionMatrix { get => _transitionMatrix; - set => _transitionMatrix = value.to_type(Type); + set => _transitionMatrix = value?.to_type(Type); } /// @@ -61,10 +61,11 @@ public string TransitionMatrixXml /// The measurement function. /// [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] public Tensor MeasurementFunction { get => _measurementFunction; - set => _measurementFunction = value.to_type(Type); + set => _measurementFunction = value?.to_type(Type); } /// @@ -84,10 +85,11 @@ public string MeasurementFunctionXml /// The process noise variance. /// [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] public Tensor ProcessNoiseCovariance { get => _processNoiseCovariance; - set => _processNoiseCovariance = value.to_type(Type); + set => _processNoiseCovariance = value?.to_type(Type); } /// @@ -107,10 +109,11 @@ public string ProcessNoiseCovarianceXml /// The measurement noise covariance matrix. /// [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] public Tensor MeasurementNoiseCovariance { get => _measurementNoiseCovariance; - set => _measurementNoiseCovariance = value.to_type(Type); + set => _measurementNoiseCovariance = value?.to_type(Type); } /// @@ -125,15 +128,16 @@ public string MeasurementNoiseCovarianceXml set => MeasurementNoiseCovariance = TensorConverter.ConvertFromString(value, _scalarType); } - private Tensor _initialState; + private Tensor _initialState = null; /// /// The initial state. /// [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] public Tensor InitialState { get => _initialState; - set => _initialState = value.to_type(Type); + set => _initialState = value?.to_type(Type); } /// @@ -148,15 +152,16 @@ public string InitialStateXml set => InitialState = TensorConverter.ConvertFromString(value, _scalarType); } - private Tensor _initialCovariance; + private Tensor _initialCovariance = null; /// /// The initial covariance. /// [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] public Tensor InitialCovariance { get => _initialCovariance; - set => _initialCovariance = value.to_type(Type); + set => _initialCovariance = value?.to_type(Type); } /// @@ -173,12 +178,12 @@ public string InitialCovarianceXml private void ConvertTensorsScalarType(ScalarType scalarType) { - _transitionMatrix = _transitionMatrix.to_type(scalarType); - _measurementFunction = _measurementFunction.to_type(scalarType); - _processNoiseCovariance = _processNoiseCovariance.to_type(scalarType); - _measurementNoiseCovariance = _measurementNoiseCovariance.to_type(scalarType); - _initialState = _initialState.to_type(scalarType); - _initialCovariance = _initialCovariance.to_type(scalarType); + _transitionMatrix = _transitionMatrix?.to_type(scalarType); + _measurementFunction = _measurementFunction?.to_type(scalarType); + _processNoiseCovariance = _processNoiseCovariance?.to_type(scalarType); + _measurementNoiseCovariance = _measurementNoiseCovariance?.to_type(scalarType); + _initialState = _initialState?.to_type(scalarType); + _initialCovariance = _initialCovariance?.to_type(scalarType); } /// From 1740570d5f9a0d6ccd2deae48af6279af0ada1cf Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 18 Sep 2025 21:37:02 +0100 Subject: [PATCH 09/92] Updated target framework to include netstandard2.0 --- src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj b/src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj index 7f755850..30bf344b 100644 --- a/src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj +++ b/src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj @@ -2,7 +2,7 @@ Bonsai.ML.Torch.LDS Bonsai library. Bonsai Rx Bonsai ML Torch TorchSharp LDS LinearDynamicalSystems - net472 + net472;netstandard2.0 From d7dc4a12b00f060671afd756789e14d7b7318f22 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 18 Sep 2025 21:38:12 +0100 Subject: [PATCH 10/92] Updated `CreateKalmanFilter` class to use `TensorConverter` for tensor properties --- src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs index 82222b5c..6d512477 100644 --- a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs +++ b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs @@ -102,6 +102,7 @@ public string TransitionMatrixXml /// The measurement function. /// [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] public Tensor MeasurementFunction { get => _measurementFunction; @@ -125,6 +126,7 @@ public string MeasurementFunctionXml /// The process noise variance. /// [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] public Tensor ProcessNoiseVariance { get => _processNoiseVariance; @@ -148,6 +150,7 @@ public string ProcessNoiseVarianceXml /// The measurement noise variance. /// [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] public Tensor MeasurementNoiseVariance { get => _measurementNoiseVariance; @@ -171,6 +174,7 @@ public string MeasurementNoiseVarianceXml /// The initial state. /// [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] public Tensor InitialState { get => _initialState; @@ -194,6 +198,7 @@ public string InitialStateXml /// The initial covariance. /// [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] public Tensor InitialCovariance { get => _initialCovariance; From b4939f181d204991b7b758aface85ddc6a75c286 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 18 Sep 2025 21:38:48 +0100 Subject: [PATCH 11/92] Refactored to use the KalmanFilterModelManager class --- src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs | 51 +++++++++++-------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs index 6d512477..6f56cf68 100644 --- a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs +++ b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs @@ -222,31 +222,40 @@ public string InitialCovarianceXml /// public IObservable Process() { - return Observable.Return( - new KalmanFilter( - numStates: NumStates, - numObservations: NumObservations, - transitionMatrix: _transitionMatrix, - measurementFunction: _measurementFunction, - initialState: _initialState, - initialCovariance: _initialCovariance, - processNoiseVariance: _processNoiseVariance, - measurementNoiseVariance: _measurementNoiseVariance, - device: Device, - scalarType: Type - )); - } - + return Observable.Using(() => KalmanFilterModelManager.Reserve( + ModelName, + NumStates, + NumObservations, + _transitionMatrix, + _measurementFunction, + _initialState, + _initialCovariance, + _processNoiseVariance, + _measurementNoiseVariance, + Device, + Type + ), resource => Observable.Return(resource.Model) + .Concat(Observable.Never(resource.Model)) + .Finally(resource.Dispose) + ); + } + /// /// Creates a Kalman filter model using the parameters provided in the input sequence. /// public IObservable Process(IObservable source) { - return source.Select(parameters => - new KalmanFilter( - parameters: parameters, - device: Device, - scalarType: Type - )); + return source.SelectMany(parameters => + { + return Observable.Using(() => KalmanFilterModelManager.Reserve( + ModelName, + parameters, + Device, + Type + ), resource => Observable.Return(resource.Model) + .Concat(Observable.Never(resource.Model)) + .Finally(resource.Dispose) + ); + }); } } From 3adfccf856263348391eb0bfc6d3d8bacaecd63f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 18 Sep 2025 21:39:59 +0100 Subject: [PATCH 12/92] Removed attempt to move EM algorithm to background process in favor of keeping it synchronous --- .../ExpectationMaximization.cs | 94 ++++++++----------- 1 file changed, 41 insertions(+), 53 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs b/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs index c27e4613..72a518b1 100644 --- a/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs +++ b/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs @@ -59,70 +59,58 @@ public bool Verbose public IObservable Process(IObservable source) { - return source.SelectMany(input => + return source.Select(input => { var model = KalmanFilterModelManager.GetKalmanFilter(ModelName); - return Observable.FromAsync(cancellationToken => + var previousLogLikelihood = double.NegativeInfinity; + var logLikelihood = torch.zeros(new long[] { MaxIterations }, device: input.device); + + for (int i = 0; i < MaxIterations; i++) { - return Task.Run(() => + ExpectationMaximizationResult result; + using (KalmanFilterModelManager.Read(model)) { - var previousLogLikelihood = double.NegativeInfinity; - var logLikelihood = torch.zeros(new long[] { MaxIterations }, device: input.device); - - for (int i = 0; i < MaxIterations; i++) - { - if (cancellationToken.IsCancellationRequested) - { - break; - } + result = model.ExpectationMaximization(input, 1, Tolerance, false); + } - ExpectationMaximizationResult result; - using (KalmanFilterModelManager.Read(model)) - { - result = model.ExpectationMaximization(input, 1, Tolerance, false); - } + var logLikelihoodSum = result.LogLikelihood + .cpu() + .to_type(torch.ScalarType.Float32) + .ReadCpuSingle(0); - var logLikelihoodSum = result.LogLikelihood - .cpu() - .to_type(torch.ScalarType.Float32) - .ReadCpuSingle(0); + logLikelihood[i] = logLikelihoodSum; - logLikelihood[i] = logLikelihoodSum; - - if (Verbose) - { - Console.WriteLine("Iteration " + (i + 1) + ", Log Likelihood: " + logLikelihoodSum); - if (i == MaxIterations - 1) - { - Console.WriteLine("EM reached the maximum number of iterations."); - } - } - - if (logLikelihoodSum - previousLogLikelihood < Tolerance) - { - if (Verbose) - { - Console.WriteLine("EM converged after " + (i + 1) + " iterations."); - } - logLikelihood = logLikelihood[torch.TensorIndex.Slice(0, i + 1)]; - break; - } - previousLogLikelihood = logLikelihoodSum; + if (Verbose) + { + Console.WriteLine("Iteration " + (i + 1) + ", Log Likelihood: " + logLikelihoodSum); + if (i == MaxIterations - 1) + { + Console.WriteLine("EM reached the maximum number of iterations."); + } + } - using (KalmanFilterModelManager.Write(model)) - { - model.UpdateParameters(result.Parameters); - } + if (logLikelihoodSum - previousLogLikelihood < Tolerance) + { + if (Verbose) + { + Console.WriteLine("EM converged after " + (i + 1) + " iterations."); } + logLikelihood = logLikelihood[torch.TensorIndex.Slice(0, i + 1)]; + break; + } + previousLogLikelihood = logLikelihoodSum; + + using (KalmanFilterModelManager.Write(model)) + { + model.UpdateParameters(result.Parameters); + } + } - var expectationMaximizationResult = new ExpectationMaximizationResult( - logLikelihood, - model.Parameters); + var expectationMaximizationResult = new ExpectationMaximizationResult( + logLikelihood, + model.Parameters); - return expectationMaximizationResult; - }, - cancellationToken); - }); + return expectationMaximizationResult; }); } } \ No newline at end of file From 408e93d9936b443394f1254c6d4c619fc9ec904f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 18 Sep 2025 21:42:41 +0100 Subject: [PATCH 13/92] Removed extra properties of structs to streamline filter and smooth processes during KalmanFilterClass refector --- src/Bonsai.ML.Torch.LDS/FilteredResult.cs | 18 ++---------------- src/Bonsai.ML.Torch.LDS/SmoothedResult.cs | 9 +-------- 2 files changed, 3 insertions(+), 24 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS/FilteredResult.cs b/src/Bonsai.ML.Torch.LDS/FilteredResult.cs index 51e4d9d7..6cd4285d 100644 --- a/src/Bonsai.ML.Torch.LDS/FilteredResult.cs +++ b/src/Bonsai.ML.Torch.LDS/FilteredResult.cs @@ -3,21 +3,17 @@ namespace Bonsai.ML.Torch.LDS; /// -/// Represents the result of a Kalman filter update step. +/// Represents the result of a Kalman filter. /// /// /// /// /// -/// -/// public struct FilteredResult( Tensor predictedState, Tensor predictedCovariance, Tensor updatedState, - Tensor updatedCovariance, - Tensor logLikelihood, - Tensor kalmanGain) + Tensor updatedCovariance) { /// /// The predicted state after the prediction step. @@ -38,14 +34,4 @@ public struct FilteredResult( /// The updated covariance after the update step. /// public Tensor UpdatedCovariance = updatedCovariance; - - /// - /// The log likelihood of the measurement given the predicted state. - /// - public Tensor LogLikelihood = logLikelihood; - - /// - /// The Kalman gain. - /// - public Tensor KalmanGain = kalmanGain; } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/SmoothedResult.cs b/src/Bonsai.ML.Torch.LDS/SmoothedResult.cs index 1e45cd80..94230cba 100644 --- a/src/Bonsai.ML.Torch.LDS/SmoothedResult.cs +++ b/src/Bonsai.ML.Torch.LDS/SmoothedResult.cs @@ -3,17 +3,15 @@ namespace Bonsai.ML.Torch.LDS; /// -/// Represents the result of a Kalman smoother step. +/// Represents the result of a Kalman smoother. /// /// /// -/// /// /// public struct SmoothedResult( Tensor smoothedState, Tensor smoothedCovariance, - Tensor smoothedLagOneCovariance, Tensor smoothedInitialState = null, Tensor smoothedInitialCovariance = null) { @@ -27,11 +25,6 @@ public struct SmoothedResult( /// public Tensor SmoothedCovariance = smoothedCovariance; - /// - /// The smoothed lag-one covariance after the smoothing step. - /// - public Tensor SmoothedLagOneCovariance = smoothedLagOneCovariance; - /// /// The smoothed initial state after the smoothing step. /// From b4ad0adf0581b3ca862b417d4d69f606ba3ec589 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 18 Sep 2025 21:44:05 +0100 Subject: [PATCH 14/92] Refactored kalman filter class to support static methods, to streamline non-static filter and smooth procedures, and to use wrapped tensor objects instead of explicit dispose scopes --- src/Bonsai.ML.Torch.LDS/KalmanFilter.cs | 486 ++++++++++++++++++------ 1 file changed, 367 insertions(+), 119 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs index 21173eef..82a8b0c8 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs @@ -13,7 +13,6 @@ internal class KalmanFilter : nn.Module private readonly Tensor _processNoiseCovariance; private readonly Tensor _measurementNoiseCovariance; private readonly Tensor _identityStates; - private readonly Tensor _identityObservations; private readonly Tensor _state; private readonly Tensor _covariance; private readonly int _numStates; @@ -46,10 +45,9 @@ public KalmanFilter( ValidateAndSetMatrix(parameters.MeasurementNoiseCovariance, "Measurement noise covariance", _scalarType, _device, out _measurementNoiseCovariance, out _, out _, isSquare: true, expectedDimension1: _numObservations); _identityStates = eye(_numStates, dtype: _scalarType, device: _device); - _identityObservations = eye(_numObservations, dtype: _scalarType, device: _device); - _state = _initialState.clone(); - _covariance = _initialCovariance.clone(); + _state = empty(_numStates, dtype: _scalarType, device: _device).requires_grad_(false); + _covariance = empty([_numStates, _numStates], dtype: _scalarType, device: _device).requires_grad_(false); } public KalmanFilter( @@ -70,7 +68,6 @@ public KalmanFilter( _numObservations = numObservations; _identityStates = eye(_numStates, dtype: _scalarType, device: _device); - _identityObservations = eye(_numObservations, dtype: _scalarType, device: _device); _transitionMatrix = transitionMatrix?.clone().to_type(_scalarType).to(_device).requires_grad_(false) ?? eye(_numStates, dtype: _scalarType, device: _device); @@ -91,8 +88,8 @@ public KalmanFilter( _processNoiseCovariance = CreateCovarianceMatrix(processNoiseVariance, _scalarType, _device, numStates, "Process noise variance"); _measurementNoiseCovariance = CreateCovarianceMatrix(measurementNoiseVariance, _scalarType, _device, numObservations, "Measurement noise variance"); - _state = _initialState.clone(); - _covariance = _initialCovariance.clone(); + _state = empty(_numStates, dtype: _scalarType, device: _device).requires_grad_(false); + _covariance = empty([_numStates, _numStates], dtype: _scalarType, device: _device).requires_grad_(false); RegisterComponents(); } @@ -163,16 +160,29 @@ private Tensor CreateCovarianceMatrix(Tensor variance, ScalarType scalarType, De return (scalar * eye(dimension, dtype: scalarType, device: device)).requires_grad_(false); } - private readonly struct PredictedResult(Tensor predictedState, Tensor predictedCovariance) + private readonly struct PredictedResult( + Tensor predictedState, + Tensor predictedCovariance) { public readonly Tensor PredictedState = predictedState; public readonly Tensor PredictedCovariance = predictedCovariance; } - private PredictedResult FilterPredict(Tensor state, Tensor covariance) => - new(_transitionMatrix.matmul(state), - EnsureSymmetric(_transitionMatrix.matmul(covariance) - .matmul(_transitionMatrix.mT) + _processNoiseCovariance)); + private PredictedResult FilterPredict( + Tensor state, + Tensor covariance) => + new(_transitionMatrix.matmul(state), + _transitionMatrix.matmul(covariance) + .matmul(_transitionMatrix.mT) + _processNoiseCovariance); + + private static PredictedResult FilterPredict( + Tensor state, + Tensor covariance, + Tensor transitionMatrix, + Tensor processNoiseCovariance) => + new(transitionMatrix.matmul(state), + transitionMatrix.matmul(covariance) + .matmul(transitionMatrix.mT) + processNoiseCovariance); private readonly struct UpdatedResult( Tensor updatedState, @@ -188,67 +198,193 @@ private readonly struct UpdatedResult( public readonly Tensor KalmanGain = kalmanGain; } - private UpdatedResult FilterUpdate(Tensor predictedState, Tensor predictedCovariance, Tensor observation) + private UpdatedResult FilterUpdate( + Tensor predictedState, + Tensor predictedCovariance, + Tensor observation) { // Innovation step var innovation = observation - _measurementFunction.matmul(predictedState); - var innovationCovariance = EnsureSymmetric( + var innovationCovariance = WrappedTensorDisposeScope(() => EnsureSymmetric( _measurementFunction.matmul(predictedCovariance) - .matmul(_measurementFunction.mT) + _measurementNoiseCovariance); + .matmul(_measurementFunction.mT) + _measurementNoiseCovariance)); // Kalman gain - var kalmanGain = InverseCholesky( + var kalmanGain = WrappedTensorDisposeScope(() => InverseCholesky( predictedCovariance.matmul(_measurementFunction.mT), - innovationCovariance); + innovationCovariance)); // Update step var updatedState = predictedState + kalmanGain.matmul(innovation); - var updatedCovariance = EnsureSymmetric(predictedCovariance - - kalmanGain.matmul(_measurementFunction).matmul(predictedCovariance)); + var updatedCovariance = WrappedTensorDisposeScope(() => predictedCovariance + - kalmanGain.matmul(_measurementFunction).matmul(predictedCovariance)); return new UpdatedResult(updatedState, updatedCovariance, innovation, innovationCovariance, kalmanGain); } + + private static UpdatedResult FilterUpdate( + Tensor predictedState, + Tensor predictedCovariance, + Tensor observation, + Tensor measurementFunction, + Tensor measurementNoiseCovariance) + { + // Innovation step + var innovation = observation - measurementFunction.matmul(predictedState); + var innovationCovariance = WrappedTensorDisposeScope(() => EnsureSymmetric( + measurementFunction.matmul(predictedCovariance) + .matmul(measurementFunction.mT) + measurementNoiseCovariance)); + + // Kalman gain + var kalmanGain = WrappedTensorDisposeScope(() => InverseCholesky( + predictedCovariance.matmul(measurementFunction.mT), + innovationCovariance)); + + // Update step + var updatedState = predictedState + kalmanGain.matmul(innovation); + var updatedCovariance = WrappedTensorDisposeScope(() => predictedCovariance + - kalmanGain.matmul(measurementFunction).matmul(predictedCovariance)); + + return new UpdatedResult( + updatedState: updatedState, + updatedCovariance: updatedCovariance, + innovation: innovation, + innovationCovariance: innovationCovariance, + kalmanGain: kalmanGain + ); + } public FilteredResult Filter(Tensor observation) { var obs = observation.atleast_2d(); var timeBins = obs.size(0); - - var logLikelihood = empty(timeBins, dtype: _scalarType, device: _device); + var predictedState = empty(new long[] { timeBins, _numStates }, dtype: _scalarType, device: _device); var predictedCovariance = empty(new long[] { timeBins, _numStates, _numStates }, dtype: _scalarType, device: _device); var updatedState = empty(new long[] { timeBins, _numStates }, dtype: _scalarType, device: _device); var updatedCovariance = empty(new long[] { timeBins, _numStates, _numStates }, dtype: _scalarType, device: _device); - var kalmanGain = empty(new long[] { timeBins, _numStates, _numObservations }, dtype: _scalarType, device: _device); + + if (_state.NumberOfElements == 0) + _state.set_(_initialState); + if (_covariance.NumberOfElements == 0) + _covariance.set_(_initialCovariance); for (long time = 0; time < timeBins; time++) { - using var d = NewDisposeScope(); - // Predict var prediction = FilterPredict(_state, _covariance); // Update var update = FilterUpdate(prediction.PredictedState, prediction.PredictedCovariance, obs[time]); - // Log Likelihood - var invInnovationCov = InverseCholesky(_identityObservations, update.InnovationCovariance); - var logLikelihoodData = -1.0 * (slogdet(update.InnovationCovariance).logabsdet - + update.Innovation.T.matmul(invInnovationCov).matmul(update.Innovation)); - - // Detach and assign - logLikelihood[time] = logLikelihoodData.DetachFromDisposeScope(); - predictedState[time] = prediction.PredictedState.DetachFromDisposeScope(); - predictedCovariance[time] = prediction.PredictedCovariance.DetachFromDisposeScope(); - updatedState[time] = update.UpdatedState.DetachFromDisposeScope(); - updatedCovariance[time] = update.UpdatedCovariance.DetachFromDisposeScope(); - kalmanGain[time] = update.KalmanGain.DetachFromDisposeScope(); + predictedState[time] = prediction.PredictedState; + predictedCovariance[time] = prediction.PredictedCovariance; + updatedState[time] = update.UpdatedState; + updatedCovariance[time] = update.UpdatedCovariance; _state.set_(update.UpdatedState); _covariance.set_(update.UpdatedCovariance); } - return new FilteredResult(predictedState, predictedCovariance, updatedState, updatedCovariance, logLikelihood, kalmanGain); + return new FilteredResult( + predictedState: predictedState, + predictedCovariance: predictedCovariance, + updatedState: updatedState, + updatedCovariance: updatedCovariance); + } + + private readonly struct FilteredResultWithAuxiliaryVariables( + Tensor predictedState, + Tensor predictedCovariance, + Tensor updatedState, + Tensor updatedCovariance, + Tensor innovation, + Tensor innovationCovariance, + Tensor logLikelihood, + Tensor kalmanGain) + { + public readonly Tensor PredictedState = predictedState; + public readonly Tensor PredictedCovariance = predictedCovariance; + public readonly Tensor UpdatedState = updatedState; + public readonly Tensor UpdatedCovariance = updatedCovariance; + public readonly Tensor Innovation = innovation; + public readonly Tensor InnovationCovariance = innovationCovariance; + public readonly Tensor LogLikelihood = logLikelihood; + public readonly Tensor KalmanGain = kalmanGain; + } + + private static FilteredResultWithAuxiliaryVariables Filter( + Tensor observation, + long timeBins, + int numStates, + int numObservations, + Tensor transitionMatrix, + Tensor measurementFunction, + Tensor processNoiseCovariance, + Tensor measurementNoiseCovariance, + Tensor initialState, + Tensor initialCovariance, + ScalarType scalarType, + Device device) + { + var logLikelihood = empty(timeBins, dtype: scalarType, device: device); + var predictedState = empty(new long[] { timeBins, numStates }, dtype: scalarType, device: device); + var predictedCovariance = empty(new long[] { timeBins, numStates, numStates }, dtype: scalarType, device: device); + var updatedState = empty(new long[] { timeBins, numStates }, dtype: scalarType, device: device); + var updatedCovariance = empty(new long[] { timeBins, numStates, numStates }, dtype: scalarType, device: device); + var innovation = empty(new long[] { timeBins, numObservations }, dtype: scalarType, device: device); + var innovationCovariance = empty(new long[] { timeBins, numObservations, numObservations }, dtype: scalarType, device: device); + var kalmanGain = empty(new long[] { timeBins, numStates, numObservations }, dtype: scalarType, device: device); + + var state = initialState; + var covariance = initialCovariance; + + for (long time = 0; time < timeBins; time++) + { + // Predict + var prediction = FilterPredict( + state: state, + covariance: covariance, + transitionMatrix: transitionMatrix, + processNoiseCovariance: processNoiseCovariance); + + // Update + var update = FilterUpdate( + predictedState: prediction.PredictedState, + predictedCovariance: prediction.PredictedCovariance, + observation: observation[time], + measurementFunction: measurementFunction, + measurementNoiseCovariance: measurementNoiseCovariance); + + // Log Likelihood + var logLikelihoodData = -(slogdet(update.InnovationCovariance).logabsdet + + InverseCholesky(update.Innovation.T, update.InnovationCovariance) + .matmul(update.Innovation)).squeeze(); + + // Detach and assign + logLikelihood[time] = logLikelihoodData; + predictedState[time] = prediction.PredictedState; + predictedCovariance[time] = prediction.PredictedCovariance; + updatedState[time] = update.UpdatedState; + updatedCovariance[time] = update.UpdatedCovariance; + innovation[time] = update.Innovation; + innovationCovariance[time] = update.InnovationCovariance; + kalmanGain[time] = update.KalmanGain; + + state = update.UpdatedState; + covariance = update.UpdatedCovariance; + } + + return new FilteredResultWithAuxiliaryVariables( + predictedState: predictedState, + predictedCovariance: predictedCovariance, + updatedState: updatedState, + updatedCovariance: updatedCovariance, + innovation: innovation, + innovationCovariance: innovationCovariance, + logLikelihood: logLikelihood, + kalmanGain: kalmanGain + ); } public SmoothedResult Smooth(FilteredResult filteredResult) @@ -257,94 +393,189 @@ public SmoothedResult Smooth(FilteredResult filteredResult) var predictedCovariance = filteredResult.PredictedCovariance; var updatedState = filteredResult.UpdatedState; var updatedCovariance = filteredResult.UpdatedCovariance; - var kalmanGain = filteredResult.KalmanGain; var timeBins = predictedState.size(0); var smoothedState = empty_like(updatedState); var smoothedCovariance = empty_like(updatedCovariance); - var smoothedLagOneCovariance = empty(new long[] { timeBins, _numStates, _numStates }, dtype: _scalarType, device: _device); // Fix the last time point smoothedState[-1] = updatedState[-1]; smoothedCovariance[-1] = updatedCovariance[-1]; - smoothedLagOneCovariance[-1] = (_identityStates - kalmanGain[-1] - .matmul(_measurementFunction)) - .matmul(_transitionMatrix) - .matmul(updatedCovariance[-2]); var smoothingGain = empty(new long[] { _numStates, _numStates }, dtype: _scalarType, device: _device); // Backward pass for (long time = timeBins - 2; time >= 0; time--) { - using var d = NewDisposeScope(); // Smoothing gain - smoothingGain = updatedCovariance[time].matmul( + smoothingGain = WrappedTensorDisposeScope(() => updatedCovariance[time].matmul( InverseCholesky(_transitionMatrix.mT, predictedCovariance[time + 1]) - ).DetachFromDisposeScope(); + )); + + // Smoothed state + smoothedState[time] = WrappedTensorDisposeScope(() => updatedState[time] + + smoothingGain.matmul( + (smoothedState[time + 1] - predictedState[time + 1]).unsqueeze(-1) + ).squeeze(-1)); + + // Smoothed covariance + smoothedCovariance[time] = WrappedTensorDisposeScope(() => updatedCovariance[time] + smoothingGain + .matmul(smoothedCovariance[time + 1] - predictedCovariance[time + 1]) + .matmul(smoothingGain.mT) + ); + } + + // Smoothed initial state + var smoothedInitialState = WrappedTensorDisposeScope(() => _initialState + smoothingGain.matmul( + (smoothedState[0] - predictedState[0]).unsqueeze(-1) + ).squeeze(-1)); + + // Smoothed initial covariance + var smoothedInitialCovariance = WrappedTensorDisposeScope(() => _initialCovariance[0] + smoothingGain + .matmul(smoothedCovariance[0] - predictedCovariance[0]) + .matmul(smoothingGain.mT)); + + return new SmoothedResult( + smoothedState, + smoothedCovariance, + smoothedInitialState, + smoothedInitialCovariance + ); + } + + private readonly struct SmoothedResultWithAuxiliaryVariables( + Tensor smoothedState, + Tensor smoothedCovariance, + Tensor smoothedInitialState, + Tensor smoothedInitialCovariance, + Tensor autoCorrelationStatesCurrent, + Tensor crossCorrelationStates, + Tensor autoCorrelationStatesNext) + { + public readonly Tensor SmoothedState = smoothedState; + public readonly Tensor SmoothedCovariance = smoothedCovariance; + public readonly Tensor SmoothedInitialState = smoothedInitialState; + public readonly Tensor SmoothedInitialCovariance = smoothedInitialCovariance; + public readonly Tensor AutoCorrelationStatesCurrent = autoCorrelationStatesCurrent; + public readonly Tensor CrossCorrelationStates = crossCorrelationStates; + public readonly Tensor AutoCorrelationStatesNext = autoCorrelationStatesNext; + } + + private static SmoothedResultWithAuxiliaryVariables Smooth( + FilteredResultWithAuxiliaryVariables filteredResult, + long timeBins, + int numStates, + Tensor transitionMatrix, + Tensor measurementFunction, + Tensor initialState, + Tensor initialCovariance, + Tensor identityStates, + ScalarType scalarType, + Device device + ) + { + if (timeBins < 2) + throw new ArgumentException("Smoothing requires at least two time bins."); + + var predictedState = filteredResult.PredictedState; + var predictedCovariance = filteredResult.PredictedCovariance; + var updatedState = filteredResult.UpdatedState; + var updatedCovariance = filteredResult.UpdatedCovariance; + var kalmanGain = filteredResult.KalmanGain; + + var smoothedState = empty_like(updatedState); + var smoothedCovariance = empty_like(updatedCovariance); + + var autoCorrelationStatesCurrent = zeros_like(smoothedCovariance, dtype: scalarType, device: device); + var crossCorrelationStates = zeros_like(smoothedCovariance, dtype: scalarType, device: device); + var autoCorrelationStatesNext = zeros_like(smoothedCovariance, dtype: scalarType, device: device); + + // Fix the last time point + smoothedState[-1] = updatedState[-1]; + smoothedCovariance[-1] = updatedCovariance[-1]; + var smoothedLagOneCovariance = WrappedTensorDisposeScope(() => + (identityStates - kalmanGain[-1] + .matmul(measurementFunction)) + .matmul(transitionMatrix) + .matmul(updatedCovariance[-2])); + + autoCorrelationStatesNext[-1] = outer(updatedState[-1], updatedState[-1]) + updatedCovariance[-1]; + + var smoothingGain = empty([numStates, numStates], dtype: scalarType, device: device); + var smoothingGainNext = null as Tensor; + + // Backward pass + for (long time = timeBins - 2; time >= 0; time--) + { + // Smoothing gain + smoothingGain = smoothingGainNext ?? WrappedTensorDisposeScope(() => updatedCovariance[time].matmul( + InverseCholesky(transitionMatrix.mT, predictedCovariance[time + 1]) + )); // Smoothed state - smoothedState[time] = updatedState[time] + smoothedState[time] = WrappedTensorDisposeScope(() => updatedState[time] + smoothingGain.matmul( (smoothedState[time + 1] - predictedState[time + 1]).unsqueeze(-1) - ).squeeze(-1) - .DetachFromDisposeScope(); + ).squeeze(-1)); // Smoothed covariance - smoothedCovariance[time] = EnsureSymmetric( - updatedCovariance[time] + smoothingGain + smoothedCovariance[time] = WrappedTensorDisposeScope(() => updatedCovariance[time] + smoothingGain .matmul(smoothedCovariance[time + 1] - predictedCovariance[time + 1]) .matmul(smoothingGain.mT) - ).DetachFromDisposeScope(); + ); + + var expectationUpdate = outer(smoothedState[time], smoothedState[time]) + smoothedCovariance[time]; + autoCorrelationStatesNext[time] = expectationUpdate; + autoCorrelationStatesCurrent[time + 1] = expectationUpdate; + crossCorrelationStates[time + 1] = outer(smoothedState[time + 1], smoothedState[time]) + smoothedLagOneCovariance; // Compute next smoothing gain for lag one covariance if (time > 0) { - var smoothingGainNext = updatedCovariance[time - 1] - .matmul(InverseCholesky(_transitionMatrix.mT, predictedCovariance[time])); + smoothingGainNext = WrappedTensorDisposeScope(() => updatedCovariance[time - 1] + .matmul(InverseCholesky(transitionMatrix.mT, predictedCovariance[time]))); // Smoothed lag one covariance - - smoothedLagOneCovariance[time] = smoothedCovariance[time] + smoothedLagOneCovariance = WrappedTensorDisposeScope(() => updatedCovariance[time] .matmul(smoothingGainNext.mT) - + smoothingGain.matmul(smoothedLagOneCovariance[time + 1] - - _transitionMatrix.matmul(updatedCovariance[time])) - .matmul(smoothingGainNext.mT) - .DetachFromDisposeScope(); + + smoothingGain.matmul(smoothedLagOneCovariance + - transitionMatrix.matmul(updatedCovariance[time])) + .matmul(smoothingGainNext.mT)); } } + var smoothingGain0 = WrappedTensorDisposeScope(() => initialCovariance.matmul( + InverseCholesky(transitionMatrix.mT, predictedCovariance[0]) + )); + // Smoothed initial state - var smoothedInitialState = _initialState + smoothingGain.matmul( + var smoothedInitialState = WrappedTensorDisposeScope(() => initialState + smoothingGain0.matmul( (smoothedState[0] - predictedState[0]).unsqueeze(-1) - ).squeeze(-1); + ).squeeze(-1)); // Smoothed initial covariance - var smoothedInitialCovariance = EnsureSymmetric( - _initialCovariance[0] + smoothingGain + var smoothedInitialCovariance = WrappedTensorDisposeScope(() => initialCovariance + smoothingGain0 .matmul(smoothedCovariance[0] - predictedCovariance[0]) - .matmul(smoothingGain.mT) - ); - - // Smoothing gain at time 0 - var smoothingGain0 = _initialCovariance.matmul( - InverseCholesky(_transitionMatrix.mT, predictedCovariance[0]) - ); + .matmul(smoothingGain0.mT)); // Smoothed lag one covariance at time 0 - smoothedLagOneCovariance[0] = smoothedCovariance[0] + smoothedLagOneCovariance = WrappedTensorDisposeScope(() => updatedCovariance[0] .matmul(smoothingGain0.mT) - + smoothingGain.matmul(smoothedLagOneCovariance[1] - - _transitionMatrix.matmul(updatedCovariance[0])) - .matmul(smoothingGain0.mT) - .DetachFromDisposeScope(); - - return new SmoothedResult( - smoothedState, - smoothedCovariance, - smoothedLagOneCovariance, - smoothedInitialState, - smoothedInitialCovariance + + smoothingGain.matmul(smoothedLagOneCovariance + - transitionMatrix.matmul(updatedCovariance[0])) + .matmul(smoothingGain0.mT)); + + crossCorrelationStates[0] = outer(smoothedState[0], smoothedInitialState) + smoothedLagOneCovariance; + autoCorrelationStatesCurrent[0] = outer(smoothedInitialState, smoothedInitialState) + smoothedInitialCovariance; + + return new SmoothedResultWithAuxiliaryVariables( + smoothedState: smoothedState, + smoothedCovariance: smoothedCovariance, + smoothedInitialState: smoothedInitialState, + smoothedInitialCovariance: smoothedInitialCovariance, + autoCorrelationStatesCurrent: autoCorrelationStatesCurrent, + crossCorrelationStates: crossCorrelationStates, + autoCorrelationStatesNext: autoCorrelationStatesNext ); } @@ -357,18 +588,28 @@ public ExpectationMaximizationResult ExpectationMaximization( var timeBins = observation.size(0); var logLikelihood = empty(maxIterations, dtype: ScalarType.Float32, device: _device); var previousLogLikelihood = double.NegativeInfinity; - var logLikelihoodConst = -0.5 * timeBins * _numObservations * Math.Log(2 * Math.PI); + var logLikelihoodConst = -0.5 * timeBins * _numObservations * log(2.0 * Math.PI); var updatedParameters = Parameters; for (int iteration = 0; iteration < maxIterations; iteration++) { - using var d = NewDisposeScope(); - // Filter observations - var filterResult = Filter(observation); + var filteredResult = Filter( + observation: observation, + timeBins: timeBins, + numStates: _numStates, + numObservations: _numObservations, + transitionMatrix: _transitionMatrix, + measurementFunction: _measurementFunction, + processNoiseCovariance: _processNoiseCovariance, + measurementNoiseCovariance: _measurementNoiseCovariance, + initialState: _initialState, + initialCovariance: _initialCovariance, + scalarType: _scalarType, + device: _device); // Compute log likelihood - var filteredLogLikelihood = logLikelihoodConst + 0.5 * filterResult.LogLikelihood.sum(); + var filteredLogLikelihood = logLikelihoodConst + 0.5 * filteredResult.LogLikelihood.sum(); var filteredLogLikelihoodSum = filteredLogLikelihood.to_type(ScalarType.Float64).item(); logLikelihood[iteration] = filteredLogLikelihoodSum; @@ -386,46 +627,52 @@ public ExpectationMaximizationResult ExpectationMaximization( previousLogLikelihood = filteredLogLikelihoodSum; // Smooth the filtered results - var smoothedResult = Smooth(filterResult); + var smoothedResult = Smooth( + filteredResult: filteredResult, + timeBins: timeBins, + numStates: _numStates, + transitionMatrix: _transitionMatrix, + measurementFunction: _measurementFunction, + initialState: _initialState, + initialCovariance: _initialCovariance, + identityStates: _identityStates, + scalarType: _scalarType, + device: _device); // Sufficient statistics - var Ezzt = smoothedResult.SmoothedCovariance + einsum("tn,tm->tnm", smoothedResult.SmoothedState, smoothedResult.SmoothedState); - var Ezztm1 = smoothedResult.SmoothedLagOneCovariance[torch.TensorIndex.Slice(1)] - + einsum("tn,tm->tnm", - smoothedResult.SmoothedState[torch.TensorIndex.Slice(1)], - smoothedResult.SmoothedState[torch.TensorIndex.Slice(0, -1)]); - - var S00 = Ezzt[torch.TensorIndex.Slice(0, -1)].sum(new long[] { 0 }); - var S10 = Ezztm1.sum(new long[] { 0 }); - var S11 = Ezzt[torch.TensorIndex.Slice(1)].sum(new long[] { 0 }); + var autoCorrelationStatesCurrent = smoothedResult.AutoCorrelationStatesCurrent.sum([0]); + var autoCorrelationStatesNext = smoothedResult.AutoCorrelationStatesNext.sum([0]); + var crossCorrelationStates = smoothedResult.CrossCorrelationStates.sum([0]); - var Syz = einsum("tp,tn->pn", observation, smoothedResult.SmoothedState); - var Eyy = einsum("tp,tq->pq", observation, observation); + var crossCorrelationObservations = einsum("tp,tn->pn", observation, smoothedResult.SmoothedState); + var autoCorrelationObservations = einsum("tp,tq->pq", observation, observation); // Update parameters - var updatedTransitionMatrix = InverseCholesky(S10, S00).DetachFromDisposeScope(); - var updatedMeasurementFunction = InverseCholesky(Syz, S11).DetachFromDisposeScope(); - var updatedProcessNoiseCovariance = EnsureSymmetric((S11 - InverseCholesky(S10, S00).matmul(S10.T)) / timeBins).DetachFromDisposeScope(); + var updatedTransitionMatrix = InverseCholesky(crossCorrelationStates, autoCorrelationStatesCurrent); + var updatedMeasurementFunction = InverseCholesky(crossCorrelationObservations, autoCorrelationStatesNext); + var updatedProcessNoiseCovariance = WrappedTensorDisposeScope(() => + EnsureSymmetric((autoCorrelationStatesNext - InverseCholesky(crossCorrelationStates, autoCorrelationStatesCurrent).matmul(crossCorrelationStates.T)) / timeBins)); - var CSyzT = updatedMeasurementFunction.matmul(Syz.mT); - var updatedMeasurementNoiseCovariance = EnsureSymmetric( - (Eyy - CSyzT - CSyzT.mT + updatedMeasurementFunction.matmul(S11).matmul(updatedMeasurementFunction.mT)) / timeBins - ).DetachFromDisposeScope(); + var CSyzT = updatedMeasurementFunction.matmul(crossCorrelationObservations.mT); + var updatedMeasurementNoiseCovariance = WrappedTensorDisposeScope(() => + EnsureSymmetric((autoCorrelationObservations - CSyzT - CSyzT.mT + + updatedMeasurementFunction.matmul(autoCorrelationStatesNext) + .matmul(updatedMeasurementFunction.mT)) / timeBins)); updatedParameters = new KalmanFilterParameters( - updatedTransitionMatrix, - updatedMeasurementFunction, - updatedProcessNoiseCovariance, - updatedMeasurementNoiseCovariance, - smoothedResult.SmoothedInitialState.DetachFromDisposeScope(), - smoothedResult.SmoothedInitialCovariance.DetachFromDisposeScope() + transitionMatrix: updatedTransitionMatrix, + measurementFunction: updatedMeasurementFunction, + processNoiseCovariance: updatedProcessNoiseCovariance, + measurementNoiseCovariance: updatedMeasurementNoiseCovariance, + initialState: smoothedResult.SmoothedInitialState, + initialCovariance: smoothedResult.SmoothedInitialCovariance ); if (updateParameters) UpdateParameters(updatedParameters); } - return new ExpectationMaximizationResult(logLikelihood.DetachFromDisposeScope(), updatedParameters); + return new ExpectationMaximizationResult(logLikelihood, updatedParameters); } public OrthogonalizedResult OrthogonalizeStateAndCovariance(Tensor state, Tensor covariance) @@ -451,13 +698,14 @@ public void UpdateParameters(KalmanFilterParameters updatedParameters) _initialCovariance.set_(updatedParameters.InitialCovariance); } - private static Tensor EnsureSymmetric(Tensor M) => 0.5f * (M + M.transpose(0, 1)); + private static Tensor EnsureSymmetric(Tensor M) => 0.5 * (M + M.mT); + + private static Tensor Ensure2D(Tensor M) => M.atleast_2d(); private static Tensor InverseCholesky(Tensor B, Tensor A) { - using var d = NewDisposeScope(); - var L = linalg.cholesky(A); - var solT = cholesky_solve(B.transpose(0, 1), L); - return solT.transpose(0, 1).MoveToOuterDisposeScope(); + var L = linalg.cholesky(Ensure2D(A)); + var solT = cholesky_solve(Ensure2D(B).mT, L); + return solT.mT; } } From 91549b7d8d9b8bbb0e4945061afecf1bf13bd878 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 18 Sep 2025 21:44:33 +0100 Subject: [PATCH 15/92] Added reserve method to model manager to support model creation from parameters --- .../KalmanFilterModelManager.cs | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs index 76b844d0..6e362845 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs @@ -85,6 +85,33 @@ internal static KalmanFilterDisposable Reserve( })); } + internal static KalmanFilterDisposable Reserve( + string name, + KalmanFilterParameters parameters, + Device? device = null, + ScalarType? scalarType = null + ) + { + if (_models.ContainsKey(name)) + { + throw new InvalidOperationException($"A Kalman filter with name {name} already exists."); + } + + var kalmanFilter = new KalmanFilter( + parameters: parameters, + device: device, + scalarType: scalarType ?? ScalarType.Float32 + ); + + _models.Add(name, kalmanFilter); + + return new KalmanFilterDisposable(kalmanFilter, Disposable.Create(() => + { + _models.Remove(name); + kalmanFilter.Dispose(); + })); + } + private readonly struct ManagedLock( ReaderWriterLockSlim lockObject, Mode mode) : IDisposable From 22efde7c496568f20c9ad56a03aaa804508d3833 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 18 Sep 2025 21:44:51 +0100 Subject: [PATCH 16/92] Added project references to tests --- .../Bonsai.ML.Torch.LDS.Tests/Bonsai.ML.Torch.LDS.Tests.csproj | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/Bonsai.ML.Torch.LDS.Tests.csproj b/tests/Bonsai.ML.Torch.LDS.Tests/Bonsai.ML.Torch.LDS.Tests.csproj index af262cc7..14ddd981 100644 --- a/tests/Bonsai.ML.Torch.LDS.Tests/Bonsai.ML.Torch.LDS.Tests.csproj +++ b/tests/Bonsai.ML.Torch.LDS.Tests/Bonsai.ML.Torch.LDS.Tests.csproj @@ -27,6 +27,8 @@ + + From 30799b225c7e978f413072a54e9abba19a0f38b0 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 18 Sep 2025 21:45:26 +0100 Subject: [PATCH 17/92] Removed unused variables that were previously there for plotting --- .../estimate_neural_latents.py | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/estimate_neural_latents.py b/tests/Bonsai.ML.Torch.LDS.Tests/estimate_neural_latents.py index fad7f533..107dcb8f 100644 --- a/tests/Bonsai.ML.Torch.LDS.Tests/estimate_neural_latents.py +++ b/tests/Bonsai.ML.Torch.LDS.Tests/estimate_neural_latents.py @@ -14,27 +14,20 @@ import os # Parse arguments -parser = argparse.ArgumentParser() -parser.add_argument("base_dir", type=str, default=None) -args = parser.parse_args() +try: + parser = argparse.ArgumentParser() + parser.add_argument("base_dir", type=str, default=None) + args = parser.parse_args() -base_dir = args.base_dir + base_dir = args.base_dir +except: + base_dir = os.path.realpath(os.path.dirname(__file__)) # data dandiset_ID = "000140" dandi_filepath = "sub-Jenkins/sub-Jenkins_ses-small_desc-train_behavior+ecephys.nwb" bin_size = 0.02 -# plot -events_names = ["start_time", "target_on_time", "go_cue_time", - "move_onset_time", "stop_time"] -events_linetypes = ["dot", "dash", "dashdot", "longdash", "solid"] -events_colors_spikes = ["white", "white", "white", "white", "white"] -events_colors_latents = ["black", "black", "black", "black", "black"] -cb_alpha = 0.3 -from_time = 100.0 -to_time = 130.0 - # model n_latents = 10 @@ -47,8 +40,8 @@ sigma_V0 = 0.1 # estimation parameters -max_iter = 5 -tol = 1e-1 +max_iter = 1 +tol = 0.1 vars_to_estimate = {"B": True, "Q": True, "Z": True, "R": True, "m0": True, "V0": True, } From 5c4adc6fa6b8bfb7d91dc043423904f25c4b9c17 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 18 Sep 2025 21:46:30 +0100 Subject: [PATCH 18/92] Modified requirements.txt to use ssm from github instead of local --- tests/Bonsai.ML.Torch.LDS.Tests/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/requirements.txt b/tests/Bonsai.ML.Torch.LDS.Tests/requirements.txt index 05d89b7e..ff3efa23 100644 --- a/tests/Bonsai.ML.Torch.LDS.Tests/requirements.txt +++ b/tests/Bonsai.ML.Torch.LDS.Tests/requirements.txt @@ -127,7 +127,7 @@ secretstorage==3.3.3 semantic-version==2.10.0 setuptools==80.9.0 six==1.17.0 --e file:///home/nicholas/lds_python +ssm @ git+https://github.com/ncguilbeault/lds_python@75e3e5e92ce6344009b62a5034db49b238db63ef stack-data==0.6.3 sympy==1.14.0 tenacity==9.1.2 From e7a995b1a861e1cd0f86a7377402bea41bb0300a Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 19 Sep 2025 09:06:20 +0100 Subject: [PATCH 19/92] Corrected bug with initialization of state and covariance --- src/Bonsai.ML.Torch.LDS/KalmanFilter.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs index 82a8b0c8..e7a9dd8a 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs @@ -46,8 +46,8 @@ public KalmanFilter( _identityStates = eye(_numStates, dtype: _scalarType, device: _device); - _state = empty(_numStates, dtype: _scalarType, device: _device).requires_grad_(false); - _covariance = empty([_numStates, _numStates], dtype: _scalarType, device: _device).requires_grad_(false); + _state = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); + _covariance = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); } public KalmanFilter( @@ -88,8 +88,8 @@ public KalmanFilter( _processNoiseCovariance = CreateCovarianceMatrix(processNoiseVariance, _scalarType, _device, numStates, "Process noise variance"); _measurementNoiseCovariance = CreateCovarianceMatrix(measurementNoiseVariance, _scalarType, _device, numObservations, "Measurement noise variance"); - _state = empty(_numStates, dtype: _scalarType, device: _device).requires_grad_(false); - _covariance = empty([_numStates, _numStates], dtype: _scalarType, device: _device).requires_grad_(false); + _state = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); + _covariance = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); RegisterComponents(); } From 0a680269bf3a9b0ebfd991ffd7e17578ad841e43 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 19 Sep 2025 09:48:35 +0100 Subject: [PATCH 20/92] Refactored EM function for improved readability --- src/Bonsai.ML.Torch.LDS/KalmanFilter.cs | 68 ++++++++++++++----------- 1 file changed, 38 insertions(+), 30 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs index e7a9dd8a..b8f54f29 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs @@ -589,7 +589,13 @@ public ExpectationMaximizationResult ExpectationMaximization( var logLikelihood = empty(maxIterations, dtype: ScalarType.Float32, device: _device); var previousLogLikelihood = double.NegativeInfinity; var logLikelihoodConst = -0.5 * timeBins * _numObservations * log(2.0 * Math.PI); - var updatedParameters = Parameters; + + var transitionMatrix = _transitionMatrix; + var measurementFunction = _measurementFunction; + var processNoiseCovariance = _processNoiseCovariance; + var measurementNoiseCovariance = _measurementNoiseCovariance; + var initialState = _initialState; + var initialCovariance = _initialCovariance; for (int iteration = 0; iteration < maxIterations; iteration++) { @@ -599,12 +605,12 @@ public ExpectationMaximizationResult ExpectationMaximization( timeBins: timeBins, numStates: _numStates, numObservations: _numObservations, - transitionMatrix: _transitionMatrix, - measurementFunction: _measurementFunction, - processNoiseCovariance: _processNoiseCovariance, - measurementNoiseCovariance: _measurementNoiseCovariance, - initialState: _initialState, - initialCovariance: _initialCovariance, + transitionMatrix: transitionMatrix, + measurementFunction: measurementFunction, + processNoiseCovariance: processNoiseCovariance, + measurementNoiseCovariance: measurementNoiseCovariance, + initialState: initialState, + initialCovariance: initialCovariance, scalarType: _scalarType, device: _device); @@ -631,10 +637,10 @@ public ExpectationMaximizationResult ExpectationMaximization( filteredResult: filteredResult, timeBins: timeBins, numStates: _numStates, - transitionMatrix: _transitionMatrix, - measurementFunction: _measurementFunction, - initialState: _initialState, - initialCovariance: _initialCovariance, + transitionMatrix: transitionMatrix, + measurementFunction: measurementFunction, + initialState: initialState, + initialCovariance: initialCovariance, identityStates: _identityStates, scalarType: _scalarType, device: _device); @@ -648,30 +654,32 @@ public ExpectationMaximizationResult ExpectationMaximization( var autoCorrelationObservations = einsum("tp,tq->pq", observation, observation); // Update parameters - var updatedTransitionMatrix = InverseCholesky(crossCorrelationStates, autoCorrelationStatesCurrent); - var updatedMeasurementFunction = InverseCholesky(crossCorrelationObservations, autoCorrelationStatesNext); - var updatedProcessNoiseCovariance = WrappedTensorDisposeScope(() => + transitionMatrix = InverseCholesky(crossCorrelationStates, autoCorrelationStatesCurrent); + measurementFunction = InverseCholesky(crossCorrelationObservations, autoCorrelationStatesNext); + processNoiseCovariance = WrappedTensorDisposeScope(() => EnsureSymmetric((autoCorrelationStatesNext - InverseCholesky(crossCorrelationStates, autoCorrelationStatesCurrent).matmul(crossCorrelationStates.T)) / timeBins)); - var CSyzT = updatedMeasurementFunction.matmul(crossCorrelationObservations.mT); - var updatedMeasurementNoiseCovariance = WrappedTensorDisposeScope(() => - EnsureSymmetric((autoCorrelationObservations - CSyzT - CSyzT.mT - + updatedMeasurementFunction.matmul(autoCorrelationStatesNext) - .matmul(updatedMeasurementFunction.mT)) / timeBins)); - - updatedParameters = new KalmanFilterParameters( - transitionMatrix: updatedTransitionMatrix, - measurementFunction: updatedMeasurementFunction, - processNoiseCovariance: updatedProcessNoiseCovariance, - measurementNoiseCovariance: updatedMeasurementNoiseCovariance, - initialState: smoothedResult.SmoothedInitialState, - initialCovariance: smoothedResult.SmoothedInitialCovariance - ); + var explainedObservationCovariance = measurementFunction.matmul(crossCorrelationObservations.mT); + measurementNoiseCovariance = WrappedTensorDisposeScope(() => + EnsureSymmetric((autoCorrelationObservations - explainedObservationCovariance - explainedObservationCovariance.mT + measurementFunction.matmul(autoCorrelationStatesNext) + .matmul(measurementFunction.mT)) / timeBins)); - if (updateParameters) - UpdateParameters(updatedParameters); + initialState = smoothedResult.SmoothedInitialState; + initialCovariance = smoothedResult.SmoothedInitialCovariance; } + var updatedParameters = new KalmanFilterParameters( + transitionMatrix: transitionMatrix, + measurementFunction: measurementFunction, + processNoiseCovariance: processNoiseCovariance, + measurementNoiseCovariance: measurementNoiseCovariance, + initialState: initialState, + initialCovariance: initialCovariance + ); + + if (updateParameters) + UpdateParameters(updatedParameters); + return new ExpectationMaximizationResult(logLikelihood, updatedParameters); } From 0d161e4a47ac358a05628ccc40556b81beecdbde Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 19 Sep 2025 09:51:04 +0100 Subject: [PATCH 21/92] Updated test to correctly compare python and bonsai tensor results --- .../NeuralLatentsTest.bonsai | 662 +++--------------- .../NeuralLatentsTest.cs | 116 +-- 2 files changed, 166 insertions(+), 612 deletions(-) diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.bonsai b/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.bonsai index 43d1b506..98b8cfa2 100644 --- a/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.bonsai +++ b/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.bonsai @@ -3,484 +3,123 @@ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns:p1="clr-namespace:Bonsai.ML.Torch;assembly=Bonsai.ML.Torch" xmlns:rx="clr-namespace:Bonsai.Reactive;assembly=Bonsai.Core" - xmlns:p2="clr-namespace:Bonsai.Dsp;assembly=Bonsai.Dsp" - xmlns:p3="clr-namespace:;assembly=Extensions" + xmlns:p2="clr-namespace:Bonsai.ML.Torch.LDS;assembly=Bonsai.ML.Torch.LDS" xmlns="https://bonsai-rx.org/2018/workflow"> - - - CUDA - -1 - - - - CUDA - LoadData - - ../data/transformed_binned_spikes.bin - 0 - 0 - 142 - 0 - F64 - RowMajor - - - - - - - - - CUDA - - - - - - - - - - - - Float32 - - - - - - -1 - 142 - - - - - ObservationT - - - - - ../data/stop_times.bin - 0 - 0 - 1 - 0 - F64 - RowMajor - - - - - + + transformed_binned_spikes.pt - - CUDA - - - - - - - + - - Float32 - - - - + - -1 1 + 0 - TrialEnd + ObservationT - - ../data/JoaquinModelParameters/covariance.bin - 0 - 0 - 10 - 0 - F64 - RowMajor - - - - - - - - - CUDA - - - - - - - - - - - - Float32 + + python_V0_0.pt - - - 10 - 10 - - + Covariance - - ../data/JoaquinModelParameters/state.bin - 0 - 0 - 10 - 0 - F64 - RowMajor + + python_m0_0.pt - - - - - - CUDA - - - - - - - - - - - - Float32 - - - - - - 10 - - + State - - ../data/JoaquinModelParameters/measurementFunction.bin - 0 - 0 - 10 - 0 - F64 - RowMajor - - - - - - - - - CUDA - - - - - - - - - - - - Float32 + + python_Z0.pt - - - 142 - 10 - - + MeasurementFunction - - ../data/JoaquinModelParameters/measurementNoiseCovariance.bin - 0 - 0 - 142 - 0 - F64 - RowMajor - - - - - - - - - CUDA - - - - - - - - - - - - Float32 + + python_R0.pt - - - 142 - 142 - - + MeasurementNoiseCovariance - - ../data/JoaquinModelParameters/transitionMatrix.bin - 0 - 0 - 10 - 0 - F64 - RowMajor + + python_B0.pt - - - - - - CUDA - - - - - - - - - - - - Float32 - - - - - - 10 - 10 - - + TransitionMatrix - - ../data/JoaquinModelParameters/processNoiseCovariance.bin - 0 - 0 - 10 - 0 - F64 - RowMajor + + python_Q0.pt - - - - - - CUDA - - - - - - - - - - - - Float32 - - - - - - 10 - 10 - - + ProcessNoiseCovariance - - - ../data/bin_centers.bin - 0 - 0 - 1 - 0 - F64 - RowMajor - - - - - - - - - CUDA - - - - - - - - - - - - Float32 - - - - - - -1 - 1 - - - - - Time - - - + - - + - - - + - - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - @@ -520,20 +159,28 @@ - - - - CUDA - - - - - + + Float64 + [] + [] + [] + [] + [] + [] + - - 10 - 142 + + KalmanFilter + 2 + 2 + Float64 + [] + [] + [] + [] + [] + [] @@ -550,11 +197,9 @@ - + - - - + @@ -565,19 +210,12 @@ ObservationT - - KalmanFilterModel - - - - - - - - 5 - 0.1 - true + + KalmanFilter + 1 + 0.1 + true @@ -589,101 +227,32 @@ - - ExpectationMaximizationResult - - - LogLikelihood - - - - 0 - - - - Float32 - - - - 0 - - - - - - - - - - - - - - - - - - - - - UpdateParameters - - - - ExpectationMaximizationResult - - - Parameters - - - KalmanFilterModel - - - - - - - - - - - ParametersUpdated - - - - + + - - - Filter + Smoother ObservationT - - KalmanFilterModel - - - - - - - + + KalmanFilter + UpdatedFilteredResult - ParametersUpdated + ExpectationMaximizationResult @@ -691,116 +260,81 @@ UpdatedFilteredResult - - KalmanFilterModel - - - - - - - + + KalmanFilter + UpdatedSmoothedResult - - ParametersUpdated - - - - UpdatedSmoothedResult - - KalmanFilterModel - - - - - - - + + KalmanFilter + OrthogonalizedResult - ParametersUpdated + OrthogonalizedResult - - + + OrthogonalizedState - - PT0S - PT0.01S + + bonsai_means.pt OrthogonalizedResult - - - - - TimerSource - - - OrthogonalizedResult - - OrthogonalizedState + OrthogonalizedCovariance - - 5000:5500,0 + + bonsai_covs.pt - - 0 - - - - Float32 + + - + - - - - - + + + + - - - - - + + + + + - - - - - - - - - + + + 1 + + + - + + \ No newline at end of file diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.cs b/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.cs index 426307b9..a6c64d73 100644 --- a/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.cs +++ b/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.cs @@ -7,6 +7,8 @@ using System.Runtime.InteropServices; using System.Threading.Tasks; using Bonsai.ML.Tests.Utilities; +using static TorchSharp.torch; +using TorchSharp; namespace Bonsai.ML.Torch.LDS.Tests; @@ -29,22 +31,9 @@ private static void RunPythonScript(string basePath) Console.WriteLine("Run python script finished."); } - private static async Task RunBonsaiWorkflow(string basePath) - { - var currentDirectory = Environment.CurrentDirectory; - Environment.CurrentDirectory = basePath; - try - { - var workflowPath = Path.Combine(basePath, "NeuralLatentsTest.bonsai"); - await WorkflowHelper.RunWorkflow( - workflowPath); - Console.WriteLine("Run bonsai workflow finished."); - } - finally { Environment.CurrentDirectory = currentDirectory; } - } - private static double[] ReadBinaryFile(string fileName) { + Console.WriteLine($"Reading binary file: {fileName}"); using var fileStream = new FileStream(fileName, FileMode.Open, FileAccess.Read); using var binaryReader = new BinaryReader(fileStream); var fileLength = fileStream.Length; @@ -54,45 +43,64 @@ private static double[] ReadBinaryFile(string fileName) { data[i] = binaryReader.ReadDouble(); } + Console.WriteLine($"Read {numDoubles} doubles from {fileName}"); return data; } - private static bool CompareBinaryData(string basePath, double tolerance = 1e-4) + private static void WriteToTensor(string fileName, long[] shape) { - var bonsaiMeansFileName = Path.Combine(basePath, "bonsai_means.bin"); - var bonsaiCovariancesFileName = Path.Combine(basePath, "bonsai_covs.bin"); - - var pythonMeansFileName = Path.Combine(basePath, "python_means.bin"); - var pythonCovariancesFileName = Path.Combine(basePath, "python_covs.bin"); - - var bonsaiMeans = ReadBinaryFile(bonsaiMeansFileName); - var bonsaiCovariances = ReadBinaryFile(bonsaiCovariancesFileName); - var pythonMeans = ReadBinaryFile(pythonMeansFileName); - var pythonCovariances = ReadBinaryFile(pythonCovariancesFileName); - - if (bonsaiMeans.Length != pythonMeans.Length || - bonsaiCovariances.Length != pythonCovariances.Length) - { - return false; - } + Console.WriteLine($"Reading filename: {fileName} and creating tensor with shape [{string.Join(", ", shape)}]"); + var data = ReadBinaryFile(fileName); + var tensor = from_array(data).reshape(shape); + var outputFileName = Path.ChangeExtension(fileName, ".pt"); + tensor.Save(outputFileName); + Console.WriteLine($"Saved tensor to {outputFileName}"); + } - for (int i = 0; i < bonsaiMeans.Length; i++) - { - if (Math.Abs(bonsaiMeans[i] - pythonMeans[i]) > tolerance) - { - return false; - } - } + private static void ConvertBinaryFiles(string basePath) + { + var transformedBinnedSpikesFileName = Path.Combine(basePath, "transformed_binned_spikes.bin"); + WriteToTensor(transformedBinnedSpikesFileName, [142, -1]); + + var transitionMatrixFileName = Path.Combine(basePath, "python_B0.bin"); + WriteToTensor(transitionMatrixFileName, [10, 10]); + + var measurementFunctionFileName = Path.Combine(basePath, "python_Z0.bin"); + WriteToTensor(measurementFunctionFileName, [142, 10]); + + var processNoiseFileName = Path.Combine(basePath, "python_Q0.bin"); + WriteToTensor(processNoiseFileName, [10, 10]); - for (int i = 0; i < bonsaiCovariances.Length; i++) + var observationNoiseFileName = Path.Combine(basePath, "python_R0.bin"); + WriteToTensor(observationNoiseFileName, [142, 142]); + + var initialStateFileName = Path.Combine(basePath, "python_m0_0.bin"); + WriteToTensor(initialStateFileName, [10]); + + var initialCovarianceFileName = Path.Combine(basePath, "python_V0_0.bin"); + WriteToTensor(initialCovarianceFileName, [10, 10]); + + var outputMeansFileName = Path.Combine(basePath, "python_means.bin"); + WriteToTensor(outputMeansFileName, [10, -1]); + + var outputCovariancesFileName = Path.Combine(basePath, "python_covs.bin"); + WriteToTensor(outputCovariancesFileName, [10, 10, -1]); + } + + private static async Task RunBonsaiWorkflow(string basePath) + { + Console.WriteLine($"Running Bonsai workflow..."); + var currentDirectory = Environment.CurrentDirectory; + Environment.CurrentDirectory = basePath; + + try { - if (Math.Abs(bonsaiCovariances[i] - pythonCovariances[i]) > tolerance) - { - return false; - } + var workflowPath = Path.Combine(basePath, "NeuralLatentsTest.bonsai"); + await WorkflowHelper.RunWorkflow( + workflowPath); + Console.WriteLine("Run bonsai workflow finished."); } - - return true; + finally { Environment.CurrentDirectory = currentDirectory; } } /// @@ -106,6 +114,7 @@ public async Task TestSetup() { Directory.CreateDirectory(basePath); RunPythonScript(basePath); + ConvertBinaryFiles(basePath); await RunBonsaiWorkflow(basePath); } @@ -113,9 +122,20 @@ public async Task TestSetup() /// Compares the results from the Python script and the Bonsai workflow. /// [TestMethod] - public void CompareResults() + public void CompareTensorData() { - var result = CompareBinaryData(basePath); - Assert.IsTrue(result); + var bonsaiMeansFileName = Path.Combine(basePath, "bonsai_means.pt"); + var bonsaiCovariancesFileName = Path.Combine(basePath, "bonsai_covs.pt"); + + var pythonMeansFileName = Path.Combine(basePath, "python_means.pt"); + var pythonCovariancesFileName = Path.Combine(basePath, "python_covs.pt"); + + var bonsaiMeans = Tensor.Load(bonsaiMeansFileName); + var bonsaiCovariances = Tensor.Load(bonsaiCovariancesFileName); + var pythonMeans = Tensor.Load(pythonMeansFileName).permute(1, 0); + var pythonCovariances = Tensor.Load(pythonCovariancesFileName).permute(2, 0, 1); + + Assert.IsTrue(allclose(bonsaiMeans, pythonMeans)); + Assert.IsTrue(allclose(bonsaiCovariances, pythonCovariances)); } } From ef3d11aa23990f7d23fe2e9aa8da8da36eba9bf7 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 29 Sep 2025 10:46:49 +0100 Subject: [PATCH 22/92] Removed unused import --- src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs index e8d6ca92..85f8461c 100644 --- a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs +++ b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs @@ -1,6 +1,5 @@ using System; using System.ComponentModel; -using System.Collections.Generic; using System.Linq; using System.Reactive.Linq; using System.Xml.Serialization; From 5394bd3090689e0e6610500c49745fc9959e580d Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 29 Sep 2025 11:43:58 +0100 Subject: [PATCH 23/92] Removed unused imports and added XML docs --- .../ExpectationMaximization.cs | 57 +++--- src/Bonsai.ML.Torch.LDS/Filter.cs | 19 +- src/Bonsai.ML.Torch.LDS/KalmanFilter.cs | 193 ++++++++++-------- .../KalmanFilterModelManager.cs | 3 +- .../KalmanFilterNameConverter.cs | 16 +- src/Bonsai.ML.Torch.LDS/Orthogonalize.cs | 42 +++- .../OrthogonalizedResult.cs | 31 ++- src/Bonsai.ML.Torch.LDS/Smooth.cs | 36 +++- src/Bonsai.ML.Torch.LDS/UpdateParameters.cs | 21 +- 9 files changed, 256 insertions(+), 162 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs b/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs index 72a518b1..7421050c 100644 --- a/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs +++ b/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs @@ -1,69 +1,72 @@ using System; using System.ComponentModel; -using System.Reactive; using System.Reactive.Linq; -using System.Reflection; -using System.Runtime.InteropServices; -using System.Threading; -using System.Threading.Tasks; -using System.Xml.Serialization; -using Bonsai; -using Bonsai.ML.Torch; -using Bonsai.ML.Torch.NeuralNets; -using Bonsai.Reactive; using TorchSharp; -using TorchSharp.Modules; +using static TorchSharp.torch; namespace Bonsai.ML.Torch.LDS; /// -/// Learn the parameters kalman filter using the batch EM update algorithm. +/// Learn the parameters of a kalman filter using the batch EM update algorithm. /// [Combinator] [ResetCombinator] -[Description("Learn the parameters kalman filter using the batch EM update algorithm.")] +[Description("Learn the parameters of a kalman filter using the batch EM update algorithm.")] [WorkflowElementCategory(ElementCategory.Transform)] public class ExpectationMaximization { + /// + /// The name of the Kalman filter model to be trained. + /// [TypeConverter(typeof(KalmanFilterNameConverter))] + [Description("The name of the Kalman filter model to be trained.")] public string ModelName { get; set; } = "KalmanFilter"; private int _maxIterations = 10; + /// + /// The maximum number of EM iterations to perform. + /// + [Description("The maximum number of EM iterations to perform.")] public int MaxIterations { get => _maxIterations; - set - { - if (value < 1) throw new ArgumentOutOfRangeException("MaxIterations must be at least 1."); - _maxIterations = value; - } + set => _maxIterations = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(MaxIterations), "Must be greater than zero."); } private double _tolerance = 1e-4; + /// + /// The convergence tolerance for the EM algorithm. + /// + [Description("The convergence tolerance for the EM algorithm.")] public double Tolerance { get => _tolerance; - set - { - if (value < 0) throw new ArgumentOutOfRangeException("Tolerance must be non-negative."); - _tolerance = value; - } + set => _tolerance = value >= 0 ? value : throw new ArgumentOutOfRangeException(nameof(Tolerance), "Must be greater than or equal to zero."); } private bool _verbose = true; + /// + /// If true, prints progress messages to the console. + /// + [Description("If true, prints progress messages to the console.")] public bool Verbose { get => _verbose; set => _verbose = value; } - public IObservable Process(IObservable source) + /// + /// Processes an observable sequence of input tensors, applying the Expectation-Maximization algorithm to learn the parameters of a Kalman filter model. + /// + /// + /// + public IObservable Process(IObservable source) { return source.Select(input => { var model = KalmanFilterModelManager.GetKalmanFilter(ModelName); var previousLogLikelihood = double.NegativeInfinity; - var logLikelihood = torch.zeros(new long[] { MaxIterations }, device: input.device); + var logLikelihood = zeros(new long[] { MaxIterations }, device: input.device); for (int i = 0; i < MaxIterations; i++) { @@ -75,7 +78,7 @@ public IObservable Process(IObservable Process(IObservable +/// Applies a Kalman filter to the input tensor sequence. +/// [Combinator] -[ResetCombinator] -[Description("")] +[Description("Applies a Kalman filter to the input tensor sequence.")] [WorkflowElementCategory(ElementCategory.Transform)] public class Filter { + /// + /// The name of the Kalman filter model to be used. + /// [TypeConverter(typeof(KalmanFilterNameConverter))] + [Description("The name of the Kalman filter model to be used.")] public string ModelName { get; set; } = "KalmanFilter"; - public IObservable Process(IObservable source) + /// + /// Processes an observable sequence of input tensors, applying the Kalman filter to each tensor. + /// + public IObservable Process(IObservable source) { return source.Select((input) => { diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs index b8f54f29..8c5d7ba9 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs @@ -1,5 +1,4 @@ using System; -using TorchSharp; using static TorchSharp.torch; namespace Bonsai.ML.Torch.LDS; @@ -256,6 +255,8 @@ private static UpdatedResult FilterUpdate( public FilteredResult Filter(Tensor observation) { + using var g = no_grad(); + var obs = observation.atleast_2d(); var timeBins = obs.size(0); @@ -389,6 +390,8 @@ private static FilteredResultWithAuxiliaryVariables Filter( public SmoothedResult Smooth(FilteredResult filteredResult) { + using var g = no_grad(); + var predictedState = filteredResult.PredictedState; var predictedCovariance = filteredResult.PredictedCovariance; var updatedState = filteredResult.UpdatedState; @@ -448,17 +451,17 @@ private readonly struct SmoothedResultWithAuxiliaryVariables( Tensor smoothedCovariance, Tensor smoothedInitialState, Tensor smoothedInitialCovariance, - Tensor autoCorrelationStatesCurrent, - Tensor crossCorrelationStates, - Tensor autoCorrelationStatesNext) + Tensor S00, + Tensor S10, + Tensor S11) { public readonly Tensor SmoothedState = smoothedState; public readonly Tensor SmoothedCovariance = smoothedCovariance; public readonly Tensor SmoothedInitialState = smoothedInitialState; public readonly Tensor SmoothedInitialCovariance = smoothedInitialCovariance; - public readonly Tensor AutoCorrelationStatesCurrent = autoCorrelationStatesCurrent; - public readonly Tensor CrossCorrelationStates = crossCorrelationStates; - public readonly Tensor AutoCorrelationStatesNext = autoCorrelationStatesNext; + public readonly Tensor S00 = S00; + public readonly Tensor S10 = S10; + public readonly Tensor S11 = S11; } private static SmoothedResultWithAuxiliaryVariables Smooth( @@ -486,9 +489,9 @@ Device device var smoothedState = empty_like(updatedState); var smoothedCovariance = empty_like(updatedCovariance); - var autoCorrelationStatesCurrent = zeros_like(smoothedCovariance, dtype: scalarType, device: device); - var crossCorrelationStates = zeros_like(smoothedCovariance, dtype: scalarType, device: device); - var autoCorrelationStatesNext = zeros_like(smoothedCovariance, dtype: scalarType, device: device); + var S00 = zeros_like(smoothedCovariance, dtype: scalarType, device: device); + var S10 = zeros_like(smoothedCovariance, dtype: scalarType, device: device); + var S11 = zeros_like(smoothedCovariance, dtype: scalarType, device: device); // Fix the last time point smoothedState[-1] = updatedState[-1]; @@ -499,7 +502,7 @@ Device device .matmul(transitionMatrix) .matmul(updatedCovariance[-2])); - autoCorrelationStatesNext[-1] = outer(updatedState[-1], updatedState[-1]) + updatedCovariance[-1]; + S11[-1] = outer(updatedState[-1], updatedState[-1]) + updatedCovariance[-1]; var smoothingGain = empty([numStates, numStates], dtype: scalarType, device: device); var smoothingGainNext = null as Tensor; @@ -525,9 +528,9 @@ Device device ); var expectationUpdate = outer(smoothedState[time], smoothedState[time]) + smoothedCovariance[time]; - autoCorrelationStatesNext[time] = expectationUpdate; - autoCorrelationStatesCurrent[time + 1] = expectationUpdate; - crossCorrelationStates[time + 1] = outer(smoothedState[time + 1], smoothedState[time]) + smoothedLagOneCovariance; + S11[time] = expectationUpdate; + S00[time + 1] = expectationUpdate; + S10[time + 1] = outer(smoothedState[time + 1], smoothedState[time]) + smoothedLagOneCovariance; // Compute next smoothing gain for lag one covariance if (time > 0) @@ -565,17 +568,17 @@ Device device - transitionMatrix.matmul(updatedCovariance[0])) .matmul(smoothingGain0.mT)); - crossCorrelationStates[0] = outer(smoothedState[0], smoothedInitialState) + smoothedLagOneCovariance; - autoCorrelationStatesCurrent[0] = outer(smoothedInitialState, smoothedInitialState) + smoothedInitialCovariance; + S10[0] = outer(smoothedState[0], smoothedInitialState) + smoothedLagOneCovariance; + S00[0] = outer(smoothedInitialState, smoothedInitialState) + smoothedInitialCovariance; return new SmoothedResultWithAuxiliaryVariables( smoothedState: smoothedState, smoothedCovariance: smoothedCovariance, smoothedInitialState: smoothedInitialState, smoothedInitialCovariance: smoothedInitialCovariance, - autoCorrelationStatesCurrent: autoCorrelationStatesCurrent, - crossCorrelationStates: crossCorrelationStates, - autoCorrelationStatesNext: autoCorrelationStatesNext + S00: S00, + S10: S10, + S11: S11 ); } @@ -588,7 +591,7 @@ public ExpectationMaximizationResult ExpectationMaximization( var timeBins = observation.size(0); var logLikelihood = empty(maxIterations, dtype: ScalarType.Float32, device: _device); var previousLogLikelihood = double.NegativeInfinity; - var logLikelihoodConst = -0.5 * timeBins * _numObservations * log(2.0 * Math.PI); + var logLikelihoodConst = -0.5 * timeBins * _numObservations * Math.Log(2.0 * Math.PI); var transitionMatrix = _transitionMatrix; var measurementFunction = _measurementFunction; @@ -597,75 +600,85 @@ public ExpectationMaximizationResult ExpectationMaximization( var initialState = _initialState; var initialCovariance = _initialCovariance; - for (int iteration = 0; iteration < maxIterations; iteration++) - { - // Filter observations - var filteredResult = Filter( - observation: observation, - timeBins: timeBins, - numStates: _numStates, - numObservations: _numObservations, - transitionMatrix: transitionMatrix, - measurementFunction: measurementFunction, - processNoiseCovariance: processNoiseCovariance, - measurementNoiseCovariance: measurementNoiseCovariance, - initialState: initialState, - initialCovariance: initialCovariance, - scalarType: _scalarType, - device: _device); - - // Compute log likelihood - var filteredLogLikelihood = logLikelihoodConst + 0.5 * filteredResult.LogLikelihood.sum(); - var filteredLogLikelihoodSum = filteredLogLikelihood.to_type(ScalarType.Float64).item(); - - logLikelihood[iteration] = filteredLogLikelihoodSum; + // Precompute constant observation terms reused across EM iterations + var observationT = observation.mT; + var autoCorrelationObservations = observationT.matmul(observation); - // Check for convergence - if (filteredLogLikelihoodSum <= previousLogLikelihood) + using (var _ = no_grad()) + { + for (int iteration = 0; iteration < maxIterations; iteration++) { - Console.WriteLine($"Warning: Log likelihood decreased! New: {filteredLogLikelihoodSum}, Previous: {previousLogLikelihood}"); - break; + // Filter observations + var filteredResult = Filter( + observation: observation, + timeBins: timeBins, + numStates: _numStates, + numObservations: _numObservations, + transitionMatrix: transitionMatrix, + measurementFunction: measurementFunction, + processNoiseCovariance: processNoiseCovariance, + measurementNoiseCovariance: measurementNoiseCovariance, + initialState: initialState, + initialCovariance: initialCovariance, + scalarType: _scalarType, + device: _device); + + // Compute log likelihood (avoid creating intermediate tensors) + var llSumDouble = filteredResult.LogLikelihood.sum() + .to_type(ScalarType.Float64).item(); + var filteredLogLikelihoodSum = logLikelihoodConst + 0.5 * llSumDouble; + + logLikelihood[iteration] = filteredLogLikelihoodSum; + + // Check for convergence + if (filteredLogLikelihoodSum <= previousLogLikelihood) + { + Console.WriteLine($"Warning: Log likelihood decreased! New: {filteredLogLikelihoodSum}, Previous: {previousLogLikelihood}"); + break; + } + + if (filteredLogLikelihoodSum - previousLogLikelihood < tolerance) + break; + + previousLogLikelihood = filteredLogLikelihoodSum; + + // Smooth the filtered results + var smoothedResult = Smooth( + filteredResult: filteredResult, + timeBins: timeBins, + numStates: _numStates, + transitionMatrix: transitionMatrix, + measurementFunction: measurementFunction, + initialState: initialState, + initialCovariance: initialCovariance, + identityStates: _identityStates, + scalarType: _scalarType, + device: _device); + + // Sufficient statistics + var S00 = smoothedResult.S00.sum([0]); + var S11 = smoothedResult.S11.sum([0]); + var S10 = smoothedResult.S10.sum([0]); + + // Replace einsum with faster matmul + var crossCorrelationObservations = observationT.matmul(smoothedResult.SmoothedState); + + // Update parameters + transitionMatrix = InverseCholesky(S10, S00); + measurementFunction = InverseCholesky(crossCorrelationObservations, S11); + + // Reuse transitionMatrix (avoid an extra solve) + processNoiseCovariance = WrappedTensorDisposeScope(() => + EnsureSymmetric((S11 - transitionMatrix.matmul(S10.mT)) / timeBins)); + + var explainedObservationCovariance = measurementFunction.matmul(crossCorrelationObservations.mT); + measurementNoiseCovariance = WrappedTensorDisposeScope(() => + EnsureSymmetric((autoCorrelationObservations - explainedObservationCovariance - explainedObservationCovariance.mT + + measurementFunction.matmul(S11).matmul(measurementFunction.mT)) / timeBins)); + + initialState = smoothedResult.SmoothedInitialState; + initialCovariance = smoothedResult.SmoothedInitialCovariance; } - - if (filteredLogLikelihoodSum - previousLogLikelihood < tolerance) - break; - - previousLogLikelihood = filteredLogLikelihoodSum; - - // Smooth the filtered results - var smoothedResult = Smooth( - filteredResult: filteredResult, - timeBins: timeBins, - numStates: _numStates, - transitionMatrix: transitionMatrix, - measurementFunction: measurementFunction, - initialState: initialState, - initialCovariance: initialCovariance, - identityStates: _identityStates, - scalarType: _scalarType, - device: _device); - - // Sufficient statistics - var autoCorrelationStatesCurrent = smoothedResult.AutoCorrelationStatesCurrent.sum([0]); - var autoCorrelationStatesNext = smoothedResult.AutoCorrelationStatesNext.sum([0]); - var crossCorrelationStates = smoothedResult.CrossCorrelationStates.sum([0]); - - var crossCorrelationObservations = einsum("tp,tn->pn", observation, smoothedResult.SmoothedState); - var autoCorrelationObservations = einsum("tp,tq->pq", observation, observation); - - // Update parameters - transitionMatrix = InverseCholesky(crossCorrelationStates, autoCorrelationStatesCurrent); - measurementFunction = InverseCholesky(crossCorrelationObservations, autoCorrelationStatesNext); - processNoiseCovariance = WrappedTensorDisposeScope(() => - EnsureSymmetric((autoCorrelationStatesNext - InverseCholesky(crossCorrelationStates, autoCorrelationStatesCurrent).matmul(crossCorrelationStates.T)) / timeBins)); - - var explainedObservationCovariance = measurementFunction.matmul(crossCorrelationObservations.mT); - measurementNoiseCovariance = WrappedTensorDisposeScope(() => - EnsureSymmetric((autoCorrelationObservations - explainedObservationCovariance - explainedObservationCovariance.mT + measurementFunction.matmul(autoCorrelationStatesNext) - .matmul(measurementFunction.mT)) / timeBins)); - - initialState = smoothedResult.SmoothedInitialState; - initialCovariance = smoothedResult.SmoothedInitialCovariance; } var updatedParameters = new KalmanFilterParameters( @@ -685,13 +698,13 @@ public ExpectationMaximizationResult ExpectationMaximization( public OrthogonalizedResult OrthogonalizeStateAndCovariance(Tensor state, Tensor covariance) { - var (U, S, Vt) = linalg.svd(_measurementFunction); + var (_, S, Vt) = linalg.svd(_measurementFunction); var SVt = diag(S).matmul(Vt); - var orthogonalizedState = einsum("tk,kj->tj", state, SVt.mT); + var orthogonalizedState = matmul(state, SVt.mT); - var auxilary = einsum("ik,tkj->tij", SVt, covariance); - var orthogonalizedCovariance = einsum("tij,jk->tik", auxilary, SVt.mT); + var auxilary = matmul(SVt, covariance); + var orthogonalizedCovariance = matmul(auxilary, SVt.mT); return new OrthogonalizedResult(orthogonalizedState, orthogonalizedCovariance); } diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs index 6e362845..fb132ea2 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs @@ -5,10 +5,9 @@ using System.Collections.Generic; using static TorchSharp.torch; using Bonsai.ML.Torch.LDS; -using TorchSharp; // -// Manages instances of the Kalman Filter in a thread-safe manner. +// Manages instances of the Kalman Filter model with a thread-safe locking mechanism for reading state tensors and writing parameters. // internal sealed class KalmanFilterModelManager { diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilterNameConverter.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilterNameConverter.cs index af31bf37..78eb99b2 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilterNameConverter.cs +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilterNameConverter.cs @@ -1,10 +1,12 @@ -using Bonsai; using Bonsai.Expressions; using System.Linq; using System.ComponentModel; namespace Bonsai.ML.Torch.LDS; +/// +/// Provides a type converter to select the name of an existing Kalman filter model in the workflow. +/// public class KalmanFilterNameConverter : StringConverter { /// @@ -22,12 +24,12 @@ public override StandardValuesCollection GetStandardValues(ITypeDescriptorContex if (workflowBuilder != null) { var models = (from builder in workflowBuilder.Workflow.Descendants() - where builder.GetType() != typeof(DisableBuilder) - let managedModelNode = ExpressionBuilder.GetWorkflowElement(builder) - where managedModelNode != null && managedModelNode is CreateKalmanFilter - let createKalmanFilter = (CreateKalmanFilter)managedModelNode - where createKalmanFilter != null && !string.IsNullOrEmpty(createKalmanFilter.ModelName) - select createKalmanFilter.ModelName) + where builder.GetType() != typeof(DisableBuilder) + let managedModelNode = ExpressionBuilder.GetWorkflowElement(builder) + where managedModelNode != null && managedModelNode is CreateKalmanFilter + let createKalmanFilter = (CreateKalmanFilter)managedModelNode + where createKalmanFilter != null && !string.IsNullOrEmpty(createKalmanFilter.ModelName) + select createKalmanFilter.ModelName) .Distinct() .ToList(); if (models.Count > 0) diff --git a/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs b/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs index 2081056d..54cc01fe 100644 --- a/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs +++ b/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs @@ -1,20 +1,30 @@ -using TorchSharp; using System; -using Bonsai; using System.ComponentModel; using System.Reactive.Linq; -using System.Xml.Serialization; +using static TorchSharp.torch; namespace Bonsai.ML.Torch.LDS; +/// +/// Orthogonalizes the state and covariance estimates from a Kalman filter or smoother. +/// [Combinator] -[Description("")] +[Description("Orthogonalizes the state and covariance estimates from a Kalman filter or smoother.")] [WorkflowElementCategory(ElementCategory.Transform)] public class Orthogonalize { + /// + /// The name of the Kalman filter model to be used. + /// [TypeConverter(typeof(KalmanFilterNameConverter))] + [Description("The name of the Kalman filter model to be used.")] public string ModelName { get; set; } = "KalmanFilter"; + /// + /// Processes an observable sequence of smoothed results, orthogonalizing the state and covariance estimates. + /// + /// + /// public IObservable Process(IObservable source) { return source.Select(input => @@ -29,6 +39,11 @@ public IObservable Process(IObservable sou }); } + /// + /// Processes an observable sequence of filtered results, orthogonalizing the state and covariance estimates. + /// + /// + /// public IObservable Process(IObservable source) { return source.Select(input => @@ -42,4 +57,23 @@ public IObservable Process(IObservable sou } }); } + + /// + /// Processes an observable sequence of state and covariance tuples, orthogonalizing the state and covariance estimates. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(input => + { + var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); + var state = input.Item1; + var covariance = input.Item2; + using (KalmanFilterModelManager.Read(kalmanFilter)) + { + return kalmanFilter.OrthogonalizeStateAndCovariance(state, covariance); + } + }); + } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/OrthogonalizedResult.cs b/src/Bonsai.ML.Torch.LDS/OrthogonalizedResult.cs index e2419f07..0666f8dc 100644 --- a/src/Bonsai.ML.Torch.LDS/OrthogonalizedResult.cs +++ b/src/Bonsai.ML.Torch.LDS/OrthogonalizedResult.cs @@ -1,17 +1,26 @@ -using TorchSharp; +using static TorchSharp.torch; namespace Bonsai.ML.Torch.LDS; -public struct OrthogonalizedResult +/// +/// Represents the result of orthogonalizing the state and covariance estimates. +/// +/// +/// Initializes a new instance of the struct. +/// +/// +/// +public struct OrthogonalizedResult( + Tensor orthogonalizedState, + Tensor orthogonalizedCovariance) { - public torch.Tensor OrthogonalizedState; - public torch.Tensor OrthogonalizedCovariance; + /// + /// The orthogonalized state estimate. + /// + public Tensor OrthogonalizedState = orthogonalizedState; - public OrthogonalizedResult( - torch.Tensor orthogonalizedState, - torch.Tensor orthogonalizedCovariance) - { - OrthogonalizedState = orthogonalizedState; - OrthogonalizedCovariance = orthogonalizedCovariance; - } + /// + /// The orthogonalized covariance estimate. + /// + public Tensor OrthogonalizedCovariance = orthogonalizedCovariance; } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/Smooth.cs b/src/Bonsai.ML.Torch.LDS/Smooth.cs index d06f6946..8c27de22 100644 --- a/src/Bonsai.ML.Torch.LDS/Smooth.cs +++ b/src/Bonsai.ML.Torch.LDS/Smooth.cs @@ -1,24 +1,31 @@ using System; using System.ComponentModel; using System.Reactive.Linq; -using System.Xml.Serialization; -using Bonsai; -using Bonsai.ML.Torch; -using Bonsai.ML.Torch.NeuralNets; -using TorchSharp; -using TorchSharp.Modules; +using static TorchSharp.torch; namespace Bonsai.ML.Torch.LDS; +/// +/// Applies a Kalman smoother to the input filtered result sequence. +/// [Combinator] [ResetCombinator] -[Description("")] +[Description("Applies a Kalman smoother to the input filtered result sequence.")] [WorkflowElementCategory(ElementCategory.Transform)] public class Smooth { + /// + /// The name of the Kalman filter model to be used. + /// [TypeConverter(typeof(KalmanFilterNameConverter))] + [Description("The name of the Kalman filter model to be used.")] public string ModelName { get; set; } = "KalmanFilter"; + /// + /// Processes an observable sequence of filtered results, applying the Kalman smoother to each result. + /// + /// + /// public IObservable Process(IObservable source) { return source.Select((input) => @@ -30,4 +37,19 @@ public IObservable Process(IObservable source) } }); } + + /// + /// Processes an observable sequence of tuples containing the components of a filtered result (predictedState, predictedCovariance, updatedState, updatedCovariance), applying the Kalman smoother to each result. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select((input) => + { + var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); + var filteredResult = new FilteredResult(input.Item1, input.Item2, input.Item3, input.Item4); + return kalmanFilter.Smooth(filteredResult); + }); + } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/UpdateParameters.cs b/src/Bonsai.ML.Torch.LDS/UpdateParameters.cs index 3286c2b9..30998c21 100644 --- a/src/Bonsai.ML.Torch.LDS/UpdateParameters.cs +++ b/src/Bonsai.ML.Torch.LDS/UpdateParameters.cs @@ -1,25 +1,30 @@ using System; using System.ComponentModel; using System.Reactive.Linq; -using System.Runtime.InteropServices; -using System.Xml.Serialization; -using Bonsai; -using Bonsai.ML.Torch; -using Bonsai.ML.Torch.NeuralNets; -using TorchSharp; -using TorchSharp.Modules; namespace Bonsai.ML.Torch.LDS; +/// +/// Updates the parameters of a Kalman filter model instance using the provided Kalman filter parameters. +/// [Combinator] [ResetCombinator] -[Description("Learn the parameters kalman filter using the batch EM update algorithm.")] +[Description("Updates the parameters of a Kalman filter model instance using the provided Kalman filter parameters.")] [WorkflowElementCategory(ElementCategory.Sink)] public class UpdateParameters { + /// + /// The name of the Kalman filter model to be used. + /// [TypeConverter(typeof(KalmanFilterNameConverter))] + [Description("The name of the Kalman filter model to be used.")] public string ModelName { get; set; } = "KalmanFilter"; + /// + /// Updates the parameters of a Kalman filter model using elements from the input sequence. + /// + /// + /// public IObservable Process(IObservable source) { return source.Do((input) => From 1bb1d4584c9b16457a3cbe9b82766e20fcf766a4 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 29 Sep 2025 11:45:30 +0100 Subject: [PATCH 24/92] Removed convoluted lock mechanism in favor of no lock --- .../ExpectationMaximization.cs | 10 +---- src/Bonsai.ML.Torch.LDS/Filter.cs | 5 +-- .../KalmanFilterModelManager.cs | 43 ------------------- src/Bonsai.ML.Torch.LDS/Orthogonalize.cs | 15 ++----- src/Bonsai.ML.Torch.LDS/Smooth.cs | 5 +-- src/Bonsai.ML.Torch.LDS/UpdateParameters.cs | 5 +-- 6 files changed, 7 insertions(+), 76 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs b/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs index 7421050c..b974e709 100644 --- a/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs +++ b/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs @@ -70,11 +70,7 @@ public IObservable Process(IObservable so for (int i = 0; i < MaxIterations; i++) { - ExpectationMaximizationResult result; - using (KalmanFilterModelManager.Read(model)) - { - result = model.ExpectationMaximization(input, 1, Tolerance, false); - } + var result = model.ExpectationMaximization(input, 1, Tolerance, false); var logLikelihoodSum = result.LogLikelihood .cpu() @@ -102,11 +98,7 @@ public IObservable Process(IObservable so break; } previousLogLikelihood = logLikelihoodSum; - - using (KalmanFilterModelManager.Write(model)) - { model.UpdateParameters(result.Parameters); - } } var expectationMaximizationResult = new ExpectationMaximizationResult( diff --git a/src/Bonsai.ML.Torch.LDS/Filter.cs b/src/Bonsai.ML.Torch.LDS/Filter.cs index 7d52a231..78a1f8e7 100644 --- a/src/Bonsai.ML.Torch.LDS/Filter.cs +++ b/src/Bonsai.ML.Torch.LDS/Filter.cs @@ -28,10 +28,7 @@ public IObservable Process(IObservable source) return source.Select((input) => { var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); - using (KalmanFilterModelManager.Read(kalmanFilter)) - { - return kalmanFilter.Filter(input); - } + return kalmanFilter.Filter(input); }); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs index fb132ea2..e1265267 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs @@ -11,31 +11,6 @@ // internal sealed class KalmanFilterModelManager { - private static readonly ConditionalWeakTable _moduleLocks = new(); - - public static ReaderWriterLockSlim GetLock(KalmanFilter instance) => - _moduleLocks.GetValue(instance, _ => new ReaderWriterLockSlim(LockRecursionPolicy.NoRecursion)); - - public static IDisposable Read(KalmanFilter instance) - { - var lockObject = GetLock(instance); - lockObject.EnterReadLock(); - return new ManagedLock(lockObject, Mode.Read); - } - - public static IDisposable Write(KalmanFilter instance) - { - var lockObject = GetLock(instance); - lockObject.EnterWriteLock(); - return new ManagedLock(lockObject, Mode.Write); - } - - private enum Mode - { - Read, - Write - } - private static readonly Dictionary _models = new(); public static KalmanFilter GetKalmanFilter(string name) @@ -111,24 +86,6 @@ internal static KalmanFilterDisposable Reserve( })); } - private readonly struct ManagedLock( - ReaderWriterLockSlim lockObject, - Mode mode) : IDisposable - { - private readonly ReaderWriterLockSlim _lockObject = lockObject; - private readonly Mode _mode = mode; - - public void Dispose() - { - // Exit in the reverse mode we entered. - switch (_mode) - { - case Mode.Read when _lockObject.IsReadLockHeld: _lockObject.ExitReadLock(); break; - case Mode.Write when _lockObject.IsWriteLockHeld: _lockObject.ExitWriteLock(); break; - } - } - } - internal sealed class KalmanFilterDisposable(KalmanFilter model, IDisposable disposable) : IDisposable { private IDisposable? resource = disposable ?? throw new ArgumentNullException(nameof(disposable)); diff --git a/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs b/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs index 54cc01fe..5922ffd9 100644 --- a/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs +++ b/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs @@ -32,10 +32,7 @@ public IObservable Process(IObservable sou var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); var smoothedState = input.SmoothedState; var smoothedCovariance = input.SmoothedCovariance; - using (KalmanFilterModelManager.Read(kalmanFilter)) - { - return kalmanFilter.OrthogonalizeStateAndCovariance(smoothedState, smoothedCovariance); - } + return kalmanFilter.OrthogonalizeStateAndCovariance(smoothedState, smoothedCovariance); }); } @@ -51,10 +48,7 @@ public IObservable Process(IObservable sou var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); var filteredState = input.UpdatedState; var filteredCovariance = input.UpdatedCovariance; - using (KalmanFilterModelManager.Read(kalmanFilter)) - { - return kalmanFilter.OrthogonalizeStateAndCovariance(filteredState, filteredCovariance); - } + return kalmanFilter.OrthogonalizeStateAndCovariance(filteredState, filteredCovariance); }); } @@ -70,10 +64,7 @@ public IObservable Process(IObservable Process(IObservable source) return source.Select((input) => { var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); - using (KalmanFilterModelManager.Read(kalmanFilter)) - { - return kalmanFilter.Smooth(input); - } + return kalmanFilter.Smooth(input); }); } diff --git a/src/Bonsai.ML.Torch.LDS/UpdateParameters.cs b/src/Bonsai.ML.Torch.LDS/UpdateParameters.cs index 30998c21..3852a8c5 100644 --- a/src/Bonsai.ML.Torch.LDS/UpdateParameters.cs +++ b/src/Bonsai.ML.Torch.LDS/UpdateParameters.cs @@ -30,10 +30,7 @@ public IObservable Process(IObservable { var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); - using (KalmanFilterModelManager.Write(kalmanFilter)) - { - kalmanFilter.UpdateParameters(input); - } + kalmanFilter.UpdateParameters(input); }); } } \ No newline at end of file From b9f6880f5779f5752b80d3e57166fc07b4a167c4 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 29 Sep 2025 11:46:57 +0100 Subject: [PATCH 25/92] Added cleanup to test --- .../NeuralLatentsTest.cs | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.cs b/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.cs index a6c64d73..3fb6620e 100644 --- a/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.cs +++ b/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.cs @@ -110,6 +110,7 @@ await WorkflowHelper.RunWorkflow( [DeploymentItem("bootstrap_test_environment.py")] [DeploymentItem("estimate_neural_latents.py")] [DeploymentItem("NeuralLatentsTest.bonsai")] + [DeploymentItem("requirements.txt")] public async Task TestSetup() { Directory.CreateDirectory(basePath); @@ -118,6 +119,30 @@ public async Task TestSetup() await RunBonsaiWorkflow(basePath); } + /// + /// Cleanup files generated for test. + /// + [TestCleanup] + public void TestCleanup() + { + var ptFiles = Directory.GetFiles(basePath, "*.pt"); + var binFiles = Directory.GetFiles(basePath, "*.bin"); + foreach (var file in ptFiles) File.Delete(file); + foreach (var file in binFiles) File.Delete(file); + + var virtualEnvPath = Path.Combine(basePath, ".venv"); + if (Directory.Exists(virtualEnvPath)) + { + Directory.Delete(virtualEnvPath, true); + } + + var remfileCachePath = Path.Combine(basePath, "remfile_cache"); + if (Directory.Exists(remfileCachePath)) + { + Directory.Delete(remfileCachePath, true); + } + } + /// /// Compares the results from the Python script and the Bonsai workflow. /// From a68f90e34dfe783abc4c5f3569141712c3130268 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 29 Sep 2025 13:36:55 +0100 Subject: [PATCH 26/92] Updated package installs and removed requirements --- .../bootstrap_test_environment.py | 7 +- .../requirements.txt | 152 ------------------ 2 files changed, 6 insertions(+), 153 deletions(-) delete mode 100644 tests/Bonsai.ML.Torch.LDS.Tests/requirements.txt diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/bootstrap_test_environment.py b/tests/Bonsai.ML.Torch.LDS.Tests/bootstrap_test_environment.py index a842b565..7340b4ea 100644 --- a/tests/Bonsai.ML.Torch.LDS.Tests/bootstrap_test_environment.py +++ b/tests/Bonsai.ML.Torch.LDS.Tests/bootstrap_test_environment.py @@ -73,7 +73,12 @@ def install_requirements(requirements_file: str, venv_path: str = None): base_dir = get_base_dir(args.base_dir) venv_path = create_venv(base_dir) activate_venv(venv_path) -install_requirements(os.path.join(base_dir, "requirements.txt"), venv_path) + +install("torch", venv_path) +install("plotly", venv_path) +install("remfile", venv_path) +install("dandi", venv_path) +install("ssm@git+https://github.com/ncguilbeault/lds_python@75e3e5e92ce6344009b62a5034db49b238db63ef", venv_path) python_path = get_python_path(venv_path) diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/requirements.txt b/tests/Bonsai.ML.Torch.LDS.Tests/requirements.txt deleted file mode 100644 index ff3efa23..00000000 --- a/tests/Bonsai.ML.Torch.LDS.Tests/requirements.txt +++ /dev/null @@ -1,152 +0,0 @@ -acres==0.5.0 -aiobotocore==2.24.1 -aiohappyeyeballs==2.6.1 -aiohttp==3.12.15 -aioitertools==0.12.0 -aiosignal==1.4.0 -annotated-types==0.7.0 -arrow==1.3.0 -asciitree==0.3.3 -asttokens==3.0.0 -attrs==25.3.0 -bids-validator-deno==2.0.11 -bidsschematools==1.0.14 -blessed==1.21.0 -botocore==1.39.11 -certifi==2025.8.3 -cffi==1.17.1 -charset-normalizer==3.4.3 -ci-info==0.3.0 -click==8.1.8 -click-didyoumean==0.3.1 -comm==0.2.3 -cryptography==45.0.7 -dandi==0.71.3 -dandischema==0.11.1 -debugpy==1.8.16 -decorator==5.2.1 -deprecated==1.2.18 -dnspython==2.7.0 -email-validator==2.3.0 -etelemetry==0.3.1 -executing==2.2.1 -fasteners==0.20 -fastjsonschema==2.21.2 -filelock==3.19.1 -fqdn==1.5.1 -frozenlist==1.7.0 -fscacher==0.4.4 -fsspec==2025.9.0 -h5py==3.14.0 -hdmf==4.1.0 -hdmf-zarr==0.11.3 -humanize==4.13.0 -idna==3.10 -interleave==0.3.0 -ipykernel==6.30.1 -ipython==9.5.0 -ipython-pygments-lexers==1.1.1 -isodate==0.7.2 -isoduration==20.11.0 -jaraco-classes==3.4.0 -jaraco-context==6.0.1 -jaraco-functools==4.3.0 -jedi==0.19.2 -jeepney==0.9.0 -jinja2==3.1.6 -jmespath==1.0.1 -joblib==1.5.2 -jsonpointer==3.0.0 -jsonschema==4.25.1 -jsonschema-specifications==2025.4.1 -jupyter-client==8.6.3 -jupyter-core==5.8.1 -keyring==25.6.0 -keyrings-alt==5.0.2 -markupsafe==3.0.2 -matplotlib-inline==0.1.7 -ml-dtypes==0.5.3 -more-itertools==10.8.0 -mpmath==1.3.0 -multidict==6.6.4 -narwhals==2.3.0 -natsort==8.4.0 -nbformat==5.10.4 -nest-asyncio==1.6.0 -networkx==3.5 -numcodecs==0.15.1 -numpy==2.3.2 -nvidia-cublas-cu12==12.8.4.1 -nvidia-cuda-cupti-cu12==12.8.90 -nvidia-cuda-nvrtc-cu12==12.8.93 -nvidia-cuda-runtime-cu12==12.8.90 -nvidia-cudnn-cu12==9.10.2.21 -nvidia-cufft-cu12==11.3.3.83 -nvidia-cufile-cu12==1.13.1.3 -nvidia-curand-cu12==10.3.9.90 -nvidia-cusolver-cu12==11.7.3.90 -nvidia-cusparse-cu12==12.5.8.93 -nvidia-cusparselt-cu12==0.7.1 -nvidia-nccl-cu12==2.27.3 -nvidia-nvjitlink-cu12==12.8.93 -nvidia-nvtx-cu12==12.8.90 -nwbinspector==0.6.5 -packaging==25.0 -pandas==2.3.2 -parso==0.8.5 -pexpect==4.9.0 -platformdirs==4.4.0 -plotly==6.3.0 -prompt-toolkit==3.0.52 -propcache==0.3.2 -psutil==7.0.0 -ptyprocess==0.7.0 -pure-eval==0.2.3 -pycparser==2.22 -pycryptodomex==3.23.0 -pydantic==2.11.7 -pydantic-core==2.33.2 -pygments==2.19.2 -pynwb==3.1.2 -pyout==0.8.1 -python-dateutil==2.9.0.post0 -pytz==2025.2 -pyyaml==6.0.2 -pyzmq==27.0.2 -referencing==0.36.2 -remfile==0.1.13 -requests==2.32.5 -rfc3339-validator==0.1.4 -rfc3987==1.3.8 -rpds-py==0.27.1 -ruamel-yaml==0.18.15 -ruamel-yaml-clib==0.2.12 -s3fs==2025.9.0 -scipy==1.16.1 -secretstorage==3.3.3 -semantic-version==2.10.0 -setuptools==80.9.0 -six==1.17.0 -ssm @ git+https://github.com/ncguilbeault/lds_python@75e3e5e92ce6344009b62a5034db49b238db63ef -stack-data==0.6.3 -sympy==1.14.0 -tenacity==9.1.2 -tensorstore==0.1.76 -threadpoolctl==3.6.0 -torch==2.8.0 -tornado==6.5.2 -tqdm==4.67.1 -traitlets==5.14.3 -triton==3.4.0 -types-python-dateutil==2.9.0.20250822 -typing-extensions==4.15.0 -typing-inspection==0.4.1 -tzdata==2025.2 -uri-template==1.3.0 -urllib3==2.5.0 -wcwidth==0.2.13 -webcolors==24.11.1 -wrapt==1.17.3 -yarl==1.20.1 -zarr==2.18.7 -zarr-checksum==0.4.7 From 91dea0c76f074c597eff03def7f5b557314bc76c Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 30 Sep 2025 11:00:41 +0100 Subject: [PATCH 27/92] Added the package to the documentation and included a basic article for installation --- docs/README.md | 3 +++ docs/articles/Torch.LDS/torch-lds-overview.md | 7 +++++++ docs/articles/toc.yml | 4 +++- 3 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 docs/articles/Torch.LDS/torch-lds-overview.md diff --git a/docs/README.md b/docs/README.md index 60864b6d..419f9d6a 100644 --- a/docs/README.md +++ b/docs/README.md @@ -51,6 +51,9 @@ Interfaces with the [TorchSharp](https://github.com/dotnet/TorchSharp) package, > [!NOTE] > Bonsai.ML packages can be installed through Bonsai's integrated package manager and are generally ready for immediate use. However, some packages may require additional installation steps. Refer to the specific package section for detailed installation guides and documentation. +### Bonsai.ML.Torch.LDS +Linear dynamical systems implemented using the `Bonsai.ML.Torch` package for online filtering and smoothing of latent states and parameter estimation. + ## Development Roadmap The ultimate goal of the `Bonsai.ML` project is to bring powerful machine learning tools into Bonsai to enable intelligent experimental control. To achieve this, our plan is to incorporate several different packages, models, frameworks, and language integrations. You can follow our progress by going to the [Bonsai ML development roadmap](https://github.com/orgs/bonsai-rx/projects/7). diff --git a/docs/articles/Torch.LDS/torch-lds-overview.md b/docs/articles/Torch.LDS/torch-lds-overview.md new file mode 100644 index 00000000..be2d8260 --- /dev/null +++ b/docs/articles/Torch.LDS/torch-lds-overview.md @@ -0,0 +1,7 @@ +# Bonsai.ML.Torch.LDS - Overview + +This package provides an implementation of the Kalman filter, Rauch-Tung-Striebel (RTS) smoother, and expectation maximization (EM) algorithm developed for online filtering, smoothing, and parameter estimation from data streams in Bonsai using the TorchSharp package. + +## Installation Guide + +Install the `Bonsai.ML.Torch.LDS` package from the Bonsai package manager. You will also need to follow the [instructions for setting up the Bonsai.ML.Torch package](../Torch/torch-overview.md) for running on the CPU or GPU. \ No newline at end of file diff --git a/docs/articles/toc.yml b/docs/articles/toc.yml index 162c5226..7f42b6bd 100644 --- a/docs/articles/toc.yml +++ b/docs/articles/toc.yml @@ -19,4 +19,6 @@ href: PointProcessDecoder/ppd-overview.md - name: Torch - name: Overview - href: Torch/torch-overview.md \ No newline at end of file + href: Torch/torch-overview.md +- name: Torch.LDS + href: Torch.LDS/torch-lds-overview.md \ No newline at end of file From 3b47e1ca106872c0b1c2672ae8b3ec1325c95701 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 30 Sep 2025 17:36:49 +0100 Subject: [PATCH 28/92] Updated variable naming from state to mean to more accurately represent the meaning of the variable --- src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs | 24 +- .../CreateKalmanFilterParameters.cs | 24 +- .../ExpectationMaximization.cs | 79 ++++- src/Bonsai.ML.Torch.LDS/FilteredResult.cs | 16 +- src/Bonsai.ML.Torch.LDS/KalmanFilter.cs | 281 +++++++++++------- .../KalmanFilterModelManager.cs | 4 +- .../KalmanFilterParameters.cs | 8 +- src/Bonsai.ML.Torch.LDS/Orthogonalize.cs | 18 +- .../OrthogonalizedResult.cs | 10 +- src/Bonsai.ML.Torch.LDS/Smooth.cs | 2 +- src/Bonsai.ML.Torch.LDS/SmoothedResult.cs | 12 +- 11 files changed, 305 insertions(+), 173 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs index 6f56cf68..1729f05e 100644 --- a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs +++ b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs @@ -68,7 +68,7 @@ private void ConvertTensorsScalarType(ScalarType scalarType) _measurementFunction = _measurementFunction?.to_type(scalarType); _processNoiseVariance = _processNoiseVariance?.to_type(scalarType); _measurementNoiseVariance = _measurementNoiseVariance?.to_type(scalarType); - _initialState = _initialState?.to_type(scalarType); + _initialMean = _initialMean?.to_type(scalarType); _initialCovariance = _initialCovariance?.to_type(scalarType); } @@ -169,28 +169,28 @@ public string MeasurementNoiseVarianceXml set => MeasurementNoiseVariance = TensorConverter.ConvertFromString(value, _scalarType); } - private Tensor _initialState; + private Tensor _initialMean; /// - /// The initial state. + /// The initial mean. /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - public Tensor InitialState + public Tensor InitialMean { - get => _initialState; - set => _initialState = value?.to_type(Type); + get => _initialMean; + set => _initialMean = value?.to_type(Type); } /// - /// The XML string representation of the initial state for serialization. + /// The XML string representation of the initial mean for serialization. /// [Browsable(false)] - [XmlElement(nameof(InitialState))] + [XmlElement(nameof(InitialMean))] [EditorBrowsable(EditorBrowsableState.Never)] - public string InitialStateXml + public string InitialMeanXml { - get => TensorConverter.ConvertToString(InitialState, _scalarType); - set => InitialState = TensorConverter.ConvertFromString(value, _scalarType); + get => TensorConverter.ConvertToString(InitialMean, _scalarType); + set => InitialMean = TensorConverter.ConvertFromString(value, _scalarType); } private Tensor _initialCovariance; @@ -228,7 +228,7 @@ public string InitialCovarianceXml NumObservations, _transitionMatrix, _measurementFunction, - _initialState, + _initialMean, _initialCovariance, _processNoiseVariance, _measurementNoiseVariance, diff --git a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs index 85f8461c..0c6d5b2d 100644 --- a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs +++ b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs @@ -127,28 +127,28 @@ public string MeasurementNoiseCovarianceXml set => MeasurementNoiseCovariance = TensorConverter.ConvertFromString(value, _scalarType); } - private Tensor _initialState = null; + private Tensor _initialMean = null; /// - /// The initial state. + /// The initial mean. /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - public Tensor InitialState + public Tensor InitialMean { - get => _initialState; - set => _initialState = value?.to_type(Type); + get => _initialMean; + set => _initialMean = value?.to_type(Type); } /// /// The XML string representation of the initial state for serialization. /// [Browsable(false)] - [XmlElement(nameof(InitialState))] + [XmlElement(nameof(InitialMean))] [EditorBrowsable(EditorBrowsableState.Never)] - public string InitialStateXml + public string InitialMeanXml { - get => TensorConverter.ConvertToString(InitialState, _scalarType); - set => InitialState = TensorConverter.ConvertFromString(value, _scalarType); + get => TensorConverter.ConvertToString(InitialMean, _scalarType); + set => InitialMean = TensorConverter.ConvertFromString(value, _scalarType); } private Tensor _initialCovariance = null; @@ -181,7 +181,7 @@ private void ConvertTensorsScalarType(ScalarType scalarType) _measurementFunction = _measurementFunction?.to_type(scalarType); _processNoiseCovariance = _processNoiseCovariance?.to_type(scalarType); _measurementNoiseCovariance = _measurementNoiseCovariance?.to_type(scalarType); - _initialState = _initialState?.to_type(scalarType); + _initialMean = _initialMean?.to_type(scalarType); _initialCovariance = _initialCovariance?.to_type(scalarType); } @@ -196,7 +196,7 @@ public IObservable Process() MeasurementFunction, ProcessNoiseCovariance, MeasurementNoiseCovariance, - InitialState, + InitialMean, InitialCovariance ) ); @@ -213,7 +213,7 @@ public IObservable Process(IObservable source) MeasurementFunction, ProcessNoiseCovariance, MeasurementNoiseCovariance, - InitialState, + InitialMean, InitialCovariance ) ); diff --git a/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs b/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs index b974e709..9634a6f9 100644 --- a/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs +++ b/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs @@ -2,6 +2,7 @@ using System.ComponentModel; using System.Reactive.Linq; using TorchSharp; +using System.Collections.Generic; using static TorchSharp.torch; namespace Bonsai.ML.Torch.LDS; @@ -55,6 +56,72 @@ public bool Verbose set => _verbose = value; } + private bool _estimateTransitionMatrix = true; + /// + /// If true, the transition matrix will be estimated during the EM algorithm. + /// + [Description("If true, the transition matrix will be estimated during the EM algorithm.")] + public bool EstimateTransitionMatrix + { + get => _estimateTransitionMatrix; + set => _estimateTransitionMatrix = value; + } + + private bool _estimateMeasurementFunction = true; + /// + /// If true, the measurement function will be estimated during the EM algorithm. + /// + [Description("If true, the measurement function will be estimated during the EM algorithm.")] + public bool EstimateMeasurementFunction + { + get => _estimateMeasurementFunction; + set => _estimateMeasurementFunction = value; + } + + private bool _estimateProcessNoiseCovariance = true; + /// + /// If true, the process noise covariance will be estimated during the EM algorithm. + /// + [Description("If true, the process noise covariance will be estimated during the EM algorithm.")] + public bool EstimateProcessNoiseCovariance + { + get => _estimateProcessNoiseCovariance; + set => _estimateProcessNoiseCovariance = value; + } + + private bool _estimateMeasurementNoiseCovariance = true; + /// + /// If true, the measurement noise covariance will be estimated during the EM algorithm. + /// + [Description("If true, the measurement noise covariance will be estimated during the EM algorithm.")] + public bool EstimateMeasurementNoiseCovariance + { + get => _estimateMeasurementNoiseCovariance; + set => _estimateMeasurementNoiseCovariance = value; + } + + private bool _estimateInitialMean = true; + /// + /// If true, the initial mean will be estimated during the EM algorithm. + /// + [Description("If true, the initial mean will be estimated during the EM algorithm.")] + public bool EstimateInitialMean + { + get => _estimateInitialMean; + set => _estimateInitialMean = value; + } + + private bool _estimateInitialCovariance = true; + /// + /// If true, the initial covariance will be estimated during the EM algorithm. + /// + [Description("If true, the initial covariance will be estimated during the EM algorithm.")] + public bool EstimateInitialCovariance + { + get => _estimateInitialCovariance; + set => _estimateInitialCovariance = value; + } + /// /// Processes an observable sequence of input tensors, applying the Expectation-Maximization algorithm to learn the parameters of a Kalman filter model. /// @@ -68,9 +135,19 @@ public IObservable Process(IObservable so var previousLogLikelihood = double.NegativeInfinity; var logLikelihood = zeros(new long[] { MaxIterations }, device: input.device); + var parametersToEstimate = new Dictionary + { + { "TransitionMatrix", EstimateTransitionMatrix }, + { "MeasurementFunction", EstimateMeasurementFunction }, + { "ProcessNoiseCovariance", EstimateProcessNoiseCovariance }, + { "MeasurementNoiseCovariance", EstimateMeasurementNoiseCovariance }, + { "InitialState", EstimateInitialMean }, + { "InitialCovariance", EstimateInitialCovariance } + }; + for (int i = 0; i < MaxIterations; i++) { - var result = model.ExpectationMaximization(input, 1, Tolerance, false); + var result = model.ExpectationMaximization(input, 1, Tolerance, parametersToEstimate, false); var logLikelihoodSum = result.LogLikelihood .cpu() diff --git a/src/Bonsai.ML.Torch.LDS/FilteredResult.cs b/src/Bonsai.ML.Torch.LDS/FilteredResult.cs index 6cd4285d..0eec4356 100644 --- a/src/Bonsai.ML.Torch.LDS/FilteredResult.cs +++ b/src/Bonsai.ML.Torch.LDS/FilteredResult.cs @@ -5,20 +5,20 @@ namespace Bonsai.ML.Torch.LDS; /// /// Represents the result of a Kalman filter. /// -/// +/// /// -/// +/// /// public struct FilteredResult( - Tensor predictedState, + Tensor predictedMean, Tensor predictedCovariance, - Tensor updatedState, + Tensor updatedMean, Tensor updatedCovariance) { /// - /// The predicted state after the prediction step. + /// The predicted mean after the prediction step. /// - public Tensor PredictedState = predictedState; + public Tensor PredictedMean = predictedMean; /// /// The predicted covariance after the prediction step. @@ -26,9 +26,9 @@ public struct FilteredResult( public Tensor PredictedCovariance = predictedCovariance; /// - /// The updated state after the update step. + /// The updated mean after the update step. /// - public Tensor UpdatedState = updatedState; + public Tensor UpdatedMean = updatedMean; /// /// The updated covariance after the update step. diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs index 8c5d7ba9..e84ef87b 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using static TorchSharp.torch; namespace Bonsai.ML.Torch.LDS; @@ -7,12 +8,12 @@ internal class KalmanFilter : nn.Module { private readonly Tensor _transitionMatrix; private readonly Tensor _measurementFunction; - private readonly Tensor _initialState; + private readonly Tensor _initialMean; private readonly Tensor _initialCovariance; private readonly Tensor _processNoiseCovariance; private readonly Tensor _measurementNoiseCovariance; private readonly Tensor _identityStates; - private readonly Tensor _state; + private readonly Tensor _mean; private readonly Tensor _covariance; private readonly int _numStates; private readonly int _numObservations; @@ -24,7 +25,7 @@ internal class KalmanFilter : nn.Module _measurementFunction, _processNoiseCovariance, _measurementNoiseCovariance, - _initialState, + _initialMean, _initialCovariance ); @@ -38,14 +39,14 @@ public KalmanFilter( ValidateAndSetMatrix(parameters.TransitionMatrix, "Transition matrix", _scalarType, _device, out _transitionMatrix, out _numStates, out _, isSquare: true); ValidateAndSetMatrix(parameters.MeasurementFunction, "Measurement function", _scalarType, _device, out _measurementFunction, out _numObservations, out _); - ValidateAndSetVector(parameters.InitialState, "Initial state", _scalarType, _device, out _initialState, out _, expectedLength: _numStates); + ValidateAndSetVector(parameters.InitialMean, "Initial mean", _scalarType, _device, out _initialMean, out _, expectedLength: _numStates); ValidateAndSetMatrix(parameters.InitialCovariance, "Initial covariance", _scalarType, _device, out _initialCovariance, out _, out _, isSquare: true, expectedDimension1: _numStates); ValidateAndSetMatrix(parameters.ProcessNoiseCovariance, "Process noise covariance", _scalarType, _device, out _processNoiseCovariance, out _, out _, isSquare: true, expectedDimension1: _numStates); ValidateAndSetMatrix(parameters.MeasurementNoiseCovariance, "Measurement noise covariance", _scalarType, _device, out _measurementNoiseCovariance, out _, out _, isSquare: true, expectedDimension1: _numObservations); _identityStates = eye(_numStates, dtype: _scalarType, device: _device); - _state = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); + _mean = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); _covariance = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); } @@ -54,7 +55,7 @@ public KalmanFilter( int numObservations, Tensor transitionMatrix = null, Tensor measurementFunction = null, - Tensor initialState = null, + Tensor initialMean = null, Tensor initialCovariance = null, Tensor processNoiseVariance = null, Tensor measurementNoiseVariance = null, @@ -68,7 +69,7 @@ public KalmanFilter( _identityStates = eye(_numStates, dtype: _scalarType, device: _device); - _transitionMatrix = transitionMatrix?.clone().to_type(_scalarType).to(_device).requires_grad_(false) + _transitionMatrix = transitionMatrix?.clone().to_type(_scalarType).to(_device).requires_grad_(false) ?? eye(_numStates, dtype: _scalarType, device: _device); ValidateMatrix(_transitionMatrix, "Transition matrix", isSquare: true, expectedDimension1: _numStates); @@ -76,9 +77,9 @@ public KalmanFilter( ?? eye(_numObservations, _numStates, dtype: _scalarType, device: _device); ValidateMatrix(_measurementFunction, "Measurement function", expectedDimension1: _numObservations, expectedDimension2: _numStates); - _initialState = initialState?.clone().to_type(_scalarType).to(_device).requires_grad_(false) + _initialMean = initialMean?.clone().to_type(_scalarType).to(_device).requires_grad_(false) ?? zeros(_numStates, dtype: _scalarType, device: _device); - ValidateVector(_initialState, "Initial state", _numStates); + ValidateVector(_initialMean, "Initial mean", _numStates); _initialCovariance = initialCovariance?.clone().to_type(_scalarType).to(_device).requires_grad_(false) ?? eye(_numStates, dtype: _scalarType, device: _device); @@ -87,7 +88,7 @@ public KalmanFilter( _processNoiseCovariance = CreateCovarianceMatrix(processNoiseVariance, _scalarType, _device, numStates, "Process noise variance"); _measurementNoiseCovariance = CreateCovarianceMatrix(measurementNoiseVariance, _scalarType, _device, numObservations, "Measurement noise variance"); - _state = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); + _mean = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); _covariance = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); RegisterComponents(); @@ -139,7 +140,7 @@ private static void ValidateVector(Tensor vector, string name, int? expectedLeng if (vector.Dimensions != 1) throw new ArgumentException($"{name} must be a vector."); - + if (expectedLength.HasValue && vector.NumberOfElements != expectedLength.Value) throw new ArgumentException($"{name} must be a vector with length equal to {expectedLength.Value}."); } @@ -160,37 +161,37 @@ private Tensor CreateCovarianceMatrix(Tensor variance, ScalarType scalarType, De } private readonly struct PredictedResult( - Tensor predictedState, + Tensor predictedMean, Tensor predictedCovariance) { - public readonly Tensor PredictedState = predictedState; + public readonly Tensor PredictedMean = predictedMean; public readonly Tensor PredictedCovariance = predictedCovariance; } private PredictedResult FilterPredict( - Tensor state, - Tensor covariance) => - new(_transitionMatrix.matmul(state), + Tensor mean, + Tensor covariance) => + new(_transitionMatrix.matmul(mean), _transitionMatrix.matmul(covariance) .matmul(_transitionMatrix.mT) + _processNoiseCovariance); private static PredictedResult FilterPredict( - Tensor state, + Tensor mean, Tensor covariance, Tensor transitionMatrix, - Tensor processNoiseCovariance) => - new(transitionMatrix.matmul(state), + Tensor processNoiseCovariance) => + new(transitionMatrix.matmul(mean), transitionMatrix.matmul(covariance) .matmul(transitionMatrix.mT) + processNoiseCovariance); private readonly struct UpdatedResult( - Tensor updatedState, + Tensor updatedMean, Tensor updatedCovariance, Tensor innovation, Tensor innovationCovariance, Tensor kalmanGain) { - public readonly Tensor UpdatedState = updatedState; + public readonly Tensor UpdatedMean = updatedMean; public readonly Tensor UpdatedCovariance = updatedCovariance; public readonly Tensor Innovation = innovation; public readonly Tensor InnovationCovariance = innovationCovariance; @@ -198,12 +199,12 @@ private readonly struct UpdatedResult( } private UpdatedResult FilterUpdate( - Tensor predictedState, + Tensor predictedMean, Tensor predictedCovariance, Tensor observation) { // Innovation step - var innovation = observation - _measurementFunction.matmul(predictedState); + var innovation = observation - _measurementFunction.matmul(predictedMean); var innovationCovariance = WrappedTensorDisposeScope(() => EnsureSymmetric( _measurementFunction.matmul(predictedCovariance) .matmul(_measurementFunction.mT) + _measurementNoiseCovariance)); @@ -214,22 +215,22 @@ private UpdatedResult FilterUpdate( innovationCovariance)); // Update step - var updatedState = predictedState + kalmanGain.matmul(innovation); + var updatedMean = predictedMean + kalmanGain.matmul(innovation); var updatedCovariance = WrappedTensorDisposeScope(() => predictedCovariance - kalmanGain.matmul(_measurementFunction).matmul(predictedCovariance)); - return new UpdatedResult(updatedState, updatedCovariance, innovation, innovationCovariance, kalmanGain); + return new UpdatedResult(updatedMean, updatedCovariance, innovation, innovationCovariance, kalmanGain); } - + private static UpdatedResult FilterUpdate( - Tensor predictedState, + Tensor predictedMean, Tensor predictedCovariance, Tensor observation, Tensor measurementFunction, Tensor measurementNoiseCovariance) { // Innovation step - var innovation = observation - measurementFunction.matmul(predictedState); + var innovation = observation - measurementFunction.matmul(predictedMean); var innovationCovariance = WrappedTensorDisposeScope(() => EnsureSymmetric( measurementFunction.matmul(predictedCovariance) .matmul(measurementFunction.mT) + measurementNoiseCovariance)); @@ -240,12 +241,12 @@ private static UpdatedResult FilterUpdate( innovationCovariance)); // Update step - var updatedState = predictedState + kalmanGain.matmul(innovation); + var updatedMean = predictedMean + kalmanGain.matmul(innovation); var updatedCovariance = WrappedTensorDisposeScope(() => predictedCovariance - kalmanGain.matmul(measurementFunction).matmul(predictedCovariance)); return new UpdatedResult( - updatedState: updatedState, + updatedMean: updatedMean, updatedCovariance: updatedCovariance, innovation: innovation, innovationCovariance: innovationCovariance, @@ -260,53 +261,53 @@ public FilteredResult Filter(Tensor observation) var obs = observation.atleast_2d(); var timeBins = obs.size(0); - var predictedState = empty(new long[] { timeBins, _numStates }, dtype: _scalarType, device: _device); + var predictedMean = empty(new long[] { timeBins, _numStates }, dtype: _scalarType, device: _device); var predictedCovariance = empty(new long[] { timeBins, _numStates, _numStates }, dtype: _scalarType, device: _device); - var updatedState = empty(new long[] { timeBins, _numStates }, dtype: _scalarType, device: _device); + var updatedMean = empty(new long[] { timeBins, _numStates }, dtype: _scalarType, device: _device); var updatedCovariance = empty(new long[] { timeBins, _numStates, _numStates }, dtype: _scalarType, device: _device); - if (_state.NumberOfElements == 0) - _state.set_(_initialState); + if (_mean.NumberOfElements == 0) + _mean.set_(_initialMean); if (_covariance.NumberOfElements == 0) _covariance.set_(_initialCovariance); for (long time = 0; time < timeBins; time++) { // Predict - var prediction = FilterPredict(_state, _covariance); + var prediction = FilterPredict(_mean, _covariance); // Update - var update = FilterUpdate(prediction.PredictedState, prediction.PredictedCovariance, obs[time]); + var update = FilterUpdate(prediction.PredictedMean, prediction.PredictedCovariance, obs[time]); - predictedState[time] = prediction.PredictedState; + predictedMean[time] = prediction.PredictedMean; predictedCovariance[time] = prediction.PredictedCovariance; - updatedState[time] = update.UpdatedState; + updatedMean[time] = update.UpdatedMean; updatedCovariance[time] = update.UpdatedCovariance; - _state.set_(update.UpdatedState); + _mean.set_(update.UpdatedMean); _covariance.set_(update.UpdatedCovariance); } return new FilteredResult( - predictedState: predictedState, + predictedMean: predictedMean, predictedCovariance: predictedCovariance, - updatedState: updatedState, + updatedMean: updatedMean, updatedCovariance: updatedCovariance); } - + private readonly struct FilteredResultWithAuxiliaryVariables( - Tensor predictedState, + Tensor predictedMean, Tensor predictedCovariance, - Tensor updatedState, + Tensor updatedMean, Tensor updatedCovariance, Tensor innovation, Tensor innovationCovariance, Tensor logLikelihood, Tensor kalmanGain) { - public readonly Tensor PredictedState = predictedState; + public readonly Tensor PredictedMean = predictedMean; public readonly Tensor PredictedCovariance = predictedCovariance; - public readonly Tensor UpdatedState = updatedState; + public readonly Tensor UpdatedMean = updatedMean; public readonly Tensor UpdatedCovariance = updatedCovariance; public readonly Tensor Innovation = innovation; public readonly Tensor InnovationCovariance = innovationCovariance; @@ -323,35 +324,35 @@ private static FilteredResultWithAuxiliaryVariables Filter( Tensor measurementFunction, Tensor processNoiseCovariance, Tensor measurementNoiseCovariance, - Tensor initialState, + Tensor initialMean, Tensor initialCovariance, ScalarType scalarType, Device device) { var logLikelihood = empty(timeBins, dtype: scalarType, device: device); - var predictedState = empty(new long[] { timeBins, numStates }, dtype: scalarType, device: device); + var predictedMean = empty(new long[] { timeBins, numStates }, dtype: scalarType, device: device); var predictedCovariance = empty(new long[] { timeBins, numStates, numStates }, dtype: scalarType, device: device); - var updatedState = empty(new long[] { timeBins, numStates }, dtype: scalarType, device: device); + var updatedMean = empty(new long[] { timeBins, numStates }, dtype: scalarType, device: device); var updatedCovariance = empty(new long[] { timeBins, numStates, numStates }, dtype: scalarType, device: device); var innovation = empty(new long[] { timeBins, numObservations }, dtype: scalarType, device: device); var innovationCovariance = empty(new long[] { timeBins, numObservations, numObservations }, dtype: scalarType, device: device); var kalmanGain = empty(new long[] { timeBins, numStates, numObservations }, dtype: scalarType, device: device); - var state = initialState; + var mean = initialMean; var covariance = initialCovariance; for (long time = 0; time < timeBins; time++) { // Predict var prediction = FilterPredict( - state: state, + mean: mean, covariance: covariance, transitionMatrix: transitionMatrix, processNoiseCovariance: processNoiseCovariance); // Update var update = FilterUpdate( - predictedState: prediction.PredictedState, + predictedMean: prediction.PredictedMean, predictedCovariance: prediction.PredictedCovariance, observation: observation[time], measurementFunction: measurementFunction, @@ -364,22 +365,22 @@ private static FilteredResultWithAuxiliaryVariables Filter( // Detach and assign logLikelihood[time] = logLikelihoodData; - predictedState[time] = prediction.PredictedState; + predictedMean[time] = prediction.PredictedMean; predictedCovariance[time] = prediction.PredictedCovariance; - updatedState[time] = update.UpdatedState; + updatedMean[time] = update.UpdatedMean; updatedCovariance[time] = update.UpdatedCovariance; innovation[time] = update.Innovation; innovationCovariance[time] = update.InnovationCovariance; kalmanGain[time] = update.KalmanGain; - state = update.UpdatedState; + mean = update.UpdatedMean; covariance = update.UpdatedCovariance; } return new FilteredResultWithAuxiliaryVariables( - predictedState: predictedState, + predictedMean: predictedMean, predictedCovariance: predictedCovariance, - updatedState: updatedState, + updatedMean: updatedMean, updatedCovariance: updatedCovariance, innovation: innovation, innovationCovariance: innovationCovariance, @@ -391,18 +392,18 @@ private static FilteredResultWithAuxiliaryVariables Filter( public SmoothedResult Smooth(FilteredResult filteredResult) { using var g = no_grad(); - - var predictedState = filteredResult.PredictedState; + + var predictedMean = filteredResult.PredictedMean; var predictedCovariance = filteredResult.PredictedCovariance; - var updatedState = filteredResult.UpdatedState; + var updatedMean = filteredResult.UpdatedMean; var updatedCovariance = filteredResult.UpdatedCovariance; - var timeBins = predictedState.size(0); - var smoothedState = empty_like(updatedState); + var timeBins = predictedMean.size(0); + var smoothedMean = empty_like(updatedMean); var smoothedCovariance = empty_like(updatedCovariance); // Fix the last time point - smoothedState[-1] = updatedState[-1]; + smoothedMean[-1] = updatedMean[-1]; smoothedCovariance[-1] = updatedCovariance[-1]; var smoothingGain = empty(new long[] { _numStates, _numStates }, dtype: _scalarType, device: _device); @@ -415,10 +416,10 @@ public SmoothedResult Smooth(FilteredResult filteredResult) InverseCholesky(_transitionMatrix.mT, predictedCovariance[time + 1]) )); - // Smoothed state - smoothedState[time] = WrappedTensorDisposeScope(() => updatedState[time] + // Smoothed mean + smoothedMean[time] = WrappedTensorDisposeScope(() => updatedMean[time] + smoothingGain.matmul( - (smoothedState[time + 1] - predictedState[time + 1]).unsqueeze(-1) + (smoothedMean[time + 1] - predictedMean[time + 1]).unsqueeze(-1) ).squeeze(-1)); // Smoothed covariance @@ -428,9 +429,9 @@ public SmoothedResult Smooth(FilteredResult filteredResult) ); } - // Smoothed initial state - var smoothedInitialState = WrappedTensorDisposeScope(() => _initialState + smoothingGain.matmul( - (smoothedState[0] - predictedState[0]).unsqueeze(-1) + // Smoothed initial mean + var smoothedInitialMean = WrappedTensorDisposeScope(() => _initialMean + smoothingGain.matmul( + (smoothedMean[0] - predictedMean[0]).unsqueeze(-1) ).squeeze(-1)); // Smoothed initial covariance @@ -439,25 +440,25 @@ public SmoothedResult Smooth(FilteredResult filteredResult) .matmul(smoothingGain.mT)); return new SmoothedResult( - smoothedState, + smoothedMean, smoothedCovariance, - smoothedInitialState, + smoothedInitialMean, smoothedInitialCovariance ); } private readonly struct SmoothedResultWithAuxiliaryVariables( - Tensor smoothedState, + Tensor smoothedMean, Tensor smoothedCovariance, - Tensor smoothedInitialState, + Tensor smoothedInitialMean, Tensor smoothedInitialCovariance, Tensor S00, Tensor S10, Tensor S11) { - public readonly Tensor SmoothedState = smoothedState; + public readonly Tensor SmoothedMean = smoothedMean; public readonly Tensor SmoothedCovariance = smoothedCovariance; - public readonly Tensor SmoothedInitialState = smoothedInitialState; + public readonly Tensor SmoothedInitialMean = smoothedInitialMean; public readonly Tensor SmoothedInitialCovariance = smoothedInitialCovariance; public readonly Tensor S00 = S00; public readonly Tensor S10 = S10; @@ -470,7 +471,7 @@ private static SmoothedResultWithAuxiliaryVariables Smooth( int numStates, Tensor transitionMatrix, Tensor measurementFunction, - Tensor initialState, + Tensor initialMean, Tensor initialCovariance, Tensor identityStates, ScalarType scalarType, @@ -479,14 +480,14 @@ Device device { if (timeBins < 2) throw new ArgumentException("Smoothing requires at least two time bins."); - - var predictedState = filteredResult.PredictedState; + + var predictedMean = filteredResult.PredictedMean; var predictedCovariance = filteredResult.PredictedCovariance; - var updatedState = filteredResult.UpdatedState; + var updatedMean = filteredResult.UpdatedMean; var updatedCovariance = filteredResult.UpdatedCovariance; var kalmanGain = filteredResult.KalmanGain; - var smoothedState = empty_like(updatedState); + var smoothedMean = empty_like(updatedMean); var smoothedCovariance = empty_like(updatedCovariance); var S00 = zeros_like(smoothedCovariance, dtype: scalarType, device: device); @@ -494,7 +495,7 @@ Device device var S11 = zeros_like(smoothedCovariance, dtype: scalarType, device: device); // Fix the last time point - smoothedState[-1] = updatedState[-1]; + smoothedMean[-1] = updatedMean[-1]; smoothedCovariance[-1] = updatedCovariance[-1]; var smoothedLagOneCovariance = WrappedTensorDisposeScope(() => (identityStates - kalmanGain[-1] @@ -502,7 +503,7 @@ Device device .matmul(transitionMatrix) .matmul(updatedCovariance[-2])); - S11[-1] = outer(updatedState[-1], updatedState[-1]) + updatedCovariance[-1]; + S11[-1] = outer(updatedMean[-1], updatedMean[-1]) + updatedCovariance[-1]; var smoothingGain = empty([numStates, numStates], dtype: scalarType, device: device); var smoothingGainNext = null as Tensor; @@ -515,10 +516,10 @@ Device device InverseCholesky(transitionMatrix.mT, predictedCovariance[time + 1]) )); - // Smoothed state - smoothedState[time] = WrappedTensorDisposeScope(() => updatedState[time] + // Smoothed mean + smoothedMean[time] = WrappedTensorDisposeScope(() => updatedMean[time] + smoothingGain.matmul( - (smoothedState[time + 1] - predictedState[time + 1]).unsqueeze(-1) + (smoothedMean[time + 1] - predictedMean[time + 1]).unsqueeze(-1) ).squeeze(-1)); // Smoothed covariance @@ -527,10 +528,10 @@ Device device .matmul(smoothingGain.mT) ); - var expectationUpdate = outer(smoothedState[time], smoothedState[time]) + smoothedCovariance[time]; + var expectationUpdate = outer(smoothedMean[time], smoothedMean[time]) + smoothedCovariance[time]; S11[time] = expectationUpdate; S00[time + 1] = expectationUpdate; - S10[time + 1] = outer(smoothedState[time + 1], smoothedState[time]) + smoothedLagOneCovariance; + S10[time + 1] = outer(smoothedMean[time + 1], smoothedMean[time]) + smoothedLagOneCovariance; // Compute next smoothing gain for lag one covariance if (time > 0) @@ -551,9 +552,9 @@ Device device InverseCholesky(transitionMatrix.mT, predictedCovariance[0]) )); - // Smoothed initial state - var smoothedInitialState = WrappedTensorDisposeScope(() => initialState + smoothingGain0.matmul( - (smoothedState[0] - predictedState[0]).unsqueeze(-1) + // Smoothed initial mean + var smoothedInitialMean = WrappedTensorDisposeScope(() => initialMean + smoothingGain0.matmul( + (smoothedMean[0] - predictedMean[0]).unsqueeze(-1) ).squeeze(-1)); // Smoothed initial covariance @@ -568,13 +569,13 @@ Device device - transitionMatrix.matmul(updatedCovariance[0])) .matmul(smoothingGain0.mT)); - S10[0] = outer(smoothedState[0], smoothedInitialState) + smoothedLagOneCovariance; - S00[0] = outer(smoothedInitialState, smoothedInitialState) + smoothedInitialCovariance; + S10[0] = outer(smoothedMean[0], smoothedInitialMean) + smoothedLagOneCovariance; + S00[0] = outer(smoothedInitialMean, smoothedInitialMean) + smoothedInitialCovariance; return new SmoothedResultWithAuxiliaryVariables( - smoothedState: smoothedState, + smoothedMean: smoothedMean, smoothedCovariance: smoothedCovariance, - smoothedInitialState: smoothedInitialState, + smoothedInitialMean: smoothedInitialMean, smoothedInitialCovariance: smoothedInitialCovariance, S00: S00, S10: S10, @@ -582,12 +583,55 @@ Device device ); } + private static Dictionary ValidateAndSetParametersToEstimate(Dictionary parametersToUpdate) + { + var validParameters = new HashSet + { + "TransitionMatrix", + "MeasurementFunction", + "ProcessNoiseCovariance", + "MeasurementNoiseCovariance", + "InitialMean", + "InitialCovariance" + }; + + if (parametersToUpdate is null) + return new Dictionary + { + { "TransitionMatrix", true }, + { "MeasurementFunction", true }, + { "ProcessNoiseCovariance", true }, + { "MeasurementNoiseCovariance", true }, + { "InitialMean", true }, + { "InitialCovariance", true } + }; + + // Check for invalid parameter names + foreach (var key in parametersToUpdate.Keys) + { + if (!validParameters.Contains(key)) + throw new ArgumentException($"Invalid parameter name '{key}' in parametersToUpdate. Valid names are: {string.Join(", ", validParameters)}"); + } + + // Ensure all valid parameters are present in the dictionary + foreach (var param in validParameters) + { + if (!parametersToUpdate.ContainsKey(param)) + parametersToUpdate[param] = false; + } + + return parametersToUpdate; + } + public ExpectationMaximizationResult ExpectationMaximization( Tensor observation, int maxIterations = 100, double tolerance = 1e-4, + Dictionary parametersToUpdate = null, bool updateParameters = true) { + parametersToUpdate = ValidateAndSetParametersToEstimate(parametersToUpdate); + var timeBins = observation.size(0); var logLikelihood = empty(maxIterations, dtype: ScalarType.Float32, device: _device); var previousLogLikelihood = double.NegativeInfinity; @@ -597,7 +641,7 @@ public ExpectationMaximizationResult ExpectationMaximization( var measurementFunction = _measurementFunction; var processNoiseCovariance = _processNoiseCovariance; var measurementNoiseCovariance = _measurementNoiseCovariance; - var initialState = _initialState; + var initialMean = _initialMean; var initialCovariance = _initialCovariance; // Precompute constant observation terms reused across EM iterations @@ -618,7 +662,7 @@ public ExpectationMaximizationResult ExpectationMaximization( measurementFunction: measurementFunction, processNoiseCovariance: processNoiseCovariance, measurementNoiseCovariance: measurementNoiseCovariance, - initialState: initialState, + initialMean: initialMean, initialCovariance: initialCovariance, scalarType: _scalarType, device: _device); @@ -649,7 +693,7 @@ public ExpectationMaximizationResult ExpectationMaximization( numStates: _numStates, transitionMatrix: transitionMatrix, measurementFunction: measurementFunction, - initialState: initialState, + initialMean: initialMean, initialCovariance: initialCovariance, identityStates: _identityStates, scalarType: _scalarType, @@ -661,23 +705,31 @@ public ExpectationMaximizationResult ExpectationMaximization( var S10 = smoothedResult.S10.sum([0]); // Replace einsum with faster matmul - var crossCorrelationObservations = observationT.matmul(smoothedResult.SmoothedState); + var crossCorrelationObservations = observationT.matmul(smoothedResult.SmoothedMean); // Update parameters - transitionMatrix = InverseCholesky(S10, S00); - measurementFunction = InverseCholesky(crossCorrelationObservations, S11); + if (parametersToUpdate["TransitionMatrix"]) + transitionMatrix = InverseCholesky(S10, S00); + + if (parametersToUpdate["MeasurementFunction"]) + measurementFunction = InverseCholesky(crossCorrelationObservations, S11); - // Reuse transitionMatrix (avoid an extra solve) - processNoiseCovariance = WrappedTensorDisposeScope(() => - EnsureSymmetric((S11 - transitionMatrix.matmul(S10.mT)) / timeBins)); + if (parametersToUpdate["ProcessNoiseCovariance"]) + processNoiseCovariance = WrappedTensorDisposeScope(() => + EnsureSymmetric((S11 - transitionMatrix.matmul(S10.mT)) / timeBins)); var explainedObservationCovariance = measurementFunction.matmul(crossCorrelationObservations.mT); - measurementNoiseCovariance = WrappedTensorDisposeScope(() => - EnsureSymmetric((autoCorrelationObservations - explainedObservationCovariance - explainedObservationCovariance.mT - + measurementFunction.matmul(S11).matmul(measurementFunction.mT)) / timeBins)); - initialState = smoothedResult.SmoothedInitialState; - initialCovariance = smoothedResult.SmoothedInitialCovariance; + if (parametersToUpdate["MeasurementNoiseCovariance"]) + measurementNoiseCovariance = WrappedTensorDisposeScope(() => + EnsureSymmetric((autoCorrelationObservations - explainedObservationCovariance - explainedObservationCovariance.mT + + measurementFunction.matmul(S11).matmul(measurementFunction.mT)) / timeBins)); + + if (parametersToUpdate["InitialMean"]) + initialMean = smoothedResult.SmoothedInitialMean; + + if (parametersToUpdate["InitialCovariance"]) + initialCovariance = smoothedResult.SmoothedInitialCovariance; } } @@ -686,7 +738,7 @@ public ExpectationMaximizationResult ExpectationMaximization( measurementFunction: measurementFunction, processNoiseCovariance: processNoiseCovariance, measurementNoiseCovariance: measurementNoiseCovariance, - initialState: initialState, + initialMean: initialMean, initialCovariance: initialCovariance ); @@ -696,17 +748,20 @@ public ExpectationMaximizationResult ExpectationMaximization( return new ExpectationMaximizationResult(logLikelihood, updatedParameters); } - public OrthogonalizedResult OrthogonalizeStateAndCovariance(Tensor state, Tensor covariance) + public OrthogonalizedResult OrthogonalizeMeanAndCovariance(Tensor mean, Tensor covariance) { var (_, S, Vt) = linalg.svd(_measurementFunction); var SVt = diag(S).matmul(Vt); - var orthogonalizedState = matmul(state, SVt.mT); + var orthogonalizedMean = matmul(mean, SVt.mT); var auxilary = matmul(SVt, covariance); var orthogonalizedCovariance = matmul(auxilary, SVt.mT); - return new OrthogonalizedResult(orthogonalizedState, orthogonalizedCovariance); + return new OrthogonalizedResult( + orthogonalizedMean: orthogonalizedMean, + orthogonalizedCovariance: orthogonalizedCovariance + ); } public void UpdateParameters(KalmanFilterParameters updatedParameters) @@ -715,7 +770,7 @@ public void UpdateParameters(KalmanFilterParameters updatedParameters) _measurementFunction.set_(updatedParameters.MeasurementFunction); _processNoiseCovariance.set_(updatedParameters.ProcessNoiseCovariance); _measurementNoiseCovariance.set_(updatedParameters.MeasurementNoiseCovariance); - _initialState.set_(updatedParameters.InitialState); + _initialMean.set_(updatedParameters.InitialMean); _initialCovariance.set_(updatedParameters.InitialCovariance); } diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs index e1265267..00bc072d 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs @@ -26,7 +26,7 @@ internal static KalmanFilterDisposable Reserve( Tensor measurementFunction, Tensor processNoiseVariance, Tensor measurementNoiseVariance, - Tensor initialState, + Tensor initialMean, Tensor initialCovariance, Device? device = null, ScalarType? scalarType = null @@ -44,7 +44,7 @@ internal static KalmanFilterDisposable Reserve( measurementFunction: measurementFunction, processNoiseVariance: processNoiseVariance, measurementNoiseVariance: measurementNoiseVariance, - initialState: initialState, + initialMean: initialMean, initialCovariance: initialCovariance, device: device, scalarType: scalarType ?? ScalarType.Float32 diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilterParameters.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilterParameters.cs index ad708605..5fadf1c3 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilterParameters.cs +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilterParameters.cs @@ -12,14 +12,14 @@ namespace Bonsai.ML.Torch.LDS; /// /// /// -/// +/// /// public struct KalmanFilterParameters( Tensor transitionMatrix, Tensor measurementFunction, Tensor processNoiseCovariance, Tensor measurementNoiseCovariance, - Tensor initialState, + Tensor initialMean, Tensor initialCovariance) { /// @@ -43,9 +43,9 @@ public struct KalmanFilterParameters( public Tensor MeasurementNoiseCovariance = measurementNoiseCovariance; /// - /// The initial state. + /// The initial mean. /// - public Tensor InitialState = initialState; + public Tensor InitialMean = initialMean; /// /// The initial covariance. diff --git a/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs b/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs index 5922ffd9..06142419 100644 --- a/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs +++ b/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs @@ -21,7 +21,7 @@ public class Orthogonalize public string ModelName { get; set; } = "KalmanFilter"; /// - /// Processes an observable sequence of smoothed results, orthogonalizing the state and covariance estimates. + /// Processes an observable sequence of smoothed results, orthogonalizing the mean and covariance estimates. /// /// /// @@ -30,14 +30,14 @@ public IObservable Process(IObservable sou return source.Select(input => { var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); - var smoothedState = input.SmoothedState; + var smoothedMean = input.SmoothedMean; var smoothedCovariance = input.SmoothedCovariance; - return kalmanFilter.OrthogonalizeStateAndCovariance(smoothedState, smoothedCovariance); + return kalmanFilter.OrthogonalizeMeanAndCovariance(smoothedMean, smoothedCovariance); }); } /// - /// Processes an observable sequence of filtered results, orthogonalizing the state and covariance estimates. + /// Processes an observable sequence of filtered results, orthogonalizing the mean and covariance estimates. /// /// /// @@ -46,14 +46,14 @@ public IObservable Process(IObservable sou return source.Select(input => { var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); - var filteredState = input.UpdatedState; + var filteredMean = input.UpdatedMean; var filteredCovariance = input.UpdatedCovariance; - return kalmanFilter.OrthogonalizeStateAndCovariance(filteredState, filteredCovariance); + return kalmanFilter.OrthogonalizeMeanAndCovariance(filteredMean, filteredCovariance); }); } /// - /// Processes an observable sequence of state and covariance tuples, orthogonalizing the state and covariance estimates. + /// Processes an observable sequence of mean and covariance tuples, orthogonalizing the mean and covariance estimates. /// /// /// @@ -62,9 +62,9 @@ public IObservable Process(IObservable { var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); - var state = input.Item1; + var mean = input.Item1; var covariance = input.Item2; - return kalmanFilter.OrthogonalizeStateAndCovariance(state, covariance); + return kalmanFilter.OrthogonalizeMeanAndCovariance(mean, covariance); }); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/OrthogonalizedResult.cs b/src/Bonsai.ML.Torch.LDS/OrthogonalizedResult.cs index 0666f8dc..dcf810a0 100644 --- a/src/Bonsai.ML.Torch.LDS/OrthogonalizedResult.cs +++ b/src/Bonsai.ML.Torch.LDS/OrthogonalizedResult.cs @@ -3,21 +3,21 @@ namespace Bonsai.ML.Torch.LDS; /// -/// Represents the result of orthogonalizing the state and covariance estimates. +/// Represents the result of orthogonalizing the mean and covariance estimates. /// /// /// Initializes a new instance of the struct. /// -/// +/// /// public struct OrthogonalizedResult( - Tensor orthogonalizedState, + Tensor orthogonalizedMean, Tensor orthogonalizedCovariance) { /// - /// The orthogonalized state estimate. + /// The orthogonalized mean estimate. /// - public Tensor OrthogonalizedState = orthogonalizedState; + public Tensor OrthogonalizedMean = orthogonalizedMean; /// /// The orthogonalized covariance estimate. diff --git a/src/Bonsai.ML.Torch.LDS/Smooth.cs b/src/Bonsai.ML.Torch.LDS/Smooth.cs index 08540794..16b84fcb 100644 --- a/src/Bonsai.ML.Torch.LDS/Smooth.cs +++ b/src/Bonsai.ML.Torch.LDS/Smooth.cs @@ -36,7 +36,7 @@ public IObservable Process(IObservable source) } /// - /// Processes an observable sequence of tuples containing the components of a filtered result (predictedState, predictedCovariance, updatedState, updatedCovariance), applying the Kalman smoother to each result. + /// Processes an observable sequence of tuples containing the components of a filtered result (predictedMean, predictedCovariance, updatedMean, updatedCovariance), applying the Kalman smoother to each result. /// /// /// diff --git a/src/Bonsai.ML.Torch.LDS/SmoothedResult.cs b/src/Bonsai.ML.Torch.LDS/SmoothedResult.cs index 94230cba..17f55f22 100644 --- a/src/Bonsai.ML.Torch.LDS/SmoothedResult.cs +++ b/src/Bonsai.ML.Torch.LDS/SmoothedResult.cs @@ -5,20 +5,20 @@ namespace Bonsai.ML.Torch.LDS; /// /// Represents the result of a Kalman smoother. /// -/// +/// /// -/// +/// /// public struct SmoothedResult( - Tensor smoothedState, + Tensor smoothedMean, Tensor smoothedCovariance, - Tensor smoothedInitialState = null, + Tensor smoothedInitialMean = null, Tensor smoothedInitialCovariance = null) { /// /// The smoothed state after the smoothing step. /// - public Tensor SmoothedState = smoothedState; + public Tensor SmoothedMean = smoothedMean; /// /// The smoothed covariance after the smoothing step. @@ -28,7 +28,7 @@ public struct SmoothedResult( /// /// The smoothed initial state after the smoothing step. /// - public Tensor SmoothedInitialState = smoothedInitialState; + public Tensor SmoothedInitialMean = smoothedInitialMean; /// /// The smoothed initial covariance after the smoothing step. From 3e3ca3b1c4489fc3ebbbf46dc2763bd5a3995160 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 30 Sep 2025 17:49:00 +0100 Subject: [PATCH 29/92] Removed the line declaring requirements.txt is a deployment item --- tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.cs b/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.cs index 3fb6620e..70e2165e 100644 --- a/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.cs +++ b/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.cs @@ -110,7 +110,6 @@ await WorkflowHelper.RunWorkflow( [DeploymentItem("bootstrap_test_environment.py")] [DeploymentItem("estimate_neural_latents.py")] [DeploymentItem("NeuralLatentsTest.bonsai")] - [DeploymentItem("requirements.txt")] public async Task TestSetup() { Directory.CreateDirectory(basePath); From 7201ef915d9a46be72e7a5c87bdddb5c9c3de0b7 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 30 Sep 2025 17:49:31 +0100 Subject: [PATCH 30/92] Updated test workflow to use the variable name mean instead of state --- .../NeuralLatentsTest.bonsai | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.bonsai b/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.bonsai index 98b8cfa2..f3959c7f 100644 --- a/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.bonsai +++ b/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.bonsai @@ -52,7 +52,7 @@ - State + Mean @@ -140,7 +140,7 @@ MeasurementNoiseCovariance - State + Mean Covariance @@ -154,7 +154,7 @@ - + @@ -165,7 +165,7 @@ [] [] [] - [] + [] @@ -179,7 +179,7 @@ [] [] [] - [] + [] @@ -216,6 +216,12 @@ 1 0.1 true + true + true + true + true + true + true @@ -283,7 +289,7 @@ OrthogonalizedResult - OrthogonalizedState + OrthogonalizedMean From faff97e8d63964c63a7ebd50fd1441bf4a9bbbfa Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 30 Sep 2025 17:59:46 +0100 Subject: [PATCH 31/92] Added functionality to allow fine grained control over which parameters to estimate in the EM algorithm --- .../ExpectationMaximization.cs | 58 +++++-------------- src/Bonsai.ML.Torch.LDS/KalmanFilter.cs | 56 +++--------------- .../ParametersToEstimate.cs | 52 +++++++++++++++++ 3 files changed, 72 insertions(+), 94 deletions(-) create mode 100644 src/Bonsai.ML.Torch.LDS/ParametersToEstimate.cs diff --git a/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs b/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs index 9634a6f9..71d1d7b7 100644 --- a/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs +++ b/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs @@ -56,71 +56,41 @@ public bool Verbose set => _verbose = value; } - private bool _estimateTransitionMatrix = true; /// /// If true, the transition matrix will be estimated during the EM algorithm. /// [Description("If true, the transition matrix will be estimated during the EM algorithm.")] - public bool EstimateTransitionMatrix - { - get => _estimateTransitionMatrix; - set => _estimateTransitionMatrix = value; - } + public bool EstimateTransitionMatrix { get; set; } = true; - private bool _estimateMeasurementFunction = true; /// /// If true, the measurement function will be estimated during the EM algorithm. /// [Description("If true, the measurement function will be estimated during the EM algorithm.")] - public bool EstimateMeasurementFunction - { - get => _estimateMeasurementFunction; - set => _estimateMeasurementFunction = value; - } + public bool EstimateMeasurementFunction { get; set; } = true; - private bool _estimateProcessNoiseCovariance = true; /// /// If true, the process noise covariance will be estimated during the EM algorithm. /// [Description("If true, the process noise covariance will be estimated during the EM algorithm.")] - public bool EstimateProcessNoiseCovariance - { - get => _estimateProcessNoiseCovariance; - set => _estimateProcessNoiseCovariance = value; - } + public bool EstimateProcessNoiseCovariance { get; set; } = true; - private bool _estimateMeasurementNoiseCovariance = true; /// /// If true, the measurement noise covariance will be estimated during the EM algorithm. /// [Description("If true, the measurement noise covariance will be estimated during the EM algorithm.")] - public bool EstimateMeasurementNoiseCovariance - { - get => _estimateMeasurementNoiseCovariance; - set => _estimateMeasurementNoiseCovariance = value; - } + public bool EstimateMeasurementNoiseCovariance { get; set; } = true; - private bool _estimateInitialMean = true; /// /// If true, the initial mean will be estimated during the EM algorithm. /// [Description("If true, the initial mean will be estimated during the EM algorithm.")] - public bool EstimateInitialMean - { - get => _estimateInitialMean; - set => _estimateInitialMean = value; - } + public bool EstimateInitialMean { get; set; } = true; - private bool _estimateInitialCovariance = true; /// /// If true, the initial covariance will be estimated during the EM algorithm. /// [Description("If true, the initial covariance will be estimated during the EM algorithm.")] - public bool EstimateInitialCovariance - { - get => _estimateInitialCovariance; - set => _estimateInitialCovariance = value; - } + public bool EstimateInitialCovariance { get; set; } = true; /// /// Processes an observable sequence of input tensors, applying the Expectation-Maximization algorithm to learn the parameters of a Kalman filter model. @@ -135,15 +105,13 @@ public IObservable Process(IObservable so var previousLogLikelihood = double.NegativeInfinity; var logLikelihood = zeros(new long[] { MaxIterations }, device: input.device); - var parametersToEstimate = new Dictionary - { - { "TransitionMatrix", EstimateTransitionMatrix }, - { "MeasurementFunction", EstimateMeasurementFunction }, - { "ProcessNoiseCovariance", EstimateProcessNoiseCovariance }, - { "MeasurementNoiseCovariance", EstimateMeasurementNoiseCovariance }, - { "InitialState", EstimateInitialMean }, - { "InitialCovariance", EstimateInitialCovariance } - }; + var parametersToEstimate = new ParametersToEstimate( + transitionMatrix: EstimateTransitionMatrix, + measurementFunction: EstimateMeasurementFunction, + processNoiseCovariance: EstimateProcessNoiseCovariance, + measurementNoiseCovariance: EstimateMeasurementNoiseCovariance, + initialMean: EstimateInitialMean, + initialCovariance: EstimateInitialCovariance); for (int i = 0; i < MaxIterations; i++) { diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs index e84ef87b..708461d9 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs @@ -583,55 +583,13 @@ Device device ); } - private static Dictionary ValidateAndSetParametersToEstimate(Dictionary parametersToUpdate) - { - var validParameters = new HashSet - { - "TransitionMatrix", - "MeasurementFunction", - "ProcessNoiseCovariance", - "MeasurementNoiseCovariance", - "InitialMean", - "InitialCovariance" - }; - - if (parametersToUpdate is null) - return new Dictionary - { - { "TransitionMatrix", true }, - { "MeasurementFunction", true }, - { "ProcessNoiseCovariance", true }, - { "MeasurementNoiseCovariance", true }, - { "InitialMean", true }, - { "InitialCovariance", true } - }; - - // Check for invalid parameter names - foreach (var key in parametersToUpdate.Keys) - { - if (!validParameters.Contains(key)) - throw new ArgumentException($"Invalid parameter name '{key}' in parametersToUpdate. Valid names are: {string.Join(", ", validParameters)}"); - } - - // Ensure all valid parameters are present in the dictionary - foreach (var param in validParameters) - { - if (!parametersToUpdate.ContainsKey(param)) - parametersToUpdate[param] = false; - } - - return parametersToUpdate; - } - public ExpectationMaximizationResult ExpectationMaximization( Tensor observation, int maxIterations = 100, double tolerance = 1e-4, - Dictionary parametersToUpdate = null, + ParametersToEstimate parametersToEstimate = new(), bool updateParameters = true) { - parametersToUpdate = ValidateAndSetParametersToEstimate(parametersToUpdate); - var timeBins = observation.size(0); var logLikelihood = empty(maxIterations, dtype: ScalarType.Float32, device: _device); var previousLogLikelihood = double.NegativeInfinity; @@ -708,27 +666,27 @@ public ExpectationMaximizationResult ExpectationMaximization( var crossCorrelationObservations = observationT.matmul(smoothedResult.SmoothedMean); // Update parameters - if (parametersToUpdate["TransitionMatrix"]) + if (parametersToEstimate.TransitionMatrix) transitionMatrix = InverseCholesky(S10, S00); - if (parametersToUpdate["MeasurementFunction"]) + if (parametersToEstimate.MeasurementFunction) measurementFunction = InverseCholesky(crossCorrelationObservations, S11); - if (parametersToUpdate["ProcessNoiseCovariance"]) + if (parametersToEstimate.ProcessNoiseCovariance) processNoiseCovariance = WrappedTensorDisposeScope(() => EnsureSymmetric((S11 - transitionMatrix.matmul(S10.mT)) / timeBins)); var explainedObservationCovariance = measurementFunction.matmul(crossCorrelationObservations.mT); - if (parametersToUpdate["MeasurementNoiseCovariance"]) + if (parametersToEstimate.MeasurementNoiseCovariance) measurementNoiseCovariance = WrappedTensorDisposeScope(() => EnsureSymmetric((autoCorrelationObservations - explainedObservationCovariance - explainedObservationCovariance.mT + measurementFunction.matmul(S11).matmul(measurementFunction.mT)) / timeBins)); - if (parametersToUpdate["InitialMean"]) + if (parametersToEstimate.InitialMean) initialMean = smoothedResult.SmoothedInitialMean; - if (parametersToUpdate["InitialCovariance"]) + if (parametersToEstimate.InitialCovariance) initialCovariance = smoothedResult.SmoothedInitialCovariance; } } diff --git a/src/Bonsai.ML.Torch.LDS/ParametersToEstimate.cs b/src/Bonsai.ML.Torch.LDS/ParametersToEstimate.cs new file mode 100644 index 00000000..d0819541 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/ParametersToEstimate.cs @@ -0,0 +1,52 @@ +namespace Bonsai.ML.Torch.LDS; + +/// +/// Represents the parameters to estimate for a Kalman filter model. +/// +/// +/// Initializes a new instance of the struct with the specified parameters. +/// +/// +/// +/// +/// +/// +/// +public struct ParametersToEstimate( + bool transitionMatrix = true, + bool measurementFunction = true, + bool processNoiseCovariance = true, + bool measurementNoiseCovariance = true, + bool initialMean = true, + bool initialCovariance = true) +{ + /// + /// The state transition matrix. + /// + public bool TransitionMatrix = transitionMatrix; + + /// + /// The measurement function. + /// + public bool MeasurementFunction = measurementFunction; + + /// + /// The process noise covariance. + /// + public bool ProcessNoiseCovariance = processNoiseCovariance; + + /// + /// The measurement noise covariance. + /// + public bool MeasurementNoiseCovariance = measurementNoiseCovariance; + + /// + /// The initial mean. + /// + public bool InitialMean = initialMean; + + /// + /// The initial covariance. + /// + public bool InitialCovariance = initialCovariance; +} \ No newline at end of file From e647b4a8c0ac1c0b009edb131820eb07fee1a7fb Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 2 Oct 2025 14:43:07 +0100 Subject: [PATCH 32/92] Updated package info with better description and shared package tags --- src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj b/src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj index 30bf344b..4b6a4bb6 100644 --- a/src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj +++ b/src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj @@ -1,7 +1,7 @@ - Bonsai.ML.Torch.LDS Bonsai library. - Bonsai Rx Bonsai ML Torch TorchSharp LDS LinearDynamicalSystems + A Bonsai package building on the Bonsai.ML.Torch library that implements Linear Dynamical Systems. + $(PackageTags) Torch TorchSharp LDS LinearDynamicalSystems net472;netstandard2.0 From 6a4ea9a4d63a3514db4e6f66131037a13c9b9b74 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 6 Oct 2025 11:39:49 +0100 Subject: [PATCH 33/92] Added `Bonsai.ML.Torch.LDS.Design` package for visualizing latents --- Bonsai.ML.sln | 6 ++++++ .../Bonsai.ML.Torch.LDS.Design.csproj | 15 +++++++++++++++ .../Properties/AssemblyInfo.cs | 6 ++++++ .../Properties/launchSettings.json | 10 ++++++++++ 4 files changed, 37 insertions(+) create mode 100644 src/Bonsai.ML.Torch.LDS.Design/Bonsai.ML.Torch.LDS.Design.csproj create mode 100644 src/Bonsai.ML.Torch.LDS.Design/Properties/AssemblyInfo.cs create mode 100644 src/Bonsai.ML.Torch.LDS.Design/Properties/launchSettings.json diff --git a/Bonsai.ML.sln b/Bonsai.ML.sln index 445012ab..d096b62f 100644 --- a/Bonsai.ML.sln +++ b/Bonsai.ML.sln @@ -44,6 +44,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Torch.LDS", "src\ EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Torch.LDS.Tests", "tests\Bonsai.ML.Torch.LDS.Tests\Bonsai.ML.Torch.LDS.Tests.csproj", "{0B258929-0B07-4CE7-BE8D-A86BBC46AAD4}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Torch.LDS.Design", "src\Bonsai.ML.Torch.LDS.Design\Bonsai.ML.Torch.LDS.Design.csproj", "{1F52DECD-1B2C-4F6C-996C-14C715283B80}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -114,6 +116,10 @@ Global {0B258929-0B07-4CE7-BE8D-A86BBC46AAD4}.Debug|Any CPU.Build.0 = Debug|Any CPU {0B258929-0B07-4CE7-BE8D-A86BBC46AAD4}.Release|Any CPU.ActiveCfg = Release|Any CPU {0B258929-0B07-4CE7-BE8D-A86BBC46AAD4}.Release|Any CPU.Build.0 = Release|Any CPU + {1F52DECD-1B2C-4F6C-996C-14C715283B80}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {1F52DECD-1B2C-4F6C-996C-14C715283B80}.Debug|Any CPU.Build.0 = Debug|Any CPU + {1F52DECD-1B2C-4F6C-996C-14C715283B80}.Release|Any CPU.ActiveCfg = Release|Any CPU + {1F52DECD-1B2C-4F6C-996C-14C715283B80}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/Bonsai.ML.Torch.LDS.Design/Bonsai.ML.Torch.LDS.Design.csproj b/src/Bonsai.ML.Torch.LDS.Design/Bonsai.ML.Torch.LDS.Design.csproj new file mode 100644 index 00000000..009aabc3 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS.Design/Bonsai.ML.Torch.LDS.Design.csproj @@ -0,0 +1,15 @@ + + + Visualizers for the Bonsai.ML.Torch.LDS library. + $(PackageTags) Torch LDS Design + net472 + true + + + + + + + + + \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS.Design/Properties/AssemblyInfo.cs b/src/Bonsai.ML.Torch.LDS.Design/Properties/AssemblyInfo.cs new file mode 100644 index 00000000..b28e8258 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS.Design/Properties/AssemblyInfo.cs @@ -0,0 +1,6 @@ +using Bonsai; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: XmlNamespacePrefix("clr-namespace:Bonsai.ML.Torch.LDS.Design", null)] diff --git a/src/Bonsai.ML.Torch.LDS.Design/Properties/launchSettings.json b/src/Bonsai.ML.Torch.LDS.Design/Properties/launchSettings.json new file mode 100644 index 00000000..4af4f468 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS.Design/Properties/launchSettings.json @@ -0,0 +1,10 @@ +{ + "profiles": { + "Bonsai": { + "commandName": "Executable", + "executablePath": "$(BonsaiExecutablePath)", + "commandLineArgs": "--lib:\"$(TargetDir).\"", + "nativeDebugging": true + } + } +} \ No newline at end of file From 567dae97138057be9c616213658f0c25fcb43bf4 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 6 Oct 2025 11:40:32 +0100 Subject: [PATCH 34/92] Moved color cycle class to shared `Bonsai.ML.Design` library --- .../OxyColorPresetCycle.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename src/{Bonsai.ML.PointProcessDecoder.Design => Bonsai.ML.Design}/OxyColorPresetCycle.cs (95%) diff --git a/src/Bonsai.ML.PointProcessDecoder.Design/OxyColorPresetCycle.cs b/src/Bonsai.ML.Design/OxyColorPresetCycle.cs similarity index 95% rename from src/Bonsai.ML.PointProcessDecoder.Design/OxyColorPresetCycle.cs rename to src/Bonsai.ML.Design/OxyColorPresetCycle.cs index e2976cf4..800744cf 100644 --- a/src/Bonsai.ML.PointProcessDecoder.Design/OxyColorPresetCycle.cs +++ b/src/Bonsai.ML.Design/OxyColorPresetCycle.cs @@ -1,7 +1,7 @@ using OxyPlot; -namespace Bonsai.ML.PointProcessDecoder.Design +namespace Bonsai.ML.Design { /// /// Enumerates the colors and provides a preset collection of colors to cycle through. From 174752cd19571b4643697c1ecddb2424d59db415 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 6 Oct 2025 11:44:03 +0100 Subject: [PATCH 35/92] Added keyword arguments to function call for precision --- src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs index 1729f05e..f0c358ed 100644 --- a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs +++ b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs @@ -223,17 +223,17 @@ public string InitialCovarianceXml public IObservable Process() { return Observable.Using(() => KalmanFilterModelManager.Reserve( - ModelName, - NumStates, - NumObservations, - _transitionMatrix, - _measurementFunction, - _initialMean, - _initialCovariance, - _processNoiseVariance, - _measurementNoiseVariance, - Device, - Type + name: ModelName, + numStates: NumStates, + numObservations: NumObservations, + transitionMatrix: _transitionMatrix.NumberOfElements > 0 ? _transitionMatrix : null, + measurementFunction: _measurementFunction.NumberOfElements > 0 ? _measurementFunction : null, + initialMean: _initialMean.NumberOfElements > 0 ? _initialMean : null, + initialCovariance: _initialCovariance.NumberOfElements > 0 ? _initialCovariance : null, + processNoiseVariance: _processNoiseVariance.NumberOfElements > 0 ? _processNoiseVariance : null, + measurementNoiseVariance: _measurementNoiseVariance.NumberOfElements > 0 ? _measurementNoiseVariance : null, + device: Device, + scalarType: Type ), resource => Observable.Return(resource.Model) .Concat(Observable.Never(resource.Model)) .Finally(resource.Dispose) @@ -248,10 +248,10 @@ public string InitialCovarianceXml return source.SelectMany(parameters => { return Observable.Using(() => KalmanFilterModelManager.Reserve( - ModelName, - parameters, - Device, - Type + name: ModelName, + parameters: parameters, + device: Device, + scalarType: Type ), resource => Observable.Return(resource.Model) .Concat(Observable.Never(resource.Model)) .Finally(resource.Dispose) From 3cbf26f3b3e14bcd25dcc284a052ad09be541f1f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 6 Oct 2025 11:44:37 +0100 Subject: [PATCH 36/92] Added `StateVisualizer` class to design package to support visualizing states from Kalman filter output --- .../StateVisualizer.cs | 201 ++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs diff --git a/src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs b/src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs new file mode 100644 index 00000000..ab36b8f0 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs @@ -0,0 +1,201 @@ +using System; +using System.Reactive; +using System.Linq; +using System.Windows.Forms; +using System.Collections.Generic; + +using Bonsai; +using Bonsai.Design; +using Bonsai.ML.Design; + +using OxyPlot; +using OxyPlot.Series; + +using static TorchSharp.torch; + +[assembly: TypeVisualizer(typeof(Bonsai.ML.Torch.LDS.Design.StateVisualizer), + Target = typeof(Bonsai.ML.Torch.LDS.FilteredResult))] +[assembly: TypeVisualizer(typeof(Bonsai.ML.Torch.LDS.Design.StateVisualizer), + Target = typeof(Bonsai.ML.Torch.LDS.SmoothedResult))] +[assembly: TypeVisualizer(typeof(Bonsai.ML.Torch.LDS.Design.StateVisualizer), + Target = typeof(Bonsai.ML.Torch.LDS.OrthogonalizedResult))] + +namespace Bonsai.ML.Torch.LDS.Design; + +public class StateVisualizer : BufferedVisualizer +{ + private TimeSeriesOxyPlotBase _plot; + private LineSeries[] _lineSeries; + private AreaSeries[] _areaSeries; + + /// + /// Gets or sets the amount of time in seconds that should be shown along the x axis. + /// + public int Capacity { get; set; } = 10; + + /// + /// Gets or sets a boolean value that determines whether to buffer the data beyond the capacity. + /// + public bool BufferData { get; set; } = false; + + /// + public override void Load(IServiceProvider provider) + { + _plot = new TimeSeriesOxyPlotBase() + { + Capacity = Capacity, + Dock = DockStyle.Fill, + StartTime = DateTime.Now, + BufferData = BufferData + }; + + var capacityStatusLabel = new ToolStripStatusLabel + { + Text = "Capacity: ", + }; + + var capacityStatusControl = new ToolStripTextBox + { + Text = Capacity.ToString(), + }; + + capacityStatusControl.TextChanged += (sender, e) => + { + if (int.TryParse(capacityStatusControl.Text, out int capacity)) + { + Capacity = capacity; + _plot.Capacity = Capacity; + } + }; + + var bufferDataStatusLabel = new ToolStripStatusLabel + { + Text = "Buffer Data: ", + }; + + var bufferDataStatusControl = new ToolStripComboBox + { + Width = 100, + }; + + bufferDataStatusControl.Items.AddRange(new[] { "True", "False" }); + bufferDataStatusControl.SelectedIndex = BufferData ? 0 : 1; + + bufferDataStatusControl.SelectedIndexChanged += (sender, e) => + { + BufferData = bufferDataStatusControl.SelectedIndex == 0; + _plot.BufferData = BufferData; + }; + + _plot.StatusStrip.Items.Add(capacityStatusLabel); + _plot.StatusStrip.Items.Add(capacityStatusControl); + _plot.StatusStrip.Items.Add(bufferDataStatusLabel); + _plot.StatusStrip.Items.Add(bufferDataStatusControl); + + var visualizerService = (IDialogTypeVisualizerService)provider.GetService(typeof(IDialogTypeVisualizerService)); + visualizerService?.AddControl(_plot); + } + + /// + public override void Show(object value) + { + } + + /// + protected override void Show(DateTime time, object value) + { + if (value is null) return; + + var mean = value switch + { + FilteredResult filteredResult => filteredResult.UpdatedMean, + SmoothedResult smoothedResult => smoothedResult.SmoothedMean, + OrthogonalizedResult orthogonalizedResult => orthogonalizedResult.OrthogonalizedMean, + _ => throw new ArgumentException($"Expected value to be of type {nameof(FilteredResult)}, {nameof(SmoothedResult)}, or {nameof(OrthogonalizedResult)}.", nameof(value)) + }; + + var covariance = value switch + { + FilteredResult filteredResult => filteredResult.UpdatedCovariance, + SmoothedResult smoothedResult => smoothedResult.SmoothedCovariance, + OrthogonalizedResult orthogonalizedResult => orthogonalizedResult.OrthogonalizedCovariance, + _ => throw new ArgumentException($"Expected value to be of type {nameof(FilteredResult)}, {nameof(SmoothedResult)}, or {nameof(OrthogonalizedResult)}.", nameof(value)) + }; + + if (mean.Dimensions == 1) + { + mean = mean.unsqueeze(0); + covariance = covariance.unsqueeze(0); + } + + var numTimesteps = mean.shape[0]; + var numStates = mean.shape[1]; + + if (_lineSeries is null || _areaSeries is null) + { + var colors = new OxyColorPresetCycle(); + + _lineSeries = new LineSeries[numStates]; + _areaSeries = new AreaSeries[numStates]; + + for (int i = 0; i < numStates; i++) + { + var currentColor = colors.Next(); + + _lineSeries[i] = _plot.AddNewLineSeries( + lineSeriesName: $"Mean: {i}", + color: currentColor + ); + + _areaSeries[i] = _plot.AddNewAreaSeries( + areaSeriesName: $"Variance: {i}", + color: currentColor + ); + } + } + + var covarianceDiagonal = covariance.diagonal(0, 1, 2); + + for (int i = 0; i < numTimesteps; i++) + { + for (int j = 0; j < numStates; j++) + { + + var meanVal = mean[i, j].to_type(ScalarType.Float64).item(); + + _plot.AddToLineSeries( + lineSeries: _lineSeries[j], + time: time, + value: meanVal + ); + + var sigmaVal = covarianceDiagonal[i,j].sqrt().to_type(ScalarType.Float64).item(); + + _plot.AddToAreaSeries( + areaSeries: _areaSeries[j], + time: time, + value1: meanVal + sigmaVal, + value2: meanVal - sigmaVal + ); + } + } + } + + /// + protected override void ShowBuffer(IList> values) + { + base.ShowBuffer(values); + if (values.Count > 0) + { + var time = values.LastOrDefault().Timestamp.DateTime; + _plot.SetAxes(minTime: time.AddSeconds(-Capacity), maxTime: time); + _plot.UpdatePlot(); + } + } + + /// + public override void Unload() + { + if (!_plot.IsDisposed) _plot.Dispose(); + } +} From 3d9c83b2d9faaafc7cbf79bdfb757ba76002aa21 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 6 Oct 2025 11:45:16 +0100 Subject: [PATCH 37/92] Added explicit conversion to null for empty tensors --- .../CreateKalmanFilterParameters.cs | 39 ++++++++++--------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs index 0c6d5b2d..bf7c492d 100644 --- a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs +++ b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs @@ -190,16 +190,16 @@ private void ConvertTensorsScalarType(ScalarType scalarType) /// public IObservable Process() { - return Observable.Return( - new KalmanFilterParameters( - TransitionMatrix, - MeasurementFunction, - ProcessNoiseCovariance, - MeasurementNoiseCovariance, - InitialMean, - InitialCovariance - ) + var parameters = new KalmanFilterParameters( + TransitionMatrix.NumberOfElements > 0 ? TransitionMatrix : null, + MeasurementFunction.NumberOfElements > 0 ? MeasurementFunction : null, + ProcessNoiseCovariance.NumberOfElements > 0 ? ProcessNoiseCovariance : null, + MeasurementNoiseCovariance.NumberOfElements > 0 ? MeasurementNoiseCovariance : null, + InitialMean.NumberOfElements > 0 ? InitialMean : null, + InitialCovariance.NumberOfElements > 0 ? InitialCovariance : null ); + + return Observable.Return(parameters); } /// @@ -208,14 +208,17 @@ public IObservable Process() public IObservable Process(IObservable source) { return source.Select(_ => - new KalmanFilterParameters( - TransitionMatrix, - MeasurementFunction, - ProcessNoiseCovariance, - MeasurementNoiseCovariance, - InitialMean, - InitialCovariance - ) - ); + { + var parameters = new KalmanFilterParameters( + TransitionMatrix.NumberOfElements > 0 ? TransitionMatrix : null, + MeasurementFunction.NumberOfElements > 0 ? MeasurementFunction : null, + ProcessNoiseCovariance.NumberOfElements > 0 ? ProcessNoiseCovariance : null, + MeasurementNoiseCovariance.NumberOfElements > 0 ? MeasurementNoiseCovariance : null, + InitialMean.NumberOfElements > 0 ? InitialMean : null, + InitialCovariance.NumberOfElements > 0 ? InitialCovariance : null + ); + + return parameters; + }); } } \ No newline at end of file From 4313d3b3f3e3fb132f3b195a76db87e1dc84a0b6 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 6 Oct 2025 11:47:30 +0100 Subject: [PATCH 38/92] Removed initial values from non-static Kalman smoother --- src/Bonsai.ML.Torch.LDS/KalmanFilter.cs | 14 +------------- src/Bonsai.ML.Torch.LDS/SmoothedResult.cs | 16 +--------------- 2 files changed, 2 insertions(+), 28 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs index 708461d9..bb047064 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs @@ -429,21 +429,9 @@ public SmoothedResult Smooth(FilteredResult filteredResult) ); } - // Smoothed initial mean - var smoothedInitialMean = WrappedTensorDisposeScope(() => _initialMean + smoothingGain.matmul( - (smoothedMean[0] - predictedMean[0]).unsqueeze(-1) - ).squeeze(-1)); - - // Smoothed initial covariance - var smoothedInitialCovariance = WrappedTensorDisposeScope(() => _initialCovariance[0] + smoothingGain - .matmul(smoothedCovariance[0] - predictedCovariance[0]) - .matmul(smoothingGain.mT)); - return new SmoothedResult( smoothedMean, - smoothedCovariance, - smoothedInitialMean, - smoothedInitialCovariance + smoothedCovariance ); } diff --git a/src/Bonsai.ML.Torch.LDS/SmoothedResult.cs b/src/Bonsai.ML.Torch.LDS/SmoothedResult.cs index 17f55f22..8eb9cbe5 100644 --- a/src/Bonsai.ML.Torch.LDS/SmoothedResult.cs +++ b/src/Bonsai.ML.Torch.LDS/SmoothedResult.cs @@ -7,13 +7,9 @@ namespace Bonsai.ML.Torch.LDS; /// /// /// -/// -/// public struct SmoothedResult( Tensor smoothedMean, - Tensor smoothedCovariance, - Tensor smoothedInitialMean = null, - Tensor smoothedInitialCovariance = null) + Tensor smoothedCovariance) { /// /// The smoothed state after the smoothing step. @@ -24,14 +20,4 @@ public struct SmoothedResult( /// The smoothed covariance after the smoothing step. /// public Tensor SmoothedCovariance = smoothedCovariance; - - /// - /// The smoothed initial state after the smoothing step. - /// - public Tensor SmoothedInitialMean = smoothedInitialMean; - - /// - /// The smoothed initial covariance after the smoothing step. - /// - public Tensor SmoothedInitialCovariance = smoothedInitialCovariance; } \ No newline at end of file From 55dd7a6fd9b10b15f76041396056835362177568 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 6 Oct 2025 14:15:29 +0100 Subject: [PATCH 39/92] Removed explicit null conversion from empty tensor --- src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs | 24 ++++++++-------- .../CreateKalmanFilterParameters.cs | 28 +++++++++---------- 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs index f0c358ed..82b7a8a3 100644 --- a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs +++ b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs @@ -40,9 +40,7 @@ public int NumObservations } private ScalarType _scalarType = ScalarType.Float32; - /// - /// The data type of the tensor elements. - /// + /// [Description("The data type of the tensor elements.")] [TypeConverter(typeof(ScalarTypeConverter))] public ScalarType Type @@ -224,16 +222,16 @@ public string InitialCovarianceXml { return Observable.Using(() => KalmanFilterModelManager.Reserve( name: ModelName, - numStates: NumStates, - numObservations: NumObservations, - transitionMatrix: _transitionMatrix.NumberOfElements > 0 ? _transitionMatrix : null, - measurementFunction: _measurementFunction.NumberOfElements > 0 ? _measurementFunction : null, - initialMean: _initialMean.NumberOfElements > 0 ? _initialMean : null, - initialCovariance: _initialCovariance.NumberOfElements > 0 ? _initialCovariance : null, - processNoiseVariance: _processNoiseVariance.NumberOfElements > 0 ? _processNoiseVariance : null, - measurementNoiseVariance: _measurementNoiseVariance.NumberOfElements > 0 ? _measurementNoiseVariance : null, + numStates: _numStates, + numObservations: _numObservations, + transitionMatrix: _transitionMatrix, + measurementFunction: _measurementFunction, + initialMean: _initialMean, + initialCovariance: _initialCovariance, + processNoiseVariance: _processNoiseVariance, + measurementNoiseVariance: _measurementNoiseVariance, device: Device, - scalarType: Type + scalarType: _scalarType ), resource => Observable.Return(resource.Model) .Concat(Observable.Never(resource.Model)) .Finally(resource.Dispose) @@ -251,7 +249,7 @@ public string InitialCovarianceXml name: ModelName, parameters: parameters, device: Device, - scalarType: Type + scalarType: _scalarType ), resource => Observable.Return(resource.Model) .Concat(Observable.Never(resource.Model)) .Finally(resource.Dispose) diff --git a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs index bf7c492d..43ef7adb 100644 --- a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs +++ b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs @@ -15,9 +15,7 @@ namespace Bonsai.ML.Torch.LDS; [WorkflowElementCategory(ElementCategory.Source)] public class CreateKalmanFilterParameters : IScalarTypeProvider { - /// - /// The data type of the tensor elements. - /// + /// [Description("The data type of the tensor elements.")] [TypeConverter(typeof(ScalarTypeConverter))] public ScalarType Type @@ -191,12 +189,12 @@ private void ConvertTensorsScalarType(ScalarType scalarType) public IObservable Process() { var parameters = new KalmanFilterParameters( - TransitionMatrix.NumberOfElements > 0 ? TransitionMatrix : null, - MeasurementFunction.NumberOfElements > 0 ? MeasurementFunction : null, - ProcessNoiseCovariance.NumberOfElements > 0 ? ProcessNoiseCovariance : null, - MeasurementNoiseCovariance.NumberOfElements > 0 ? MeasurementNoiseCovariance : null, - InitialMean.NumberOfElements > 0 ? InitialMean : null, - InitialCovariance.NumberOfElements > 0 ? InitialCovariance : null + transitionMatrix: _transitionMatrix, + measurementFunction: _measurementFunction, + processNoiseCovariance: _processNoiseCovariance, + measurementNoiseCovariance: _measurementNoiseCovariance, + initialMean: _initialMean, + initialCovariance: _initialCovariance ); return Observable.Return(parameters); @@ -210,12 +208,12 @@ public IObservable Process(IObservable source) return source.Select(_ => { var parameters = new KalmanFilterParameters( - TransitionMatrix.NumberOfElements > 0 ? TransitionMatrix : null, - MeasurementFunction.NumberOfElements > 0 ? MeasurementFunction : null, - ProcessNoiseCovariance.NumberOfElements > 0 ? ProcessNoiseCovariance : null, - MeasurementNoiseCovariance.NumberOfElements > 0 ? MeasurementNoiseCovariance : null, - InitialMean.NumberOfElements > 0 ? InitialMean : null, - InitialCovariance.NumberOfElements > 0 ? InitialCovariance : null + transitionMatrix: _transitionMatrix, + measurementFunction: _measurementFunction, + processNoiseCovariance: _processNoiseCovariance, + measurementNoiseCovariance: _measurementNoiseCovariance, + initialMean: _initialMean, + initialCovariance: _initialCovariance ); return parameters; From 0aab968f8b8e144db177c2b0aba4833e2a62ef32 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 6 Oct 2025 14:18:25 +0100 Subject: [PATCH 40/92] Updated `KalmanFilter` with to allow automatically populating null parameter values with sufficient non-null parameter info --- src/Bonsai.ML.Torch.LDS/KalmanFilter.cs | 140 +++++++++++++++++------- 1 file changed, 100 insertions(+), 40 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs index bb047064..4769d56e 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs @@ -37,22 +37,37 @@ public KalmanFilter( _device = device ?? CPU; _scalarType = scalarType; - ValidateAndSetMatrix(parameters.TransitionMatrix, "Transition matrix", _scalarType, _device, out _transitionMatrix, out _numStates, out _, isSquare: true); - ValidateAndSetMatrix(parameters.MeasurementFunction, "Measurement function", _scalarType, _device, out _measurementFunction, out _numObservations, out _); - ValidateAndSetVector(parameters.InitialMean, "Initial mean", _scalarType, _device, out _initialMean, out _, expectedLength: _numStates); - ValidateAndSetMatrix(parameters.InitialCovariance, "Initial covariance", _scalarType, _device, out _initialCovariance, out _, out _, isSquare: true, expectedDimension1: _numStates); - ValidateAndSetMatrix(parameters.ProcessNoiseCovariance, "Process noise covariance", _scalarType, _device, out _processNoiseCovariance, out _, out _, isSquare: true, expectedDimension1: _numStates); - ValidateAndSetMatrix(parameters.MeasurementNoiseCovariance, "Measurement noise covariance", _scalarType, _device, out _measurementNoiseCovariance, out _, out _, isSquare: true, expectedDimension1: _numObservations); + ValidateNumStates(parameters.TransitionMatrix, parameters.MeasurementFunction, parameters.InitialMean, parameters.InitialCovariance, parameters.ProcessNoiseCovariance, out _numStates); + ValidateNumObservations(parameters.MeasurementFunction, parameters.MeasurementNoiseCovariance, out _numObservations); _identityStates = eye(_numStates, dtype: _scalarType, device: _device); + _transitionMatrix = parameters.TransitionMatrix?.clone().to_type(_scalarType).to(_device).requires_grad_(false) + ?? eye(_numStates, dtype: _scalarType, device: _device).requires_grad_(false); + ValidateMatrix(_transitionMatrix, "Transition matrix", isSquare: true, expectedDimension1: _numStates); + + _measurementFunction = parameters.MeasurementFunction?.clone().to_type(_scalarType).to(_device).requires_grad_(false) + ?? eye(_numObservations, _numStates, dtype: _scalarType, device: _device).requires_grad_(false); + ValidateMatrix(_measurementFunction, "Measurement function", expectedDimension1: _numObservations, expectedDimension2: _numStates); + + _initialMean = parameters.InitialMean?.clone().to_type(_scalarType).to(_device).requires_grad_(false) + ?? zeros(_numStates, dtype: _scalarType, device: _device).requires_grad_(false); + ValidateVector(_initialMean, "Initial mean", expectedLength: _numStates); + + _initialCovariance = parameters.InitialCovariance?.clone().to_type(_scalarType).to(_device).requires_grad_(false) + ?? eye(_numStates, dtype: _scalarType, device: _device).requires_grad_(false); + ValidateMatrix(_initialCovariance, "Initial covariance", isSquare: true, expectedDimension1: _numStates); + + _processNoiseCovariance = parameters.ProcessNoiseCovariance ?? CreateCovarianceMatrix(tensor(1.0), _scalarType, _device, _numStates, "Process noise variance"); + _measurementNoiseCovariance = parameters.MeasurementNoiseCovariance ?? CreateCovarianceMatrix(tensor(1.0), _scalarType, _device, _numObservations, "Measurement noise variance"); + _mean = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); _covariance = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); } public KalmanFilter( - int numStates, - int numObservations, + int? numStates = null, + int? numObservations = null, Tensor transitionMatrix = null, Tensor measurementFunction = null, Tensor initialMean = null, @@ -64,29 +79,46 @@ public KalmanFilter( { _device = device ?? CPU; _scalarType = scalarType; - _numStates = numStates; - _numObservations = numObservations; + + if (numStates is null) + { + ValidateNumStates(transitionMatrix, measurementFunction, initialMean, initialCovariance, processNoiseVariance, out var inferredNumStates); + _numStates = inferredNumStates; + } + else + _numStates = numStates.Value > 0 ? numStates.Value : throw new ArgumentOutOfRangeException(nameof(numStates), "Number of states must be greater than zero."); + + if (numObservations is null) + { + ValidateNumObservations(measurementFunction, measurementNoiseVariance, out var inferredNumObservations); + _numObservations = inferredNumObservations; + } + else + _numObservations = numObservations.Value > 0 ? numObservations.Value : throw new ArgumentOutOfRangeException(nameof(numObservations), "Number of observations must be greater than zero."); _identityStates = eye(_numStates, dtype: _scalarType, device: _device); _transitionMatrix = transitionMatrix?.clone().to_type(_scalarType).to(_device).requires_grad_(false) - ?? eye(_numStates, dtype: _scalarType, device: _device); + ?? eye(_numStates, dtype: _scalarType, device: _device).requires_grad_(false); ValidateMatrix(_transitionMatrix, "Transition matrix", isSquare: true, expectedDimension1: _numStates); _measurementFunction = measurementFunction?.clone().to_type(_scalarType).to(_device).requires_grad_(false) - ?? eye(_numObservations, _numStates, dtype: _scalarType, device: _device); + ?? eye(_numObservations, _numStates, dtype: _scalarType, device: _device).requires_grad_(false); ValidateMatrix(_measurementFunction, "Measurement function", expectedDimension1: _numObservations, expectedDimension2: _numStates); _initialMean = initialMean?.clone().to_type(_scalarType).to(_device).requires_grad_(false) - ?? zeros(_numStates, dtype: _scalarType, device: _device); + ?? zeros(_numStates, dtype: _scalarType, device: _device).requires_grad_(false); ValidateVector(_initialMean, "Initial mean", _numStates); _initialCovariance = initialCovariance?.clone().to_type(_scalarType).to(_device).requires_grad_(false) - ?? eye(_numStates, dtype: _scalarType, device: _device); + ?? eye(_numStates, dtype: _scalarType, device: _device).requires_grad_(false); ValidateMatrix(_initialCovariance, "Initial covariance", isSquare: true, expectedDimension1: _numStates); - _processNoiseCovariance = CreateCovarianceMatrix(processNoiseVariance, _scalarType, _device, numStates, "Process noise variance"); - _measurementNoiseCovariance = CreateCovarianceMatrix(measurementNoiseVariance, _scalarType, _device, numObservations, "Measurement noise variance"); + processNoiseVariance ??= tensor(1.0, dtype: _scalarType, device: _device); + measurementNoiseVariance ??= tensor(1.0, dtype: _scalarType, device: _device); + + _processNoiseCovariance = CreateCovarianceMatrix(processNoiseVariance, _scalarType, _device, _numStates, "Process noise variance"); + _measurementNoiseCovariance = CreateCovarianceMatrix(measurementNoiseVariance, _scalarType, _device, _numObservations, "Measurement noise variance"); _mean = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); _covariance = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); @@ -94,32 +126,62 @@ public KalmanFilter( RegisterComponents(); } - private static void ValidateAndSetMatrix(Tensor matrix, string name, ScalarType scalarType, Device device, out Tensor result, out int rows, out int columns, bool isSquare = false, int? expectedDimension1 = null, int? expectedDimension2 = null) + private static void ValidateNumStates(Tensor transitionMatrix, Tensor measurementFunction, Tensor initialMean, Tensor initialCovariance, Tensor processNoiseCovariance, out int numStates) { - ValidateMatrix(matrix, name, isSquare, expectedDimension1, expectedDimension2); - result = matrix.clone().to_type(scalarType).to(device).requires_grad_(false); - rows = (int)matrix.size(0); - columns = (int)matrix.size(1); - } - - private static void ValidateAndSetVector(Tensor vector, string name, ScalarType scalarType, Device device, out Tensor result, out int length, int? expectedLength = null) - { - ValidateVector(vector, name, expectedLength); - result = vector.clone().to_type(scalarType).to(device).requires_grad_(false); - length = (int)vector.size(0); + if (transitionMatrix is not null) + { + ValidateMatrix(transitionMatrix, "Transition matrix", isSquare: true); + numStates = (int)transitionMatrix.size(0); + } + else if (measurementFunction is not null) + { + ValidateMatrix(measurementFunction, "Measurement function"); + numStates = (int)measurementFunction.size(1); + } + else if (initialMean is not null) + { + ValidateVector(initialMean, "Initial mean"); + numStates = (int)initialMean.size(0); + } + else if (initialCovariance is not null) + { + ValidateMatrix(initialCovariance, "Initial covariance", isSquare: true); + numStates = (int)initialCovariance.size(0); + } + else if (processNoiseCovariance is not null) + { + ValidateMatrix(processNoiseCovariance, "Process noise covariance", isSquare: true); + numStates = (int)processNoiseCovariance.size(0); + } + else + { + throw new ArgumentException("At least one of the parameters must be provided to infer the number of states."); + } } - private static void ValidateAndSetScalar(Tensor scalar, string name, ScalarType scalarType, Device device, out Tensor result) + private static void ValidateNumObservations(Tensor measurementFunction, Tensor measurementNoiseCovariance, out int numObservations) { - ValidateScalar(scalar, name); - result = scalar.clone().squeeze().to_type(scalarType).to(device).requires_grad_(false); + if (measurementFunction is not null) + { + ValidateMatrix(measurementFunction, "Measurement function"); + numObservations = (int)measurementFunction.size(0); + } + else if (measurementNoiseCovariance is not null) + { + ValidateMatrix(measurementNoiseCovariance, "Measurement noise covariance", isSquare: true); + numObservations = (int)measurementNoiseCovariance.size(0); + } + else + { + throw new ArgumentException("At least one of the measurement function or measurement noise covariance must be provided to infer the number of observations."); + } } private static void ValidateMatrix(Tensor matrix, string name, bool isSquare = false, int? expectedDimension1 = null, int? expectedDimension2 = null) { - if (matrix is null) - throw new ArgumentException($"{name} cannot be null."); - + if (matrix.NumberOfElements == 0) + throw new ArgumentException($"{name} must be a non-empty matrix."); + if (matrix.Dimensions != 2) throw new ArgumentException($"{name} must be 2-dimensional."); @@ -135,8 +197,8 @@ private static void ValidateMatrix(Tensor matrix, string name, bool isSquare = f private static void ValidateVector(Tensor vector, string name, int? expectedLength = null) { - if (vector is null) - throw new ArgumentException($"{name} cannot be null."); + if (vector.NumberOfElements == 0) + throw new ArgumentException($"{name} must be a non-empty vector."); if (vector.Dimensions != 1) throw new ArgumentException($"{name} must be a vector."); @@ -147,16 +209,14 @@ private static void ValidateVector(Tensor vector, string name, int? expectedLeng private static void ValidateScalar(Tensor scalar, string name) { - if (scalar is null) - throw new ArgumentException($"{name} cannot be null."); - if (scalar.NumberOfElements != 1) throw new ArgumentException($"{name} must be a scalar."); } private Tensor CreateCovarianceMatrix(Tensor variance, ScalarType scalarType, Device device, int dimension, string name) { - ValidateAndSetScalar(variance, name, scalarType, device, out var scalar); + ValidateScalar(variance, name); + var scalar = variance.clone().squeeze().to_type(scalarType).to(device); return (scalar * eye(dimension, dtype: scalarType, device: device)).requires_grad_(false); } From 2989912d692a257b7cd315e5b0ec83fab418c9c8 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 6 Oct 2025 14:18:51 +0100 Subject: [PATCH 41/92] Updated to allow parameters to contain null tensor values --- src/Bonsai.ML.Torch.LDS/KalmanFilterParameters.cs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilterParameters.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilterParameters.cs index 5fadf1c3..7b56a163 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilterParameters.cs +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilterParameters.cs @@ -15,12 +15,12 @@ namespace Bonsai.ML.Torch.LDS; /// /// public struct KalmanFilterParameters( - Tensor transitionMatrix, - Tensor measurementFunction, - Tensor processNoiseCovariance, - Tensor measurementNoiseCovariance, - Tensor initialMean, - Tensor initialCovariance) + Tensor transitionMatrix = null, + Tensor measurementFunction = null, + Tensor processNoiseCovariance = null, + Tensor measurementNoiseCovariance = null, + Tensor initialMean = null, + Tensor initialCovariance = null) { /// /// The state transition matrix. From fdad188cd6b4884820b66633d425d2ac7bcaefbe Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 6 Oct 2025 14:20:11 +0100 Subject: [PATCH 42/92] Added XML docs to class --- src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs b/src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs index ab36b8f0..3372bc67 100644 --- a/src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs +++ b/src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs @@ -22,6 +22,9 @@ namespace Bonsai.ML.Torch.LDS.Design; +/// +/// Provides a visualizer for the state means and covariances from a Kalman filter or smoother. +/// public class StateVisualizer : BufferedVisualizer { private TimeSeriesOxyPlotBase _plot; From 9196517b54ac7a86bc329c70f7b07c60c408c9a5 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 6 Oct 2025 14:24:52 +0100 Subject: [PATCH 43/92] Refactored to use autosize and expose plot control --- src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs b/src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs index 3372bc67..04f555b6 100644 --- a/src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs +++ b/src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs @@ -41,6 +41,11 @@ public class StateVisualizer : BufferedVisualizer /// public bool BufferData { get; set; } = false; + /// + /// Gets the underlying plot control. + /// + public TimeSeriesOxyPlotBase Plot => _plot; + /// public override void Load(IServiceProvider provider) { @@ -55,11 +60,13 @@ public override void Load(IServiceProvider provider) var capacityStatusLabel = new ToolStripStatusLabel { Text = "Capacity: ", + AutoSize = true }; var capacityStatusControl = new ToolStripTextBox { Text = Capacity.ToString(), + AutoSize = true }; capacityStatusControl.TextChanged += (sender, e) => @@ -74,14 +81,15 @@ public override void Load(IServiceProvider provider) var bufferDataStatusLabel = new ToolStripStatusLabel { Text = "Buffer Data: ", + AutoSize = true }; var bufferDataStatusControl = new ToolStripComboBox { - Width = 100, + AutoSize = true }; - bufferDataStatusControl.Items.AddRange(new[] { "True", "False" }); + bufferDataStatusControl.Items.AddRange(["True", "False"]); bufferDataStatusControl.SelectedIndex = BufferData ? 0 : 1; bufferDataStatusControl.SelectedIndexChanged += (sender, e) => @@ -172,7 +180,7 @@ protected override void Show(DateTime time, object value) value: meanVal ); - var sigmaVal = covarianceDiagonal[i,j].sqrt().to_type(ScalarType.Float64).item(); + var sigmaVal = covarianceDiagonal[i, j].sqrt().to_type(ScalarType.Float64).item(); _plot.AddToAreaSeries( areaSeries: _areaSeries[j], From d65888a055e21cb5f38e20cd0fdaa1061af929c6 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 6 Oct 2025 14:27:14 +0100 Subject: [PATCH 44/92] Updated `NeuralLatentsTest` with default null values --- .../NeuralLatentsTest.bonsai | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.bonsai b/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.bonsai index f3959c7f..98b8375e 100644 --- a/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.bonsai +++ b/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.bonsai @@ -161,12 +161,12 @@ Float64 - [] - [] - [] - [] + + + + - [] + @@ -175,12 +175,12 @@ 2 2 Float64 - [] - [] - [] - [] + + + + - [] + From 7c1c070a6d39a137317ad53e251bdc6194a41c14 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 6 Oct 2025 16:38:09 +0100 Subject: [PATCH 45/92] Added categories to class properties and renamed `ModelName` to just `Name` for consistency --- src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs | 71 +++++++++++-------- .../KalmanFilterNameConverter.cs | 4 +- 2 files changed, 43 insertions(+), 32 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs index 82b7a8a3..d37083cd 100644 --- a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs +++ b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs @@ -17,32 +17,14 @@ public class CreateKalmanFilter : IScalarTypeProvider /// /// A unique name for the Kalman filter model. /// - public string ModelName { get; set; } = "KalmanFilter"; - - private int _numStates = 2; - /// - /// The number of states in the Kalman filter model. - /// - public int NumStates - { - get => _numStates; - set => _numStates = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(value), "Number of states must be greater than zero."); - } - - private int _numObservations = 2; - /// - /// The number of observations in the Kalman filter model. - /// - public int NumObservations - { - get => _numObservations; - set => _numObservations = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(value), "Number of observations must be greater than zero."); - } + [Category("Required Parameters")] + public string Name { get; set; } = "KalmanFilter"; private ScalarType _scalarType = ScalarType.Float32; /// [Description("The data type of the tensor elements.")] [TypeConverter(typeof(ScalarTypeConverter))] + [Category("Required Parameters")] public ScalarType Type { get => _scalarType; @@ -58,16 +40,29 @@ public ScalarType Type /// [Description("The device on which to create the tensor.")] [XmlIgnore] + [Category("Required Parameters")] public Device Device { get; set; } - private void ConvertTensorsScalarType(ScalarType scalarType) + private int _numStates = 2; + /// + /// The number of states in the Kalman filter model. + /// + [Category("Required Parameters")] + public int NumStates { - _transitionMatrix = _transitionMatrix?.to_type(scalarType); - _measurementFunction = _measurementFunction?.to_type(scalarType); - _processNoiseVariance = _processNoiseVariance?.to_type(scalarType); - _measurementNoiseVariance = _measurementNoiseVariance?.to_type(scalarType); - _initialMean = _initialMean?.to_type(scalarType); - _initialCovariance = _initialCovariance?.to_type(scalarType); + get => _numStates; + set => _numStates = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(value), "Number of states must be greater than zero."); + } + + private int _numObservations = 2; + /// + /// The number of observations in the Kalman filter model. + /// + [Category("Required Parameters")] + public int NumObservations + { + get => _numObservations; + set => _numObservations = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(value), "Number of observations must be greater than zero."); } // Tensor properties with XML serialization support @@ -77,6 +72,7 @@ private void ConvertTensorsScalarType(ScalarType scalarType) /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] + [Category("Optional Parameters")] public Tensor TransitionMatrix { get => _transitionMatrix; @@ -101,6 +97,7 @@ public string TransitionMatrixXml /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] + [Category("Optional Parameters")] public Tensor MeasurementFunction { get => _measurementFunction; @@ -125,6 +122,7 @@ public string MeasurementFunctionXml /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] + [Category("Optional Parameters")] public Tensor ProcessNoiseVariance { get => _processNoiseVariance; @@ -149,6 +147,7 @@ public string ProcessNoiseVarianceXml /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] + [Category("Optional Parameters")] public Tensor MeasurementNoiseVariance { get => _measurementNoiseVariance; @@ -173,6 +172,7 @@ public string MeasurementNoiseVarianceXml /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] + [Category("Optional Parameters")] public Tensor InitialMean { get => _initialMean; @@ -197,6 +197,7 @@ public string InitialMeanXml /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] + [Category("Optional Parameters")] public Tensor InitialCovariance { get => _initialCovariance; @@ -215,13 +216,23 @@ public string InitialCovarianceXml set => InitialCovariance = TensorConverter.ConvertFromString(value, _scalarType); } + private void ConvertTensorsScalarType(ScalarType scalarType) + { + _transitionMatrix = _transitionMatrix?.to_type(scalarType); + _measurementFunction = _measurementFunction?.to_type(scalarType); + _processNoiseVariance = _processNoiseVariance?.to_type(scalarType); + _measurementNoiseVariance = _measurementNoiseVariance?.to_type(scalarType); + _initialMean = _initialMean?.to_type(scalarType); + _initialCovariance = _initialCovariance?.to_type(scalarType); + } + /// /// Creates a Kalman filter model using the properties of this class. /// public IObservable Process() { return Observable.Using(() => KalmanFilterModelManager.Reserve( - name: ModelName, + name: Name, numStates: _numStates, numObservations: _numObservations, transitionMatrix: _transitionMatrix, @@ -246,7 +257,7 @@ public string InitialCovarianceXml return source.SelectMany(parameters => { return Observable.Using(() => KalmanFilterModelManager.Reserve( - name: ModelName, + name: Name, parameters: parameters, device: Device, scalarType: _scalarType diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilterNameConverter.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilterNameConverter.cs index 78eb99b2..0dee35d0 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilterNameConverter.cs +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilterNameConverter.cs @@ -28,8 +28,8 @@ where builder.GetType() != typeof(DisableBuilder) let managedModelNode = ExpressionBuilder.GetWorkflowElement(builder) where managedModelNode != null && managedModelNode is CreateKalmanFilter let createKalmanFilter = (CreateKalmanFilter)managedModelNode - where createKalmanFilter != null && !string.IsNullOrEmpty(createKalmanFilter.ModelName) - select createKalmanFilter.ModelName) + where createKalmanFilter != null && !string.IsNullOrEmpty(createKalmanFilter.Name) + select createKalmanFilter.Name) .Distinct() .ToList(); if (models.Count > 0) From f671f3392e425709d442045689d3ff67882ad2f9 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 7 Oct 2025 12:53:30 +0100 Subject: [PATCH 46/92] Added `ResetCombinator` to class attributes --- src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs | 1 + src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs | 1 + 2 files changed, 2 insertions(+) diff --git a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs index d37083cd..a3ba1ba2 100644 --- a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs +++ b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs @@ -10,6 +10,7 @@ namespace Bonsai.ML.Torch.LDS; /// Creates a Kalman filter model. /// [Combinator] +[ResetCombinator] [Description("Creates a Kalman filter model.")] [WorkflowElementCategory(ElementCategory.Source)] public class CreateKalmanFilter : IScalarTypeProvider diff --git a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs index 43ef7adb..abd77ee8 100644 --- a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs +++ b/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs @@ -11,6 +11,7 @@ namespace Bonsai.ML.Torch.LDS; /// Initializes the parameters for a new Kalman filter model. /// [Combinator] +[ResetCombinator] [Description("Initializes the parameters for a new Kalman filter model.")] [WorkflowElementCategory(ElementCategory.Source)] public class CreateKalmanFilterParameters : IScalarTypeProvider From 39897c1b068fbc39035ec2cb6814da1d4b45e6c4 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 7 Oct 2025 12:54:25 +0100 Subject: [PATCH 47/92] Added generic class and interface to represent LDS state --- src/Bonsai.ML.Torch.LDS/CreateLdsState.cs | 125 ++++++++++++++++++++++ src/Bonsai.ML.Torch.LDS/ILdsState.cs | 19 ++++ src/Bonsai.ML.Torch.LDS/LdsState.cs | 17 +++ 3 files changed, 161 insertions(+) create mode 100644 src/Bonsai.ML.Torch.LDS/CreateLdsState.cs create mode 100644 src/Bonsai.ML.Torch.LDS/ILdsState.cs create mode 100644 src/Bonsai.ML.Torch.LDS/LdsState.cs diff --git a/src/Bonsai.ML.Torch.LDS/CreateLdsState.cs b/src/Bonsai.ML.Torch.LDS/CreateLdsState.cs new file mode 100644 index 00000000..af7367a4 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/CreateLdsState.cs @@ -0,0 +1,125 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using System.Xml.Serialization; +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LDS; + +/// +/// Creates a generic state object for a linear gaussian dynamical system. +/// +[Combinator] +[ResetCombinator] +[Description("Creates a new state for a linear gaussian dynamical system.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class CreateLdsState : IScalarTypeProvider +{ + private ScalarType _scalarType = ScalarType.Float32; + /// + [Description("The data type of the tensor elements.")] + [TypeConverter(typeof(ScalarTypeConverter))] + public ScalarType Type + { + get => _scalarType; + set + { + _scalarType = value; + ConvertTensorsScalarType(value); + } + } + + /// + /// The device on which to create the tensor. + /// + [Description("The device on which to create the tensor.")] + [XmlIgnore] + public Device Device { get; set; } + + private void ConvertTensorsScalarType(ScalarType scalarType) + { + _mean = _mean?.to_type(scalarType); + _covariance = _covariance?.to_type(scalarType); + } + + private Tensor _mean = null; + /// + /// The mean of the state. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor Mean + { + get => _mean; + set => _mean = value?.to_type(Type); + } + + /// + /// The XML string representation of the mean for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(Mean))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string MeanXml + { + get => TensorConverter.ConvertToString(Mean, _scalarType); + set => Mean = TensorConverter.ConvertFromString(value, _scalarType); + } + + private Tensor _covariance = null; + /// + /// The covariance of the state. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor Covariance + { + get => _covariance; + set => _covariance = value?.to_type(Type); + } + + /// + /// The XML string representation of the covariance for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(Covariance))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string CovarianceXml + { + get => TensorConverter.ConvertToString(Covariance, _scalarType); + set => Covariance = TensorConverter.ConvertFromString(value, _scalarType); + } + + /// + /// Creates an observable sequence and emits the state for a linear gaussian dynamical system. + /// + /// + public IObservable Process() + { + return Observable.Defer(() => + { + var device = Device ?? CPU; + var mean = Mean?.to(device) ?? throw new InvalidOperationException("The mean of the state must be specified."); + var covariance = Covariance?.to(device) ?? throw new InvalidOperationException("The covariance of the state must be specified."); + return Observable.Return(new LdsState(mean, covariance)); + }); + } + + /// + /// Processes an observable sequence and emits new states for a linear gaussian dynamical system. + /// + /// + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(_ => + { + var device = Device ?? CPU; + var mean = Mean?.to(device) ?? throw new InvalidOperationException("The mean of the state must be specified."); + var covariance = Covariance?.to(device) ?? throw new InvalidOperationException("The covariance of the state must be specified."); + return new LdsState(mean, covariance); + }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/ILdsState.cs b/src/Bonsai.ML.Torch.LDS/ILdsState.cs new file mode 100644 index 00000000..ddeda06a --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/ILdsState.cs @@ -0,0 +1,19 @@ +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LDS; + +/// +/// Represents the state of a linear gaussian dynamical system. +/// +public interface ILdsState +{ + /// + /// The mean of the state. + /// + Tensor Mean { get; } + + /// + /// The covariance of the state. + /// + Tensor Covariance { get; } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/LdsState.cs b/src/Bonsai.ML.Torch.LDS/LdsState.cs new file mode 100644 index 00000000..8c778a12 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS/LdsState.cs @@ -0,0 +1,17 @@ +using static TorchSharp.torch; + +namespace Bonsai.ML.Torch.LDS; + +/// +/// Represents the state of a linear gaussian dynamical system. +/// +/// +/// +public class LdsState(Tensor mean, Tensor covariance) : ILdsState +{ + /// + public Tensor Mean => mean; + + /// + public Tensor Covariance => covariance; +} \ No newline at end of file From fcd19ad15ab78a79526fc7f96b53abe29f544162 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 7 Oct 2025 12:56:02 +0100 Subject: [PATCH 48/92] Removed `ResetCombinator` attribute from classes where it is not needed --- src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs | 1 - src/Bonsai.ML.Torch.LDS/Smooth.cs | 1 - 2 files changed, 2 deletions(-) diff --git a/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs b/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs index 71d1d7b7..7023cec5 100644 --- a/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs +++ b/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs @@ -11,7 +11,6 @@ namespace Bonsai.ML.Torch.LDS; /// Learn the parameters of a kalman filter using the batch EM update algorithm. /// [Combinator] -[ResetCombinator] [Description("Learn the parameters of a kalman filter using the batch EM update algorithm.")] [WorkflowElementCategory(ElementCategory.Transform)] public class ExpectationMaximization diff --git a/src/Bonsai.ML.Torch.LDS/Smooth.cs b/src/Bonsai.ML.Torch.LDS/Smooth.cs index 16b84fcb..2f57f281 100644 --- a/src/Bonsai.ML.Torch.LDS/Smooth.cs +++ b/src/Bonsai.ML.Torch.LDS/Smooth.cs @@ -9,7 +9,6 @@ namespace Bonsai.ML.Torch.LDS; /// Applies a Kalman smoother to the input filtered result sequence. /// [Combinator] -[ResetCombinator] [Description("Applies a Kalman smoother to the input filtered result sequence.")] [WorkflowElementCategory(ElementCategory.Transform)] public class Smooth From ffac1025a805cb1e57fb819012ec026aaf5373bd Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 7 Oct 2025 12:57:41 +0100 Subject: [PATCH 49/92] Changed naming from `xResult` to `xState` to for improved naming consistency --- .../StateVisualizer.cs | 28 ++--- src/Bonsai.ML.Torch.LDS/Filter.cs | 2 +- .../{FilteredResult.cs => FilteredState.cs} | 12 +- src/Bonsai.ML.Torch.LDS/KalmanFilter.cs | 110 ++++++++++-------- src/Bonsai.ML.Torch.LDS/Orthogonalize.cs | 38 +++++- ...alizedResult.cs => OrthogonalizedState.cs} | 14 ++- src/Bonsai.ML.Torch.LDS/Smooth.cs | 8 +- .../{SmoothedResult.cs => SmoothedState.cs} | 12 +- 8 files changed, 140 insertions(+), 84 deletions(-) rename src/Bonsai.ML.Torch.LDS/{FilteredResult.cs => FilteredState.cs} (77%) rename src/Bonsai.ML.Torch.LDS/{OrthogonalizedResult.cs => OrthogonalizedState.cs} (56%) rename src/Bonsai.ML.Torch.LDS/{SmoothedResult.cs => SmoothedState.cs} (64%) diff --git a/src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs b/src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs index 04f555b6..46cda2d0 100644 --- a/src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs +++ b/src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs @@ -14,11 +14,13 @@ using static TorchSharp.torch; [assembly: TypeVisualizer(typeof(Bonsai.ML.Torch.LDS.Design.StateVisualizer), - Target = typeof(Bonsai.ML.Torch.LDS.FilteredResult))] + Target = typeof(Bonsai.ML.Torch.LDS.FilteredState))] [assembly: TypeVisualizer(typeof(Bonsai.ML.Torch.LDS.Design.StateVisualizer), - Target = typeof(Bonsai.ML.Torch.LDS.SmoothedResult))] + Target = typeof(Bonsai.ML.Torch.LDS.SmoothedState))] [assembly: TypeVisualizer(typeof(Bonsai.ML.Torch.LDS.Design.StateVisualizer), - Target = typeof(Bonsai.ML.Torch.LDS.OrthogonalizedResult))] + Target = typeof(Bonsai.ML.Torch.LDS.OrthogonalizedState))] +[assembly: TypeVisualizer(typeof(Bonsai.ML.Torch.LDS.Design.StateVisualizer), + Target = typeof(Bonsai.ML.Torch.LDS.LdsState))] namespace Bonsai.ML.Torch.LDS.Design; @@ -117,21 +119,13 @@ protected override void Show(DateTime time, object value) { if (value is null) return; - var mean = value switch - { - FilteredResult filteredResult => filteredResult.UpdatedMean, - SmoothedResult smoothedResult => smoothedResult.SmoothedMean, - OrthogonalizedResult orthogonalizedResult => orthogonalizedResult.OrthogonalizedMean, - _ => throw new ArgumentException($"Expected value to be of type {nameof(FilteredResult)}, {nameof(SmoothedResult)}, or {nameof(OrthogonalizedResult)}.", nameof(value)) - }; + if (value is not ILdsState state) + throw new ArgumentException($"Expected value to be a type of {nameof(ILdsState)}.", nameof(value)); - var covariance = value switch - { - FilteredResult filteredResult => filteredResult.UpdatedCovariance, - SmoothedResult smoothedResult => smoothedResult.SmoothedCovariance, - OrthogonalizedResult orthogonalizedResult => orthogonalizedResult.OrthogonalizedCovariance, - _ => throw new ArgumentException($"Expected value to be of type {nameof(FilteredResult)}, {nameof(SmoothedResult)}, or {nameof(OrthogonalizedResult)}.", nameof(value)) - }; + var mean = state.Mean; + var covariance = state.Covariance; + + if (mean is null || covariance is null) return; if (mean.Dimensions == 1) { diff --git a/src/Bonsai.ML.Torch.LDS/Filter.cs b/src/Bonsai.ML.Torch.LDS/Filter.cs index 78a1f8e7..853d4026 100644 --- a/src/Bonsai.ML.Torch.LDS/Filter.cs +++ b/src/Bonsai.ML.Torch.LDS/Filter.cs @@ -23,7 +23,7 @@ public class Filter /// /// Processes an observable sequence of input tensors, applying the Kalman filter to each tensor. /// - public IObservable Process(IObservable source) + public IObservable Process(IObservable source) { return source.Select((input) => { diff --git a/src/Bonsai.ML.Torch.LDS/FilteredResult.cs b/src/Bonsai.ML.Torch.LDS/FilteredState.cs similarity index 77% rename from src/Bonsai.ML.Torch.LDS/FilteredResult.cs rename to src/Bonsai.ML.Torch.LDS/FilteredState.cs index 0eec4356..e2b9a8af 100644 --- a/src/Bonsai.ML.Torch.LDS/FilteredResult.cs +++ b/src/Bonsai.ML.Torch.LDS/FilteredState.cs @@ -3,17 +3,17 @@ namespace Bonsai.ML.Torch.LDS; /// -/// Represents the result of a Kalman filter. +/// Represents the state of a Kalman filter. /// /// /// /// /// -public struct FilteredResult( +public struct FilteredState( Tensor predictedMean, Tensor predictedCovariance, Tensor updatedMean, - Tensor updatedCovariance) + Tensor updatedCovariance) : ILdsState { /// /// The predicted mean after the prediction step. @@ -34,4 +34,10 @@ public struct FilteredResult( /// The updated covariance after the update step. /// public Tensor UpdatedCovariance = updatedCovariance; + + /// + public readonly Tensor Mean => UpdatedMean; + + /// + public readonly Tensor Covariance => UpdatedCovariance; } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs index 4769d56e..7b9c4211 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs +++ b/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs @@ -220,7 +220,7 @@ private Tensor CreateCovarianceMatrix(Tensor variance, ScalarType scalarType, De return (scalar * eye(dimension, dtype: scalarType, device: device)).requires_grad_(false); } - private readonly struct PredictedResult( + private readonly struct PredictedState( Tensor predictedMean, Tensor predictedCovariance) { @@ -228,14 +228,14 @@ private readonly struct PredictedResult( public readonly Tensor PredictedCovariance = predictedCovariance; } - private PredictedResult FilterPredict( + private PredictedState FilterPredict( Tensor mean, Tensor covariance) => new(_transitionMatrix.matmul(mean), _transitionMatrix.matmul(covariance) .matmul(_transitionMatrix.mT) + _processNoiseCovariance); - private static PredictedResult FilterPredict( + private static PredictedState FilterPredict( Tensor mean, Tensor covariance, Tensor transitionMatrix, @@ -244,7 +244,7 @@ private static PredictedResult FilterPredict( transitionMatrix.matmul(covariance) .matmul(transitionMatrix.mT) + processNoiseCovariance); - private readonly struct UpdatedResult( + private readonly struct UpdatedState( Tensor updatedMean, Tensor updatedCovariance, Tensor innovation, @@ -258,7 +258,7 @@ private readonly struct UpdatedResult( public readonly Tensor KalmanGain = kalmanGain; } - private UpdatedResult FilterUpdate( + private UpdatedState FilterUpdate( Tensor predictedMean, Tensor predictedCovariance, Tensor observation) @@ -279,10 +279,10 @@ private UpdatedResult FilterUpdate( var updatedCovariance = WrappedTensorDisposeScope(() => predictedCovariance - kalmanGain.matmul(_measurementFunction).matmul(predictedCovariance)); - return new UpdatedResult(updatedMean, updatedCovariance, innovation, innovationCovariance, kalmanGain); + return new UpdatedState(updatedMean, updatedCovariance, innovation, innovationCovariance, kalmanGain); } - private static UpdatedResult FilterUpdate( + private static UpdatedState FilterUpdate( Tensor predictedMean, Tensor predictedCovariance, Tensor observation, @@ -305,7 +305,7 @@ private static UpdatedResult FilterUpdate( var updatedCovariance = WrappedTensorDisposeScope(() => predictedCovariance - kalmanGain.matmul(measurementFunction).matmul(predictedCovariance)); - return new UpdatedResult( + return new UpdatedState( updatedMean: updatedMean, updatedCovariance: updatedCovariance, innovation: innovation, @@ -314,7 +314,7 @@ private static UpdatedResult FilterUpdate( ); } - public FilteredResult Filter(Tensor observation) + public FilteredState Filter(Tensor observation) { using var g = no_grad(); @@ -348,14 +348,14 @@ public FilteredResult Filter(Tensor observation) _covariance.set_(update.UpdatedCovariance); } - return new FilteredResult( + return new FilteredState( predictedMean: predictedMean, predictedCovariance: predictedCovariance, updatedMean: updatedMean, updatedCovariance: updatedCovariance); } - private readonly struct FilteredResultWithAuxiliaryVariables( + private readonly struct FilteredStateWithAuxiliaryVariables( Tensor predictedMean, Tensor predictedCovariance, Tensor updatedMean, @@ -375,7 +375,7 @@ private readonly struct FilteredResultWithAuxiliaryVariables( public readonly Tensor KalmanGain = kalmanGain; } - private static FilteredResultWithAuxiliaryVariables Filter( + private static FilteredStateWithAuxiliaryVariables Filter( Tensor observation, long timeBins, int numStates, @@ -437,7 +437,7 @@ private static FilteredResultWithAuxiliaryVariables Filter( covariance = update.UpdatedCovariance; } - return new FilteredResultWithAuxiliaryVariables( + return new FilteredStateWithAuxiliaryVariables( predictedMean: predictedMean, predictedCovariance: predictedCovariance, updatedMean: updatedMean, @@ -449,14 +449,14 @@ private static FilteredResultWithAuxiliaryVariables Filter( ); } - public SmoothedResult Smooth(FilteredResult filteredResult) + public SmoothedState Smooth(FilteredState filteredState) { using var g = no_grad(); - var predictedMean = filteredResult.PredictedMean; - var predictedCovariance = filteredResult.PredictedCovariance; - var updatedMean = filteredResult.UpdatedMean; - var updatedCovariance = filteredResult.UpdatedCovariance; + var predictedMean = filteredState.PredictedMean; + var predictedCovariance = filteredState.PredictedCovariance; + var updatedMean = filteredState.UpdatedMean; + var updatedCovariance = filteredState.UpdatedCovariance; var timeBins = predictedMean.size(0); var smoothedMean = empty_like(updatedMean); @@ -489,13 +489,13 @@ public SmoothedResult Smooth(FilteredResult filteredResult) ); } - return new SmoothedResult( + return new SmoothedState( smoothedMean, smoothedCovariance ); } - private readonly struct SmoothedResultWithAuxiliaryVariables( + private readonly struct SmoothedStateWithAuxiliaryVariables( Tensor smoothedMean, Tensor smoothedCovariance, Tensor smoothedInitialMean, @@ -513,8 +513,8 @@ private readonly struct SmoothedResultWithAuxiliaryVariables( public readonly Tensor S11 = S11; } - private static SmoothedResultWithAuxiliaryVariables Smooth( - FilteredResultWithAuxiliaryVariables filteredResult, + private static SmoothedStateWithAuxiliaryVariables Smooth( + FilteredStateWithAuxiliaryVariables filteredState, long timeBins, int numStates, Tensor transitionMatrix, @@ -529,11 +529,11 @@ Device device if (timeBins < 2) throw new ArgumentException("Smoothing requires at least two time bins."); - var predictedMean = filteredResult.PredictedMean; - var predictedCovariance = filteredResult.PredictedCovariance; - var updatedMean = filteredResult.UpdatedMean; - var updatedCovariance = filteredResult.UpdatedCovariance; - var kalmanGain = filteredResult.KalmanGain; + var predictedMean = filteredState.PredictedMean; + var predictedCovariance = filteredState.PredictedCovariance; + var updatedMean = filteredState.UpdatedMean; + var updatedCovariance = filteredState.UpdatedCovariance; + var kalmanGain = filteredState.KalmanGain; var smoothedMean = empty_like(updatedMean); var smoothedCovariance = empty_like(updatedCovariance); @@ -620,7 +620,7 @@ Device device S10[0] = outer(smoothedMean[0], smoothedInitialMean) + smoothedLagOneCovariance; S00[0] = outer(smoothedInitialMean, smoothedInitialMean) + smoothedInitialCovariance; - return new SmoothedResultWithAuxiliaryVariables( + return new SmoothedStateWithAuxiliaryVariables( smoothedMean: smoothedMean, smoothedCovariance: smoothedCovariance, smoothedInitialMean: smoothedInitialMean, @@ -659,7 +659,7 @@ public ExpectationMaximizationResult ExpectationMaximization( for (int iteration = 0; iteration < maxIterations; iteration++) { // Filter observations - var filteredResult = Filter( + var filteredState = Filter( observation: observation, timeBins: timeBins, numStates: _numStates, @@ -674,7 +674,7 @@ public ExpectationMaximizationResult ExpectationMaximization( device: _device); // Compute log likelihood (avoid creating intermediate tensors) - var llSumDouble = filteredResult.LogLikelihood.sum() + var llSumDouble = filteredState.LogLikelihood.sum() .to_type(ScalarType.Float64).item(); var filteredLogLikelihoodSum = logLikelihoodConst + 0.5 * llSumDouble; @@ -693,8 +693,8 @@ public ExpectationMaximizationResult ExpectationMaximization( previousLogLikelihood = filteredLogLikelihoodSum; // Smooth the filtered results - var smoothedResult = Smooth( - filteredResult: filteredResult, + var smoothedState = Smooth( + filteredState: filteredState, timeBins: timeBins, numStates: _numStates, transitionMatrix: transitionMatrix, @@ -706,12 +706,12 @@ public ExpectationMaximizationResult ExpectationMaximization( device: _device); // Sufficient statistics - var S00 = smoothedResult.S00.sum([0]); - var S11 = smoothedResult.S11.sum([0]); - var S10 = smoothedResult.S10.sum([0]); + var S00 = smoothedState.S00.sum([0]); + var S11 = smoothedState.S11.sum([0]); + var S10 = smoothedState.S10.sum([0]); // Replace einsum with faster matmul - var crossCorrelationObservations = observationT.matmul(smoothedResult.SmoothedMean); + var crossCorrelationObservations = observationT.matmul(smoothedState.SmoothedMean); // Update parameters if (parametersToEstimate.TransitionMatrix) @@ -732,10 +732,10 @@ public ExpectationMaximizationResult ExpectationMaximization( + measurementFunction.matmul(S11).matmul(measurementFunction.mT)) / timeBins)); if (parametersToEstimate.InitialMean) - initialMean = smoothedResult.SmoothedInitialMean; + initialMean = smoothedState.SmoothedInitialMean; if (parametersToEstimate.InitialCovariance) - initialCovariance = smoothedResult.SmoothedInitialCovariance; + initialCovariance = smoothedState.SmoothedInitialCovariance; } } @@ -754,17 +754,23 @@ public ExpectationMaximizationResult ExpectationMaximization( return new ExpectationMaximizationResult(logLikelihood, updatedParameters); } - public OrthogonalizedResult OrthogonalizeMeanAndCovariance(Tensor mean, Tensor covariance) + public OrthogonalizedState OrthogonalizeMeanAndCovariance(Tensor mean, Tensor covariance) { var (_, S, Vt) = linalg.svd(_measurementFunction); var SVt = diag(S).matmul(Vt); - var orthogonalizedMean = matmul(mean, SVt.mT); + Tensor orthogonalizedMean = null; + if (mean is not null) + orthogonalizedMean = matmul(mean, SVt.mT); - var auxilary = matmul(SVt, covariance); - var orthogonalizedCovariance = matmul(auxilary, SVt.mT); + Tensor orthogonalizedCovariance = null; + if (covariance is not null) + { + var auxilary = matmul(SVt, covariance); + orthogonalizedCovariance = matmul(auxilary, SVt.mT); + } - return new OrthogonalizedResult( + return new OrthogonalizedState( orthogonalizedMean: orthogonalizedMean, orthogonalizedCovariance: orthogonalizedCovariance ); @@ -772,12 +778,18 @@ public OrthogonalizedResult OrthogonalizeMeanAndCovariance(Tensor mean, Tensor c public void UpdateParameters(KalmanFilterParameters updatedParameters) { - _transitionMatrix.set_(updatedParameters.TransitionMatrix); - _measurementFunction.set_(updatedParameters.MeasurementFunction); - _processNoiseCovariance.set_(updatedParameters.ProcessNoiseCovariance); - _measurementNoiseCovariance.set_(updatedParameters.MeasurementNoiseCovariance); - _initialMean.set_(updatedParameters.InitialMean); - _initialCovariance.set_(updatedParameters.InitialCovariance); + if (updatedParameters.TransitionMatrix is not null) + _transitionMatrix.set_(updatedParameters.TransitionMatrix); + if (updatedParameters.MeasurementFunction is not null) + _measurementFunction.set_(updatedParameters.MeasurementFunction); + if (updatedParameters.ProcessNoiseCovariance is not null) + _processNoiseCovariance.set_(updatedParameters.ProcessNoiseCovariance); + if (updatedParameters.MeasurementNoiseCovariance is not null) + _measurementNoiseCovariance.set_(updatedParameters.MeasurementNoiseCovariance); + if (updatedParameters.InitialMean is not null) + _initialMean.set_(updatedParameters.InitialMean); + if (updatedParameters.InitialCovariance is not null) + _initialCovariance.set_(updatedParameters.InitialCovariance); } private static Tensor EnsureSymmetric(Tensor M) => 0.5 * (M + M.mT); diff --git a/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs b/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs index 06142419..598bb6f7 100644 --- a/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs +++ b/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs @@ -25,7 +25,7 @@ public class Orthogonalize /// /// /// - public IObservable Process(IObservable source) + public IObservable Process(IObservable source) { return source.Select(input => { @@ -41,7 +41,7 @@ public IObservable Process(IObservable sou /// /// /// - public IObservable Process(IObservable source) + public IObservable Process(IObservable source) { return source.Select(input => { @@ -52,12 +52,44 @@ public IObservable Process(IObservable sou }); } + /// + /// Processes an observable sequence of filtered results, orthogonalizing the mean and covariance estimates. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => + { + var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); + var mean = input.Mean; + var covariance = input.Covariance; + return kalmanFilter.OrthogonalizeMeanAndCovariance(mean, covariance); + }); + } + + /// + /// Processes an observable sequence of LDS states, orthogonalizing the mean and covariance estimates. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => + { + var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); + var mean = input.Mean; + var covariance = input.Covariance; + return kalmanFilter.OrthogonalizeMeanAndCovariance(mean, covariance); + }); + } + /// /// Processes an observable sequence of mean and covariance tuples, orthogonalizing the mean and covariance estimates. /// /// /// - public IObservable Process(IObservable> source) + public IObservable Process(IObservable> source) { return source.Select(input => { diff --git a/src/Bonsai.ML.Torch.LDS/OrthogonalizedResult.cs b/src/Bonsai.ML.Torch.LDS/OrthogonalizedState.cs similarity index 56% rename from src/Bonsai.ML.Torch.LDS/OrthogonalizedResult.cs rename to src/Bonsai.ML.Torch.LDS/OrthogonalizedState.cs index dcf810a0..e0cddea4 100644 --- a/src/Bonsai.ML.Torch.LDS/OrthogonalizedResult.cs +++ b/src/Bonsai.ML.Torch.LDS/OrthogonalizedState.cs @@ -3,16 +3,16 @@ namespace Bonsai.ML.Torch.LDS; /// -/// Represents the result of orthogonalizing the mean and covariance estimates. +/// Represents the state of an LDS after orthogonalizing the state mean and covariance estimates. /// /// -/// Initializes a new instance of the struct. +/// Initializes a new instance of the struct. /// /// /// -public struct OrthogonalizedResult( +public struct OrthogonalizedState( Tensor orthogonalizedMean, - Tensor orthogonalizedCovariance) + Tensor orthogonalizedCovariance) : ILdsState { /// /// The orthogonalized mean estimate. @@ -23,4 +23,10 @@ public struct OrthogonalizedResult( /// The orthogonalized covariance estimate. /// public Tensor OrthogonalizedCovariance = orthogonalizedCovariance; + + /// + public readonly Tensor Mean => OrthogonalizedMean; + + /// + public readonly Tensor Covariance => OrthogonalizedCovariance; } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/Smooth.cs b/src/Bonsai.ML.Torch.LDS/Smooth.cs index 2f57f281..a441f509 100644 --- a/src/Bonsai.ML.Torch.LDS/Smooth.cs +++ b/src/Bonsai.ML.Torch.LDS/Smooth.cs @@ -25,7 +25,7 @@ public class Smooth /// /// /// - public IObservable Process(IObservable source) + public IObservable Process(IObservable source) { return source.Select((input) => { @@ -39,13 +39,13 @@ public IObservable Process(IObservable source) /// /// /// - public IObservable Process(IObservable> source) + public IObservable Process(IObservable> source) { return source.Select((input) => { var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); - var filteredResult = new FilteredResult(input.Item1, input.Item2, input.Item3, input.Item4); - return kalmanFilter.Smooth(filteredResult); + var filteredState = new FilteredState(input.Item1, input.Item2, input.Item3, input.Item4); + return kalmanFilter.Smooth(filteredState); }); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/SmoothedResult.cs b/src/Bonsai.ML.Torch.LDS/SmoothedState.cs similarity index 64% rename from src/Bonsai.ML.Torch.LDS/SmoothedResult.cs rename to src/Bonsai.ML.Torch.LDS/SmoothedState.cs index 8eb9cbe5..db87a119 100644 --- a/src/Bonsai.ML.Torch.LDS/SmoothedResult.cs +++ b/src/Bonsai.ML.Torch.LDS/SmoothedState.cs @@ -3,13 +3,13 @@ namespace Bonsai.ML.Torch.LDS; /// -/// Represents the result of a Kalman smoother. +/// Represents the state of a Kalman smoother. /// /// /// -public struct SmoothedResult( +public struct SmoothedState( Tensor smoothedMean, - Tensor smoothedCovariance) + Tensor smoothedCovariance) : ILdsState { /// /// The smoothed state after the smoothing step. @@ -20,4 +20,10 @@ public struct SmoothedResult( /// The smoothed covariance after the smoothing step. /// public Tensor SmoothedCovariance = smoothedCovariance; + + /// + public readonly Tensor Mean => SmoothedMean; + + /// + public readonly Tensor Covariance => SmoothedCovariance; } \ No newline at end of file From 8be2f91ed4b887350074c6a588c2a5a65f99c6c7 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 7 Oct 2025 12:58:45 +0100 Subject: [PATCH 50/92] Added property `Bonsai.ML.Torch.LDS.Design` project to ignore repackaging SkiaSharp libraries and avoid the nuget warning --- .../Bonsai.ML.Torch.LDS.Design.csproj | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/Bonsai.ML.Torch.LDS.Design/Bonsai.ML.Torch.LDS.Design.csproj b/src/Bonsai.ML.Torch.LDS.Design/Bonsai.ML.Torch.LDS.Design.csproj index 009aabc3..e943d766 100644 --- a/src/Bonsai.ML.Torch.LDS.Design/Bonsai.ML.Torch.LDS.Design.csproj +++ b/src/Bonsai.ML.Torch.LDS.Design/Bonsai.ML.Torch.LDS.Design.csproj @@ -12,4 +12,8 @@ + + + false + \ No newline at end of file From ae4488949bb4e32497da10a358441b57d7083fe8 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 7 Oct 2025 18:42:32 +0100 Subject: [PATCH 51/92] Refactored `ExpectationMaximization` to emit values on each iteration and added a visualizer to observe log likelihood over time --- .../ExpectationMaximizationVisualizer.cs | 90 ++++++++++++++++ .../ExpectationMaximization.cs | 100 +++++++++++------- .../ExpectationMaximizationResult.cs | 9 +- 3 files changed, 157 insertions(+), 42 deletions(-) create mode 100644 src/Bonsai.ML.Torch.LDS.Design/ExpectationMaximizationVisualizer.cs diff --git a/src/Bonsai.ML.Torch.LDS.Design/ExpectationMaximizationVisualizer.cs b/src/Bonsai.ML.Torch.LDS.Design/ExpectationMaximizationVisualizer.cs new file mode 100644 index 00000000..bf366598 --- /dev/null +++ b/src/Bonsai.ML.Torch.LDS.Design/ExpectationMaximizationVisualizer.cs @@ -0,0 +1,90 @@ +using System; +using System.Reactive; +using System.Linq; +using System.Windows.Forms; +using System.Collections.Generic; + +using Bonsai; +using Bonsai.Design; +using Bonsai.ML.Design; + +using OxyPlot; +using OxyPlot.Series; + +using static TorchSharp.torch; + +[assembly: TypeVisualizer(typeof(Bonsai.ML.Torch.LDS.Design.ExpectationMaximizationVisualizer), + Target = typeof(Bonsai.ML.Torch.LDS.ExpectationMaximizationResult))] + +namespace Bonsai.ML.Torch.LDS.Design; + +/// +/// Provides a visualizer for the state means and covariances from a Kalman filter or smoother. +/// +public class ExpectationMaximizationVisualizer : BufferedVisualizer +{ + private TimeSeriesOxyPlotBase _plot; + private LineSeries _lineSeries; + + /// + /// Gets the underlying plot control. + /// + public TimeSeriesOxyPlotBase Plot => _plot; + + /// + public override void Load(IServiceProvider provider) + { + _plot = new TimeSeriesOxyPlotBase() + { + Dock = DockStyle.Fill, + StartTime = DateTime.Now, + BufferData = true, + ValueLabel = "Log Likelihood" + }; + + _lineSeries = _plot.AddNewLineSeries("Log Likelihood", OxyColors.Blue); + + var visualizerService = (IDialogTypeVisualizerService)provider.GetService(typeof(IDialogTypeVisualizerService)); + visualizerService?.AddControl(_plot); + } + + /// + public override void Show(object value) + { + } + + /// + protected override void Show(DateTime time, object value) + { + if (value is null) return; + + if (value is not ExpectationMaximizationResult result) return; + + var logLikelihood = result.LogLikelihood; + if (logLikelihood is null) return; + + var ll = logLikelihood[-1].to_type(ScalarType.Float64).item(); + + _plot.AddToLineSeries( + lineSeries: _lineSeries, + time: time, + value: ll + ); + } + + /// + protected override void ShowBuffer(IList> values) + { + base.ShowBuffer(values); + if (values.Count > 0) + { + _plot.UpdatePlot(); + } + } + + /// + public override void Unload() + { + if (!_plot.IsDisposed) _plot.Dispose(); + } +} diff --git a/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs b/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs index 7023cec5..b692559b 100644 --- a/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs +++ b/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs @@ -2,8 +2,8 @@ using System.ComponentModel; using System.Reactive.Linq; using TorchSharp; -using System.Collections.Generic; using static TorchSharp.torch; +using System.Threading.Tasks; namespace Bonsai.ML.Torch.LDS; @@ -98,58 +98,76 @@ public bool Verbose /// public IObservable Process(IObservable source) { - return source.Select(input => + return source.SelectMany(input => Observable.Create((observer, cancellationToken) => { - var model = KalmanFilterModelManager.GetKalmanFilter(ModelName); - var previousLogLikelihood = double.NegativeInfinity; - var logLikelihood = zeros(new long[] { MaxIterations }, device: input.device); - - var parametersToEstimate = new ParametersToEstimate( - transitionMatrix: EstimateTransitionMatrix, - measurementFunction: EstimateMeasurementFunction, - processNoiseCovariance: EstimateProcessNoiseCovariance, - measurementNoiseCovariance: EstimateMeasurementNoiseCovariance, - initialMean: EstimateInitialMean, - initialCovariance: EstimateInitialCovariance); - - for (int i = 0; i < MaxIterations; i++) + return Task.Run(() => { - var result = model.ExpectationMaximization(input, 1, Tolerance, parametersToEstimate, false); + var model = KalmanFilterModelManager.GetKalmanFilter(ModelName); + var previousLogLikelihood = double.NegativeInfinity; + var logLikelihood = zeros(new long[] { MaxIterations }, device: input.device); + + var parametersToEstimate = new ParametersToEstimate( + transitionMatrix: EstimateTransitionMatrix, + measurementFunction: EstimateMeasurementFunction, + processNoiseCovariance: EstimateProcessNoiseCovariance, + measurementNoiseCovariance: EstimateMeasurementNoiseCovariance, + initialMean: EstimateInitialMean, + initialCovariance: EstimateInitialCovariance); + + for (int i = 0; i < MaxIterations; i++) + { + // Check for cancellation before each iteration + if (cancellationToken.IsCancellationRequested) + { + observer.OnCompleted(); + return System.Reactive.Disposables.Disposable.Empty; + } - var logLikelihoodSum = result.LogLikelihood - .cpu() - .to_type(ScalarType.Float32) - .ReadCpuSingle(0); + var result = model.ExpectationMaximization(input, 1, Tolerance, parametersToEstimate, false); - logLikelihood[i] = logLikelihoodSum; + var logLikelihoodSum = result.LogLikelihood + .cpu() + .to_type(ScalarType.Float32) + .ReadCpuSingle(0); - if (Verbose) - { - Console.WriteLine("Iteration " + (i + 1) + ", Log Likelihood: " + logLikelihoodSum); - if (i == MaxIterations - 1) + logLikelihood[i] = logLikelihoodSum; + + if (Verbose) { - Console.WriteLine("EM reached the maximum number of iterations."); + Console.WriteLine("Iteration " + (i + 1) + ", Log Likelihood: " + logLikelihoodSum); + if (i == MaxIterations - 1) + { + Console.WriteLine("EM reached the maximum number of iterations."); + } } - } - if (logLikelihoodSum - previousLogLikelihood < Tolerance) - { - if (Verbose) + if (logLikelihoodSum - previousLogLikelihood < Tolerance) { - Console.WriteLine("EM converged after " + (i + 1) + " iterations."); + if (Verbose) + { + Console.WriteLine("EM converged after " + (i + 1) + " iterations."); + } + logLikelihood = logLikelihood[torch.TensorIndex.Slice(0, i + 1)]; + break; } - logLikelihood = logLikelihood[torch.TensorIndex.Slice(0, i + 1)]; - break; + previousLogLikelihood = logLikelihoodSum; + model.UpdateParameters(result.Parameters); + + observer.OnNext(new ExpectationMaximizationResult( + logLikelihood: logLikelihood[torch.TensorIndex.Slice(0, i + 1)], + parameters: model.Parameters, + finished: false)); } - previousLogLikelihood = logLikelihoodSum; - model.UpdateParameters(result.Parameters); - } - var expectationMaximizationResult = new ExpectationMaximizationResult( - logLikelihood, - model.Parameters); + observer.OnNext(new ExpectationMaximizationResult( + logLikelihood: logLikelihood, + parameters: model.Parameters, + finished: true)); - return expectationMaximizationResult; - }); + observer.OnCompleted(); + return System.Reactive.Disposables.Disposable.Empty; + }, + cancellationToken); + })); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Torch.LDS/ExpectationMaximizationResult.cs b/src/Bonsai.ML.Torch.LDS/ExpectationMaximizationResult.cs index a51a54b0..75b32751 100644 --- a/src/Bonsai.ML.Torch.LDS/ExpectationMaximizationResult.cs +++ b/src/Bonsai.ML.Torch.LDS/ExpectationMaximizationResult.cs @@ -7,9 +7,11 @@ namespace Bonsai.ML.Torch.LDS; /// /// /// +/// public struct ExpectationMaximizationResult( Tensor logLikelihood, - KalmanFilterParameters parameters) + KalmanFilterParameters parameters, + bool finished = false) { /// /// The log likelihood of the observed data given the model parameters after each iteration. @@ -20,4 +22,9 @@ public struct ExpectationMaximizationResult( /// The final updated Kalman filter parameters after the last expectation-maximization step. /// public KalmanFilterParameters Parameters = parameters; + + /// + /// Indicates whether the EM algorithm has finished. + /// + public bool Finished = finished; } \ No newline at end of file From fd7a98b43d46e619859a2d6d8de345d3f8241c29 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 8 Oct 2025 09:35:53 +0100 Subject: [PATCH 52/92] Changed `ExpectationMaximization` operator to a type of `Combinator` element to better reflect its operation --- src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs b/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs index b692559b..09cc1707 100644 --- a/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs +++ b/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs @@ -12,7 +12,7 @@ namespace Bonsai.ML.Torch.LDS; /// [Combinator] [Description("Learn the parameters of a kalman filter using the batch EM update algorithm.")] -[WorkflowElementCategory(ElementCategory.Transform)] +[WorkflowElementCategory(ElementCategory.Combinator)] public class ExpectationMaximization { /// From 69fc3831baa53b113e44e3c1fcddaeb7d5a7e2e0 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 10 Oct 2025 16:03:37 +0100 Subject: [PATCH 53/92] Updated name of `Bonsai.ML.Torch.LDS` package to `Bonsai.ML.Lds.Torch` in line with new package and namespace structure --- Bonsai.ML.sln | 6 +++--- .../Bonsai.ML.Lds.Torch.Design.csproj} | 4 ++-- .../ExpectationMaximizationVisualizer.cs | 6 +++--- .../Properties/AssemblyInfo.cs | 2 +- .../Properties/launchSettings.json | 0 .../StateVisualizer.cs | 20 +++++++++---------- .../Bonsai.ML.Lds.Torch.csproj} | 0 .../CreateKalmanFilter.cs | 2 +- .../CreateKalmanFilterParameters.cs | 2 +- .../CreateLdsState.cs | 2 +- .../ExpectationMaximization.cs | 2 +- .../ExpectationMaximizationResult.cs | 2 +- .../Filter.cs | 2 +- .../FilteredState.cs | 2 +- .../ILdsState.cs | 2 +- .../KalmanFilter.cs | 2 +- .../KalmanFilterModelManager.cs | 2 +- .../KalmanFilterNameConverter.cs | 2 +- .../KalmanFilterParameters.cs | 2 +- .../LdsState.cs | 2 +- .../Orthogonalize.cs | 2 +- .../OrthogonalizedState.cs | 2 +- .../ParametersToEstimate.cs | 2 +- .../Properties/launchSettings.json | 0 .../Smooth.cs | 2 +- .../SmoothedState.cs | 2 +- .../UpdateParameters.cs | 2 +- .../Bonsai.ML.Lds.Torch.Tests.csproj} | 2 +- .../NeuralLatentsTest.bonsai | 2 +- .../NeuralLatentsTest.cs | 2 +- .../bootstrap_test_environment.py | 0 .../estimate_neural_latents.py | 0 32 files changed, 41 insertions(+), 41 deletions(-) rename src/{Bonsai.ML.Torch.LDS.Design/Bonsai.ML.Torch.LDS.Design.csproj => Bonsai.ML.Lds.Torch.Design/Bonsai.ML.Lds.Torch.Design.csproj} (82%) rename src/{Bonsai.ML.Torch.LDS.Design => Bonsai.ML.Lds.Torch.Design}/ExpectationMaximizationVisualizer.cs (92%) rename src/{Bonsai.ML.Torch.LDS.Design => Bonsai.ML.Lds.Torch.Design}/Properties/AssemblyInfo.cs (77%) rename src/{Bonsai.ML.Torch.LDS.Design => Bonsai.ML.Lds.Torch.Design}/Properties/launchSettings.json (100%) rename src/{Bonsai.ML.Torch.LDS.Design => Bonsai.ML.Lds.Torch.Design}/StateVisualizer.cs (91%) rename src/{Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj => Bonsai.ML.Lds.Torch/Bonsai.ML.Lds.Torch.csproj} (100%) rename src/{Bonsai.ML.Torch.LDS => Bonsai.ML.Lds.Torch}/CreateKalmanFilter.cs (99%) rename src/{Bonsai.ML.Torch.LDS => Bonsai.ML.Lds.Torch}/CreateKalmanFilterParameters.cs (99%) rename src/{Bonsai.ML.Torch.LDS => Bonsai.ML.Lds.Torch}/CreateLdsState.cs (99%) rename src/{Bonsai.ML.Torch.LDS => Bonsai.ML.Lds.Torch}/ExpectationMaximization.cs (99%) rename src/{Bonsai.ML.Torch.LDS => Bonsai.ML.Lds.Torch}/ExpectationMaximizationResult.cs (96%) rename src/{Bonsai.ML.Torch.LDS => Bonsai.ML.Lds.Torch}/Filter.cs (97%) rename src/{Bonsai.ML.Torch.LDS => Bonsai.ML.Lds.Torch}/FilteredState.cs (97%) rename src/{Bonsai.ML.Torch.LDS => Bonsai.ML.Lds.Torch}/ILdsState.cs (91%) rename src/{Bonsai.ML.Torch.LDS => Bonsai.ML.Lds.Torch}/KalmanFilter.cs (99%) rename src/{Bonsai.ML.Torch.LDS => Bonsai.ML.Lds.Torch}/KalmanFilterModelManager.cs (99%) rename src/{Bonsai.ML.Torch.LDS => Bonsai.ML.Lds.Torch}/KalmanFilterNameConverter.cs (98%) rename src/{Bonsai.ML.Torch.LDS => Bonsai.ML.Lds.Torch}/KalmanFilterParameters.cs (98%) rename src/{Bonsai.ML.Torch.LDS => Bonsai.ML.Lds.Torch}/LdsState.cs (92%) rename src/{Bonsai.ML.Torch.LDS => Bonsai.ML.Lds.Torch}/Orthogonalize.cs (99%) rename src/{Bonsai.ML.Torch.LDS => Bonsai.ML.Lds.Torch}/OrthogonalizedState.cs (96%) rename src/{Bonsai.ML.Torch.LDS => Bonsai.ML.Lds.Torch}/ParametersToEstimate.cs (98%) rename src/{Bonsai.ML.Torch.LDS => Bonsai.ML.Lds.Torch}/Properties/launchSettings.json (100%) rename src/{Bonsai.ML.Torch.LDS => Bonsai.ML.Lds.Torch}/Smooth.cs (98%) rename src/{Bonsai.ML.Torch.LDS => Bonsai.ML.Lds.Torch}/SmoothedState.cs (95%) rename src/{Bonsai.ML.Torch.LDS => Bonsai.ML.Lds.Torch}/UpdateParameters.cs (97%) rename tests/{Bonsai.ML.Torch.LDS.Tests/Bonsai.ML.Torch.LDS.Tests.csproj => Bonsai.ML.Lds.Torch.Tests/Bonsai.ML.Lds.Torch.Tests.csproj} (92%) rename tests/{Bonsai.ML.Torch.LDS.Tests => Bonsai.ML.Lds.Torch.Tests}/NeuralLatentsTest.bonsai (99%) rename tests/{Bonsai.ML.Torch.LDS.Tests => Bonsai.ML.Lds.Torch.Tests}/NeuralLatentsTest.cs (99%) rename tests/{Bonsai.ML.Torch.LDS.Tests => Bonsai.ML.Lds.Torch.Tests}/bootstrap_test_environment.py (100%) rename tests/{Bonsai.ML.Torch.LDS.Tests => Bonsai.ML.Lds.Torch.Tests}/estimate_neural_latents.py (100%) diff --git a/Bonsai.ML.sln b/Bonsai.ML.sln index d096b62f..fc4e963f 100644 --- a/Bonsai.ML.sln +++ b/Bonsai.ML.sln @@ -40,11 +40,11 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "build", "build", "{DEE5DD87 EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Tests.Utilities", "tests\Bonsai.ML.Tests.Utilities\Bonsai.ML.Tests.Utilities.csproj", "{DB1090D3-38DD-404B-96DF-66BF3C39E508}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Torch.LDS", "src\Bonsai.ML.Torch.LDS\Bonsai.ML.Torch.LDS.csproj", "{41D4BEC7-AB1F-41E4-95FE-4DB23970FF4B}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Lds.Torch", "src\Bonsai.ML.Lds.Torch\Bonsai.ML.Lds.Torch.csproj", "{41D4BEC7-AB1F-41E4-95FE-4DB23970FF4B}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Torch.LDS.Tests", "tests\Bonsai.ML.Torch.LDS.Tests\Bonsai.ML.Torch.LDS.Tests.csproj", "{0B258929-0B07-4CE7-BE8D-A86BBC46AAD4}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Lds.Torch.Tests", "tests\Bonsai.ML.Lds.Torch.Tests\Bonsai.ML.Lds.Torch.Tests.csproj", "{0B258929-0B07-4CE7-BE8D-A86BBC46AAD4}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Torch.LDS.Design", "src\Bonsai.ML.Torch.LDS.Design\Bonsai.ML.Torch.LDS.Design.csproj", "{1F52DECD-1B2C-4F6C-996C-14C715283B80}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Lds.Torch.Design", "src\Bonsai.ML.Lds.Torch.Design\Bonsai.ML.Lds.Torch.Design.csproj", "{1F52DECD-1B2C-4F6C-996C-14C715283B80}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution diff --git a/src/Bonsai.ML.Torch.LDS.Design/Bonsai.ML.Torch.LDS.Design.csproj b/src/Bonsai.ML.Lds.Torch.Design/Bonsai.ML.Lds.Torch.Design.csproj similarity index 82% rename from src/Bonsai.ML.Torch.LDS.Design/Bonsai.ML.Torch.LDS.Design.csproj rename to src/Bonsai.ML.Lds.Torch.Design/Bonsai.ML.Lds.Torch.Design.csproj index e943d766..2185e90e 100644 --- a/src/Bonsai.ML.Torch.LDS.Design/Bonsai.ML.Torch.LDS.Design.csproj +++ b/src/Bonsai.ML.Lds.Torch.Design/Bonsai.ML.Lds.Torch.Design.csproj @@ -1,6 +1,6 @@ - Visualizers for the Bonsai.ML.Torch.LDS library. + Visualizers for the Bonsai.ML.Lds.Torch library. $(PackageTags) Torch LDS Design net472 true @@ -9,7 +9,7 @@ - + diff --git a/src/Bonsai.ML.Torch.LDS.Design/ExpectationMaximizationVisualizer.cs b/src/Bonsai.ML.Lds.Torch.Design/ExpectationMaximizationVisualizer.cs similarity index 92% rename from src/Bonsai.ML.Torch.LDS.Design/ExpectationMaximizationVisualizer.cs rename to src/Bonsai.ML.Lds.Torch.Design/ExpectationMaximizationVisualizer.cs index bf366598..0648685e 100644 --- a/src/Bonsai.ML.Torch.LDS.Design/ExpectationMaximizationVisualizer.cs +++ b/src/Bonsai.ML.Lds.Torch.Design/ExpectationMaximizationVisualizer.cs @@ -13,10 +13,10 @@ using static TorchSharp.torch; -[assembly: TypeVisualizer(typeof(Bonsai.ML.Torch.LDS.Design.ExpectationMaximizationVisualizer), - Target = typeof(Bonsai.ML.Torch.LDS.ExpectationMaximizationResult))] +[assembly: TypeVisualizer(typeof(Bonsai.ML.Lds.Torch.Design.ExpectationMaximizationVisualizer), + Target = typeof(Bonsai.ML.Lds.Torch.ExpectationMaximizationResult))] -namespace Bonsai.ML.Torch.LDS.Design; +namespace Bonsai.ML.Lds.Torch.Design; /// /// Provides a visualizer for the state means and covariances from a Kalman filter or smoother. diff --git a/src/Bonsai.ML.Torch.LDS.Design/Properties/AssemblyInfo.cs b/src/Bonsai.ML.Lds.Torch.Design/Properties/AssemblyInfo.cs similarity index 77% rename from src/Bonsai.ML.Torch.LDS.Design/Properties/AssemblyInfo.cs rename to src/Bonsai.ML.Lds.Torch.Design/Properties/AssemblyInfo.cs index b28e8258..7a732600 100644 --- a/src/Bonsai.ML.Torch.LDS.Design/Properties/AssemblyInfo.cs +++ b/src/Bonsai.ML.Lds.Torch.Design/Properties/AssemblyInfo.cs @@ -3,4 +3,4 @@ // General Information about an assembly is controlled through the following // set of attributes. Change these attribute values to modify the information // associated with an assembly. -[assembly: XmlNamespacePrefix("clr-namespace:Bonsai.ML.Torch.LDS.Design", null)] +[assembly: XmlNamespacePrefix("clr-namespace:Bonsai.ML.Lds.Torch.Design", null)] diff --git a/src/Bonsai.ML.Torch.LDS.Design/Properties/launchSettings.json b/src/Bonsai.ML.Lds.Torch.Design/Properties/launchSettings.json similarity index 100% rename from src/Bonsai.ML.Torch.LDS.Design/Properties/launchSettings.json rename to src/Bonsai.ML.Lds.Torch.Design/Properties/launchSettings.json diff --git a/src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs b/src/Bonsai.ML.Lds.Torch.Design/StateVisualizer.cs similarity index 91% rename from src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs rename to src/Bonsai.ML.Lds.Torch.Design/StateVisualizer.cs index 46cda2d0..3f1837d9 100644 --- a/src/Bonsai.ML.Torch.LDS.Design/StateVisualizer.cs +++ b/src/Bonsai.ML.Lds.Torch.Design/StateVisualizer.cs @@ -13,16 +13,16 @@ using static TorchSharp.torch; -[assembly: TypeVisualizer(typeof(Bonsai.ML.Torch.LDS.Design.StateVisualizer), - Target = typeof(Bonsai.ML.Torch.LDS.FilteredState))] -[assembly: TypeVisualizer(typeof(Bonsai.ML.Torch.LDS.Design.StateVisualizer), - Target = typeof(Bonsai.ML.Torch.LDS.SmoothedState))] -[assembly: TypeVisualizer(typeof(Bonsai.ML.Torch.LDS.Design.StateVisualizer), - Target = typeof(Bonsai.ML.Torch.LDS.OrthogonalizedState))] -[assembly: TypeVisualizer(typeof(Bonsai.ML.Torch.LDS.Design.StateVisualizer), - Target = typeof(Bonsai.ML.Torch.LDS.LdsState))] - -namespace Bonsai.ML.Torch.LDS.Design; +[assembly: TypeVisualizer(typeof(Bonsai.ML.Lds.Torch.Design.StateVisualizer), + Target = typeof(Bonsai.ML.Lds.Torch.FilteredState))] +[assembly: TypeVisualizer(typeof(Bonsai.ML.Lds.Torch.Design.StateVisualizer), + Target = typeof(Bonsai.ML.Lds.Torch.SmoothedState))] +[assembly: TypeVisualizer(typeof(Bonsai.ML.Lds.Torch.Design.StateVisualizer), + Target = typeof(Bonsai.ML.Lds.Torch.OrthogonalizedState))] +[assembly: TypeVisualizer(typeof(Bonsai.ML.Lds.Torch.Design.StateVisualizer), + Target = typeof(Bonsai.ML.Lds.Torch.LdsState))] + +namespace Bonsai.ML.Lds.Torch.Design; /// /// Provides a visualizer for the state means and covariances from a Kalman filter or smoother. diff --git a/src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj b/src/Bonsai.ML.Lds.Torch/Bonsai.ML.Lds.Torch.csproj similarity index 100% rename from src/Bonsai.ML.Torch.LDS/Bonsai.ML.Torch.LDS.csproj rename to src/Bonsai.ML.Lds.Torch/Bonsai.ML.Lds.Torch.csproj diff --git a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs similarity index 99% rename from src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs rename to src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs index a3ba1ba2..961bb0e9 100644 --- a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilter.cs +++ b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs @@ -4,7 +4,7 @@ using System.Xml.Serialization; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.LDS; +namespace Bonsai.ML.Lds.Torch; /// /// Creates a Kalman filter model. diff --git a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs similarity index 99% rename from src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs rename to src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs index abd77ee8..55ff3ec0 100644 --- a/src/Bonsai.ML.Torch.LDS/CreateKalmanFilterParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs @@ -5,7 +5,7 @@ using System.Xml.Serialization; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.LDS; +namespace Bonsai.ML.Lds.Torch; /// /// Initializes the parameters for a new Kalman filter model. diff --git a/src/Bonsai.ML.Torch.LDS/CreateLdsState.cs b/src/Bonsai.ML.Lds.Torch/CreateLdsState.cs similarity index 99% rename from src/Bonsai.ML.Torch.LDS/CreateLdsState.cs rename to src/Bonsai.ML.Lds.Torch/CreateLdsState.cs index af7367a4..24444408 100644 --- a/src/Bonsai.ML.Torch.LDS/CreateLdsState.cs +++ b/src/Bonsai.ML.Lds.Torch/CreateLdsState.cs @@ -4,7 +4,7 @@ using System.Xml.Serialization; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.LDS; +namespace Bonsai.ML.Lds.Torch; /// /// Creates a generic state object for a linear gaussian dynamical system. diff --git a/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs similarity index 99% rename from src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs rename to src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs index 09cc1707..418e9654 100644 --- a/src/Bonsai.ML.Torch.LDS/ExpectationMaximization.cs +++ b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs @@ -5,7 +5,7 @@ using static TorchSharp.torch; using System.Threading.Tasks; -namespace Bonsai.ML.Torch.LDS; +namespace Bonsai.ML.Lds.Torch; /// /// Learn the parameters of a kalman filter using the batch EM update algorithm. diff --git a/src/Bonsai.ML.Torch.LDS/ExpectationMaximizationResult.cs b/src/Bonsai.ML.Lds.Torch/ExpectationMaximizationResult.cs similarity index 96% rename from src/Bonsai.ML.Torch.LDS/ExpectationMaximizationResult.cs rename to src/Bonsai.ML.Lds.Torch/ExpectationMaximizationResult.cs index 75b32751..67dd8f22 100644 --- a/src/Bonsai.ML.Torch.LDS/ExpectationMaximizationResult.cs +++ b/src/Bonsai.ML.Lds.Torch/ExpectationMaximizationResult.cs @@ -1,6 +1,6 @@ using static TorchSharp.torch; -namespace Bonsai.ML.Torch.LDS; +namespace Bonsai.ML.Lds.Torch; /// /// Represents the result of an expectation-maximization step for a Kalman filter model. diff --git a/src/Bonsai.ML.Torch.LDS/Filter.cs b/src/Bonsai.ML.Lds.Torch/Filter.cs similarity index 97% rename from src/Bonsai.ML.Torch.LDS/Filter.cs rename to src/Bonsai.ML.Lds.Torch/Filter.cs index 853d4026..4fcd7058 100644 --- a/src/Bonsai.ML.Torch.LDS/Filter.cs +++ b/src/Bonsai.ML.Lds.Torch/Filter.cs @@ -3,7 +3,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.LDS; +namespace Bonsai.ML.Lds.Torch; /// /// Applies a Kalman filter to the input tensor sequence. diff --git a/src/Bonsai.ML.Torch.LDS/FilteredState.cs b/src/Bonsai.ML.Lds.Torch/FilteredState.cs similarity index 97% rename from src/Bonsai.ML.Torch.LDS/FilteredState.cs rename to src/Bonsai.ML.Lds.Torch/FilteredState.cs index e2b9a8af..c2ea7bc8 100644 --- a/src/Bonsai.ML.Torch.LDS/FilteredState.cs +++ b/src/Bonsai.ML.Lds.Torch/FilteredState.cs @@ -1,6 +1,6 @@ using static TorchSharp.torch; -namespace Bonsai.ML.Torch.LDS; +namespace Bonsai.ML.Lds.Torch; /// /// Represents the state of a Kalman filter. diff --git a/src/Bonsai.ML.Torch.LDS/ILdsState.cs b/src/Bonsai.ML.Lds.Torch/ILdsState.cs similarity index 91% rename from src/Bonsai.ML.Torch.LDS/ILdsState.cs rename to src/Bonsai.ML.Lds.Torch/ILdsState.cs index ddeda06a..6527cfe5 100644 --- a/src/Bonsai.ML.Torch.LDS/ILdsState.cs +++ b/src/Bonsai.ML.Lds.Torch/ILdsState.cs @@ -1,6 +1,6 @@ using static TorchSharp.torch; -namespace Bonsai.ML.Torch.LDS; +namespace Bonsai.ML.Lds.Torch; /// /// Represents the state of a linear gaussian dynamical system. diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs similarity index 99% rename from src/Bonsai.ML.Torch.LDS/KalmanFilter.cs rename to src/Bonsai.ML.Lds.Torch/KalmanFilter.cs index 7b9c4211..8bfdab18 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilter.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs @@ -2,7 +2,7 @@ using System.Collections.Generic; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.LDS; +namespace Bonsai.ML.Lds.Torch; internal class KalmanFilter : nn.Module { diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilterModelManager.cs similarity index 99% rename from src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs rename to src/Bonsai.ML.Lds.Torch/KalmanFilterModelManager.cs index 00bc072d..1196aa83 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilterModelManager.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilterModelManager.cs @@ -4,7 +4,7 @@ using System.Runtime.CompilerServices; using System.Collections.Generic; using static TorchSharp.torch; -using Bonsai.ML.Torch.LDS; +using Bonsai.ML.Lds.Torch; // // Manages instances of the Kalman Filter model with a thread-safe locking mechanism for reading state tensors and writing parameters. diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilterNameConverter.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilterNameConverter.cs similarity index 98% rename from src/Bonsai.ML.Torch.LDS/KalmanFilterNameConverter.cs rename to src/Bonsai.ML.Lds.Torch/KalmanFilterNameConverter.cs index 0dee35d0..cfb881ac 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilterNameConverter.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilterNameConverter.cs @@ -2,7 +2,7 @@ using System.Linq; using System.ComponentModel; -namespace Bonsai.ML.Torch.LDS; +namespace Bonsai.ML.Lds.Torch; /// /// Provides a type converter to select the name of an existing Kalman filter model in the workflow. diff --git a/src/Bonsai.ML.Torch.LDS/KalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs similarity index 98% rename from src/Bonsai.ML.Torch.LDS/KalmanFilterParameters.cs rename to src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs index 7b56a163..1e5f8e82 100644 --- a/src/Bonsai.ML.Torch.LDS/KalmanFilterParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs @@ -1,6 +1,6 @@ using static TorchSharp.torch; -namespace Bonsai.ML.Torch.LDS; +namespace Bonsai.ML.Lds.Torch; /// /// Represents the parameters of a Kalman filter model. diff --git a/src/Bonsai.ML.Torch.LDS/LdsState.cs b/src/Bonsai.ML.Lds.Torch/LdsState.cs similarity index 92% rename from src/Bonsai.ML.Torch.LDS/LdsState.cs rename to src/Bonsai.ML.Lds.Torch/LdsState.cs index 8c778a12..99dd4a08 100644 --- a/src/Bonsai.ML.Torch.LDS/LdsState.cs +++ b/src/Bonsai.ML.Lds.Torch/LdsState.cs @@ -1,6 +1,6 @@ using static TorchSharp.torch; -namespace Bonsai.ML.Torch.LDS; +namespace Bonsai.ML.Lds.Torch; /// /// Represents the state of a linear gaussian dynamical system. diff --git a/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs b/src/Bonsai.ML.Lds.Torch/Orthogonalize.cs similarity index 99% rename from src/Bonsai.ML.Torch.LDS/Orthogonalize.cs rename to src/Bonsai.ML.Lds.Torch/Orthogonalize.cs index 598bb6f7..6fa67e53 100644 --- a/src/Bonsai.ML.Torch.LDS/Orthogonalize.cs +++ b/src/Bonsai.ML.Lds.Torch/Orthogonalize.cs @@ -3,7 +3,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.LDS; +namespace Bonsai.ML.Lds.Torch; /// /// Orthogonalizes the state and covariance estimates from a Kalman filter or smoother. diff --git a/src/Bonsai.ML.Torch.LDS/OrthogonalizedState.cs b/src/Bonsai.ML.Lds.Torch/OrthogonalizedState.cs similarity index 96% rename from src/Bonsai.ML.Torch.LDS/OrthogonalizedState.cs rename to src/Bonsai.ML.Lds.Torch/OrthogonalizedState.cs index e0cddea4..f7dee6c1 100644 --- a/src/Bonsai.ML.Torch.LDS/OrthogonalizedState.cs +++ b/src/Bonsai.ML.Lds.Torch/OrthogonalizedState.cs @@ -1,6 +1,6 @@ using static TorchSharp.torch; -namespace Bonsai.ML.Torch.LDS; +namespace Bonsai.ML.Lds.Torch; /// /// Represents the state of an LDS after orthogonalizing the state mean and covariance estimates. diff --git a/src/Bonsai.ML.Torch.LDS/ParametersToEstimate.cs b/src/Bonsai.ML.Lds.Torch/ParametersToEstimate.cs similarity index 98% rename from src/Bonsai.ML.Torch.LDS/ParametersToEstimate.cs rename to src/Bonsai.ML.Lds.Torch/ParametersToEstimate.cs index d0819541..039c4018 100644 --- a/src/Bonsai.ML.Torch.LDS/ParametersToEstimate.cs +++ b/src/Bonsai.ML.Lds.Torch/ParametersToEstimate.cs @@ -1,4 +1,4 @@ -namespace Bonsai.ML.Torch.LDS; +namespace Bonsai.ML.Lds.Torch; /// /// Represents the parameters to estimate for a Kalman filter model. diff --git a/src/Bonsai.ML.Torch.LDS/Properties/launchSettings.json b/src/Bonsai.ML.Lds.Torch/Properties/launchSettings.json similarity index 100% rename from src/Bonsai.ML.Torch.LDS/Properties/launchSettings.json rename to src/Bonsai.ML.Lds.Torch/Properties/launchSettings.json diff --git a/src/Bonsai.ML.Torch.LDS/Smooth.cs b/src/Bonsai.ML.Lds.Torch/Smooth.cs similarity index 98% rename from src/Bonsai.ML.Torch.LDS/Smooth.cs rename to src/Bonsai.ML.Lds.Torch/Smooth.cs index a441f509..c37de3e6 100644 --- a/src/Bonsai.ML.Torch.LDS/Smooth.cs +++ b/src/Bonsai.ML.Lds.Torch/Smooth.cs @@ -3,7 +3,7 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.LDS; +namespace Bonsai.ML.Lds.Torch; /// /// Applies a Kalman smoother to the input filtered result sequence. diff --git a/src/Bonsai.ML.Torch.LDS/SmoothedState.cs b/src/Bonsai.ML.Lds.Torch/SmoothedState.cs similarity index 95% rename from src/Bonsai.ML.Torch.LDS/SmoothedState.cs rename to src/Bonsai.ML.Lds.Torch/SmoothedState.cs index db87a119..70f5c24c 100644 --- a/src/Bonsai.ML.Torch.LDS/SmoothedState.cs +++ b/src/Bonsai.ML.Lds.Torch/SmoothedState.cs @@ -1,6 +1,6 @@ using static TorchSharp.torch; -namespace Bonsai.ML.Torch.LDS; +namespace Bonsai.ML.Lds.Torch; /// /// Represents the state of a Kalman smoother. diff --git a/src/Bonsai.ML.Torch.LDS/UpdateParameters.cs b/src/Bonsai.ML.Lds.Torch/UpdateParameters.cs similarity index 97% rename from src/Bonsai.ML.Torch.LDS/UpdateParameters.cs rename to src/Bonsai.ML.Lds.Torch/UpdateParameters.cs index 3852a8c5..dd21fecb 100644 --- a/src/Bonsai.ML.Torch.LDS/UpdateParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/UpdateParameters.cs @@ -2,7 +2,7 @@ using System.ComponentModel; using System.Reactive.Linq; -namespace Bonsai.ML.Torch.LDS; +namespace Bonsai.ML.Lds.Torch; /// /// Updates the parameters of a Kalman filter model instance using the provided Kalman filter parameters. diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/Bonsai.ML.Torch.LDS.Tests.csproj b/tests/Bonsai.ML.Lds.Torch.Tests/Bonsai.ML.Lds.Torch.Tests.csproj similarity index 92% rename from tests/Bonsai.ML.Torch.LDS.Tests/Bonsai.ML.Torch.LDS.Tests.csproj rename to tests/Bonsai.ML.Lds.Torch.Tests/Bonsai.ML.Lds.Torch.Tests.csproj index 14ddd981..5c6cf23a 100644 --- a/tests/Bonsai.ML.Torch.LDS.Tests/Bonsai.ML.Torch.LDS.Tests.csproj +++ b/tests/Bonsai.ML.Lds.Torch.Tests/Bonsai.ML.Lds.Torch.Tests.csproj @@ -27,7 +27,7 @@ - + diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.bonsai b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai similarity index 99% rename from tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.bonsai rename to tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai index 98b8375e..4711470a 100644 --- a/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.bonsai +++ b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai @@ -3,7 +3,7 @@ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns:p1="clr-namespace:Bonsai.ML.Torch;assembly=Bonsai.ML.Torch" xmlns:rx="clr-namespace:Bonsai.Reactive;assembly=Bonsai.Core" - xmlns:p2="clr-namespace:Bonsai.ML.Torch.LDS;assembly=Bonsai.ML.Torch.LDS" + xmlns:p2="clr-namespace:Bonsai.ML.Lds.Torch;assembly=Bonsai.ML.Lds.Torch" xmlns="https://bonsai-rx.org/2018/workflow"> diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.cs b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs similarity index 99% rename from tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.cs rename to tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs index 70e2165e..132564a9 100644 --- a/tests/Bonsai.ML.Torch.LDS.Tests/NeuralLatentsTest.cs +++ b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs @@ -10,7 +10,7 @@ using static TorchSharp.torch; using TorchSharp; -namespace Bonsai.ML.Torch.LDS.Tests; +namespace Bonsai.ML.Lds.Torch.Tests; /// /// Tests for the neural latents workflow. diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/bootstrap_test_environment.py b/tests/Bonsai.ML.Lds.Torch.Tests/bootstrap_test_environment.py similarity index 100% rename from tests/Bonsai.ML.Torch.LDS.Tests/bootstrap_test_environment.py rename to tests/Bonsai.ML.Lds.Torch.Tests/bootstrap_test_environment.py diff --git a/tests/Bonsai.ML.Torch.LDS.Tests/estimate_neural_latents.py b/tests/Bonsai.ML.Lds.Torch.Tests/estimate_neural_latents.py similarity index 100% rename from tests/Bonsai.ML.Torch.LDS.Tests/estimate_neural_latents.py rename to tests/Bonsai.ML.Lds.Torch.Tests/estimate_neural_latents.py From 513feb7606933b71d660cfea4a3e21c3321d32a9 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 10 Oct 2025 16:07:18 +0100 Subject: [PATCH 54/92] Added `Bonsai.ML.Torch` using statements to classes that depend on this package --- src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs | 1 + src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs | 1 + src/Bonsai.ML.Lds.Torch/CreateLdsState.cs | 1 + 3 files changed, 3 insertions(+) diff --git a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs index 961bb0e9..63a150e4 100644 --- a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs +++ b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs @@ -3,6 +3,7 @@ using System.Reactive.Linq; using System.Xml.Serialization; using static TorchSharp.torch; +using Bonsai.ML.Torch; namespace Bonsai.ML.Lds.Torch; diff --git a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs index 55ff3ec0..02748620 100644 --- a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs @@ -4,6 +4,7 @@ using System.Reactive.Linq; using System.Xml.Serialization; using static TorchSharp.torch; +using Bonsai.ML.Torch; namespace Bonsai.ML.Lds.Torch; diff --git a/src/Bonsai.ML.Lds.Torch/CreateLdsState.cs b/src/Bonsai.ML.Lds.Torch/CreateLdsState.cs index 24444408..3bfa5101 100644 --- a/src/Bonsai.ML.Lds.Torch/CreateLdsState.cs +++ b/src/Bonsai.ML.Lds.Torch/CreateLdsState.cs @@ -3,6 +3,7 @@ using System.Reactive.Linq; using System.Xml.Serialization; using static TorchSharp.torch; +using Bonsai.ML.Torch; namespace Bonsai.ML.Lds.Torch; From cda33a3ea0c814d02810b242c2e29d94f97ac9cf Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 13 Oct 2025 16:51:11 +0100 Subject: [PATCH 55/92] Added operators to save and load the parameters of a Kalman filter model --- .../LoadKalmanFilterParameters.cs | 117 ++++++++++++++++++ .../SaveKalmanFilterParameters.cs | 101 +++++++++++++++ 2 files changed, 218 insertions(+) create mode 100644 src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs create mode 100644 src/Bonsai.ML.Lds.Torch/SaveKalmanFilterParameters.cs diff --git a/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs new file mode 100644 index 00000000..84a8a812 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs @@ -0,0 +1,117 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using Bonsai.ML.Torch; +using System.IO; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Loads the parameters of a Kalman filter model. +/// +[Combinator] +[Description("Loads the parameters of a Kalman filter model.")] +[WorkflowElementCategory(ElementCategory.Source)] +public class LoadKalmanFilterParameters +{ + /// + /// Reads the path to a .bin file containing the transition matrix. + /// + [Description("Reads the path to a .bin file containing the transition matrix.")] + [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] + [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string TransitionMatrixFilePath { get; set; } = "transition_matrix.bin"; + + /// + /// Reads the path to a .bin file containing the measurement function. + /// + [Description("Reads the path to a .bin file containing the measurement function.")] + [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] + [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string MeasurementFunctionFilePath { get; set; } = "measurement_function.bin"; + + /// + /// Reads the path to a .bin file containing the process noise covariance. + /// + [Description("Reads the path to a .bin file containing the process noise covariance.")] + [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] + [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string ProcessNoiseCovarianceFilePath { get; set; } = "process_noise_covariance.bin"; + + /// + /// Reads the path to a .bin file containing the measurement noise covariance. + /// + [Description("Reads the path to a .bin file containing the measurement noise covariance.")] + [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] + [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string MeasurementNoiseCovarianceFilePath { get; set; } = "measurement_noise_covariance.bin"; + + /// + /// Reads the path to a .bin file containing the initial mean. + /// + [Description("Reads the path to a .bin file containing the initial mean.")] + [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] + [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string InitialMeanFilePath { get; set; } = "initial_mean.bin"; + + /// + /// Reads the path to a .bin file containing the initial covariance. + /// + [Description("Reads the path to a .bin file containing the initial covariance.")] + [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] + [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string InitialCovarianceFilePath { get; set; } = "initial_covariance.bin"; + + /// + /// Gets or sets the data type of the tensors. + /// + [Description("Gets or sets the data type of the tensors.")] + public ScalarType Type { get; set; } = ScalarType.Float32; + + /// + /// Gets or sets the device to use for tensor operations. + /// + [XmlIgnore] + [Description("Gets or sets the device to use for tensor operations.")] + public Device Device { get; set; } + + private Tensor LoadTensorFromFile(string filePath) + { + if (filePath == null) return null; + if (!File.Exists(filePath)) + { + throw new FileNotFoundException($"The specified file was not found: {filePath}"); + } + return Tensor.Load(filePath)?.to(Device).to_type(Type); + } + + /// + /// Creates parameters for a Kalman filter model using the properties of this class. + /// + public IObservable Process() + { + Device ??= CPU; + + var transitionMatrix = LoadTensorFromFile(TransitionMatrixFilePath); + var measurementFunction = LoadTensorFromFile(MeasurementFunctionFilePath); + var processNoiseCovariance = LoadTensorFromFile(ProcessNoiseCovarianceFilePath); + var measurementNoiseCovariance = LoadTensorFromFile(MeasurementNoiseCovarianceFilePath); + var initialMean = LoadTensorFromFile(InitialMeanFilePath); + var initialCovariance = LoadTensorFromFile(InitialCovarianceFilePath); + + var parameters = new KalmanFilterParameters( + transitionMatrix: transitionMatrix, + measurementFunction: measurementFunction, + processNoiseCovariance: processNoiseCovariance, + measurementNoiseCovariance: measurementNoiseCovariance, + initialMean: initialMean, + initialCovariance: initialCovariance + ); + + return Observable.Return(parameters); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Lds.Torch/SaveKalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/SaveKalmanFilterParameters.cs new file mode 100644 index 00000000..198219f5 --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/SaveKalmanFilterParameters.cs @@ -0,0 +1,101 @@ +using System; +using System.ComponentModel; +using System.Linq; +using System.Reactive.Linq; +using System.Xml.Serialization; +using Bonsai.ML.Torch; +using System.IO; +using TorchSharp; +using static TorchSharp.torch; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Saves the parameters of a Kalman filter model. +/// +[Combinator] +[Description("Saves the parameters of a Kalman filter model.")] +[WorkflowElementCategory(ElementCategory.Sink)] +public class SaveKalmanFilterParameters +{ + /// + /// Specifies the path to use for saving the transition matrix of a Kalman filter to a .bin file. + /// + [Description("Specifies the path to use for saving the transition matrix of a Kalman filter to a .bin file.")] + [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] + [Editor("Bonsai.Design.SaveFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string TransitionMatrixFilePath { get; set; } = "transition_matrix.bin"; + + /// + /// Specifies the path to use for saving the measurement function of a Kalman filter to a .bin file. + /// + [Description("Specifies the path to use for saving the measurement function of a Kalman filter to a .bin file.")] + [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] + [Editor("Bonsai.Design.SaveFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string MeasurementFunctionFilePath { get; set; } = "measurement_function.bin"; + + /// + /// Specifies the path to use for saving the process noise covariance of a Kalman filter to a .bin file. + /// + [Description("Specifies the path to use for saving the process noise covariance of a Kalman filter to a .bin file.")] + [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] + [Editor("Bonsai.Design.SaveFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string ProcessNoiseCovarianceFilePath { get; set; } = "process_noise_covariance.bin"; + + /// + /// Specifies the path to use for saving the measurement noise covariance of a Kalman filter to a .bin file. + /// + [Description("Specifies the path to use for saving the measurement noise covariance of a Kalman filter to a .bin file.")] + [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] + [Editor("Bonsai.Design.SaveFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string MeasurementNoiseCovarianceFilePath { get; set; } = "measurement_noise_covariance.bin"; + + /// + /// Specifies the path to use for saving the initial mean of a Kalman filter to a .bin file. + /// + [Description("Specifies the path to use for saving the initial mean of a Kalman filter to a .bin file.")] + [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] + [Editor("Bonsai.Design.SaveFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string InitialMeanFilePath { get; set; } = "initial_mean.bin"; + + /// + /// Specifies the path to use for saving the initial covariance of a Kalman filter to a .bin file. + /// + [Description("Specifies the path to use for saving the initial covariance of a Kalman filter to a .bin file.")] + [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] + [Editor("Bonsai.Design.SaveFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + public string InitialCovarianceFilePath { get; set; } = "initial_covariance.bin"; + + /// + /// The name of the Kalman filter model to be used. + /// + [TypeConverter(typeof(KalmanFilterNameConverter))] + [Description("The name of the Kalman filter model to be used.")] + public string ModelName { get; set; } = "KalmanFilter"; + + private void SaveTensorToFile(Tensor tensor, string filePath) + { + if (filePath != null) + { + tensor.Save(filePath); + } + } + + /// + /// Creates parameters for a Kalman filter model using the properties of this class. + /// + public IObservable Process(IObservable source) + { + return source.Do(input => + { + var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); + var parameters = kalmanFilter.Parameters; + SaveTensorToFile(parameters.TransitionMatrix, TransitionMatrixFilePath); + SaveTensorToFile(parameters.MeasurementFunction, MeasurementFunctionFilePath); + SaveTensorToFile(parameters.ProcessNoiseCovariance, ProcessNoiseCovarianceFilePath); + SaveTensorToFile(parameters.MeasurementNoiseCovariance, MeasurementNoiseCovarianceFilePath); + SaveTensorToFile(parameters.InitialMean, InitialMeanFilePath); + SaveTensorToFile(parameters.InitialCovariance, InitialCovarianceFilePath); + }); + } +} \ No newline at end of file From d9a1f462d27f33a228467ecf35a7e4a827450c04 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 15 Oct 2025 19:42:48 +0100 Subject: [PATCH 56/92] Added Stochastic Subspace Identification method --- .../ExpectationMaximization.cs | 6 +- src/Bonsai.ML.Lds.Torch/KalmanFilter.cs | 95 ++++++++++++++++ .../StochasticSubspaceIdentification.cs | 101 ++++++++++++++++++ .../StochasticSubspaceIdentificationResult.cs | 31 ++++++ 4 files changed, 231 insertions(+), 2 deletions(-) create mode 100644 src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs create mode 100644 src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentificationResult.cs diff --git a/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs index 418e9654..af105949 100644 --- a/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs +++ b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs @@ -1,5 +1,7 @@ using System; using System.ComponentModel; +using System.Reactive; +using System.Linq; using System.Reactive.Linq; using TorchSharp; using static TorchSharp.torch; @@ -98,7 +100,7 @@ public bool Verbose /// public IObservable Process(IObservable source) { - return source.SelectMany(input => Observable.Create((observer, cancellationToken) => + return source.Select(input => Observable.Create((observer, cancellationToken) => { return Task.Run(() => { @@ -168,6 +170,6 @@ public IObservable Process(IObservable so return System.Reactive.Disposables.Disposable.Empty; }, cancellationToken); - })); + })).Concat(); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs index 8bfdab18..3adb51ee 100644 --- a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs @@ -754,6 +754,101 @@ public ExpectationMaximizationResult ExpectationMaximization( return new ExpectationMaximizationResult(logLikelihood, updatedParameters); } + public static StochasticSubspaceIdentificationResult StochasticSubspaceIdentification( + Tensor observations, + int maxLag = 20, + double threshold = 0.01, + ParametersToEstimate parametersToEstimate = new()) + { + using var g = no_grad(); + + var timeBins = observations.size(0); + var numObs = observations.size(1); + + // Build Hankel matrices from observations + var numCols = (int)(timeBins - 2 * maxLag + 1); + + if (numCols <= 0) + throw new ArgumentException($"Number of time bins ({timeBins}) must be greater than 2*maxLag ({2 * maxLag}) for subspace identification."); + + var stride = observations.stride(); + var pastView = observations.as_strided([maxLag, numCols, numObs], [stride[0], stride[0], stride[1]]); + var past = pastView.permute(0, 2, 1).reshape(maxLag * numObs, numCols); + + var futureView = observations.narrow(0, maxLag, timeBins - maxLag) + .as_strided([maxLag, numCols, numObs], [stride[0], stride[0], stride[1]]); + var future = futureView.permute(0, 2, 1).reshape(maxLag * numObs, numCols); + + // Compute the projection + var Pp = past.matmul(past.mT); + var projection = InverseCholesky(future.matmul(past.mT), Pp).matmul(past); + + // Compute SVD of the past observations + var (U, S, Vt) = linalg.svd(projection, fullMatrices: false); + + // Truncate to estimated state dimension + // var rankTolerance = S[0] * effectiveRankCutoff; + // var effectiveRank = 0; + // for (int i = 0; i < S.shape[0]; i++) + // { + // if (S[i].item() > rankTolerance.item()) + // effectiveRank = i + 1; + // } + // effectiveRank = Math.Min(effectiveRank, numStates); + var effectiveStates = (int)argmax(S < threshold).item(); + + var Ur = U[TensorIndex.Colon, TensorIndex.Slice(0, effectiveStates)]; + var SrSqrt = S[TensorIndex.Slice(0, effectiveStates)].sqrt(); + var Vrt = Vt[TensorIndex.Slice(0, effectiveStates)]; + + // Estimate observability matrix + var observability = Ur.matmul(SrSqrt); + + // Extract measurement function from first block of observability matrix + var measurementFunction = observability[TensorIndex.Slice(0, numObs)]; + + // Estimate state sequence + var states = SrSqrt.diag().matmul(Vrt); + + // Estimate transition matrix A using shifted states + var statesShifted = states[TensorIndex.Colon, TensorIndex.Slice(0, numCols - 1)]; + var statesNext = states[TensorIndex.Colon, TensorIndex.Slice(1, numCols)]; + + var transitionMatrix = WrappedTensorDisposeScope(() => InverseCholesky( + statesNext.matmul(statesShifted.mT), + statesShifted.matmul(statesShifted.mT))); + + // Estimate noise covariances using residuals + var stateResiduals = statesNext - transitionMatrix.matmul(statesShifted); + var processNoiseCovariance = WrappedTensorDisposeScope(() => stateResiduals.matmul(stateResiduals.mT) / (numCols - 1)); + + // Vectorized computation of observation residuals + var observationPredictions = measurementFunction.matmul(states); + var observationWindow = observations[TensorIndex.Slice(maxLag, maxLag + numCols)].mT; + var observationResiduals = observationWindow - observationPredictions; + var measurementNoiseCovariance = WrappedTensorDisposeScope(() => observationResiduals.matmul(observationResiduals.mT) / numCols); + + // Initial state estimates + var initialMean = states[TensorIndex.Colon, 0]; + var initialCovariance = WrappedTensorDisposeScope(() => EnsureSymmetric( + states.matmul(states.mT) / numCols)); + + var parameters = new KalmanFilterParameters( + transitionMatrix: parametersToEstimate.TransitionMatrix ? transitionMatrix : null, + measurementFunction: parametersToEstimate.MeasurementFunction ? measurementFunction : null, + processNoiseCovariance: parametersToEstimate.ProcessNoiseCovariance ? processNoiseCovariance : null, + measurementNoiseCovariance: parametersToEstimate.MeasurementNoiseCovariance ? measurementNoiseCovariance : null, + initialMean: parametersToEstimate.InitialMean ? initialMean : null, + initialCovariance: parametersToEstimate.InitialCovariance ? initialCovariance : null + ); + + return new StochasticSubspaceIdentificationResult( + parameters: parameters, + effectiveStates: effectiveStates, + singularValues: S + ); + } + public OrthogonalizedState OrthogonalizeMeanAndCovariance(Tensor mean, Tensor covariance) { var (_, S, Vt) = linalg.svd(_measurementFunction); diff --git a/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs b/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs new file mode 100644 index 00000000..694cdffe --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs @@ -0,0 +1,101 @@ +using System; +using System.Linq; +using System.ComponentModel; +using System.Reactive.Linq; +using TorchSharp; +using static TorchSharp.torch; +using System.Threading.Tasks; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Learn the parameters of a kalman filter using the stochastic subspace identification method. +/// +[Combinator] +[Description("Learn the parameters of a kalman filter using the stochastic subspace identification method.")] +[WorkflowElementCategory(ElementCategory.Combinator)] +public class StochasticSubspaceIdentification +{ + private int _maxLag = 20; + /// + /// The maximum lag to consider for the subspace identification. + /// + [Description("The maximum lag to consider for the subspace identification.")] + public int MaxLag + { + get => _maxLag; + set => _maxLag = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(MaxLag), "Must be greater than zero."); + } + + private double _threshold = 1e-4; + /// + /// The threshold for the singular values to determine the effective number of states. + /// + [Description("The threshold for the singular values to determine the effective number of states.")] + public double Threshold + { + get => _threshold; + set => _threshold = value >= 0 ? value : throw new ArgumentOutOfRangeException(nameof(Threshold), "Must be greater than or equal to zero."); + } + + /// + /// If true, the transition matrix will be estimated during the EM algorithm. + /// + [Description("If true, the transition matrix will be estimated during the EM algorithm.")] + public bool EstimateTransitionMatrix { get; set; } = true; + + /// + /// If true, the measurement function will be estimated during the EM algorithm. + /// + [Description("If true, the measurement function will be estimated during the EM algorithm.")] + public bool EstimateMeasurementFunction { get; set; } = true; + + /// + /// If true, the process noise covariance will be estimated during the EM algorithm. + /// + [Description("If true, the process noise covariance will be estimated during the EM algorithm.")] + public bool EstimateProcessNoiseCovariance { get; set; } = true; + + /// + /// If true, the measurement noise covariance will be estimated during the EM algorithm. + /// + [Description("If true, the measurement noise covariance will be estimated during the EM algorithm.")] + public bool EstimateMeasurementNoiseCovariance { get; set; } = true; + + /// + /// If true, the initial mean will be estimated during the EM algorithm. + /// + [Description("If true, the initial mean will be estimated during the EM algorithm.")] + public bool EstimateInitialMean { get; set; } = true; + + /// + /// If true, the initial covariance will be estimated during the EM algorithm. + /// + [Description("If true, the initial covariance will be estimated during the EM algorithm.")] + public bool EstimateInitialCovariance { get; set; } = true; + + /// + /// Processes an observable sequence of input tensors, applying the Expectation-Maximization algorithm to learn the parameters of a Kalman filter model. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => + { + var parametersToEstimate = new ParametersToEstimate( + transitionMatrix: EstimateTransitionMatrix, + measurementFunction: EstimateMeasurementFunction, + processNoiseCovariance: EstimateProcessNoiseCovariance, + measurementNoiseCovariance: EstimateMeasurementNoiseCovariance, + initialMean: EstimateInitialMean, + initialCovariance: EstimateInitialCovariance); + + return KalmanFilter.StochasticSubspaceIdentification( + observations: input, + maxLag: MaxLag, + threshold: Threshold, + parametersToEstimate: parametersToEstimate); + }); + } +} \ No newline at end of file diff --git a/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentificationResult.cs b/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentificationResult.cs new file mode 100644 index 00000000..34a5c63b --- /dev/null +++ b/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentificationResult.cs @@ -0,0 +1,31 @@ +using static TorchSharp.torch; + +namespace Bonsai.ML.Lds.Torch; + +/// +/// Represents the result of subspace identification using the Kalman-Ho method. +/// +/// The identified Kalman filter parameters. +/// The effective states of the system determined by SVD. +/// The singular values from the SVD decomposition. +public struct StochasticSubspaceIdentificationResult( + KalmanFilterParameters parameters, + int effectiveStates, + Tensor singularValues) +{ + /// + /// The identified Kalman filter parameters from subspace identification. + /// + public KalmanFilterParameters Parameters = parameters; + + /// + /// The effective states of the system determined by SVD truncation. + /// + public int EffectiveStates = effectiveStates; + + /// + /// The singular values from the SVD decomposition of the Hankel matrix. + /// These can be used to assess the quality of the identification and choose the model order. + /// + public Tensor SingularValues = singularValues; +} From c0db6d4da993919d562b5bcaf93e7cd4003999b2 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 15 Oct 2025 22:32:51 +0100 Subject: [PATCH 57/92] Added method for `StochasticSubspaceIdentification` which is much faster than EM --- src/Bonsai.ML.Lds.Torch/KalmanFilter.cs | 21 ++++------- .../StochasticSubspaceIdentification.cs | 36 +++++++++++-------- .../StochasticSubspaceIdentificationResult.cs | 4 +-- 3 files changed, 30 insertions(+), 31 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs index 3adb51ee..039dce23 100644 --- a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs @@ -786,19 +786,12 @@ public static StochasticSubspaceIdentificationResult StochasticSubspaceIdentific // Compute SVD of the past observations var (U, S, Vt) = linalg.svd(projection, fullMatrices: false); - // Truncate to estimated state dimension - // var rankTolerance = S[0] * effectiveRankCutoff; - // var effectiveRank = 0; - // for (int i = 0; i < S.shape[0]; i++) - // { - // if (S[i].item() > rankTolerance.item()) - // effectiveRank = i + 1; - // } - // effectiveRank = Math.Min(effectiveRank, numStates); - var effectiveStates = (int)argmax(S < threshold).item(); + // Compute the effective rank + var effectiveRank = (S > (threshold * S[0])).to_type(ScalarType.Int64).sum().item(); + var effectiveStates = Math.Max(effectiveRank, 1); var Ur = U[TensorIndex.Colon, TensorIndex.Slice(0, effectiveStates)]; - var SrSqrt = S[TensorIndex.Slice(0, effectiveStates)].sqrt(); + var SrSqrt = S[TensorIndex.Slice(0, effectiveStates)].diag().sqrt(); var Vrt = Vt[TensorIndex.Slice(0, effectiveStates)]; // Estimate observability matrix @@ -808,9 +801,9 @@ public static StochasticSubspaceIdentificationResult StochasticSubspaceIdentific var measurementFunction = observability[TensorIndex.Slice(0, numObs)]; // Estimate state sequence - var states = SrSqrt.diag().matmul(Vrt); + var states = SrSqrt.matmul(Vrt); - // Estimate transition matrix A using shifted states + // Estimate transition matrix using shifted states var statesShifted = states[TensorIndex.Colon, TensorIndex.Slice(0, numCols - 1)]; var statesNext = states[TensorIndex.Colon, TensorIndex.Slice(1, numCols)]; @@ -822,7 +815,7 @@ public static StochasticSubspaceIdentificationResult StochasticSubspaceIdentific var stateResiduals = statesNext - transitionMatrix.matmul(statesShifted); var processNoiseCovariance = WrappedTensorDisposeScope(() => stateResiduals.matmul(stateResiduals.mT) / (numCols - 1)); - // Vectorized computation of observation residuals + // Compute the observation residuals var observationPredictions = measurementFunction.matmul(states); var observationWindow = observations[TensorIndex.Slice(maxLag, maxLag + numCols)].mT; var observationResiduals = observationWindow - observationPredictions; diff --git a/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs b/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs index 694cdffe..7092442d 100644 --- a/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs +++ b/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs @@ -35,7 +35,7 @@ public int MaxLag public double Threshold { get => _threshold; - set => _threshold = value >= 0 ? value : throw new ArgumentOutOfRangeException(nameof(Threshold), "Must be greater than or equal to zero."); + set => _threshold = value >= 0 && value < 1 ? value : throw new ArgumentOutOfRangeException(nameof(Threshold), "Must be greater than or equal to zero and less than one."); } /// @@ -81,21 +81,27 @@ public double Threshold /// public IObservable Process(IObservable source) { - return source.Select(input => + return source.Select(input => Observable.Create(observer => { - var parametersToEstimate = new ParametersToEstimate( - transitionMatrix: EstimateTransitionMatrix, - measurementFunction: EstimateMeasurementFunction, - processNoiseCovariance: EstimateProcessNoiseCovariance, - measurementNoiseCovariance: EstimateMeasurementNoiseCovariance, - initialMean: EstimateInitialMean, - initialCovariance: EstimateInitialCovariance); + return Task.Run(() => + { + var parametersToEstimate = new ParametersToEstimate( + transitionMatrix: EstimateTransitionMatrix, + measurementFunction: EstimateMeasurementFunction, + processNoiseCovariance: EstimateProcessNoiseCovariance, + measurementNoiseCovariance: EstimateMeasurementNoiseCovariance, + initialMean: EstimateInitialMean, + initialCovariance: EstimateInitialCovariance); - return KalmanFilter.StochasticSubspaceIdentification( - observations: input, - maxLag: MaxLag, - threshold: Threshold, - parametersToEstimate: parametersToEstimate); - }); + observer.OnNext(KalmanFilter.StochasticSubspaceIdentification( + observations: input, + maxLag: MaxLag, + threshold: Threshold, + parametersToEstimate: parametersToEstimate)); + + observer.OnCompleted(); + return System.Reactive.Disposables.Disposable.Empty; + }); + })).Concat(); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentificationResult.cs b/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentificationResult.cs index 34a5c63b..b46d16c4 100644 --- a/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentificationResult.cs +++ b/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentificationResult.cs @@ -10,7 +10,7 @@ namespace Bonsai.ML.Lds.Torch; /// The singular values from the SVD decomposition. public struct StochasticSubspaceIdentificationResult( KalmanFilterParameters parameters, - int effectiveStates, + long effectiveStates, Tensor singularValues) { /// @@ -21,7 +21,7 @@ public struct StochasticSubspaceIdentificationResult( /// /// The effective states of the system determined by SVD truncation. /// - public int EffectiveStates = effectiveStates; + public long EffectiveStates = effectiveStates; /// /// The singular values from the SVD decomposition of the Hankel matrix. From 152875fe27c1a45f903c2fe59abb6f6e06e539ef Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 15 Oct 2025 23:45:09 +0100 Subject: [PATCH 58/92] Refactored implementation to use explicit `KalmanFilter` property which is better suited for online parameter estimation with static methods --- src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs | 56 ++----- .../ExpectationMaximization.cs | 78 ++++++++-- src/Bonsai.ML.Lds.Torch/Filter.cs | 16 +- src/Bonsai.ML.Lds.Torch/KalmanFilter.cs | 143 +++++++++++++++++- .../KalmanFilterModelManager.cs | 106 ------------- .../KalmanFilterNameConverter.cs | 44 ------ src/Bonsai.ML.Lds.Torch/Orthogonalize.cs | 25 ++- .../SaveKalmanFilterParameters.cs | 14 +- src/Bonsai.ML.Lds.Torch/Smooth.cs | 19 +-- .../StochasticSubspaceIdentification.cs | 12 ++ src/Bonsai.ML.Lds.Torch/UpdateParameters.cs | 15 +- .../NeuralLatentsTest.cs | 24 +++ 12 files changed, 288 insertions(+), 264 deletions(-) delete mode 100644 src/Bonsai.ML.Lds.Torch/KalmanFilterModelManager.cs delete mode 100644 src/Bonsai.ML.Lds.Torch/KalmanFilterNameConverter.cs diff --git a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs index 63a150e4..8bf1745e 100644 --- a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs +++ b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs @@ -16,17 +16,10 @@ namespace Bonsai.ML.Lds.Torch; [WorkflowElementCategory(ElementCategory.Source)] public class CreateKalmanFilter : IScalarTypeProvider { - /// - /// A unique name for the Kalman filter model. - /// - [Category("Required Parameters")] - public string Name { get; set; } = "KalmanFilter"; - private ScalarType _scalarType = ScalarType.Float32; /// [Description("The data type of the tensor elements.")] [TypeConverter(typeof(ScalarTypeConverter))] - [Category("Required Parameters")] public ScalarType Type { get => _scalarType; @@ -42,14 +35,12 @@ public ScalarType Type /// [Description("The device on which to create the tensor.")] [XmlIgnore] - [Category("Required Parameters")] public Device Device { get; set; } private int _numStates = 2; /// /// The number of states in the Kalman filter model. /// - [Category("Required Parameters")] public int NumStates { get => _numStates; @@ -60,7 +51,6 @@ public int NumStates /// /// The number of observations in the Kalman filter model. /// - [Category("Required Parameters")] public int NumObservations { get => _numObservations; @@ -74,7 +64,6 @@ public int NumObservations /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - [Category("Optional Parameters")] public Tensor TransitionMatrix { get => _transitionMatrix; @@ -99,7 +88,6 @@ public string TransitionMatrixXml /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - [Category("Optional Parameters")] public Tensor MeasurementFunction { get => _measurementFunction; @@ -124,7 +112,6 @@ public string MeasurementFunctionXml /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - [Category("Optional Parameters")] public Tensor ProcessNoiseVariance { get => _processNoiseVariance; @@ -149,7 +136,6 @@ public string ProcessNoiseVarianceXml /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - [Category("Optional Parameters")] public Tensor MeasurementNoiseVariance { get => _measurementNoiseVariance; @@ -174,7 +160,6 @@ public string MeasurementNoiseVarianceXml /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - [Category("Optional Parameters")] public Tensor InitialMean { get => _initialMean; @@ -199,7 +184,6 @@ public string InitialMeanXml /// [XmlIgnore] [TypeConverter(typeof(TensorConverter))] - [Category("Optional Parameters")] public Tensor InitialCovariance { get => _initialCovariance; @@ -231,42 +215,34 @@ private void ConvertTensorsScalarType(ScalarType scalarType) /// /// Creates a Kalman filter model using the properties of this class. /// - public IObservable Process() + public IObservable Process() { - return Observable.Using(() => KalmanFilterModelManager.Reserve( - name: Name, - numStates: _numStates, - numObservations: _numObservations, - transitionMatrix: _transitionMatrix, - measurementFunction: _measurementFunction, - initialMean: _initialMean, - initialCovariance: _initialCovariance, - processNoiseVariance: _processNoiseVariance, - measurementNoiseVariance: _measurementNoiseVariance, - device: Device, - scalarType: _scalarType - ), resource => Observable.Return(resource.Model) - .Concat(Observable.Never(resource.Model)) - .Finally(resource.Dispose) - ); + return Observable.Return(new KalmanFilter( + numStates: NumStates, + numObservations: NumObservations, + transitionMatrix: TransitionMatrix, + measurementFunction: MeasurementFunction, + processNoiseVariance: ProcessNoiseVariance, + measurementNoiseVariance: MeasurementNoiseVariance, + initialMean: InitialMean, + initialCovariance: InitialCovariance, + device: Device, + scalarType: _scalarType + )); } /// /// Creates a Kalman filter model using the parameters provided in the input sequence. /// - public IObservable Process(IObservable source) + public IObservable Process(IObservable source) { return source.SelectMany(parameters => { - return Observable.Using(() => KalmanFilterModelManager.Reserve( - name: Name, + return Observable.Return(new KalmanFilter( parameters: parameters, device: Device, scalarType: _scalarType - ), resource => Observable.Return(resource.Model) - .Concat(Observable.Never(resource.Model)) - .Finally(resource.Dispose) - ); + )); }); } } diff --git a/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs index af105949..dd201fc4 100644 --- a/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs +++ b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs @@ -1,6 +1,7 @@ using System; using System.ComponentModel; using System.Reactive; +using System.Xml.Serialization; using System.Linq; using System.Reactive.Linq; using TorchSharp; @@ -13,16 +14,39 @@ namespace Bonsai.ML.Lds.Torch; /// Learn the parameters of a kalman filter using the batch EM update algorithm. /// [Combinator] +[ResetCombinator] [Description("Learn the parameters of a kalman filter using the batch EM update algorithm.")] [WorkflowElementCategory(ElementCategory.Combinator)] public class ExpectationMaximization { + private int _numStates = 2; /// - /// The name of the Kalman filter model to be trained. + /// The number of states in the Kalman filter model. /// - [TypeConverter(typeof(KalmanFilterNameConverter))] - [Description("The name of the Kalman filter model to be trained.")] - public string ModelName { get; set; } = "KalmanFilter"; + [Description("The number of states in the Kalman filter model.")] + public int NumStates + { + get => _numStates; + set => _numStates = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(value), "Number of states must be greater than zero."); + } + + private int _numObservations = 10; + /// + /// The number of observations in the Kalman filter model. + /// + [Description("The number of observations in the Kalman filter model.")] + public int NumObservations + { + get => _numObservations; + set => _numObservations = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(value), "Number of observations must be greater than zero."); + } + + /// + /// The Kalman filter parameters used to initialize the model. + /// + [Description("The Kalman filter parameters used to initialize the model.")] + [XmlIgnore] + public KalmanFilterParameters ModelParameters { get; set; } = new(); private int _maxIterations = 10; /// @@ -104,9 +128,9 @@ public IObservable Process(IObservable so { return Task.Run(() => { - var model = KalmanFilterModelManager.GetKalmanFilter(ModelName); var previousLogLikelihood = double.NegativeInfinity; - var logLikelihood = zeros(new long[] { MaxIterations }, device: input.device); + var logLikelihood = zeros([MaxIterations], device: input.device); + var maxIterationsReached = false; var parametersToEstimate = new ParametersToEstimate( transitionMatrix: EstimateTransitionMatrix, @@ -116,6 +140,15 @@ public IObservable Process(IObservable so initialMean: EstimateInitialMean, initialCovariance: EstimateInitialCovariance); + var parameters = new KalmanFilterParameters( + transitionMatrix: ModelParameters.TransitionMatrix, + measurementFunction: ModelParameters.MeasurementFunction, + processNoiseCovariance: ModelParameters.ProcessNoiseCovariance, + measurementNoiseCovariance: ModelParameters.MeasurementNoiseCovariance, + initialMean: ModelParameters.InitialMean, + initialCovariance: ModelParameters.InitialCovariance + ); + for (int i = 0; i < MaxIterations; i++) { // Check for cancellation before each iteration @@ -125,7 +158,16 @@ public IObservable Process(IObservable so return System.Reactive.Disposables.Disposable.Empty; } - var result = model.ExpectationMaximization(input, 1, Tolerance, parametersToEstimate, false); + var result = KalmanFilter.ExpectationMaximization( + observation: input, + numStates: _numStates, + numObservations: _numObservations, + parameters: parameters, + maxIterations: MaxIterations, + tolerance: Tolerance, + parametersToEstimate: parametersToEstimate, + device: input.device, + scalarType: input.dtype); var logLikelihoodSum = result.LogLikelihood .cpu() @@ -140,6 +182,7 @@ public IObservable Process(IObservable so if (i == MaxIterations - 1) { Console.WriteLine("EM reached the maximum number of iterations."); + maxIterationsReached = true; } } @@ -149,21 +192,26 @@ public IObservable Process(IObservable so { Console.WriteLine("EM converged after " + (i + 1) + " iterations."); } - logLikelihood = logLikelihood[torch.TensorIndex.Slice(0, i + 1)]; + logLikelihood = logLikelihood[TensorIndex.Slice(0, i + 1)]; break; } - previousLogLikelihood = logLikelihoodSum; - model.UpdateParameters(result.Parameters); - observer.OnNext(new ExpectationMaximizationResult( - logLikelihood: logLikelihood[torch.TensorIndex.Slice(0, i + 1)], - parameters: model.Parameters, - finished: false)); + parameters = result.Parameters; + + if (!maxIterationsReached) + { + previousLogLikelihood = logLikelihoodSum; + + observer.OnNext(new ExpectationMaximizationResult( + logLikelihood: logLikelihood[TensorIndex.Slice(0, i + 1)], + parameters: parameters, + finished: false)); + } } observer.OnNext(new ExpectationMaximizationResult( logLikelihood: logLikelihood, - parameters: model.Parameters, + parameters: parameters, finished: true)); observer.OnCompleted(); diff --git a/src/Bonsai.ML.Lds.Torch/Filter.cs b/src/Bonsai.ML.Lds.Torch/Filter.cs index 4fcd7058..699d6dcc 100644 --- a/src/Bonsai.ML.Lds.Torch/Filter.cs +++ b/src/Bonsai.ML.Lds.Torch/Filter.cs @@ -1,6 +1,7 @@ using System; using System.ComponentModel; using System.Reactive.Linq; +using System.Xml.Serialization; using static TorchSharp.torch; namespace Bonsai.ML.Lds.Torch; @@ -9,26 +10,23 @@ namespace Bonsai.ML.Lds.Torch; /// Applies a Kalman filter to the input tensor sequence. /// [Combinator] +[ResetCombinator] [Description("Applies a Kalman filter to the input tensor sequence.")] [WorkflowElementCategory(ElementCategory.Transform)] public class Filter { /// - /// The name of the Kalman filter model to be used. + /// The Kalman filter model. /// - [TypeConverter(typeof(KalmanFilterNameConverter))] - [Description("The name of the Kalman filter model to be used.")] - public string ModelName { get; set; } = "KalmanFilter"; + [Description("The Kalman filter model.")] + [XmlIgnore] + public KalmanFilter Model { get; set; } /// /// Processes an observable sequence of input tensors, applying the Kalman filter to each tensor. /// public IObservable Process(IObservable source) { - return source.Select((input) => - { - var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); - return kalmanFilter.Filter(input); - }); + return source.Select(Model.Filter); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs index 039dce23..ec3ad9be 100644 --- a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs @@ -4,7 +4,10 @@ namespace Bonsai.ML.Lds.Torch; -internal class KalmanFilter : nn.Module +// disable missing XML comment warnings +# pragma warning disable CS1591 + +public class KalmanFilter : nn.Module { private readonly Tensor _transitionMatrix; private readonly Tensor _measurementFunction; @@ -57,7 +60,7 @@ public KalmanFilter( _initialCovariance = parameters.InitialCovariance?.clone().to_type(_scalarType).to(_device).requires_grad_(false) ?? eye(_numStates, dtype: _scalarType, device: _device).requires_grad_(false); ValidateMatrix(_initialCovariance, "Initial covariance", isSquare: true, expectedDimension1: _numStates); - + _processNoiseCovariance = parameters.ProcessNoiseCovariance ?? CreateCovarianceMatrix(tensor(1.0), _scalarType, _device, _numStates, "Process noise variance"); _measurementNoiseCovariance = parameters.MeasurementNoiseCovariance ?? CreateCovarianceMatrix(tensor(1.0), _scalarType, _device, _numObservations, "Measurement noise variance"); @@ -181,7 +184,7 @@ private static void ValidateMatrix(Tensor matrix, string name, bool isSquare = f { if (matrix.NumberOfElements == 0) throw new ArgumentException($"{name} must be a non-empty matrix."); - + if (matrix.Dimensions != 2) throw new ArgumentException($"{name} must be 2-dimensional."); @@ -754,8 +757,137 @@ public ExpectationMaximizationResult ExpectationMaximization( return new ExpectationMaximizationResult(logLikelihood, updatedParameters); } + public static ExpectationMaximizationResult ExpectationMaximization( + Tensor observation, + int numStates, + int numObservations, + KalmanFilterParameters parameters, + int maxIterations = 100, + double tolerance = 1e-4, + ParametersToEstimate parametersToEstimate = new(), + Device device = null, + ScalarType scalarType = ScalarType.Float32) + { + device ??= CPU; + + var timeBins = observation.size(0); + var logLikelihood = empty(maxIterations, dtype: ScalarType.Float32, device: device); + var previousLogLikelihood = double.NegativeInfinity; + var logLikelihoodConst = -0.5 * timeBins * numObservations * Math.Log(2.0 * Math.PI); + + var transitionMatrix = parameters.TransitionMatrix; + var measurementFunction = parameters.MeasurementFunction; + var processNoiseCovariance = parameters.ProcessNoiseCovariance; + var measurementNoiseCovariance = parameters.MeasurementNoiseCovariance; + var initialMean = parameters.InitialMean; + var initialCovariance = parameters.InitialCovariance; + + var identityStates = eye(numStates, dtype: scalarType, device: device); + + // Precompute constant observation terms reused across EM iterations + var observationT = observation.mT; + var autoCorrelationObservations = observationT.matmul(observation); + + using (var _ = no_grad()) + { + for (int iteration = 0; iteration < maxIterations; iteration++) + { + // Filter observations + var filteredState = Filter( + observation: observation, + timeBins: timeBins, + numStates: numStates, + numObservations: numObservations, + transitionMatrix: transitionMatrix, + measurementFunction: measurementFunction, + processNoiseCovariance: processNoiseCovariance, + measurementNoiseCovariance: measurementNoiseCovariance, + initialMean: initialMean, + initialCovariance: initialCovariance, + scalarType: scalarType, + device: device); + + // Compute log likelihood (avoid creating intermediate tensors) + var llSumDouble = filteredState.LogLikelihood.sum() + .to_type(ScalarType.Float64).item(); + var filteredLogLikelihoodSum = logLikelihoodConst + 0.5 * llSumDouble; + + logLikelihood[iteration] = filteredLogLikelihoodSum; + + // Check for convergence + if (filteredLogLikelihoodSum <= previousLogLikelihood) + { + Console.WriteLine($"Warning: Log likelihood decreased! New: {filteredLogLikelihoodSum}, Previous: {previousLogLikelihood}"); + break; + } + + if (filteredLogLikelihoodSum - previousLogLikelihood < tolerance) + break; + + previousLogLikelihood = filteredLogLikelihoodSum; + + // Smooth the filtered results + var smoothedState = Smooth( + filteredState: filteredState, + timeBins: timeBins, + numStates: numStates, + transitionMatrix: transitionMatrix, + measurementFunction: measurementFunction, + initialMean: initialMean, + initialCovariance: initialCovariance, + identityStates: identityStates, + scalarType: scalarType, + device: device); + + // Sufficient statistics + var S00 = smoothedState.S00.sum([0]); + var S11 = smoothedState.S11.sum([0]); + var S10 = smoothedState.S10.sum([0]); + + // Replace einsum with faster matmul + var crossCorrelationObservations = observationT.matmul(smoothedState.SmoothedMean); + + // Update parameters + if (parametersToEstimate.TransitionMatrix) + transitionMatrix = InverseCholesky(S10, S00); + + if (parametersToEstimate.MeasurementFunction) + measurementFunction = InverseCholesky(crossCorrelationObservations, S11); + + if (parametersToEstimate.ProcessNoiseCovariance) + processNoiseCovariance = WrappedTensorDisposeScope(() => + EnsureSymmetric((S11 - transitionMatrix.matmul(S10.mT)) / timeBins)); + + var explainedObservationCovariance = measurementFunction.matmul(crossCorrelationObservations.mT); + + if (parametersToEstimate.MeasurementNoiseCovariance) + measurementNoiseCovariance = WrappedTensorDisposeScope(() => + EnsureSymmetric((autoCorrelationObservations - explainedObservationCovariance - explainedObservationCovariance.mT + + measurementFunction.matmul(S11).matmul(measurementFunction.mT)) / timeBins)); + + if (parametersToEstimate.InitialMean) + initialMean = smoothedState.SmoothedInitialMean; + + if (parametersToEstimate.InitialCovariance) + initialCovariance = smoothedState.SmoothedInitialCovariance; + } + } + + var updatedParameters = new KalmanFilterParameters( + transitionMatrix: transitionMatrix, + measurementFunction: measurementFunction, + processNoiseCovariance: processNoiseCovariance, + measurementNoiseCovariance: measurementNoiseCovariance, + initialMean: initialMean, + initialCovariance: initialCovariance + ); + + return new ExpectationMaximizationResult(logLikelihood, updatedParameters); + } + public static StochasticSubspaceIdentificationResult StochasticSubspaceIdentification( Tensor observations, + int? targetNumStates = null, int maxLag = 20, double threshold = 0.01, ParametersToEstimate parametersToEstimate = new()) @@ -788,7 +920,8 @@ public static StochasticSubspaceIdentificationResult StochasticSubspaceIdentific // Compute the effective rank var effectiveRank = (S > (threshold * S[0])).to_type(ScalarType.Int64).sum().item(); - var effectiveStates = Math.Max(effectiveRank, 1); + + var effectiveStates = Math.Min(effectiveRank, targetNumStates ?? effectiveRank); var Ur = U[TensorIndex.Colon, TensorIndex.Slice(0, effectiveStates)]; var SrSqrt = S[TensorIndex.Slice(0, effectiveStates)].diag().sqrt(); @@ -806,7 +939,7 @@ public static StochasticSubspaceIdentificationResult StochasticSubspaceIdentific // Estimate transition matrix using shifted states var statesShifted = states[TensorIndex.Colon, TensorIndex.Slice(0, numCols - 1)]; var statesNext = states[TensorIndex.Colon, TensorIndex.Slice(1, numCols)]; - + var transitionMatrix = WrappedTensorDisposeScope(() => InverseCholesky( statesNext.matmul(statesShifted.mT), statesShifted.matmul(statesShifted.mT))); diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilterModelManager.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilterModelManager.cs deleted file mode 100644 index 1196aa83..00000000 --- a/src/Bonsai.ML.Lds.Torch/KalmanFilterModelManager.cs +++ /dev/null @@ -1,106 +0,0 @@ -using System; -using System.Reactive.Disposables; -using System.Threading; -using System.Runtime.CompilerServices; -using System.Collections.Generic; -using static TorchSharp.torch; -using Bonsai.ML.Lds.Torch; - -// -// Manages instances of the Kalman Filter model with a thread-safe locking mechanism for reading state tensors and writing parameters. -// -internal sealed class KalmanFilterModelManager -{ - private static readonly Dictionary _models = new(); - - public static KalmanFilter GetKalmanFilter(string name) - { - return _models.TryGetValue(name, out var model) ? model : throw new InvalidOperationException($"Kalman filter with name {name} not found."); - } - - internal static KalmanFilterDisposable Reserve( - string name, - int numStates, - int numObservations, - Tensor transitionMatrix, - Tensor measurementFunction, - Tensor processNoiseVariance, - Tensor measurementNoiseVariance, - Tensor initialMean, - Tensor initialCovariance, - Device? device = null, - ScalarType? scalarType = null - ) - { - if (_models.ContainsKey(name)) - { - throw new InvalidOperationException($"A Kalman filter with name {name} already exists."); - } - - var kalmanFilter = new KalmanFilter( - numStates: numStates, - numObservations: numObservations, - transitionMatrix: transitionMatrix, - measurementFunction: measurementFunction, - processNoiseVariance: processNoiseVariance, - measurementNoiseVariance: measurementNoiseVariance, - initialMean: initialMean, - initialCovariance: initialCovariance, - device: device, - scalarType: scalarType ?? ScalarType.Float32 - ); - - _models.Add(name, kalmanFilter); - - return new KalmanFilterDisposable(kalmanFilter, Disposable.Create(() => - { - _models.Remove(name); - kalmanFilter.Dispose(); - })); - } - - internal static KalmanFilterDisposable Reserve( - string name, - KalmanFilterParameters parameters, - Device? device = null, - ScalarType? scalarType = null - ) - { - if (_models.ContainsKey(name)) - { - throw new InvalidOperationException($"A Kalman filter with name {name} already exists."); - } - - var kalmanFilter = new KalmanFilter( - parameters: parameters, - device: device, - scalarType: scalarType ?? ScalarType.Float32 - ); - - _models.Add(name, kalmanFilter); - - return new KalmanFilterDisposable(kalmanFilter, Disposable.Create(() => - { - _models.Remove(name); - kalmanFilter.Dispose(); - })); - } - - internal sealed class KalmanFilterDisposable(KalmanFilter model, IDisposable disposable) : IDisposable - { - private IDisposable? resource = disposable ?? throw new ArgumentNullException(nameof(disposable)); - - public bool IsDisposed => resource is null; - - private readonly KalmanFilter model = model ?? throw new ArgumentNullException(nameof(model)); - - public KalmanFilter Model => model; - - public void Dispose() - { - var disposable = Interlocked.Exchange(ref resource, null); - disposable?.Dispose(); - } - } -} - diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilterNameConverter.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilterNameConverter.cs deleted file mode 100644 index cfb881ac..00000000 --- a/src/Bonsai.ML.Lds.Torch/KalmanFilterNameConverter.cs +++ /dev/null @@ -1,44 +0,0 @@ -using Bonsai.Expressions; -using System.Linq; -using System.ComponentModel; - -namespace Bonsai.ML.Lds.Torch; - -/// -/// Provides a type converter to select the name of an existing Kalman filter model in the workflow. -/// -public class KalmanFilterNameConverter : StringConverter -{ - /// - public override bool GetStandardValuesSupported(ITypeDescriptorContext context) - { - return true; - } - - /// - public override StandardValuesCollection GetStandardValues(ITypeDescriptorContext context) - { - if (context != null) - { - var workflowBuilder = (WorkflowBuilder)context.GetService(typeof(WorkflowBuilder)); - if (workflowBuilder != null) - { - var models = (from builder in workflowBuilder.Workflow.Descendants() - where builder.GetType() != typeof(DisableBuilder) - let managedModelNode = ExpressionBuilder.GetWorkflowElement(builder) - where managedModelNode != null && managedModelNode is CreateKalmanFilter - let createKalmanFilter = (CreateKalmanFilter)managedModelNode - where createKalmanFilter != null && !string.IsNullOrEmpty(createKalmanFilter.Name) - select createKalmanFilter.Name) - .Distinct() - .ToList(); - if (models.Count > 0) - { - return new StandardValuesCollection(models); - } - } - } - - return new StandardValuesCollection(new string[] { }); - } -} \ No newline at end of file diff --git a/src/Bonsai.ML.Lds.Torch/Orthogonalize.cs b/src/Bonsai.ML.Lds.Torch/Orthogonalize.cs index 6fa67e53..68019d3e 100644 --- a/src/Bonsai.ML.Lds.Torch/Orthogonalize.cs +++ b/src/Bonsai.ML.Lds.Torch/Orthogonalize.cs @@ -1,6 +1,7 @@ using System; using System.ComponentModel; using System.Reactive.Linq; +using System.Xml.Serialization; using static TorchSharp.torch; namespace Bonsai.ML.Lds.Torch; @@ -9,16 +10,17 @@ namespace Bonsai.ML.Lds.Torch; /// Orthogonalizes the state and covariance estimates from a Kalman filter or smoother. /// [Combinator] +[ResetCombinator] [Description("Orthogonalizes the state and covariance estimates from a Kalman filter or smoother.")] [WorkflowElementCategory(ElementCategory.Transform)] public class Orthogonalize { /// - /// The name of the Kalman filter model to be used. + /// The Kalman filter model. /// - [TypeConverter(typeof(KalmanFilterNameConverter))] - [Description("The name of the Kalman filter model to be used.")] - public string ModelName { get; set; } = "KalmanFilter"; + [Description("The Kalman filter model.")] + [XmlIgnore] + public KalmanFilter Model { get; set; } /// /// Processes an observable sequence of smoothed results, orthogonalizing the mean and covariance estimates. @@ -29,10 +31,9 @@ public IObservable Process(IObservable sourc { return source.Select(input => { - var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); var smoothedMean = input.SmoothedMean; var smoothedCovariance = input.SmoothedCovariance; - return kalmanFilter.OrthogonalizeMeanAndCovariance(smoothedMean, smoothedCovariance); + return Model.OrthogonalizeMeanAndCovariance(smoothedMean, smoothedCovariance); }); } @@ -45,10 +46,9 @@ public IObservable Process(IObservable sourc { return source.Select(input => { - var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); var filteredMean = input.UpdatedMean; var filteredCovariance = input.UpdatedCovariance; - return kalmanFilter.OrthogonalizeMeanAndCovariance(filteredMean, filteredCovariance); + return Model.OrthogonalizeMeanAndCovariance(filteredMean, filteredCovariance); }); } @@ -61,10 +61,9 @@ public IObservable Process(IObservable source) { return source.Select(input => { - var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); var mean = input.Mean; var covariance = input.Covariance; - return kalmanFilter.OrthogonalizeMeanAndCovariance(mean, covariance); + return Model.OrthogonalizeMeanAndCovariance(mean, covariance); }); } @@ -77,10 +76,9 @@ public IObservable Process(IObservable source) { return source.Select(input => { - var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); var mean = input.Mean; var covariance = input.Covariance; - return kalmanFilter.OrthogonalizeMeanAndCovariance(mean, covariance); + return Model.OrthogonalizeMeanAndCovariance(mean, covariance); }); } @@ -93,10 +91,9 @@ public IObservable Process(IObservable { - var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); var mean = input.Item1; var covariance = input.Item2; - return kalmanFilter.OrthogonalizeMeanAndCovariance(mean, covariance); + return Model.OrthogonalizeMeanAndCovariance(mean, covariance); }); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Lds.Torch/SaveKalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/SaveKalmanFilterParameters.cs index 198219f5..af8a2182 100644 --- a/src/Bonsai.ML.Lds.Torch/SaveKalmanFilterParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/SaveKalmanFilterParameters.cs @@ -66,13 +66,6 @@ public class SaveKalmanFilterParameters [Editor("Bonsai.Design.SaveFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] public string InitialCovarianceFilePath { get; set; } = "initial_covariance.bin"; - /// - /// The name of the Kalman filter model to be used. - /// - [TypeConverter(typeof(KalmanFilterNameConverter))] - [Description("The name of the Kalman filter model to be used.")] - public string ModelName { get; set; } = "KalmanFilter"; - private void SaveTensorToFile(Tensor tensor, string filePath) { if (filePath != null) @@ -84,12 +77,11 @@ private void SaveTensorToFile(Tensor tensor, string filePath) /// /// Creates parameters for a Kalman filter model using the properties of this class. /// - public IObservable Process(IObservable source) + public IObservable Process(IObservable source) { - return source.Do(input => + return source.Do(model => { - var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); - var parameters = kalmanFilter.Parameters; + var parameters = model.Parameters; SaveTensorToFile(parameters.TransitionMatrix, TransitionMatrixFilePath); SaveTensorToFile(parameters.MeasurementFunction, MeasurementFunctionFilePath); SaveTensorToFile(parameters.ProcessNoiseCovariance, ProcessNoiseCovarianceFilePath); diff --git a/src/Bonsai.ML.Lds.Torch/Smooth.cs b/src/Bonsai.ML.Lds.Torch/Smooth.cs index c37de3e6..10abd869 100644 --- a/src/Bonsai.ML.Lds.Torch/Smooth.cs +++ b/src/Bonsai.ML.Lds.Torch/Smooth.cs @@ -1,6 +1,7 @@ using System; using System.ComponentModel; using System.Reactive.Linq; +using System.Xml.Serialization; using static TorchSharp.torch; namespace Bonsai.ML.Lds.Torch; @@ -9,16 +10,17 @@ namespace Bonsai.ML.Lds.Torch; /// Applies a Kalman smoother to the input filtered result sequence. /// [Combinator] +[ResetCombinator] [Description("Applies a Kalman smoother to the input filtered result sequence.")] [WorkflowElementCategory(ElementCategory.Transform)] public class Smooth { /// - /// The name of the Kalman filter model to be used. + /// The Kalman filter model. /// - [TypeConverter(typeof(KalmanFilterNameConverter))] - [Description("The name of the Kalman filter model to be used.")] - public string ModelName { get; set; } = "KalmanFilter"; + [Description("The Kalman filter model.")] + [XmlIgnore] + public KalmanFilter Model { get; set; } /// /// Processes an observable sequence of filtered results, applying the Kalman smoother to each result. @@ -27,11 +29,7 @@ public class Smooth /// public IObservable Process(IObservable source) { - return source.Select((input) => - { - var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); - return kalmanFilter.Smooth(input); - }); + return source.Select(Model.Smooth); } /// @@ -43,9 +41,8 @@ public IObservable Process(IObservable { - var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); var filteredState = new FilteredState(input.Item1, input.Item2, input.Item3, input.Item4); - return kalmanFilter.Smooth(filteredState); + return Model.Smooth(filteredState); }); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs b/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs index 7092442d..01d0c8e8 100644 --- a/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs +++ b/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs @@ -16,6 +16,17 @@ namespace Bonsai.ML.Lds.Torch; [WorkflowElementCategory(ElementCategory.Combinator)] public class StochasticSubspaceIdentification { + private int? _targetNumStates = 2; + /// + /// The target number of states in the Kalman filter model. + /// + [Description("The target number of states in the Kalman filter model.")] + public int? TargetNumStates + { + get => _targetNumStates; + set => _targetNumStates = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(value), "Number of states must be greater than zero."); + } + private int _maxLag = 20; /// /// The maximum lag to consider for the subspace identification. @@ -95,6 +106,7 @@ public IObservable Process(IObservable - /// The name of the Kalman filter model to be used. + /// The Kalman filter model. /// - [TypeConverter(typeof(KalmanFilterNameConverter))] - [Description("The name of the Kalman filter model to be used.")] - public string ModelName { get; set; } = "KalmanFilter"; + [XmlIgnore] + [Description("The Kalman filter model.")] + public KalmanFilter Model { get; set; } /// /// Updates the parameters of a Kalman filter model using elements from the input sequence. @@ -27,10 +28,6 @@ public class UpdateParameters /// public IObservable Process(IObservable source) { - return source.Do((input) => - { - var kalmanFilter = KalmanFilterModelManager.GetKalmanFilter(ModelName); - kalmanFilter.UpdateParameters(input); - }); + return source.Do(Model.UpdateParameters); } } \ No newline at end of file diff --git a/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs index 132564a9..9a3a86c3 100644 --- a/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs +++ b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs @@ -4,6 +4,7 @@ using System.IO; using System.IO.Compression; using System.Net.Http; +using System.Reactive.Linq; using System.Runtime.InteropServices; using System.Threading.Tasks; using Bonsai.ML.Tests.Utilities; @@ -162,4 +163,27 @@ public void CompareTensorData() Assert.IsTrue(allclose(bonsaiMeans, pythonMeans)); Assert.IsTrue(allclose(bonsaiCovariances, pythonCovariances)); } + + [TestMethod] + public void TestSubspaceIdentification() + { + var observationsFileName = Path.Combine(basePath, "transformed_binned_spikes.pt"); + var observations = Tensor.Load(observationsFileName).T; + var stochasticSubspaceIdentification = new StochasticSubspaceIdentification + { + MaxLag = 20, + Threshold = 0.01, + EstimateTransitionMatrix = true, + EstimateMeasurementFunction = true, + EstimateProcessNoiseCovariance = true, + EstimateMeasurementNoiseCovariance = true, + EstimateInitialMean = true, + EstimateInitialCovariance = true + }; + + StochasticSubspaceIdentificationResult? result = null; + var subscription = stochasticSubspaceIdentification.Process(Observable.Return(observations)).Subscribe(r => result = r); + + Console.WriteLine($"Estimated effective states: {result?.EffectiveStates}"); + } } From c847d15c653254de504d21fa868a250a863b988d Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 17 Oct 2025 11:54:45 +0100 Subject: [PATCH 59/92] Ensure data are centered in SSID method --- src/Bonsai.ML.Lds.Torch/KalmanFilter.cs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs index ec3ad9be..e9aecefa 100644 --- a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs @@ -676,7 +676,7 @@ public ExpectationMaximizationResult ExpectationMaximization( scalarType: _scalarType, device: _device); - // Compute log likelihood (avoid creating intermediate tensors) + // Compute log likelihood var llSumDouble = filteredState.LogLikelihood.sum() .to_type(ScalarType.Float64).item(); var filteredLogLikelihoodSum = logLikelihoodConst + 0.5 * llSumDouble; @@ -713,7 +713,7 @@ public ExpectationMaximizationResult ExpectationMaximization( var S11 = smoothedState.S11.sum([0]); var S10 = smoothedState.S10.sum([0]); - // Replace einsum with faster matmul + // Compute cross-correlation between observations and smoothed states var crossCorrelationObservations = observationT.matmul(smoothedState.SmoothedMean); // Update parameters @@ -896,6 +896,7 @@ public static StochasticSubspaceIdentificationResult StochasticSubspaceIdentific var timeBins = observations.size(0); var numObs = observations.size(1); + var centered = observations - observations.mean([0], keepdim: true); // Build Hankel matrices from observations var numCols = (int)(timeBins - 2 * maxLag + 1); @@ -903,11 +904,11 @@ public static StochasticSubspaceIdentificationResult StochasticSubspaceIdentific if (numCols <= 0) throw new ArgumentException($"Number of time bins ({timeBins}) must be greater than 2*maxLag ({2 * maxLag}) for subspace identification."); - var stride = observations.stride(); - var pastView = observations.as_strided([maxLag, numCols, numObs], [stride[0], stride[0], stride[1]]); + var stride = centered.stride(); + var pastView = centered.as_strided([maxLag, numCols, numObs], [stride[0], stride[0], stride[1]]); var past = pastView.permute(0, 2, 1).reshape(maxLag * numObs, numCols); - var futureView = observations.narrow(0, maxLag, timeBins - maxLag) + var futureView = centered.narrow(0, maxLag, timeBins - maxLag) .as_strided([maxLag, numCols, numObs], [stride[0], stride[0], stride[1]]); var future = futureView.permute(0, 2, 1).reshape(maxLag * numObs, numCols); @@ -920,8 +921,7 @@ public static StochasticSubspaceIdentificationResult StochasticSubspaceIdentific // Compute the effective rank var effectiveRank = (S > (threshold * S[0])).to_type(ScalarType.Int64).sum().item(); - - var effectiveStates = Math.Min(effectiveRank, targetNumStates ?? effectiveRank); + var effectiveStates = Math.Max(Math.Min(effectiveRank, targetNumStates ?? effectiveRank), 1); var Ur = U[TensorIndex.Colon, TensorIndex.Slice(0, effectiveStates)]; var SrSqrt = S[TensorIndex.Slice(0, effectiveStates)].diag().sqrt(); @@ -946,13 +946,13 @@ public static StochasticSubspaceIdentificationResult StochasticSubspaceIdentific // Estimate noise covariances using residuals var stateResiduals = statesNext - transitionMatrix.matmul(statesShifted); - var processNoiseCovariance = WrappedTensorDisposeScope(() => stateResiduals.matmul(stateResiduals.mT) / (numCols - 1)); + var processNoiseCovariance = WrappedTensorDisposeScope(() => EnsureSymmetric(stateResiduals.matmul(stateResiduals.mT) / (numCols - 1))); // Compute the observation residuals var observationPredictions = measurementFunction.matmul(states); - var observationWindow = observations[TensorIndex.Slice(maxLag, maxLag + numCols)].mT; + var observationWindow = centered[TensorIndex.Slice(maxLag, maxLag + numCols)].mT; var observationResiduals = observationWindow - observationPredictions; - var measurementNoiseCovariance = WrappedTensorDisposeScope(() => observationResiduals.matmul(observationResiduals.mT) / numCols); + var measurementNoiseCovariance = WrappedTensorDisposeScope(() => EnsureSymmetric(observationResiduals.matmul(observationResiduals.mT) / numCols)); // Initial state estimates var initialMean = states[TensorIndex.Colon, 0]; From 3f55f63dfc621c056e0e4b0da898e593c3ad8055 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 21 Oct 2025 12:01:04 +0100 Subject: [PATCH 60/92] Updated test workflow after changing packge to use explicit model property --- .../NeuralLatentsTest.bonsai | 97 ++++++++++++++----- .../NeuralLatentsTest.cs | 23 ----- 2 files changed, 72 insertions(+), 48 deletions(-) diff --git a/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai index 4711470a..eb0bf2fd 100644 --- a/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai +++ b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai @@ -14,6 +14,7 @@ transformed_binned_spikes.pt + true @@ -34,6 +35,7 @@ python_V0_0.pt + true @@ -46,6 +48,7 @@ python_m0_0.pt + true @@ -58,6 +61,7 @@ python_Z0.pt + true @@ -70,6 +74,7 @@ python_R0.pt + true @@ -82,6 +87,7 @@ python_B0.pt + true @@ -94,6 +100,7 @@ python_Q0.pt + true @@ -171,10 +178,9 @@ - KalmanFilter + Float64 2 2 - Float64 @@ -210,9 +216,21 @@ ObservationT + + KalmanFilterModel + + + Parameters + + + + + + - KalmanFilter + 2 + 10 1 0.1 true @@ -235,10 +253,13 @@ - + - + + + + @@ -249,10 +270,16 @@ ObservationT + + KalmanFilterModel + + + + + + - - KalmanFilter - + UpdatedFilteredResult @@ -266,10 +293,16 @@ UpdatedFilteredResult + + KalmanFilterModel + + + + + + - - KalmanFilter - + UpdatedSmoothedResult @@ -277,10 +310,16 @@ UpdatedSmoothedResult + + KalmanFilterModel + + + + + + - - KalmanFilter - + OrthogonalizedResult @@ -294,6 +333,7 @@ bonsai_means.pt + true @@ -305,6 +345,7 @@ bonsai_covs.pt + true @@ -313,21 +354,27 @@ - + - - - - + + + + + - - - - - + + + + + - + + + + + + diff --git a/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs index 9a3a86c3..35fe6228 100644 --- a/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs +++ b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs @@ -163,27 +163,4 @@ public void CompareTensorData() Assert.IsTrue(allclose(bonsaiMeans, pythonMeans)); Assert.IsTrue(allclose(bonsaiCovariances, pythonCovariances)); } - - [TestMethod] - public void TestSubspaceIdentification() - { - var observationsFileName = Path.Combine(basePath, "transformed_binned_spikes.pt"); - var observations = Tensor.Load(observationsFileName).T; - var stochasticSubspaceIdentification = new StochasticSubspaceIdentification - { - MaxLag = 20, - Threshold = 0.01, - EstimateTransitionMatrix = true, - EstimateMeasurementFunction = true, - EstimateProcessNoiseCovariance = true, - EstimateMeasurementNoiseCovariance = true, - EstimateInitialMean = true, - EstimateInitialCovariance = true - }; - - StochasticSubspaceIdentificationResult? result = null; - var subscription = stochasticSubspaceIdentification.Process(Observable.Return(observations)).Subscribe(r => result = r); - - Console.WriteLine($"Estimated effective states: {result?.EffectiveStates}"); - } } From 2f4c6fb7388616097f4222b41af8d3f95d99367d Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 21 Oct 2025 13:27:09 +0100 Subject: [PATCH 61/92] Refactored EM to avoid potential mismatch between numStates, numObservations and inferred states and observations from parameters --- src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs | 2 -- src/Bonsai.ML.Lds.Torch/KalmanFilter.cs | 5 +++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs index dd201fc4..99d4a5e1 100644 --- a/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs +++ b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs @@ -160,8 +160,6 @@ public IObservable Process(IObservable so var result = KalmanFilter.ExpectationMaximization( observation: input, - numStates: _numStates, - numObservations: _numObservations, parameters: parameters, maxIterations: MaxIterations, tolerance: Tolerance, diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs index e9aecefa..28a11f5d 100644 --- a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs @@ -759,8 +759,6 @@ public ExpectationMaximizationResult ExpectationMaximization( public static ExpectationMaximizationResult ExpectationMaximization( Tensor observation, - int numStates, - int numObservations, KalmanFilterParameters parameters, int maxIterations = 100, double tolerance = 1e-4, @@ -770,6 +768,9 @@ public static ExpectationMaximizationResult ExpectationMaximization( { device ??= CPU; + ValidateNumStates(parameters.TransitionMatrix, parameters.MeasurementFunction, parameters.InitialMean, parameters.InitialCovariance, parameters.ProcessNoiseCovariance, out var numStates); + ValidateNumObservations(parameters.MeasurementFunction, parameters.MeasurementNoiseCovariance, out var numObservations); + var timeBins = observation.size(0); var logLikelihood = empty(maxIterations, dtype: ScalarType.Float32, device: device); var previousLogLikelihood = double.NegativeInfinity; From 4123370f60f13ce141800b0a5ead8c4046e00f03 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 21 Oct 2025 13:27:52 +0100 Subject: [PATCH 62/92] Updated workflow to correctly update the parameters of the model --- .../NeuralLatentsTest.bonsai | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai index eb0bf2fd..ab8f230e 100644 --- a/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai +++ b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai @@ -242,6 +242,20 @@ true + + Parameters + + + KalmanFilterModel + + + + + + + + + ExpectationMaximizationResult @@ -258,8 +272,12 @@ - - + + + + + + From b026f286682c7aa220a39d1c87c8089de01ac082 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 23 Oct 2025 11:43:11 +0100 Subject: [PATCH 63/92] Modified python test script to install packages without caching packages locally to free up space on test runner --- .../bootstrap_test_environment.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/Bonsai.ML.Lds.Torch.Tests/bootstrap_test_environment.py b/tests/Bonsai.ML.Lds.Torch.Tests/bootstrap_test_environment.py index 7340b4ea..08beff7c 100644 --- a/tests/Bonsai.ML.Lds.Torch.Tests/bootstrap_test_environment.py +++ b/tests/Bonsai.ML.Lds.Torch.Tests/bootstrap_test_environment.py @@ -52,12 +52,14 @@ def activate_venv(venv_path: str = None): os.environ["PATH"] = os.pathsep.join([bin_path, *os.environ.get("PATH", "").split(os.pathsep)]) sys.path.insert(0, os.path.join(venv_path, 'Lib', 'site-packages')) -def install(package: str, venv_path: str = None): +def install(venv_path: str = None, pip_args: list[str] = None): # function to install pip packages into a virtual environment if venv_path is None: venv_path = get_venv_path() pip_path = get_pip_path(venv_path) - subprocess.check_call([pip_path, "install", package]) + if pip_args is None: + raise ValueError("pip_args must be provided") + subprocess.check_call([pip_path, "install", *pip_args]) def install_requirements(requirements_file: str, venv_path: str = None): # function to install pip packages from a requirements file into a virtual environment @@ -74,11 +76,11 @@ def install_requirements(requirements_file: str, venv_path: str = None): venv_path = create_venv(base_dir) activate_venv(venv_path) -install("torch", venv_path) -install("plotly", venv_path) -install("remfile", venv_path) -install("dandi", venv_path) -install("ssm@git+https://github.com/ncguilbeault/lds_python@75e3e5e92ce6344009b62a5034db49b238db63ef", venv_path) +install(venv_path, ["--no-cache-dir", "torch"]) +install(venv_path, ["--no-cache-dir", "plotly"]) +install(venv_path, ["--no-cache-dir", "remfile"]) +install(venv_path, ["--no-cache-dir", "dandi"]) +install(venv_path, ["--no-cache-dir", "ssm@git+https://github.com/ncguilbeault/lds_python@75e3e5e92ce6344009b62a5034db49b238db63ef"]) python_path = get_python_path(venv_path) From 4cf1c760e01951612ed81b454347bcf6c1df7ca0 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 23 Oct 2025 18:24:59 +0100 Subject: [PATCH 64/92] Modified test case to download data and only run Bonsai script as opposed to running both the Python and Bonsai code to save memory --- .../Bonsai.ML.Lds.Torch.Tests.csproj | 2 +- .../NeuralLatentsTest.cs | 100 ++++-------------- 2 files changed, 22 insertions(+), 80 deletions(-) diff --git a/tests/Bonsai.ML.Lds.Torch.Tests/Bonsai.ML.Lds.Torch.Tests.csproj b/tests/Bonsai.ML.Lds.Torch.Tests/Bonsai.ML.Lds.Torch.Tests.csproj index 5c6cf23a..9893cf36 100644 --- a/tests/Bonsai.ML.Lds.Torch.Tests/Bonsai.ML.Lds.Torch.Tests.csproj +++ b/tests/Bonsai.ML.Lds.Torch.Tests/Bonsai.ML.Lds.Torch.Tests.csproj @@ -20,7 +20,7 @@ - + Always diff --git a/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs index 35fe6228..d997e3fb 100644 --- a/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs +++ b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs @@ -21,71 +21,28 @@ public class NeuralLatentsTest { private readonly string basePath = Path.Combine(AppDomain.CurrentDomain.BaseDirectory); - private static void RunPythonScript(string basePath) + private static void DownloadData(string basePath) { - var pythonExec = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) - ? "python" - : "python3"; - var scriptPath = Path.Combine(basePath, "bootstrap_test_environment.py"); - ProcessHelper.RunProcess(pythonExec, $"\"{scriptPath}\" {basePath}"); + string zipFileUrl = "https://zenodo.org/records/17427805/files/Bonsai.ML.Lds.Torch.Tests.zip"; - Console.WriteLine("Run python script finished."); - } - - private static double[] ReadBinaryFile(string fileName) - { - Console.WriteLine($"Reading binary file: {fileName}"); - using var fileStream = new FileStream(fileName, FileMode.Open, FileAccess.Read); - using var binaryReader = new BinaryReader(fileStream); - var fileLength = fileStream.Length; - var numDoubles = fileLength / sizeof(double); - var data = new double[numDoubles]; - for (int i = 0; i < numDoubles; i++) + try { - data[i] = binaryReader.ReadDouble(); + byte[] responseBytes; + using (var httpClient = new HttpClient()) + { + responseBytes = httpClient.GetByteArrayAsync(zipFileUrl).Result; + Console.WriteLine("File downloaded successfully."); + } + + using MemoryStream zipStream = new(responseBytes); + using ZipArchive zip = new(zipStream, ZipArchiveMode.Read); + zip.ExtractToDirectory(basePath); + Console.WriteLine("File extracted successfully."); + } + catch (Exception ex) + { + Console.WriteLine($"An error occurred: {ex.Message}"); } - Console.WriteLine($"Read {numDoubles} doubles from {fileName}"); - return data; - } - - private static void WriteToTensor(string fileName, long[] shape) - { - Console.WriteLine($"Reading filename: {fileName} and creating tensor with shape [{string.Join(", ", shape)}]"); - var data = ReadBinaryFile(fileName); - var tensor = from_array(data).reshape(shape); - var outputFileName = Path.ChangeExtension(fileName, ".pt"); - tensor.Save(outputFileName); - Console.WriteLine($"Saved tensor to {outputFileName}"); - } - - private static void ConvertBinaryFiles(string basePath) - { - var transformedBinnedSpikesFileName = Path.Combine(basePath, "transformed_binned_spikes.bin"); - WriteToTensor(transformedBinnedSpikesFileName, [142, -1]); - - var transitionMatrixFileName = Path.Combine(basePath, "python_B0.bin"); - WriteToTensor(transitionMatrixFileName, [10, 10]); - - var measurementFunctionFileName = Path.Combine(basePath, "python_Z0.bin"); - WriteToTensor(measurementFunctionFileName, [142, 10]); - - var processNoiseFileName = Path.Combine(basePath, "python_Q0.bin"); - WriteToTensor(processNoiseFileName, [10, 10]); - - var observationNoiseFileName = Path.Combine(basePath, "python_R0.bin"); - WriteToTensor(observationNoiseFileName, [142, 142]); - - var initialStateFileName = Path.Combine(basePath, "python_m0_0.bin"); - WriteToTensor(initialStateFileName, [10]); - - var initialCovarianceFileName = Path.Combine(basePath, "python_V0_0.bin"); - WriteToTensor(initialCovarianceFileName, [10, 10]); - - var outputMeansFileName = Path.Combine(basePath, "python_means.bin"); - WriteToTensor(outputMeansFileName, [10, -1]); - - var outputCovariancesFileName = Path.Combine(basePath, "python_covs.bin"); - WriteToTensor(outputCovariancesFileName, [10, 10, -1]); } private static async Task RunBonsaiWorkflow(string basePath) @@ -108,14 +65,11 @@ await WorkflowHelper.RunWorkflow( /// Setup for the test. /// [TestInitialize] - [DeploymentItem("bootstrap_test_environment.py")] - [DeploymentItem("estimate_neural_latents.py")] [DeploymentItem("NeuralLatentsTest.bonsai")] public async Task TestSetup() { Directory.CreateDirectory(basePath); - RunPythonScript(basePath); - ConvertBinaryFiles(basePath); + DownloadData(basePath); await RunBonsaiWorkflow(basePath); } @@ -126,21 +80,9 @@ public async Task TestSetup() public void TestCleanup() { var ptFiles = Directory.GetFiles(basePath, "*.pt"); - var binFiles = Directory.GetFiles(basePath, "*.bin"); + var zipFiles = Directory.GetFiles(basePath, "*.zip"); foreach (var file in ptFiles) File.Delete(file); - foreach (var file in binFiles) File.Delete(file); - - var virtualEnvPath = Path.Combine(basePath, ".venv"); - if (Directory.Exists(virtualEnvPath)) - { - Directory.Delete(virtualEnvPath, true); - } - - var remfileCachePath = Path.Combine(basePath, "remfile_cache"); - if (Directory.Exists(remfileCachePath)) - { - Directory.Delete(remfileCachePath, true); - } + foreach (var file in zipFiles) File.Delete(file); } /// From 877d4a8b87be91f806e17995c255e7b7fbd8a5d9 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 23 Oct 2025 18:25:23 +0100 Subject: [PATCH 65/92] Removed redundant dependency --- tests/Bonsai.ML.Lds.Torch.Tests/estimate_neural_latents.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/Bonsai.ML.Lds.Torch.Tests/estimate_neural_latents.py b/tests/Bonsai.ML.Lds.Torch.Tests/estimate_neural_latents.py index 107dcb8f..9dc3666d 100644 --- a/tests/Bonsai.ML.Lds.Torch.Tests/estimate_neural_latents.py +++ b/tests/Bonsai.ML.Lds.Torch.Tests/estimate_neural_latents.py @@ -1,5 +1,4 @@ import numpy as np -import plotly.graph_objects as go import remfile, h5py from dandi.dandiapi import DandiAPIClient From 4ec5828b98a0e7ba4c176ab105821bbc41ce495f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 28 Oct 2025 13:54:43 +0000 Subject: [PATCH 66/92] Updated documentation with correct package naming --- docs/README.md | 2 +- docs/articles/Lds.Torch/lds-torch-overview.md | 7 +++++++ docs/articles/Torch.LDS/torch-lds-overview.md | 7 ------- docs/articles/toc.yml | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) create mode 100644 docs/articles/Lds.Torch/lds-torch-overview.md delete mode 100644 docs/articles/Torch.LDS/torch-lds-overview.md diff --git a/docs/README.md b/docs/README.md index 419f9d6a..9dbc6ac9 100644 --- a/docs/README.md +++ b/docs/README.md @@ -51,7 +51,7 @@ Interfaces with the [TorchSharp](https://github.com/dotnet/TorchSharp) package, > [!NOTE] > Bonsai.ML packages can be installed through Bonsai's integrated package manager and are generally ready for immediate use. However, some packages may require additional installation steps. Refer to the specific package section for detailed installation guides and documentation. -### Bonsai.ML.Torch.LDS +### Bonsai.ML.Lds.Torch Linear dynamical systems implemented using the `Bonsai.ML.Torch` package for online filtering and smoothing of latent states and parameter estimation. ## Development Roadmap diff --git a/docs/articles/Lds.Torch/lds-torch-overview.md b/docs/articles/Lds.Torch/lds-torch-overview.md new file mode 100644 index 00000000..fc140755 --- /dev/null +++ b/docs/articles/Lds.Torch/lds-torch-overview.md @@ -0,0 +1,7 @@ +# Bonsai.ML.Lds.Torch - Overview + +This package provides an implementation of the Kalman filter, Rauch-Tung-Striebel (RTS) smoother, expectation maximization (EM) algorithm, and stochastic subspace identification, developed for online filtering, smoothing, and parameter estimation from data streams in Bonsai using the TorchSharp package. + +## Installation Guide + +Install the `Bonsai.ML.Lds.Torch` package from the Bonsai package manager. You will also need to follow the [instructions for setting up the Bonsai.ML.Torch package](../Torch/torch-overview.md) for running on the CPU or GPU. \ No newline at end of file diff --git a/docs/articles/Torch.LDS/torch-lds-overview.md b/docs/articles/Torch.LDS/torch-lds-overview.md deleted file mode 100644 index be2d8260..00000000 --- a/docs/articles/Torch.LDS/torch-lds-overview.md +++ /dev/null @@ -1,7 +0,0 @@ -# Bonsai.ML.Torch.LDS - Overview - -This package provides an implementation of the Kalman filter, Rauch-Tung-Striebel (RTS) smoother, and expectation maximization (EM) algorithm developed for online filtering, smoothing, and parameter estimation from data streams in Bonsai using the TorchSharp package. - -## Installation Guide - -Install the `Bonsai.ML.Torch.LDS` package from the Bonsai package manager. You will also need to follow the [instructions for setting up the Bonsai.ML.Torch package](../Torch/torch-overview.md) for running on the CPU or GPU. \ No newline at end of file diff --git a/docs/articles/toc.yml b/docs/articles/toc.yml index 7f42b6bd..59c5a874 100644 --- a/docs/articles/toc.yml +++ b/docs/articles/toc.yml @@ -20,5 +20,5 @@ - name: Torch - name: Overview href: Torch/torch-overview.md -- name: Torch.LDS - href: Torch.LDS/torch-lds-overview.md \ No newline at end of file +- name: Lds.Torch + href: Lds.Torch/lds-torch-overview.md \ No newline at end of file From a8f58c7f18abaea986cd02bfa95441e1a93860ac Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 28 Oct 2025 14:04:43 +0000 Subject: [PATCH 67/92] Changed struct and interface names from LdsState to full LinearDynamicalSystemState --- ...LdsState.cs => CreateLinearDynamicalSystemState.cs} | 10 +++++----- src/Bonsai.ML.Lds.Torch/FilteredState.cs | 2 +- .../{ILdsState.cs => ILinearDynamicalSystemState.cs} | 2 +- .../{LdsState.cs => LinearDynamicalSystemState.cs} | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) rename src/Bonsai.ML.Lds.Torch/{CreateLdsState.cs => CreateLinearDynamicalSystemState.cs} (91%) rename src/Bonsai.ML.Lds.Torch/{ILdsState.cs => ILinearDynamicalSystemState.cs} (88%) rename src/Bonsai.ML.Lds.Torch/{LdsState.cs => LinearDynamicalSystemState.cs} (77%) diff --git a/src/Bonsai.ML.Lds.Torch/CreateLdsState.cs b/src/Bonsai.ML.Lds.Torch/CreateLinearDynamicalSystemState.cs similarity index 91% rename from src/Bonsai.ML.Lds.Torch/CreateLdsState.cs rename to src/Bonsai.ML.Lds.Torch/CreateLinearDynamicalSystemState.cs index 3bfa5101..848e1b95 100644 --- a/src/Bonsai.ML.Lds.Torch/CreateLdsState.cs +++ b/src/Bonsai.ML.Lds.Torch/CreateLinearDynamicalSystemState.cs @@ -14,7 +14,7 @@ namespace Bonsai.ML.Lds.Torch; [ResetCombinator] [Description("Creates a new state for a linear gaussian dynamical system.")] [WorkflowElementCategory(ElementCategory.Source)] -public class CreateLdsState : IScalarTypeProvider +public class CreateLinearDynamicalSystemState : IScalarTypeProvider { private ScalarType _scalarType = ScalarType.Float32; /// @@ -95,14 +95,14 @@ public string CovarianceXml /// Creates an observable sequence and emits the state for a linear gaussian dynamical system. /// /// - public IObservable Process() + public IObservable Process() { return Observable.Defer(() => { var device = Device ?? CPU; var mean = Mean?.to(device) ?? throw new InvalidOperationException("The mean of the state must be specified."); var covariance = Covariance?.to(device) ?? throw new InvalidOperationException("The covariance of the state must be specified."); - return Observable.Return(new LdsState(mean, covariance)); + return Observable.Return(new LinearDynamicalSystemState(mean, covariance)); }); } @@ -113,14 +113,14 @@ public IObservable Process() /// /// /// - public IObservable Process(IObservable source) + public IObservable Process(IObservable source) { return source.Select(_ => { var device = Device ?? CPU; var mean = Mean?.to(device) ?? throw new InvalidOperationException("The mean of the state must be specified."); var covariance = Covariance?.to(device) ?? throw new InvalidOperationException("The covariance of the state must be specified."); - return new LdsState(mean, covariance); + return new LinearDynamicalSystemState(mean, covariance); }); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Lds.Torch/FilteredState.cs b/src/Bonsai.ML.Lds.Torch/FilteredState.cs index c2ea7bc8..b2a88e44 100644 --- a/src/Bonsai.ML.Lds.Torch/FilteredState.cs +++ b/src/Bonsai.ML.Lds.Torch/FilteredState.cs @@ -13,7 +13,7 @@ public struct FilteredState( Tensor predictedMean, Tensor predictedCovariance, Tensor updatedMean, - Tensor updatedCovariance) : ILdsState + Tensor updatedCovariance) : ILinearDynamicalSystemState { /// /// The predicted mean after the prediction step. diff --git a/src/Bonsai.ML.Lds.Torch/ILdsState.cs b/src/Bonsai.ML.Lds.Torch/ILinearDynamicalSystemState.cs similarity index 88% rename from src/Bonsai.ML.Lds.Torch/ILdsState.cs rename to src/Bonsai.ML.Lds.Torch/ILinearDynamicalSystemState.cs index 6527cfe5..ee90a811 100644 --- a/src/Bonsai.ML.Lds.Torch/ILdsState.cs +++ b/src/Bonsai.ML.Lds.Torch/ILinearDynamicalSystemState.cs @@ -5,7 +5,7 @@ namespace Bonsai.ML.Lds.Torch; /// /// Represents the state of a linear gaussian dynamical system. /// -public interface ILdsState +public interface ILinearDynamicalSystemState { /// /// The mean of the state. diff --git a/src/Bonsai.ML.Lds.Torch/LdsState.cs b/src/Bonsai.ML.Lds.Torch/LinearDynamicalSystemState.cs similarity index 77% rename from src/Bonsai.ML.Lds.Torch/LdsState.cs rename to src/Bonsai.ML.Lds.Torch/LinearDynamicalSystemState.cs index 99dd4a08..f89c08ef 100644 --- a/src/Bonsai.ML.Lds.Torch/LdsState.cs +++ b/src/Bonsai.ML.Lds.Torch/LinearDynamicalSystemState.cs @@ -7,7 +7,7 @@ namespace Bonsai.ML.Lds.Torch; /// /// /// -public class LdsState(Tensor mean, Tensor covariance) : ILdsState +public class LinearDynamicalSystemState(Tensor mean, Tensor covariance) : ILinearDynamicalSystemState { /// public Tensor Mean => mean; From 4a9d64b7d79239d6d50fec295c92f0a2f7551604 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 28 Oct 2025 14:05:42 +0000 Subject: [PATCH 68/92] Removed redundant classes for orthogonalized and smoothed states in favor of using simpler LDS state --- src/Bonsai.ML.Lds.Torch/KalmanFilter.cs | 12 +++---- src/Bonsai.ML.Lds.Torch/Orthogonalize.cs | 35 ++++--------------- .../OrthogonalizedState.cs | 32 ----------------- src/Bonsai.ML.Lds.Torch/Smooth.cs | 4 +-- src/Bonsai.ML.Lds.Torch/SmoothedState.cs | 29 --------------- 5 files changed, 15 insertions(+), 97 deletions(-) delete mode 100644 src/Bonsai.ML.Lds.Torch/OrthogonalizedState.cs delete mode 100644 src/Bonsai.ML.Lds.Torch/SmoothedState.cs diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs index 28a11f5d..db3b9ed3 100644 --- a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs @@ -452,7 +452,7 @@ private static FilteredStateWithAuxiliaryVariables Filter( ); } - public SmoothedState Smooth(FilteredState filteredState) + public LinearDynamicalSystemState Smooth(FilteredState filteredState) { using var g = no_grad(); @@ -492,7 +492,7 @@ public SmoothedState Smooth(FilteredState filteredState) ); } - return new SmoothedState( + return new LinearDynamicalSystemState( smoothedMean, smoothedCovariance ); @@ -976,7 +976,7 @@ public static StochasticSubspaceIdentificationResult StochasticSubspaceIdentific ); } - public OrthogonalizedState OrthogonalizeMeanAndCovariance(Tensor mean, Tensor covariance) + public LinearDynamicalSystemState OrthogonalizeMeanAndCovariance(Tensor mean, Tensor covariance) { var (_, S, Vt) = linalg.svd(_measurementFunction); var SVt = diag(S).matmul(Vt); @@ -992,9 +992,9 @@ public OrthogonalizedState OrthogonalizeMeanAndCovariance(Tensor mean, Tensor co orthogonalizedCovariance = matmul(auxilary, SVt.mT); } - return new OrthogonalizedState( - orthogonalizedMean: orthogonalizedMean, - orthogonalizedCovariance: orthogonalizedCovariance + return new LinearDynamicalSystemState( + orthogonalizedMean, + orthogonalizedCovariance ); } diff --git a/src/Bonsai.ML.Lds.Torch/Orthogonalize.cs b/src/Bonsai.ML.Lds.Torch/Orthogonalize.cs index 68019d3e..3afa276e 100644 --- a/src/Bonsai.ML.Lds.Torch/Orthogonalize.cs +++ b/src/Bonsai.ML.Lds.Torch/Orthogonalize.cs @@ -27,13 +27,11 @@ public class Orthogonalize /// /// /// - public IObservable Process(IObservable source) + public IObservable Process(IObservable source) { return source.Select(input => { - var smoothedMean = input.SmoothedMean; - var smoothedCovariance = input.SmoothedCovariance; - return Model.OrthogonalizeMeanAndCovariance(smoothedMean, smoothedCovariance); + return Model.OrthogonalizeMeanAndCovariance(input.Mean, input.Covariance); }); } @@ -42,13 +40,11 @@ public IObservable Process(IObservable sourc /// /// /// - public IObservable Process(IObservable source) + public IObservable Process(IObservable source) { return source.Select(input => { - var filteredMean = input.UpdatedMean; - var filteredCovariance = input.UpdatedCovariance; - return Model.OrthogonalizeMeanAndCovariance(filteredMean, filteredCovariance); + return Model.OrthogonalizeMeanAndCovariance(input.Mean, input.Covariance); }); } @@ -57,28 +53,11 @@ public IObservable Process(IObservable sourc /// /// /// - public IObservable Process(IObservable source) + public IObservable Process(IObservable source) { return source.Select(input => { - var mean = input.Mean; - var covariance = input.Covariance; - return Model.OrthogonalizeMeanAndCovariance(mean, covariance); - }); - } - - /// - /// Processes an observable sequence of LDS states, orthogonalizing the mean and covariance estimates. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(input => - { - var mean = input.Mean; - var covariance = input.Covariance; - return Model.OrthogonalizeMeanAndCovariance(mean, covariance); + return Model.OrthogonalizeMeanAndCovariance(input.Mean, input.Covariance); }); } @@ -87,7 +66,7 @@ public IObservable Process(IObservable source) /// /// /// - public IObservable Process(IObservable> source) + public IObservable Process(IObservable> source) { return source.Select(input => { diff --git a/src/Bonsai.ML.Lds.Torch/OrthogonalizedState.cs b/src/Bonsai.ML.Lds.Torch/OrthogonalizedState.cs deleted file mode 100644 index f7dee6c1..00000000 --- a/src/Bonsai.ML.Lds.Torch/OrthogonalizedState.cs +++ /dev/null @@ -1,32 +0,0 @@ -using static TorchSharp.torch; - -namespace Bonsai.ML.Lds.Torch; - -/// -/// Represents the state of an LDS after orthogonalizing the state mean and covariance estimates. -/// -/// -/// Initializes a new instance of the struct. -/// -/// -/// -public struct OrthogonalizedState( - Tensor orthogonalizedMean, - Tensor orthogonalizedCovariance) : ILdsState -{ - /// - /// The orthogonalized mean estimate. - /// - public Tensor OrthogonalizedMean = orthogonalizedMean; - - /// - /// The orthogonalized covariance estimate. - /// - public Tensor OrthogonalizedCovariance = orthogonalizedCovariance; - - /// - public readonly Tensor Mean => OrthogonalizedMean; - - /// - public readonly Tensor Covariance => OrthogonalizedCovariance; -} \ No newline at end of file diff --git a/src/Bonsai.ML.Lds.Torch/Smooth.cs b/src/Bonsai.ML.Lds.Torch/Smooth.cs index 10abd869..a2780d0c 100644 --- a/src/Bonsai.ML.Lds.Torch/Smooth.cs +++ b/src/Bonsai.ML.Lds.Torch/Smooth.cs @@ -27,7 +27,7 @@ public class Smooth /// /// /// - public IObservable Process(IObservable source) + public IObservable Process(IObservable source) { return source.Select(Model.Smooth); } @@ -37,7 +37,7 @@ public IObservable Process(IObservable source) /// /// /// - public IObservable Process(IObservable> source) + public IObservable Process(IObservable> source) { return source.Select((input) => { diff --git a/src/Bonsai.ML.Lds.Torch/SmoothedState.cs b/src/Bonsai.ML.Lds.Torch/SmoothedState.cs deleted file mode 100644 index 70f5c24c..00000000 --- a/src/Bonsai.ML.Lds.Torch/SmoothedState.cs +++ /dev/null @@ -1,29 +0,0 @@ -using static TorchSharp.torch; - -namespace Bonsai.ML.Lds.Torch; - -/// -/// Represents the state of a Kalman smoother. -/// -/// -/// -public struct SmoothedState( - Tensor smoothedMean, - Tensor smoothedCovariance) : ILdsState -{ - /// - /// The smoothed state after the smoothing step. - /// - public Tensor SmoothedMean = smoothedMean; - - /// - /// The smoothed covariance after the smoothing step. - /// - public Tensor SmoothedCovariance = smoothedCovariance; - - /// - public readonly Tensor Mean => SmoothedMean; - - /// - public readonly Tensor Covariance => SmoothedCovariance; -} \ No newline at end of file From e3d51f62b461325122998f8875373bded4f1adaf Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 28 Oct 2025 14:06:34 +0000 Subject: [PATCH 69/92] Changed name from `UpdateParameters` to `UpdateKalmanFilterParameters` for improved clarity --- .../{UpdateParameters.cs => UpdateKalmanFilterParameters.cs} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename src/Bonsai.ML.Lds.Torch/{UpdateParameters.cs => UpdateKalmanFilterParameters.cs} (96%) diff --git a/src/Bonsai.ML.Lds.Torch/UpdateParameters.cs b/src/Bonsai.ML.Lds.Torch/UpdateKalmanFilterParameters.cs similarity index 96% rename from src/Bonsai.ML.Lds.Torch/UpdateParameters.cs rename to src/Bonsai.ML.Lds.Torch/UpdateKalmanFilterParameters.cs index d165f860..edcf14c3 100644 --- a/src/Bonsai.ML.Lds.Torch/UpdateParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/UpdateKalmanFilterParameters.cs @@ -12,7 +12,7 @@ namespace Bonsai.ML.Lds.Torch; [ResetCombinator] [Description("Updates the parameters of a Kalman filter model instance using the provided Kalman filter parameters.")] [WorkflowElementCategory(ElementCategory.Sink)] -public class UpdateParameters +public class UpdateKalmanFilterParameters { /// /// The Kalman filter model. From 12b9422d8b46db269af2ddcaf0bcc48583a9792f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 28 Oct 2025 14:08:19 +0000 Subject: [PATCH 70/92] Updated `StateVisualizer` to match changes in naming --- src/Bonsai.ML.Lds.Torch.Design/StateVisualizer.cs | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch.Design/StateVisualizer.cs b/src/Bonsai.ML.Lds.Torch.Design/StateVisualizer.cs index 3f1837d9..d1aea0ca 100644 --- a/src/Bonsai.ML.Lds.Torch.Design/StateVisualizer.cs +++ b/src/Bonsai.ML.Lds.Torch.Design/StateVisualizer.cs @@ -16,11 +16,7 @@ [assembly: TypeVisualizer(typeof(Bonsai.ML.Lds.Torch.Design.StateVisualizer), Target = typeof(Bonsai.ML.Lds.Torch.FilteredState))] [assembly: TypeVisualizer(typeof(Bonsai.ML.Lds.Torch.Design.StateVisualizer), - Target = typeof(Bonsai.ML.Lds.Torch.SmoothedState))] -[assembly: TypeVisualizer(typeof(Bonsai.ML.Lds.Torch.Design.StateVisualizer), - Target = typeof(Bonsai.ML.Lds.Torch.OrthogonalizedState))] -[assembly: TypeVisualizer(typeof(Bonsai.ML.Lds.Torch.Design.StateVisualizer), - Target = typeof(Bonsai.ML.Lds.Torch.LdsState))] + Target = typeof(Bonsai.ML.Lds.Torch.LinearDynamicalSystemState))] namespace Bonsai.ML.Lds.Torch.Design; @@ -119,8 +115,8 @@ protected override void Show(DateTime time, object value) { if (value is null) return; - if (value is not ILdsState state) - throw new ArgumentException($"Expected value to be a type of {nameof(ILdsState)}.", nameof(value)); + if (value is not ILinearDynamicalSystemState state) + throw new ArgumentException($"Expected value to be a type of {nameof(ILinearDynamicalSystemState)}.", nameof(value)); var mean = state.Mean; var covariance = state.Covariance; From 7b009de2813a4ffa3643ba381ad2dda54979be25 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 28 Oct 2025 14:15:48 +0000 Subject: [PATCH 71/92] Fixed test workflow after making changes --- .../NeuralLatentsTest.bonsai | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai index ab8f230e..ae0bc93b 100644 --- a/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai +++ b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai @@ -14,7 +14,6 @@ transformed_binned_spikes.pt - true @@ -35,7 +34,6 @@ python_V0_0.pt - true @@ -48,7 +46,6 @@ python_m0_0.pt - true @@ -61,7 +58,6 @@ python_Z0.pt - true @@ -74,7 +70,6 @@ python_R0.pt - true @@ -87,7 +82,6 @@ python_B0.pt - true @@ -100,7 +94,6 @@ python_Q0.pt - true @@ -254,7 +247,7 @@ - + ExpectationMaximizationResult @@ -346,24 +339,22 @@ OrthogonalizedResult - OrthogonalizedMean + Mean bonsai_means.pt - true OrthogonalizedResult - OrthogonalizedCovariance + Covariance bonsai_covs.pt - true From a9d60795ed77436bc428ad6f9fbe6a7894b8a718 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 28 Oct 2025 17:55:59 +0000 Subject: [PATCH 72/92] Refactored `KalmanFilter` for improved parameter validation when `numObservations` and `numStates` are not provided --- src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs | 16 +- .../CreateKalmanFilterParameters.cs | 30 +- .../ExpectationMaximization.cs | 40 +- src/Bonsai.ML.Lds.Torch/KalmanFilter.cs | 416 ++++++++---------- .../KalmanFilterParameters.cs | 14 + .../LoadKalmanFilterParameters.cs | 7 +- 6 files changed, 228 insertions(+), 295 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs index 8bf1745e..5e20376d 100644 --- a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs +++ b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs @@ -37,26 +37,16 @@ public ScalarType Type [XmlIgnore] public Device Device { get; set; } - private int _numStates = 2; /// /// The number of states in the Kalman filter model. /// - public int NumStates - { - get => _numStates; - set => _numStates = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(value), "Number of states must be greater than zero."); - } + public int? NumStates { get; set; } = null; - private int _numObservations = 2; /// /// The number of observations in the Kalman filter model. /// - public int NumObservations - { - get => _numObservations; - set => _numObservations = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(value), "Number of observations must be greater than zero."); - } - + public int? NumObservations { get; set; } = null; + // Tensor properties with XML serialization support private Tensor _transitionMatrix; /// diff --git a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs index 02748620..81da4840 100644 --- a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs @@ -31,6 +31,16 @@ public ScalarType Type } private ScalarType _scalarType = ScalarType.Float32; + /// + /// The number of states in the Kalman filter model. + /// + public int? NumStates { get; set; } = null; + + /// + /// The number of observations in the Kalman filter model. + /// + public int? NumObservations { get; set; } = null; + private Tensor _transitionMatrix = null; /// /// The state transition matrix. @@ -190,16 +200,17 @@ private void ConvertTensorsScalarType(ScalarType scalarType) /// public IObservable Process() { - var parameters = new KalmanFilterParameters( + return Observable.Return(KalmanFilter.InitializeParameters( + numStates: NumStates, + numObservations: NumObservations, transitionMatrix: _transitionMatrix, measurementFunction: _measurementFunction, processNoiseCovariance: _processNoiseCovariance, measurementNoiseCovariance: _measurementNoiseCovariance, initialMean: _initialMean, - initialCovariance: _initialCovariance - ); - - return Observable.Return(parameters); + initialCovariance: _initialCovariance, + scalarType: _scalarType + )); } /// @@ -209,16 +220,17 @@ public IObservable Process(IObservable source) { return source.Select(_ => { - var parameters = new KalmanFilterParameters( + return KalmanFilter.InitializeParameters( + numStates: NumStates, + numObservations: NumObservations, transitionMatrix: _transitionMatrix, measurementFunction: _measurementFunction, processNoiseCovariance: _processNoiseCovariance, measurementNoiseCovariance: _measurementNoiseCovariance, initialMean: _initialMean, - initialCovariance: _initialCovariance + initialCovariance: _initialCovariance, + scalarType: _scalarType ); - - return parameters; }); } } \ No newline at end of file diff --git a/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs index 99d4a5e1..577fe3bc 100644 --- a/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs +++ b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs @@ -19,34 +19,18 @@ namespace Bonsai.ML.Lds.Torch; [WorkflowElementCategory(ElementCategory.Combinator)] public class ExpectationMaximization { - private int _numStates = 2; /// /// The number of states in the Kalman filter model. /// [Description("The number of states in the Kalman filter model.")] - public int NumStates - { - get => _numStates; - set => _numStates = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(value), "Number of states must be greater than zero."); - } - - private int _numObservations = 10; - /// - /// The number of observations in the Kalman filter model. - /// - [Description("The number of observations in the Kalman filter model.")] - public int NumObservations - { - get => _numObservations; - set => _numObservations = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(value), "Number of observations must be greater than zero."); - } + public int? NumStates { get; set; } = null; /// /// The Kalman filter parameters used to initialize the model. /// [Description("The Kalman filter parameters used to initialize the model.")] [XmlIgnore] - public KalmanFilterParameters ModelParameters { get; set; } = new(); + public KalmanFilterParameters? ModelParameters { get; set; } = null; private int _maxIterations = 10; /// @@ -128,6 +112,7 @@ public IObservable Process(IObservable so { return Task.Run(() => { + var numObservations = (int)input.size(1); var previousLogLikelihood = double.NegativeInfinity; var logLikelihood = zeros([MaxIterations], device: input.device); var maxIterationsReached = false; @@ -140,14 +125,17 @@ public IObservable Process(IObservable so initialMean: EstimateInitialMean, initialCovariance: EstimateInitialCovariance); - var parameters = new KalmanFilterParameters( - transitionMatrix: ModelParameters.TransitionMatrix, - measurementFunction: ModelParameters.MeasurementFunction, - processNoiseCovariance: ModelParameters.ProcessNoiseCovariance, - measurementNoiseCovariance: ModelParameters.MeasurementNoiseCovariance, - initialMean: ModelParameters.InitialMean, - initialCovariance: ModelParameters.InitialCovariance - ); + var parameters = KalmanFilter.InitializeParameters( + numStates: NumStates, + numObservations: numObservations, + transitionMatrix: ModelParameters?.TransitionMatrix, + measurementFunction: ModelParameters?.MeasurementFunction, + processNoiseCovariance: ModelParameters?.ProcessNoiseCovariance, + measurementNoiseCovariance: ModelParameters?.MeasurementNoiseCovariance, + initialMean: ModelParameters?.InitialMean, + initialCovariance: ModelParameters?.InitialCovariance, + device: input.device, + scalarType: input.dtype); for (int i = 0; i < MaxIterations; i++) { diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs index db3b9ed3..4145058b 100644 --- a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs @@ -9,28 +9,12 @@ namespace Bonsai.ML.Lds.Torch; public class KalmanFilter : nn.Module { - private readonly Tensor _transitionMatrix; - private readonly Tensor _measurementFunction; - private readonly Tensor _initialMean; - private readonly Tensor _initialCovariance; - private readonly Tensor _processNoiseCovariance; - private readonly Tensor _measurementNoiseCovariance; private readonly Tensor _identityStates; private readonly Tensor _mean; private readonly Tensor _covariance; - private readonly int _numStates; - private readonly int _numObservations; private readonly Device _device; private readonly ScalarType _scalarType; - - public KalmanFilterParameters Parameters => new( - _transitionMatrix, - _measurementFunction, - _processNoiseCovariance, - _measurementNoiseCovariance, - _initialMean, - _initialCovariance - ); + public KalmanFilterParameters Parameters { get; private set; } public KalmanFilter( KalmanFilterParameters parameters, @@ -40,30 +24,18 @@ public KalmanFilter( _device = device ?? CPU; _scalarType = scalarType; - ValidateNumStates(parameters.TransitionMatrix, parameters.MeasurementFunction, parameters.InitialMean, parameters.InitialCovariance, parameters.ProcessNoiseCovariance, out _numStates); - ValidateNumObservations(parameters.MeasurementFunction, parameters.MeasurementNoiseCovariance, out _numObservations); - - _identityStates = eye(_numStates, dtype: _scalarType, device: _device); - - _transitionMatrix = parameters.TransitionMatrix?.clone().to_type(_scalarType).to(_device).requires_grad_(false) - ?? eye(_numStates, dtype: _scalarType, device: _device).requires_grad_(false); - ValidateMatrix(_transitionMatrix, "Transition matrix", isSquare: true, expectedDimension1: _numStates); - - _measurementFunction = parameters.MeasurementFunction?.clone().to_type(_scalarType).to(_device).requires_grad_(false) - ?? eye(_numObservations, _numStates, dtype: _scalarType, device: _device).requires_grad_(false); - ValidateMatrix(_measurementFunction, "Measurement function", expectedDimension1: _numObservations, expectedDimension2: _numStates); - - _initialMean = parameters.InitialMean?.clone().to_type(_scalarType).to(_device).requires_grad_(false) - ?? zeros(_numStates, dtype: _scalarType, device: _device).requires_grad_(false); - ValidateVector(_initialMean, "Initial mean", expectedLength: _numStates); - - _initialCovariance = parameters.InitialCovariance?.clone().to_type(_scalarType).to(_device).requires_grad_(false) - ?? eye(_numStates, dtype: _scalarType, device: _device).requires_grad_(false); - ValidateMatrix(_initialCovariance, "Initial covariance", isSquare: true, expectedDimension1: _numStates); - - _processNoiseCovariance = parameters.ProcessNoiseCovariance ?? CreateCovarianceMatrix(tensor(1.0), _scalarType, _device, _numStates, "Process noise variance"); - _measurementNoiseCovariance = parameters.MeasurementNoiseCovariance ?? CreateCovarianceMatrix(tensor(1.0), _scalarType, _device, _numObservations, "Measurement noise variance"); + Parameters = InitializeParameters( + transitionMatrix: parameters.TransitionMatrix, + measurementFunction: parameters.MeasurementFunction, + processNoiseCovariance: parameters.ProcessNoiseCovariance, + measurementNoiseCovariance: parameters.MeasurementNoiseCovariance, + initialMean: parameters.InitialMean, + initialCovariance: parameters.InitialCovariance, + device: _device, + scalarType: _scalarType + ); + _identityStates = eye(Parameters.NumStates, dtype: _scalarType, device: _device).requires_grad_(false); _mean = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); _covariance = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); } @@ -83,50 +55,114 @@ public KalmanFilter( _device = device ?? CPU; _scalarType = scalarType; + Parameters = InitializeParameters( + numStates, + numObservations, + transitionMatrix, + measurementFunction, + processNoiseVariance, + measurementNoiseVariance, + initialMean, + initialCovariance, + _device, + _scalarType + ); + + _identityStates = eye(Parameters.NumStates, dtype: _scalarType, device: _device).requires_grad_(false); + _mean = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); + _covariance = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); + + RegisterComponents(); + } + + /// + /// Initializes the Kalman filter parameters. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static KalmanFilterParameters InitializeParameters( + int? numStates = null, + int? numObservations = null, + Tensor transitionMatrix = null, + Tensor measurementFunction = null, + Tensor processNoiseCovariance = null, + Tensor measurementNoiseCovariance = null, + Tensor initialMean = null, + Tensor initialCovariance = null, + Device device = null, + ScalarType scalarType = ScalarType.Float32 + ) + { + var trueNumStates = numStates ?? -1; + var trueNumObservations = numObservations ?? -1; + device ??= CPU; + if (numStates is null) { - ValidateNumStates(transitionMatrix, measurementFunction, initialMean, initialCovariance, processNoiseVariance, out var inferredNumStates); - _numStates = inferredNumStates; + ValidateNumStates(transitionMatrix, measurementFunction, initialMean, initialCovariance, processNoiseCovariance, out trueNumStates); } - else - _numStates = numStates.Value > 0 ? numStates.Value : throw new ArgumentOutOfRangeException(nameof(numStates), "Number of states must be greater than zero."); + + if (trueNumStates <= 0) + throw new ArgumentOutOfRangeException(nameof(trueNumStates), "Number of states must be greater than zero."); if (numObservations is null) { - ValidateNumObservations(measurementFunction, measurementNoiseVariance, out var inferredNumObservations); - _numObservations = inferredNumObservations; + ValidateNumObservations(measurementFunction, measurementNoiseCovariance, out trueNumObservations); } - else - _numObservations = numObservations.Value > 0 ? numObservations.Value : throw new ArgumentOutOfRangeException(nameof(numObservations), "Number of observations must be greater than zero."); - _identityStates = eye(_numStates, dtype: _scalarType, device: _device); + if (trueNumObservations <= 0) + throw new ArgumentOutOfRangeException(nameof(numObservations), "Number of observations must be greater than zero."); - _transitionMatrix = transitionMatrix?.clone().to_type(_scalarType).to(_device).requires_grad_(false) - ?? eye(_numStates, dtype: _scalarType, device: _device).requires_grad_(false); - ValidateMatrix(_transitionMatrix, "Transition matrix", isSquare: true, expectedDimension1: _numStates); + transitionMatrix = transitionMatrix.clone().to_type(scalarType).to(device).requires_grad_(false) ?? eye(trueNumStates, dtype: scalarType, device: device).requires_grad_(false); - _measurementFunction = measurementFunction?.clone().to_type(_scalarType).to(_device).requires_grad_(false) - ?? eye(_numObservations, _numStates, dtype: _scalarType, device: _device).requires_grad_(false); - ValidateMatrix(_measurementFunction, "Measurement function", expectedDimension1: _numObservations, expectedDimension2: _numStates); + ValidateMatrix(transitionMatrix, "Transition matrix", isSquare: true, expectedDimension1: trueNumStates); - _initialMean = initialMean?.clone().to_type(_scalarType).to(_device).requires_grad_(false) - ?? zeros(_numStates, dtype: _scalarType, device: _device).requires_grad_(false); - ValidateVector(_initialMean, "Initial mean", _numStates); + measurementFunction = measurementFunction.clone().to_type(scalarType).to(device).requires_grad_(false) ?? eye(trueNumObservations, trueNumStates, dtype: scalarType, device: device).requires_grad_(false); - _initialCovariance = initialCovariance?.clone().to_type(_scalarType).to(_device).requires_grad_(false) - ?? eye(_numStates, dtype: _scalarType, device: _device).requires_grad_(false); - ValidateMatrix(_initialCovariance, "Initial covariance", isSquare: true, expectedDimension1: _numStates); + ValidateMatrix(measurementFunction, "Measurement function", expectedDimension1: trueNumObservations, expectedDimension2: trueNumStates); - processNoiseVariance ??= tensor(1.0, dtype: _scalarType, device: _device); - measurementNoiseVariance ??= tensor(1.0, dtype: _scalarType, device: _device); + initialMean = initialMean.clone().to_type(scalarType).to(device).requires_grad_(false) ?? zeros(trueNumStates, dtype: scalarType, device: device).requires_grad_(false); - _processNoiseCovariance = CreateCovarianceMatrix(processNoiseVariance, _scalarType, _device, _numStates, "Process noise variance"); - _measurementNoiseCovariance = CreateCovarianceMatrix(measurementNoiseVariance, _scalarType, _device, _numObservations, "Measurement noise variance"); + ValidateVector(initialMean, "Initial mean", trueNumStates); - _mean = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); - _covariance = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); + initialCovariance = initialCovariance.clone().to_type(scalarType).to(device).requires_grad_(false) ?? eye(trueNumStates, dtype: scalarType, device: device).requires_grad_(false); - RegisterComponents(); + ValidateMatrix(initialCovariance, "Initial covariance", isSquare: true, expectedDimension1: trueNumStates); + + processNoiseCovariance = processNoiseCovariance.NumberOfElements == 1 + ? CreateCovarianceMatrix(processNoiseCovariance, scalarType, device, trueNumStates, "Process noise variance") + : processNoiseCovariance.clone().to_type(scalarType).to(device).requires_grad_(false) + ?? CreateCovarianceMatrix(tensor(1.0), scalarType, device, trueNumStates, "Process noise variance"); + + ValidateMatrix(processNoiseCovariance, "Process noise covariance", isSquare: true, expectedDimension1: trueNumStates); + + measurementNoiseCovariance = measurementNoiseCovariance.NumberOfElements == 1 + ? CreateCovarianceMatrix(measurementNoiseCovariance, scalarType, device, trueNumObservations, "Measurement noise variance") + : measurementNoiseCovariance.clone().to_type(scalarType).to(device).requires_grad_(false) + ?? CreateCovarianceMatrix(tensor(1.0), scalarType, device, trueNumObservations, "Measurement noise variance"); + + ValidateMatrix(measurementNoiseCovariance, "Process noise covariance", isSquare: true, expectedDimension1: trueNumObservations); + + return new KalmanFilterParameters( + trueNumStates, + trueNumObservations, + transitionMatrix, + measurementFunction, + processNoiseCovariance, + measurementNoiseCovariance, + initialMean, + initialCovariance + ); } private static void ValidateNumStates(Tensor transitionMatrix, Tensor measurementFunction, Tensor initialMean, Tensor initialCovariance, Tensor processNoiseCovariance, out int numStates) @@ -216,7 +252,7 @@ private static void ValidateScalar(Tensor scalar, string name) throw new ArgumentException($"{name} must be a scalar."); } - private Tensor CreateCovarianceMatrix(Tensor variance, ScalarType scalarType, Device device, int dimension, string name) + private static Tensor CreateCovarianceMatrix(Tensor variance, ScalarType scalarType, Device device, int dimension, string name) { ValidateScalar(variance, name); var scalar = variance.clone().squeeze().to_type(scalarType).to(device); @@ -234,9 +270,9 @@ private readonly struct PredictedState( private PredictedState FilterPredict( Tensor mean, Tensor covariance) => - new(_transitionMatrix.matmul(mean), - _transitionMatrix.matmul(covariance) - .matmul(_transitionMatrix.mT) + _processNoiseCovariance); + new(Parameters.TransitionMatrix.matmul(mean), + Parameters.TransitionMatrix.matmul(covariance) + .matmul(Parameters.TransitionMatrix.mT) + Parameters.ProcessNoiseCovariance); private static PredictedState FilterPredict( Tensor mean, @@ -267,20 +303,20 @@ private UpdatedState FilterUpdate( Tensor observation) { // Innovation step - var innovation = observation - _measurementFunction.matmul(predictedMean); + var innovation = observation - Parameters.MeasurementFunction.matmul(predictedMean); var innovationCovariance = WrappedTensorDisposeScope(() => EnsureSymmetric( - _measurementFunction.matmul(predictedCovariance) - .matmul(_measurementFunction.mT) + _measurementNoiseCovariance)); + Parameters.MeasurementFunction.matmul(predictedCovariance) + .matmul(Parameters.MeasurementFunction.mT) + Parameters.MeasurementNoiseCovariance)); // Kalman gain var kalmanGain = WrappedTensorDisposeScope(() => InverseCholesky( - predictedCovariance.matmul(_measurementFunction.mT), + predictedCovariance.matmul(Parameters.MeasurementFunction.mT), innovationCovariance)); // Update step var updatedMean = predictedMean + kalmanGain.matmul(innovation); var updatedCovariance = WrappedTensorDisposeScope(() => predictedCovariance - - kalmanGain.matmul(_measurementFunction).matmul(predictedCovariance)); + - kalmanGain.matmul(Parameters.MeasurementFunction).matmul(predictedCovariance)); return new UpdatedState(updatedMean, updatedCovariance, innovation, innovationCovariance, kalmanGain); } @@ -324,15 +360,15 @@ public FilteredState Filter(Tensor observation) var obs = observation.atleast_2d(); var timeBins = obs.size(0); - var predictedMean = empty(new long[] { timeBins, _numStates }, dtype: _scalarType, device: _device); - var predictedCovariance = empty(new long[] { timeBins, _numStates, _numStates }, dtype: _scalarType, device: _device); - var updatedMean = empty(new long[] { timeBins, _numStates }, dtype: _scalarType, device: _device); - var updatedCovariance = empty(new long[] { timeBins, _numStates, _numStates }, dtype: _scalarType, device: _device); + var predictedMean = empty([timeBins, Parameters.NumStates], dtype: _scalarType, device: _device); + var predictedCovariance = empty([timeBins, Parameters.NumStates, Parameters.NumStates], dtype: _scalarType, device: _device); + var updatedMean = empty([timeBins, Parameters.NumStates], dtype: _scalarType, device: _device); + var updatedCovariance = empty([timeBins, Parameters.NumStates, Parameters.NumStates], dtype: _scalarType, device: _device); if (_mean.NumberOfElements == 0) - _mean.set_(_initialMean); + _mean.set_(Parameters.InitialMean); if (_covariance.NumberOfElements == 0) - _covariance.set_(_initialCovariance); + _covariance.set_(Parameters.InitialCovariance); for (long time = 0; time < timeBins; time++) { @@ -393,13 +429,13 @@ private static FilteredStateWithAuxiliaryVariables Filter( Device device) { var logLikelihood = empty(timeBins, dtype: scalarType, device: device); - var predictedMean = empty(new long[] { timeBins, numStates }, dtype: scalarType, device: device); - var predictedCovariance = empty(new long[] { timeBins, numStates, numStates }, dtype: scalarType, device: device); - var updatedMean = empty(new long[] { timeBins, numStates }, dtype: scalarType, device: device); - var updatedCovariance = empty(new long[] { timeBins, numStates, numStates }, dtype: scalarType, device: device); - var innovation = empty(new long[] { timeBins, numObservations }, dtype: scalarType, device: device); - var innovationCovariance = empty(new long[] { timeBins, numObservations, numObservations }, dtype: scalarType, device: device); - var kalmanGain = empty(new long[] { timeBins, numStates, numObservations }, dtype: scalarType, device: device); + var predictedMean = empty([timeBins, numStates], dtype: scalarType, device: device); + var predictedCovariance = empty([timeBins, numStates, numStates], dtype: scalarType, device: device); + var updatedMean = empty([timeBins, numStates], dtype: scalarType, device: device); + var updatedCovariance = empty([timeBins, numStates, numStates], dtype: scalarType, device: device); + var innovation = empty([timeBins, numObservations], dtype: scalarType, device: device); + var innovationCovariance = empty([timeBins, numObservations, numObservations], dtype: scalarType, device: device); + var kalmanGain = empty([timeBins, numStates, numObservations], dtype: scalarType, device: device); var mean = initialMean; var covariance = initialCovariance; @@ -469,14 +505,14 @@ public LinearDynamicalSystemState Smooth(FilteredState filteredState) smoothedMean[-1] = updatedMean[-1]; smoothedCovariance[-1] = updatedCovariance[-1]; - var smoothingGain = empty(new long[] { _numStates, _numStates }, dtype: _scalarType, device: _device); + var smoothingGain = empty([Parameters.NumStates, Parameters.NumStates], dtype: _scalarType, device: _device); // Backward pass for (long time = timeBins - 2; time >= 0; time--) { // Smoothing gain smoothingGain = WrappedTensorDisposeScope(() => updatedCovariance[time].matmul( - InverseCholesky(_transitionMatrix.mT, predictedCovariance[time + 1]) + InverseCholesky(Parameters.TransitionMatrix.mT, predictedCovariance[time + 1]) )); // Smoothed mean @@ -634,132 +670,9 @@ Device device ); } - public ExpectationMaximizationResult ExpectationMaximization( - Tensor observation, - int maxIterations = 100, - double tolerance = 1e-4, - ParametersToEstimate parametersToEstimate = new(), - bool updateParameters = true) - { - var timeBins = observation.size(0); - var logLikelihood = empty(maxIterations, dtype: ScalarType.Float32, device: _device); - var previousLogLikelihood = double.NegativeInfinity; - var logLikelihoodConst = -0.5 * timeBins * _numObservations * Math.Log(2.0 * Math.PI); - - var transitionMatrix = _transitionMatrix; - var measurementFunction = _measurementFunction; - var processNoiseCovariance = _processNoiseCovariance; - var measurementNoiseCovariance = _measurementNoiseCovariance; - var initialMean = _initialMean; - var initialCovariance = _initialCovariance; - - // Precompute constant observation terms reused across EM iterations - var observationT = observation.mT; - var autoCorrelationObservations = observationT.matmul(observation); - - using (var _ = no_grad()) - { - for (int iteration = 0; iteration < maxIterations; iteration++) - { - // Filter observations - var filteredState = Filter( - observation: observation, - timeBins: timeBins, - numStates: _numStates, - numObservations: _numObservations, - transitionMatrix: transitionMatrix, - measurementFunction: measurementFunction, - processNoiseCovariance: processNoiseCovariance, - measurementNoiseCovariance: measurementNoiseCovariance, - initialMean: initialMean, - initialCovariance: initialCovariance, - scalarType: _scalarType, - device: _device); - - // Compute log likelihood - var llSumDouble = filteredState.LogLikelihood.sum() - .to_type(ScalarType.Float64).item(); - var filteredLogLikelihoodSum = logLikelihoodConst + 0.5 * llSumDouble; - - logLikelihood[iteration] = filteredLogLikelihoodSum; - - // Check for convergence - if (filteredLogLikelihoodSum <= previousLogLikelihood) - { - Console.WriteLine($"Warning: Log likelihood decreased! New: {filteredLogLikelihoodSum}, Previous: {previousLogLikelihood}"); - break; - } - - if (filteredLogLikelihoodSum - previousLogLikelihood < tolerance) - break; - - previousLogLikelihood = filteredLogLikelihoodSum; - - // Smooth the filtered results - var smoothedState = Smooth( - filteredState: filteredState, - timeBins: timeBins, - numStates: _numStates, - transitionMatrix: transitionMatrix, - measurementFunction: measurementFunction, - initialMean: initialMean, - initialCovariance: initialCovariance, - identityStates: _identityStates, - scalarType: _scalarType, - device: _device); - - // Sufficient statistics - var S00 = smoothedState.S00.sum([0]); - var S11 = smoothedState.S11.sum([0]); - var S10 = smoothedState.S10.sum([0]); - - // Compute cross-correlation between observations and smoothed states - var crossCorrelationObservations = observationT.matmul(smoothedState.SmoothedMean); - - // Update parameters - if (parametersToEstimate.TransitionMatrix) - transitionMatrix = InverseCholesky(S10, S00); - - if (parametersToEstimate.MeasurementFunction) - measurementFunction = InverseCholesky(crossCorrelationObservations, S11); - - if (parametersToEstimate.ProcessNoiseCovariance) - processNoiseCovariance = WrappedTensorDisposeScope(() => - EnsureSymmetric((S11 - transitionMatrix.matmul(S10.mT)) / timeBins)); - - var explainedObservationCovariance = measurementFunction.matmul(crossCorrelationObservations.mT); - - if (parametersToEstimate.MeasurementNoiseCovariance) - measurementNoiseCovariance = WrappedTensorDisposeScope(() => - EnsureSymmetric((autoCorrelationObservations - explainedObservationCovariance - explainedObservationCovariance.mT - + measurementFunction.matmul(S11).matmul(measurementFunction.mT)) / timeBins)); - - if (parametersToEstimate.InitialMean) - initialMean = smoothedState.SmoothedInitialMean; - - if (parametersToEstimate.InitialCovariance) - initialCovariance = smoothedState.SmoothedInitialCovariance; - } - } - - var updatedParameters = new KalmanFilterParameters( - transitionMatrix: transitionMatrix, - measurementFunction: measurementFunction, - processNoiseCovariance: processNoiseCovariance, - measurementNoiseCovariance: measurementNoiseCovariance, - initialMean: initialMean, - initialCovariance: initialCovariance - ); - - if (updateParameters) - UpdateParameters(updatedParameters); - - return new ExpectationMaximizationResult(logLikelihood, updatedParameters); - } - public static ExpectationMaximizationResult ExpectationMaximization( Tensor observation, - KalmanFilterParameters parameters, + KalmanFilterParameters? parameters = null, int maxIterations = 100, double tolerance = 1e-4, ParametersToEstimate parametersToEstimate = new(), @@ -768,20 +681,31 @@ public static ExpectationMaximizationResult ExpectationMaximization( { device ??= CPU; - ValidateNumStates(parameters.TransitionMatrix, parameters.MeasurementFunction, parameters.InitialMean, parameters.InitialCovariance, parameters.ProcessNoiseCovariance, out var numStates); - ValidateNumObservations(parameters.MeasurementFunction, parameters.MeasurementNoiseCovariance, out var numObservations); - var timeBins = observation.size(0); + var numObservations = (int)observation.size(1); var logLikelihood = empty(maxIterations, dtype: ScalarType.Float32, device: device); var previousLogLikelihood = double.NegativeInfinity; var logLikelihoodConst = -0.5 * timeBins * numObservations * Math.Log(2.0 * Math.PI); - var transitionMatrix = parameters.TransitionMatrix; - var measurementFunction = parameters.MeasurementFunction; - var processNoiseCovariance = parameters.ProcessNoiseCovariance; - var measurementNoiseCovariance = parameters.MeasurementNoiseCovariance; - var initialMean = parameters.InitialMean; - var initialCovariance = parameters.InitialCovariance; + var kalmanFilterParameters = InitializeParameters( + numStates: parameters?.NumStates, + numObservations: numObservations, + transitionMatrix: parameters?.TransitionMatrix, + measurementFunction: parameters?.MeasurementFunction, + processNoiseCovariance: parameters?.ProcessNoiseCovariance, + measurementNoiseCovariance: parameters?.MeasurementNoiseCovariance, + initialMean: parameters?.InitialMean, + initialCovariance: parameters?.InitialCovariance, + device: device, + scalarType: scalarType); + + var numStates = kalmanFilterParameters.NumStates; + var transitionMatrix = kalmanFilterParameters.TransitionMatrix; + var measurementFunction = kalmanFilterParameters.MeasurementFunction; + var processNoiseCovariance = kalmanFilterParameters.ProcessNoiseCovariance; + var measurementNoiseCovariance = kalmanFilterParameters.MeasurementNoiseCovariance; + var initialMean = kalmanFilterParameters.InitialMean; + var initialCovariance = kalmanFilterParameters.InitialCovariance; var identityStates = eye(numStates, dtype: scalarType, device: device); @@ -875,12 +799,14 @@ public static ExpectationMaximizationResult ExpectationMaximization( } var updatedParameters = new KalmanFilterParameters( - transitionMatrix: transitionMatrix, - measurementFunction: measurementFunction, - processNoiseCovariance: processNoiseCovariance, - measurementNoiseCovariance: measurementNoiseCovariance, - initialMean: initialMean, - initialCovariance: initialCovariance + numStates: numStates, + numObservations: numObservations, + transitionMatrix: parametersToEstimate.TransitionMatrix ? transitionMatrix : null, + measurementFunction: parametersToEstimate.MeasurementFunction ? measurementFunction : null, + processNoiseCovariance: parametersToEstimate.ProcessNoiseCovariance ? processNoiseCovariance : null, + measurementNoiseCovariance: parametersToEstimate.MeasurementNoiseCovariance ? measurementNoiseCovariance : null, + initialMean: parametersToEstimate.InitialMean ? initialMean : null, + initialCovariance: parametersToEstimate.InitialCovariance ? initialCovariance : null ); return new ExpectationMaximizationResult(logLikelihood, updatedParameters); @@ -896,7 +822,7 @@ public static StochasticSubspaceIdentificationResult StochasticSubspaceIdentific using var g = no_grad(); var timeBins = observations.size(0); - var numObs = observations.size(1); + var numObservations = observations.size(1); var centered = observations - observations.mean([0], keepdim: true); // Build Hankel matrices from observations @@ -906,12 +832,12 @@ public static StochasticSubspaceIdentificationResult StochasticSubspaceIdentific throw new ArgumentException($"Number of time bins ({timeBins}) must be greater than 2*maxLag ({2 * maxLag}) for subspace identification."); var stride = centered.stride(); - var pastView = centered.as_strided([maxLag, numCols, numObs], [stride[0], stride[0], stride[1]]); - var past = pastView.permute(0, 2, 1).reshape(maxLag * numObs, numCols); + var pastView = centered.as_strided([maxLag, numCols, numObservations], [stride[0], stride[0], stride[1]]); + var past = pastView.permute(0, 2, 1).reshape(maxLag * numObservations, numCols); var futureView = centered.narrow(0, maxLag, timeBins - maxLag) - .as_strided([maxLag, numCols, numObs], [stride[0], stride[0], stride[1]]); - var future = futureView.permute(0, 2, 1).reshape(maxLag * numObs, numCols); + .as_strided([maxLag, numCols, numObservations], [stride[0], stride[0], stride[1]]); + var future = futureView.permute(0, 2, 1).reshape(maxLag * numObservations, numCols); // Compute the projection var Pp = past.matmul(past.mT); @@ -932,7 +858,7 @@ public static StochasticSubspaceIdentificationResult StochasticSubspaceIdentific var observability = Ur.matmul(SrSqrt); // Extract measurement function from first block of observability matrix - var measurementFunction = observability[TensorIndex.Slice(0, numObs)]; + var measurementFunction = observability[TensorIndex.Slice(0, numObservations)]; // Estimate state sequence var states = SrSqrt.matmul(Vrt); @@ -961,6 +887,8 @@ public static StochasticSubspaceIdentificationResult StochasticSubspaceIdentific states.matmul(states.mT) / numCols)); var parameters = new KalmanFilterParameters( + numStates: (int)effectiveStates, + numObservations: (int)numObservations, transitionMatrix: parametersToEstimate.TransitionMatrix ? transitionMatrix : null, measurementFunction: parametersToEstimate.MeasurementFunction ? measurementFunction : null, processNoiseCovariance: parametersToEstimate.ProcessNoiseCovariance ? processNoiseCovariance : null, @@ -978,7 +906,7 @@ public static StochasticSubspaceIdentificationResult StochasticSubspaceIdentific public LinearDynamicalSystemState OrthogonalizeMeanAndCovariance(Tensor mean, Tensor covariance) { - var (_, S, Vt) = linalg.svd(_measurementFunction); + var (_, S, Vt) = linalg.svd(Parameters.MeasurementFunction); var SVt = diag(S).matmul(Vt); Tensor orthogonalizedMean = null; @@ -1001,17 +929,17 @@ public LinearDynamicalSystemState OrthogonalizeMeanAndCovariance(Tensor mean, Te public void UpdateParameters(KalmanFilterParameters updatedParameters) { if (updatedParameters.TransitionMatrix is not null) - _transitionMatrix.set_(updatedParameters.TransitionMatrix); + Parameters.TransitionMatrix.set_(updatedParameters.TransitionMatrix); if (updatedParameters.MeasurementFunction is not null) - _measurementFunction.set_(updatedParameters.MeasurementFunction); + Parameters.MeasurementFunction.set_(updatedParameters.MeasurementFunction); if (updatedParameters.ProcessNoiseCovariance is not null) - _processNoiseCovariance.set_(updatedParameters.ProcessNoiseCovariance); + Parameters.ProcessNoiseCovariance.set_(updatedParameters.ProcessNoiseCovariance); if (updatedParameters.MeasurementNoiseCovariance is not null) - _measurementNoiseCovariance.set_(updatedParameters.MeasurementNoiseCovariance); + Parameters.MeasurementNoiseCovariance.set_(updatedParameters.MeasurementNoiseCovariance); if (updatedParameters.InitialMean is not null) - _initialMean.set_(updatedParameters.InitialMean); + Parameters.InitialMean.set_(updatedParameters.InitialMean); if (updatedParameters.InitialCovariance is not null) - _initialCovariance.set_(updatedParameters.InitialCovariance); + Parameters.InitialCovariance.set_(updatedParameters.InitialCovariance); } private static Tensor EnsureSymmetric(Tensor M) => 0.5 * (M + M.mT); diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs index 1e5f8e82..a3483390 100644 --- a/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs @@ -8,6 +8,8 @@ namespace Bonsai.ML.Lds.Torch; /// /// Initializes a new instance of the struct with the specified parameters. /// +/// +/// /// /// /// @@ -15,6 +17,8 @@ namespace Bonsai.ML.Lds.Torch; /// /// public struct KalmanFilterParameters( + int numStates, + int numObservations, Tensor transitionMatrix = null, Tensor measurementFunction = null, Tensor processNoiseCovariance = null, @@ -22,6 +26,16 @@ public struct KalmanFilterParameters( Tensor initialMean = null, Tensor initialCovariance = null) { + /// + /// The number of states in the system. + /// + public int NumStates = numStates; + + /// + /// The number of observations in the system. + /// + public int NumObservations = numObservations; + /// /// The state transition matrix. /// diff --git a/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs index 84a8a812..60f151ba 100644 --- a/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs @@ -103,14 +103,15 @@ public IObservable Process() var initialMean = LoadTensorFromFile(InitialMeanFilePath); var initialCovariance = LoadTensorFromFile(InitialCovarianceFilePath); - var parameters = new KalmanFilterParameters( + var parameters = KalmanFilter.InitializeParameters( transitionMatrix: transitionMatrix, measurementFunction: measurementFunction, processNoiseCovariance: processNoiseCovariance, measurementNoiseCovariance: measurementNoiseCovariance, initialMean: initialMean, - initialCovariance: initialCovariance - ); + initialCovariance: initialCovariance, + device: Device, + scalarType: Type); return Observable.Return(parameters); } From 6b85e19016069d776650cc4a67dd33e0ee099c95 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 28 Oct 2025 18:25:04 +0000 Subject: [PATCH 73/92] Added overload to create a `LinearDynamicalSystemState` from a stream of a tuple of tensors --- .../CreateLinearDynamicalSystemState.cs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/Bonsai.ML.Lds.Torch/CreateLinearDynamicalSystemState.cs b/src/Bonsai.ML.Lds.Torch/CreateLinearDynamicalSystemState.cs index 848e1b95..ac44e2c1 100644 --- a/src/Bonsai.ML.Lds.Torch/CreateLinearDynamicalSystemState.cs +++ b/src/Bonsai.ML.Lds.Torch/CreateLinearDynamicalSystemState.cs @@ -123,4 +123,17 @@ public IObservable Process(IObservable source) return new LinearDynamicalSystemState(mean, covariance); }); } + + /// + /// Processes an observable sequence of a tuple of tensors (mean and covariance) and emits a state for a linear gaussian dynamical system. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(input => + { + return new LinearDynamicalSystemState(input.Item1, input.Item2); + }); + } } \ No newline at end of file From 2b5c5ac9e8cffa2c27eeb4421c6831750ff8e0a4 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 28 Oct 2025 19:26:42 +0000 Subject: [PATCH 74/92] Refactored `SaveKalmanFilterParameters` class to save to a folder rather than individual files which is more consistent with other `Bonsai.ML` packages --- .../SaveKalmanFilterParameters.cs | 149 ++++++++++++------ 1 file changed, 98 insertions(+), 51 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch/SaveKalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/SaveKalmanFilterParameters.cs index af8a2182..e282afa1 100644 --- a/src/Bonsai.ML.Lds.Torch/SaveKalmanFilterParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/SaveKalmanFilterParameters.cs @@ -3,8 +3,8 @@ using System.Linq; using System.Reactive.Linq; using System.Xml.Serialization; -using Bonsai.ML.Torch; using System.IO; +using Bonsai.ML.Torch; using TorchSharp; using static TorchSharp.torch; @@ -19,75 +19,122 @@ namespace Bonsai.ML.Lds.Torch; public class SaveKalmanFilterParameters { /// - /// Specifies the path to use for saving the transition matrix of a Kalman filter to a .bin file. + /// The path to the folder where the Kalman filter model parameters will be saved. /// - [Description("Specifies the path to use for saving the transition matrix of a Kalman filter to a .bin file.")] - [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] - [Editor("Bonsai.Design.SaveFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] - public string TransitionMatrixFilePath { get; set; } = "transition_matrix.bin"; + [Editor("Bonsai.Design.FolderNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + [Description("The path to the folder where the Kalman filter model parameters will be saved.")] + public string Path { get; set; } = string.Empty; /// - /// Specifies the path to use for saving the measurement function of a Kalman filter to a .bin file. + /// If true, the contents of the folder will be overwritten if it already exists. /// - [Description("Specifies the path to use for saving the measurement function of a Kalman filter to a .bin file.")] - [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] - [Editor("Bonsai.Design.SaveFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] - public string MeasurementFunctionFilePath { get; set; } = "measurement_function.bin"; + [Description("If true, the contents of the folder will be overwritten if it already exists.")] + public bool Overwrite { get; set; } = false; /// - /// Specifies the path to use for saving the process noise covariance of a Kalman filter to a .bin file. + /// Specifies the type of suffix to add to the save path. + /// The suffix is added as a subfolder to the save path. + /// If DateTime is used, a suffix with the current date and time is added to the save path in the format, '[Path]/yyyyMMddHHmmss/...'. + /// If Guid is used, a suffix with a unique 128-bit identifier is added to the save path in the format, '[Path]/[128-bit identifier]/...'. /// - [Description("Specifies the path to use for saving the process noise covariance of a Kalman filter to a .bin file.")] - [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] - [Editor("Bonsai.Design.SaveFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] - public string ProcessNoiseCovarianceFilePath { get; set; } = "process_noise_covariance.bin"; + [Description("Specifies the type of suffix to add to the save path.")] + public SuffixType AddSuffix { get; set; } = SuffixType.None; - /// - /// Specifies the path to use for saving the measurement noise covariance of a Kalman filter to a .bin file. - /// - [Description("Specifies the path to use for saving the measurement noise covariance of a Kalman filter to a .bin file.")] - [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] - [Editor("Bonsai.Design.SaveFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] - public string MeasurementNoiseCovarianceFilePath { get; set; } = "measurement_noise_covariance.bin"; + private void SaveKalmanFilterParametersToDisk(KalmanFilterParameters parameters) + { + if (string.IsNullOrEmpty(Path)) + { + Path = Directory.GetCurrentDirectory(); + } + + var path = AddSuffix switch + { + SuffixType.DateTime => System.IO.Path.Combine(Path, $"{HighResolutionScheduler.Now:yyyyMMddHHmmss}"), + SuffixType.Guid => System.IO.Path.Combine(Path, $"{Guid.NewGuid()}"), + _ => Path + }; + + var transitionMatrixPath = System.IO.Path.Combine(path, "TransitionMatrix.bin"); + var measurementFunctionPath = System.IO.Path.Combine(path, "MeasurementFunction.bin"); + var processNoiseCovariancePath = System.IO.Path.Combine(path, "ProcessNoiseCovariance.bin"); + var measurementNoiseCovariancePath = System.IO.Path.Combine(path, "MeasurementNoiseCovariance.bin"); + var initialMeanPath = System.IO.Path.Combine(path, "InitialMean.bin"); + var initialCovariancePath = System.IO.Path.Combine(path, "InitialCovariance.bin"); + + if (Directory.Exists(path)) + { + if (!Overwrite && ( + File.Exists(transitionMatrixPath) || + File.Exists(measurementFunctionPath) || + File.Exists(processNoiseCovariancePath) || + File.Exists(measurementNoiseCovariancePath) || + File.Exists(initialMeanPath) || + File.Exists(initialCovariancePath)) + ) + { + throw new InvalidOperationException("The save path already exists."); + } + else + { + if (File.Exists(transitionMatrixPath)) + File.Delete(transitionMatrixPath); + if (File.Exists(measurementFunctionPath)) + File.Delete(measurementFunctionPath); + if (File.Exists(processNoiseCovariancePath)) + File.Delete(processNoiseCovariancePath); + if (File.Exists(measurementNoiseCovariancePath)) + File.Delete(measurementNoiseCovariancePath); + if (File.Exists(initialMeanPath)) + File.Delete(initialMeanPath); + if (File.Exists(initialCovariancePath)) + File.Delete(initialCovariancePath); + } + } + + Directory.CreateDirectory(path); + + parameters.TransitionMatrix.Save(transitionMatrixPath); + parameters.MeasurementFunction.Save(measurementFunctionPath); + parameters.ProcessNoiseCovariance.Save(processNoiseCovariancePath); + parameters.MeasurementNoiseCovariance.Save(measurementNoiseCovariancePath); + parameters.InitialMean.Save(initialMeanPath); + parameters.InitialCovariance.Save(initialCovariancePath); + } /// - /// Specifies the path to use for saving the initial mean of a Kalman filter to a .bin file. + /// Processes an observable sequence of Kalman filter parameters, saving to files. /// - [Description("Specifies the path to use for saving the initial mean of a Kalman filter to a .bin file.")] - [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] - [Editor("Bonsai.Design.SaveFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] - public string InitialMeanFilePath { get; set; } = "initial_mean.bin"; + public IObservable Process(IObservable source) + { + return source.Do(SaveKalmanFilterParametersToDisk); + } /// - /// Specifies the path to use for saving the initial covariance of a Kalman filter to a .bin file. + /// Processes an observable sequence of Kalman filter models, saving their parameters to files. /// - [Description("Specifies the path to use for saving the initial covariance of a Kalman filter to a .bin file.")] - [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] - [Editor("Bonsai.Design.SaveFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] - public string InitialCovarianceFilePath { get; set; } = "initial_covariance.bin"; - - private void SaveTensorToFile(Tensor tensor, string filePath) + public IObservable Process(IObservable source) { - if (filePath != null) - { - tensor.Save(filePath); - } + return source.Do(model => SaveKalmanFilterParametersToDisk(model.Parameters)); } /// - /// Creates parameters for a Kalman filter model using the properties of this class. + /// Specifies the type of suffix to add to the save path. /// - public IObservable Process(IObservable source) + public enum SuffixType { - return source.Do(model => - { - var parameters = model.Parameters; - SaveTensorToFile(parameters.TransitionMatrix, TransitionMatrixFilePath); - SaveTensorToFile(parameters.MeasurementFunction, MeasurementFunctionFilePath); - SaveTensorToFile(parameters.ProcessNoiseCovariance, ProcessNoiseCovarianceFilePath); - SaveTensorToFile(parameters.MeasurementNoiseCovariance, MeasurementNoiseCovarianceFilePath); - SaveTensorToFile(parameters.InitialMean, InitialMeanFilePath); - SaveTensorToFile(parameters.InitialCovariance, InitialCovarianceFilePath); - }); + /// + /// No suffix is added to the save path. + /// + None, + + /// + /// A suffix with the current date and time is added to the save path. + /// + DateTime, + + /// + /// A suffix with a unique identifier is added to the save path. + /// + Guid } } \ No newline at end of file From d05077ea0f474b2992c2a80e72b7575d713cc44e Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 28 Oct 2025 19:28:54 +0000 Subject: [PATCH 75/92] Refactored `KalmanFilterParameters` class to handle validation and moved validation logic outside of `KalmanFilter` module --- .../CreateKalmanFilterParameters.cs | 4 +- .../ExpectationMaximization.cs | 22 +- src/Bonsai.ML.Lds.Torch/KalmanFilter.cs | 226 ++--------------- .../KalmanFilterParameters.cs | 233 +++++++++++++++++- 4 files changed, 262 insertions(+), 223 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs index 81da4840..4fed8245 100644 --- a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs @@ -200,7 +200,7 @@ private void ConvertTensorsScalarType(ScalarType scalarType) /// public IObservable Process() { - return Observable.Return(KalmanFilter.InitializeParameters( + return Observable.Return(KalmanFilterParameters.Initialize( numStates: NumStates, numObservations: NumObservations, transitionMatrix: _transitionMatrix, @@ -220,7 +220,7 @@ public IObservable Process(IObservable source) { return source.Select(_ => { - return KalmanFilter.InitializeParameters( + return KalmanFilterParameters.Initialize( numStates: NumStates, numObservations: NumObservations, transitionMatrix: _transitionMatrix, diff --git a/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs index 577fe3bc..77e4c4fa 100644 --- a/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs +++ b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs @@ -125,17 +125,17 @@ public IObservable Process(IObservable so initialMean: EstimateInitialMean, initialCovariance: EstimateInitialCovariance); - var parameters = KalmanFilter.InitializeParameters( - numStates: NumStates, - numObservations: numObservations, - transitionMatrix: ModelParameters?.TransitionMatrix, - measurementFunction: ModelParameters?.MeasurementFunction, - processNoiseCovariance: ModelParameters?.ProcessNoiseCovariance, - measurementNoiseCovariance: ModelParameters?.MeasurementNoiseCovariance, - initialMean: ModelParameters?.InitialMean, - initialCovariance: ModelParameters?.InitialCovariance, - device: input.device, - scalarType: input.dtype); + + var parameters = ModelParameters?.Copy() ?? KalmanFilterParameters.Initialize( + numStates: NumStates, + numObservations: numObservations, + scalarType: input.dtype, + device: input.device); + + if (!parameters.IsValidated) + { + parameters.Validate(); + } for (int i = 0; i < MaxIterations; i++) { diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs index 4145058b..55460b18 100644 --- a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs @@ -9,7 +9,6 @@ namespace Bonsai.ML.Lds.Torch; public class KalmanFilter : nn.Module { - private readonly Tensor _identityStates; private readonly Tensor _mean; private readonly Tensor _covariance; private readonly Device _device; @@ -24,18 +23,12 @@ public KalmanFilter( _device = device ?? CPU; _scalarType = scalarType; - Parameters = InitializeParameters( - transitionMatrix: parameters.TransitionMatrix, - measurementFunction: parameters.MeasurementFunction, - processNoiseCovariance: parameters.ProcessNoiseCovariance, - measurementNoiseCovariance: parameters.MeasurementNoiseCovariance, - initialMean: parameters.InitialMean, - initialCovariance: parameters.InitialCovariance, - device: _device, - scalarType: _scalarType - ); + if (!parameters.IsValidated) + { + KalmanFilterParameters.Validate(parameters); + } - _identityStates = eye(Parameters.NumStates, dtype: _scalarType, device: _device).requires_grad_(false); + Parameters = parameters; _mean = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); _covariance = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); } @@ -55,7 +48,7 @@ public KalmanFilter( _device = device ?? CPU; _scalarType = scalarType; - Parameters = InitializeParameters( + Parameters = KalmanFilterParameters.Initialize( numStates, numObservations, transitionMatrix, @@ -68,197 +61,12 @@ public KalmanFilter( _scalarType ); - _identityStates = eye(Parameters.NumStates, dtype: _scalarType, device: _device).requires_grad_(false); _mean = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); _covariance = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); RegisterComponents(); } - /// - /// Initializes the Kalman filter parameters. - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - public static KalmanFilterParameters InitializeParameters( - int? numStates = null, - int? numObservations = null, - Tensor transitionMatrix = null, - Tensor measurementFunction = null, - Tensor processNoiseCovariance = null, - Tensor measurementNoiseCovariance = null, - Tensor initialMean = null, - Tensor initialCovariance = null, - Device device = null, - ScalarType scalarType = ScalarType.Float32 - ) - { - var trueNumStates = numStates ?? -1; - var trueNumObservations = numObservations ?? -1; - device ??= CPU; - - if (numStates is null) - { - ValidateNumStates(transitionMatrix, measurementFunction, initialMean, initialCovariance, processNoiseCovariance, out trueNumStates); - } - - if (trueNumStates <= 0) - throw new ArgumentOutOfRangeException(nameof(trueNumStates), "Number of states must be greater than zero."); - - if (numObservations is null) - { - ValidateNumObservations(measurementFunction, measurementNoiseCovariance, out trueNumObservations); - } - - if (trueNumObservations <= 0) - throw new ArgumentOutOfRangeException(nameof(numObservations), "Number of observations must be greater than zero."); - - transitionMatrix = transitionMatrix.clone().to_type(scalarType).to(device).requires_grad_(false) ?? eye(trueNumStates, dtype: scalarType, device: device).requires_grad_(false); - - ValidateMatrix(transitionMatrix, "Transition matrix", isSquare: true, expectedDimension1: trueNumStates); - - measurementFunction = measurementFunction.clone().to_type(scalarType).to(device).requires_grad_(false) ?? eye(trueNumObservations, trueNumStates, dtype: scalarType, device: device).requires_grad_(false); - - ValidateMatrix(measurementFunction, "Measurement function", expectedDimension1: trueNumObservations, expectedDimension2: trueNumStates); - - initialMean = initialMean.clone().to_type(scalarType).to(device).requires_grad_(false) ?? zeros(trueNumStates, dtype: scalarType, device: device).requires_grad_(false); - - ValidateVector(initialMean, "Initial mean", trueNumStates); - - initialCovariance = initialCovariance.clone().to_type(scalarType).to(device).requires_grad_(false) ?? eye(trueNumStates, dtype: scalarType, device: device).requires_grad_(false); - - ValidateMatrix(initialCovariance, "Initial covariance", isSquare: true, expectedDimension1: trueNumStates); - - processNoiseCovariance = processNoiseCovariance.NumberOfElements == 1 - ? CreateCovarianceMatrix(processNoiseCovariance, scalarType, device, trueNumStates, "Process noise variance") - : processNoiseCovariance.clone().to_type(scalarType).to(device).requires_grad_(false) - ?? CreateCovarianceMatrix(tensor(1.0), scalarType, device, trueNumStates, "Process noise variance"); - - ValidateMatrix(processNoiseCovariance, "Process noise covariance", isSquare: true, expectedDimension1: trueNumStates); - - measurementNoiseCovariance = measurementNoiseCovariance.NumberOfElements == 1 - ? CreateCovarianceMatrix(measurementNoiseCovariance, scalarType, device, trueNumObservations, "Measurement noise variance") - : measurementNoiseCovariance.clone().to_type(scalarType).to(device).requires_grad_(false) - ?? CreateCovarianceMatrix(tensor(1.0), scalarType, device, trueNumObservations, "Measurement noise variance"); - - ValidateMatrix(measurementNoiseCovariance, "Process noise covariance", isSquare: true, expectedDimension1: trueNumObservations); - - return new KalmanFilterParameters( - trueNumStates, - trueNumObservations, - transitionMatrix, - measurementFunction, - processNoiseCovariance, - measurementNoiseCovariance, - initialMean, - initialCovariance - ); - } - - private static void ValidateNumStates(Tensor transitionMatrix, Tensor measurementFunction, Tensor initialMean, Tensor initialCovariance, Tensor processNoiseCovariance, out int numStates) - { - if (transitionMatrix is not null) - { - ValidateMatrix(transitionMatrix, "Transition matrix", isSquare: true); - numStates = (int)transitionMatrix.size(0); - } - else if (measurementFunction is not null) - { - ValidateMatrix(measurementFunction, "Measurement function"); - numStates = (int)measurementFunction.size(1); - } - else if (initialMean is not null) - { - ValidateVector(initialMean, "Initial mean"); - numStates = (int)initialMean.size(0); - } - else if (initialCovariance is not null) - { - ValidateMatrix(initialCovariance, "Initial covariance", isSquare: true); - numStates = (int)initialCovariance.size(0); - } - else if (processNoiseCovariance is not null) - { - ValidateMatrix(processNoiseCovariance, "Process noise covariance", isSquare: true); - numStates = (int)processNoiseCovariance.size(0); - } - else - { - throw new ArgumentException("At least one of the parameters must be provided to infer the number of states."); - } - } - - private static void ValidateNumObservations(Tensor measurementFunction, Tensor measurementNoiseCovariance, out int numObservations) - { - if (measurementFunction is not null) - { - ValidateMatrix(measurementFunction, "Measurement function"); - numObservations = (int)measurementFunction.size(0); - } - else if (measurementNoiseCovariance is not null) - { - ValidateMatrix(measurementNoiseCovariance, "Measurement noise covariance", isSquare: true); - numObservations = (int)measurementNoiseCovariance.size(0); - } - else - { - throw new ArgumentException("At least one of the measurement function or measurement noise covariance must be provided to infer the number of observations."); - } - } - - private static void ValidateMatrix(Tensor matrix, string name, bool isSquare = false, int? expectedDimension1 = null, int? expectedDimension2 = null) - { - if (matrix.NumberOfElements == 0) - throw new ArgumentException($"{name} must be a non-empty matrix."); - - if (matrix.Dimensions != 2) - throw new ArgumentException($"{name} must be 2-dimensional."); - - if (isSquare && matrix.size(0) != matrix.size(1)) - throw new ArgumentException($"{name} must be square."); - - if (expectedDimension1.HasValue && matrix.size(0) != expectedDimension1.Value) - throw new ArgumentException($"{name} must have {expectedDimension1.Value} rows."); - - if (expectedDimension2.HasValue && matrix.size(1) != expectedDimension2.Value) - throw new ArgumentException($"{name} must have {expectedDimension2.Value} columns."); - } - - private static void ValidateVector(Tensor vector, string name, int? expectedLength = null) - { - if (vector.NumberOfElements == 0) - throw new ArgumentException($"{name} must be a non-empty vector."); - - if (vector.Dimensions != 1) - throw new ArgumentException($"{name} must be a vector."); - - if (expectedLength.HasValue && vector.NumberOfElements != expectedLength.Value) - throw new ArgumentException($"{name} must be a vector with length equal to {expectedLength.Value}."); - } - - private static void ValidateScalar(Tensor scalar, string name) - { - if (scalar.NumberOfElements != 1) - throw new ArgumentException($"{name} must be a scalar."); - } - - private static Tensor CreateCovarianceMatrix(Tensor variance, ScalarType scalarType, Device device, int dimension, string name) - { - ValidateScalar(variance, name); - var scalar = variance.clone().squeeze().to_type(scalarType).to(device); - return (scalar * eye(dimension, dtype: scalarType, device: device)).requires_grad_(false); - } - private readonly struct PredictedState( Tensor predictedMean, Tensor predictedCovariance) @@ -687,17 +495,17 @@ public static ExpectationMaximizationResult ExpectationMaximization( var previousLogLikelihood = double.NegativeInfinity; var logLikelihoodConst = -0.5 * timeBins * numObservations * Math.Log(2.0 * Math.PI); - var kalmanFilterParameters = InitializeParameters( - numStates: parameters?.NumStates, - numObservations: numObservations, - transitionMatrix: parameters?.TransitionMatrix, - measurementFunction: parameters?.MeasurementFunction, - processNoiseCovariance: parameters?.ProcessNoiseCovariance, - measurementNoiseCovariance: parameters?.MeasurementNoiseCovariance, - initialMean: parameters?.InitialMean, - initialCovariance: parameters?.InitialCovariance, - device: device, - scalarType: scalarType); + var kalmanFilterParameters = parameters?.Copy() + ?? KalmanFilterParameters.Initialize( + numStates: null, + numObservations: numObservations, + scalarType: scalarType, + device: device); + + if (!kalmanFilterParameters.IsValidated) + { + KalmanFilterParameters.Validate(kalmanFilterParameters); + } var numStates = kalmanFilterParameters.NumStates; var transitionMatrix = kalmanFilterParameters.TransitionMatrix; diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs index a3483390..3acda77b 100644 --- a/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs @@ -1,3 +1,4 @@ +using System; using static TorchSharp.torch; namespace Bonsai.ML.Lds.Torch; @@ -16,6 +17,7 @@ namespace Bonsai.ML.Lds.Torch; /// /// /// +/// public struct KalmanFilterParameters( int numStates, int numObservations, @@ -24,7 +26,8 @@ public struct KalmanFilterParameters( Tensor processNoiseCovariance = null, Tensor measurementNoiseCovariance = null, Tensor initialMean = null, - Tensor initialCovariance = null) + Tensor initialCovariance = null, + bool isValidated = false) { /// /// The number of states in the system. @@ -65,4 +68,232 @@ public struct KalmanFilterParameters( /// The initial covariance. /// public Tensor InitialCovariance = initialCovariance; + + /// + /// Indicates whether the parameters have been validated. + /// + /// + /// This field is used to avoid redundant validation checks. + /// + public bool IsValidated = isValidated; + + /// + /// Initializes the Kalman filter parameters. + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static KalmanFilterParameters Initialize( + int? numStates = null, + int? numObservations = null, + Tensor transitionMatrix = null, + Tensor measurementFunction = null, + Tensor processNoiseCovariance = null, + Tensor measurementNoiseCovariance = null, + Tensor initialMean = null, + Tensor initialCovariance = null, + Device device = null, + ScalarType scalarType = ScalarType.Float32 + ) + { + var trueNumStates = numStates ?? -1; + var trueNumObservations = numObservations ?? -1; + device ??= CPU; + + if (numStates is null) + { + ValidateNumStates(transitionMatrix, measurementFunction, initialMean, initialCovariance, processNoiseCovariance, out trueNumStates); + } + + if (trueNumStates <= 0) + throw new ArgumentOutOfRangeException(nameof(trueNumStates), "Number of states must be greater than zero."); + + if (numObservations is null) + { + ValidateNumObservations(measurementFunction, measurementNoiseCovariance, out trueNumObservations); + } + + if (trueNumObservations <= 0) + throw new ArgumentOutOfRangeException(nameof(numObservations), "Number of observations must be greater than zero."); + + transitionMatrix = transitionMatrix?.clone().to_type(scalarType).to(device).requires_grad_(false) ?? eye(trueNumStates, dtype: scalarType, device: device).requires_grad_(false); + + measurementFunction = measurementFunction?.clone().to_type(scalarType).to(device).requires_grad_(false) ?? eye(trueNumObservations, trueNumStates, dtype: scalarType, device: device).requires_grad_(false); + + initialMean = initialMean?.clone().to_type(scalarType).to(device).requires_grad_(false) ?? zeros(trueNumStates, dtype: scalarType, device: device).requires_grad_(false); + + initialCovariance = initialCovariance?.clone().to_type(scalarType).to(device).requires_grad_(false) ?? eye(trueNumStates, dtype: scalarType, device: device).requires_grad_(false); + + processNoiseCovariance = processNoiseCovariance?.NumberOfElements == 1 + ? CreateCovarianceMatrix(processNoiseCovariance, scalarType, device, trueNumStates, "Process noise variance") + : processNoiseCovariance?.clone().to_type(scalarType).to(device).requires_grad_(false) + ?? CreateCovarianceMatrix(tensor(1.0), scalarType, device, trueNumStates, "Process noise variance"); + + measurementNoiseCovariance = measurementNoiseCovariance?.NumberOfElements == 1 + ? CreateCovarianceMatrix(measurementNoiseCovariance, scalarType, device, trueNumObservations, "Measurement noise variance") + : measurementNoiseCovariance?.clone().to_type(scalarType).to(device).requires_grad_(false) + ?? CreateCovarianceMatrix(tensor(1.0), scalarType, device, trueNumObservations, "Measurement noise variance"); + + var parameters = new KalmanFilterParameters( + trueNumStates, + trueNumObservations, + transitionMatrix, + measurementFunction, + processNoiseCovariance, + measurementNoiseCovariance, + initialMean, + initialCovariance + ); + + parameters.Validate(); + + return parameters; + } + + /// + /// Validates the Kalman filter parameters. + /// + public void Validate() + { + if (IsValidated) + return; + + ValidateNumStates(TransitionMatrix, MeasurementFunction, InitialMean, InitialCovariance, ProcessNoiseCovariance, out NumStates); + ValidateNumObservations(MeasurementFunction, MeasurementNoiseCovariance, out NumObservations); + ValidateMatrix(TransitionMatrix, "Transition matrix", isSquare: true, expectedDimension1: NumStates); + ValidateMatrix(MeasurementFunction, "Measurement function", expectedDimension1: NumObservations, expectedDimension2: NumStates); + ValidateMatrix(ProcessNoiseCovariance, "Process noise covariance", isSquare: true, expectedDimension1: NumStates); + ValidateMatrix(MeasurementNoiseCovariance, "Measurement noise covariance", isSquare: true, expectedDimension1: NumObservations); + ValidateVector(InitialMean, "Initial mean", NumStates); + ValidateMatrix(InitialCovariance, "Initial covariance", isSquare: true, expectedDimension1: NumStates); + IsValidated = true; + } + + /// + /// Validates the specified Kalman filter parameters. + /// + /// + public static void Validate(KalmanFilterParameters parameters) + { + parameters.Validate(); + } + + /// + /// Creates a copy of the current Kalman filter parameters. + /// + /// + public readonly KalmanFilterParameters Copy() => new( + NumStates, + NumObservations, + TransitionMatrix?.clone(), + MeasurementFunction?.clone(), + ProcessNoiseCovariance?.clone(), + MeasurementNoiseCovariance?.clone(), + InitialMean?.clone(), + InitialCovariance?.clone(), + IsValidated + ); + + private static void ValidateNumStates(Tensor transitionMatrix, Tensor measurementFunction, Tensor initialMean, Tensor initialCovariance, Tensor processNoiseCovariance, out int numStates) + { + if (transitionMatrix is not null) + { + ValidateMatrix(transitionMatrix, "Transition matrix", isSquare: true); + numStates = (int)transitionMatrix.size(0); + } + else if (measurementFunction is not null) + { + ValidateMatrix(measurementFunction, "Measurement function"); + numStates = (int)measurementFunction.size(1); + } + else if (initialMean is not null) + { + ValidateVector(initialMean, "Initial mean"); + numStates = (int)initialMean.size(0); + } + else if (initialCovariance is not null) + { + ValidateMatrix(initialCovariance, "Initial covariance", isSquare: true); + numStates = (int)initialCovariance.size(0); + } + else if (processNoiseCovariance is not null) + { + ValidateMatrix(processNoiseCovariance, "Process noise covariance", isSquare: true); + numStates = (int)processNoiseCovariance.size(0); + } + else + { + throw new ArgumentException("At least one of the parameters must be provided to infer the number of states."); + } + } + + private static void ValidateNumObservations(Tensor measurementFunction, Tensor measurementNoiseCovariance, out int numObservations) + { + if (measurementFunction is not null) + { + ValidateMatrix(measurementFunction, "Measurement function"); + numObservations = (int)measurementFunction.size(0); + } + else if (measurementNoiseCovariance is not null) + { + ValidateMatrix(measurementNoiseCovariance, "Measurement noise covariance", isSquare: true); + numObservations = (int)measurementNoiseCovariance.size(0); + } + else + { + throw new ArgumentException("At least one of the measurement function or measurement noise covariance must be provided to infer the number of observations."); + } + } + + private static void ValidateMatrix(Tensor matrix, string name, bool isSquare = false, int? expectedDimension1 = null, int? expectedDimension2 = null) + { + if (matrix.NumberOfElements == 0) + throw new ArgumentException($"{name} must be a non-empty matrix."); + + if (matrix.Dimensions != 2) + throw new ArgumentException($"{name} must be 2-dimensional."); + + if (isSquare && matrix.size(0) != matrix.size(1)) + throw new ArgumentException($"{name} must be square."); + + if (expectedDimension1.HasValue && matrix.size(0) != expectedDimension1.Value) + throw new ArgumentException($"{name} must have {expectedDimension1.Value} rows."); + + if (expectedDimension2.HasValue && matrix.size(1) != expectedDimension2.Value) + throw new ArgumentException($"{name} must have {expectedDimension2.Value} columns."); + } + + private static void ValidateVector(Tensor vector, string name, int? expectedLength = null) + { + if (vector.NumberOfElements == 0) + throw new ArgumentException($"{name} must be a non-empty vector."); + + if (vector.Dimensions != 1) + throw new ArgumentException($"{name} must be a vector."); + + if (expectedLength.HasValue && vector.NumberOfElements != expectedLength.Value) + throw new ArgumentException($"{name} must be a vector with length equal to {expectedLength.Value}."); + } + + private static void ValidateScalar(Tensor scalar, string name) + { + if (scalar.NumberOfElements != 1) + throw new ArgumentException($"{name} must be a scalar."); + } + + private static Tensor CreateCovarianceMatrix(Tensor variance, ScalarType scalarType, Device device, int dimension, string name) + { + ValidateScalar(variance, name); + var scalar = variance.clone().squeeze().to_type(scalarType).to(device); + return (scalar * eye(dimension, dtype: scalarType, device: device)).requires_grad_(false); + } } \ No newline at end of file From e68add5f43d44be4fa6606ccb23ae3d48ec40f9b Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 28 Oct 2025 19:29:36 +0000 Subject: [PATCH 76/92] Updated `LoadKalmanFilterParameters` operator to load parameters from folder instead of individual file paths --- .../LoadKalmanFilterParameters.cs | 77 ++++++------------- 1 file changed, 22 insertions(+), 55 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs index 60f151ba..b17b1691 100644 --- a/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs @@ -19,52 +19,11 @@ namespace Bonsai.ML.Lds.Torch; public class LoadKalmanFilterParameters { /// - /// Reads the path to a .bin file containing the transition matrix. + /// The path to the folder where the Kalman filter model parameters were saved. /// - [Description("Reads the path to a .bin file containing the transition matrix.")] - [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] - [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] - public string TransitionMatrixFilePath { get; set; } = "transition_matrix.bin"; - - /// - /// Reads the path to a .bin file containing the measurement function. - /// - [Description("Reads the path to a .bin file containing the measurement function.")] - [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] - [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] - public string MeasurementFunctionFilePath { get; set; } = "measurement_function.bin"; - - /// - /// Reads the path to a .bin file containing the process noise covariance. - /// - [Description("Reads the path to a .bin file containing the process noise covariance.")] - [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] - [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] - public string ProcessNoiseCovarianceFilePath { get; set; } = "process_noise_covariance.bin"; - - /// - /// Reads the path to a .bin file containing the measurement noise covariance. - /// - [Description("Reads the path to a .bin file containing the measurement noise covariance.")] - [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] - [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] - public string MeasurementNoiseCovarianceFilePath { get; set; } = "measurement_noise_covariance.bin"; - - /// - /// Reads the path to a .bin file containing the initial mean. - /// - [Description("Reads the path to a .bin file containing the initial mean.")] - [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] - [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] - public string InitialMeanFilePath { get; set; } = "initial_mean.bin"; - - /// - /// Reads the path to a .bin file containing the initial covariance. - /// - [Description("Reads the path to a .bin file containing the initial covariance.")] - [FileNameFilter("Binary Data (*.bin)|*.bin|All Files|*.*")] - [Editor("Bonsai.Design.OpenFileNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] - public string InitialCovarianceFilePath { get; set; } = "initial_covariance.bin"; + [Editor("Bonsai.Design.FolderNameEditor, Bonsai.Design", DesignTypes.UITypeEditor)] + [Description("The path to the folder where the Kalman filter model parameters were saved.")] + public string Path { get; set; } = string.Empty; /// /// Gets or sets the data type of the tensors. @@ -96,14 +55,24 @@ public IObservable Process() { Device ??= CPU; - var transitionMatrix = LoadTensorFromFile(TransitionMatrixFilePath); - var measurementFunction = LoadTensorFromFile(MeasurementFunctionFilePath); - var processNoiseCovariance = LoadTensorFromFile(ProcessNoiseCovarianceFilePath); - var measurementNoiseCovariance = LoadTensorFromFile(MeasurementNoiseCovarianceFilePath); - var initialMean = LoadTensorFromFile(InitialMeanFilePath); - var initialCovariance = LoadTensorFromFile(InitialCovarianceFilePath); + if (string.IsNullOrEmpty(Path)) + { + throw new InvalidOperationException("The save path is not specified."); + } - var parameters = KalmanFilter.InitializeParameters( + if (!Directory.Exists(Path)) + { + throw new InvalidOperationException("The save path does not exist."); + } + + var transitionMatrix = LoadTensorFromFile("TransitionMatrix.bin"); + var measurementFunction = LoadTensorFromFile("MeasurementFunction.bin"); + var processNoiseCovariance = LoadTensorFromFile("ProcessNoiseCovariance.bin"); + var measurementNoiseCovariance = LoadTensorFromFile("MeasurementNoiseCovariance.bin"); + var initialMean = LoadTensorFromFile("InitialMean.bin"); + var initialCovariance = LoadTensorFromFile("InitialCovariance.bin"); + + return Observable.Return(KalmanFilterParameters.Initialize( transitionMatrix: transitionMatrix, measurementFunction: measurementFunction, processNoiseCovariance: processNoiseCovariance, @@ -111,8 +80,6 @@ public IObservable Process() initialMean: initialMean, initialCovariance: initialCovariance, device: Device, - scalarType: Type); - - return Observable.Return(parameters); + scalarType: Type)); } } \ No newline at end of file From f89d7507d17ae97ca3b84db522ca654cbc4526b1 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 28 Oct 2025 19:30:03 +0000 Subject: [PATCH 77/92] Updated test to use `null` value in `NumStates` property --- tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai index ae0bc93b..84da3814 100644 --- a/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai +++ b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai @@ -222,10 +222,9 @@ - 2 - 10 + 1 - 0.1 + 0.0001 true true true From 1b51ae487754390717680c54bf160f9c86c15835 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 29 Oct 2025 15:32:07 +0000 Subject: [PATCH 78/92] Removed unnecessary `PredictedState` struct and used `LinearDynamicalSystem` struct instead --- src/Bonsai.ML.Lds.Torch/KalmanFilter.cs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs index 55460b18..3056b15a 100644 --- a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs @@ -67,22 +67,14 @@ public KalmanFilter( RegisterComponents(); } - private readonly struct PredictedState( - Tensor predictedMean, - Tensor predictedCovariance) - { - public readonly Tensor PredictedMean = predictedMean; - public readonly Tensor PredictedCovariance = predictedCovariance; - } - - private PredictedState FilterPredict( + private LinearDynamicalSystemState FilterPredict( Tensor mean, Tensor covariance) => new(Parameters.TransitionMatrix.matmul(mean), Parameters.TransitionMatrix.matmul(covariance) .matmul(Parameters.TransitionMatrix.mT) + Parameters.ProcessNoiseCovariance); - private static PredictedState FilterPredict( + private static LinearDynamicalSystemState FilterPredict( Tensor mean, Tensor covariance, Tensor transitionMatrix, From 7e9bba8014991ac7f2ff67b091a576acfb8cbaf8 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 29 Oct 2025 15:32:58 +0000 Subject: [PATCH 79/92] Removed extra check when calling `Validate` on parameters --- src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs index 77e4c4fa..dafc14c8 100644 --- a/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs +++ b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs @@ -125,17 +125,13 @@ public IObservable Process(IObservable so initialMean: EstimateInitialMean, initialCovariance: EstimateInitialCovariance); - var parameters = ModelParameters?.Copy() ?? KalmanFilterParameters.Initialize( numStates: NumStates, numObservations: numObservations, scalarType: input.dtype, device: input.device); - if (!parameters.IsValidated) - { parameters.Validate(); - } for (int i = 0; i < MaxIterations; i++) { From 92b36c6156267043d725d1e19cb27052d14403bf Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 29 Oct 2025 15:33:53 +0000 Subject: [PATCH 80/92] Fixed issue with setting the incorrect number of iterations from the Bonsai operator --- src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs index dafc14c8..31cab28c 100644 --- a/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs +++ b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs @@ -145,11 +145,9 @@ public IObservable Process(IObservable so var result = KalmanFilter.ExpectationMaximization( observation: input, parameters: parameters, - maxIterations: MaxIterations, + maxIterations: 1, tolerance: Tolerance, - parametersToEstimate: parametersToEstimate, - device: input.device, - scalarType: input.dtype); + parametersToEstimate: parametersToEstimate); var logLikelihoodSum = result.LogLikelihood .cpu() From 7140980e953dc32d8ef23b92089a3200289e5e48 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 29 Oct 2025 15:34:29 +0000 Subject: [PATCH 81/92] Used tensorhelper method to extract float instead of explicitly moving to cpu --- src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs index 31cab28c..53032b01 100644 --- a/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs +++ b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs @@ -131,7 +131,7 @@ public IObservable Process(IObservable so scalarType: input.dtype, device: input.device); - parameters.Validate(); + parameters.Validate(); for (int i = 0; i < MaxIterations; i++) { @@ -150,9 +150,8 @@ public IObservable Process(IObservable so parametersToEstimate: parametersToEstimate); var logLikelihoodSum = result.LogLikelihood - .cpu() .to_type(ScalarType.Float32) - .ReadCpuSingle(0); + .item(); logLikelihood[i] = logLikelihoodSum; From f32b653fcbf6014e1f93ddf140a5b497d85a50b3 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 29 Oct 2025 15:35:55 +0000 Subject: [PATCH 82/92] Refactored `KalmanFilter` class to rely on device and scalar type provided by parameters --- src/Bonsai.ML.Lds.Torch/KalmanFilter.cs | 259 +++++++++++------------- 1 file changed, 115 insertions(+), 144 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs index 3056b15a..edac67e1 100644 --- a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Generic; using static TorchSharp.torch; namespace Bonsai.ML.Lds.Torch; @@ -13,24 +12,26 @@ public class KalmanFilter : nn.Module private readonly Tensor _covariance; private readonly Device _device; private readonly ScalarType _scalarType; - public KalmanFilterParameters Parameters { get; private set; } + public readonly KalmanFilterParameters Parameters; public KalmanFilter( KalmanFilterParameters parameters, Device device = null, - ScalarType scalarType = ScalarType.Float32) : base("KalmanFilter") + ScalarType? scalarType = null) : base("KalmanFilter") { - _device = device ?? CPU; - _scalarType = scalarType; + Parameters = parameters.Copy(); - if (!parameters.IsValidated) - { - KalmanFilterParameters.Validate(parameters); - } + Parameters.Validate(); + Parameters.ToScalarType(scalarType); + Parameters.ToDevice(device); + + _device = Parameters.Device; + _scalarType = Parameters.ScalarType; - Parameters = parameters; _mean = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); _covariance = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); + + RegisterComponents(); } public KalmanFilter( @@ -43,11 +44,8 @@ public KalmanFilter( Tensor processNoiseVariance = null, Tensor measurementNoiseVariance = null, Device device = null, - ScalarType scalarType = ScalarType.Float32) : base("KalmanFilter") + ScalarType? scalarType = null) : base("KalmanFilter") { - _device = device ?? CPU; - _scalarType = scalarType; - Parameters = KalmanFilterParameters.Initialize( numStates, numObservations, @@ -57,10 +55,13 @@ public KalmanFilter( measurementNoiseVariance, initialMean, initialCovariance, - _device, - _scalarType + device, + scalarType ); + _device = Parameters.Device; + _scalarType = Parameters.ScalarType; + _mean = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); _covariance = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); @@ -116,7 +117,7 @@ private UpdatedState FilterUpdate( // Update step var updatedMean = predictedMean + kalmanGain.matmul(innovation); var updatedCovariance = WrappedTensorDisposeScope(() => predictedCovariance - - kalmanGain.matmul(Parameters.MeasurementFunction).matmul(predictedCovariance)); + - kalmanGain.matmul(Parameters.MeasurementFunction).matmul(predictedCovariance)); return new UpdatedState(updatedMean, updatedCovariance, innovation, innovationCovariance, kalmanGain); } @@ -166,9 +167,9 @@ public FilteredState Filter(Tensor observation) var updatedCovariance = empty([timeBins, Parameters.NumStates, Parameters.NumStates], dtype: _scalarType, device: _device); if (_mean.NumberOfElements == 0) - _mean.set_(Parameters.InitialMean); + _mean.set_(Parameters.InitialMean.clone()); if (_covariance.NumberOfElements == 0) - _covariance.set_(Parameters.InitialCovariance); + _covariance.set_(Parameters.InitialCovariance.clone()); for (long time = 0; time < timeBins; time++) { @@ -176,10 +177,10 @@ public FilteredState Filter(Tensor observation) var prediction = FilterPredict(_mean, _covariance); // Update - var update = FilterUpdate(prediction.PredictedMean, prediction.PredictedCovariance, obs[time]); + var update = FilterUpdate(prediction.Mean, prediction.Covariance, obs[time]); - predictedMean[time] = prediction.PredictedMean; - predictedCovariance[time] = prediction.PredictedCovariance; + predictedMean[time] = prediction.Mean; + predictedCovariance[time] = prediction.Covariance; updatedMean[time] = update.UpdatedMean; updatedCovariance[time] = update.UpdatedCovariance; @@ -251,8 +252,8 @@ private static FilteredStateWithAuxiliaryVariables Filter( // Update var update = FilterUpdate( - predictedMean: prediction.PredictedMean, - predictedCovariance: prediction.PredictedCovariance, + predictedMean: prediction.Mean, + predictedCovariance: prediction.Covariance, observation: observation[time], measurementFunction: measurementFunction, measurementNoiseCovariance: measurementNoiseCovariance); @@ -264,8 +265,8 @@ private static FilteredStateWithAuxiliaryVariables Filter( // Detach and assign logLikelihood[time] = logLikelihoodData; - predictedMean[time] = prediction.PredictedMean; - predictedCovariance[time] = prediction.PredictedCovariance; + predictedMean[time] = prediction.Mean; + predictedCovariance[time] = prediction.Covariance; updatedMean[time] = update.UpdatedMean; updatedCovariance[time] = update.UpdatedCovariance; innovation[time] = update.Innovation; @@ -472,144 +473,114 @@ Device device public static ExpectationMaximizationResult ExpectationMaximization( Tensor observation, - KalmanFilterParameters? parameters = null, + KalmanFilterParameters parameters, int maxIterations = 100, double tolerance = 1e-4, - ParametersToEstimate parametersToEstimate = new(), - Device device = null, - ScalarType scalarType = ScalarType.Float32) + ParametersToEstimate parametersToEstimate = new()) { - device ??= CPU; + parameters = parameters.Copy(); + parameters.Validate(); var timeBins = observation.size(0); var numObservations = (int)observation.size(1); - var logLikelihood = empty(maxIterations, dtype: ScalarType.Float32, device: device); + var logLikelihood = empty(maxIterations, dtype: ScalarType.Float32, device: parameters.Device); var previousLogLikelihood = double.NegativeInfinity; var logLikelihoodConst = -0.5 * timeBins * numObservations * Math.Log(2.0 * Math.PI); - var kalmanFilterParameters = parameters?.Copy() - ?? KalmanFilterParameters.Initialize( - numStates: null, - numObservations: numObservations, - scalarType: scalarType, - device: device); - - if (!kalmanFilterParameters.IsValidated) - { - KalmanFilterParameters.Validate(kalmanFilterParameters); - } - - var numStates = kalmanFilterParameters.NumStates; - var transitionMatrix = kalmanFilterParameters.TransitionMatrix; - var measurementFunction = kalmanFilterParameters.MeasurementFunction; - var processNoiseCovariance = kalmanFilterParameters.ProcessNoiseCovariance; - var measurementNoiseCovariance = kalmanFilterParameters.MeasurementNoiseCovariance; - var initialMean = kalmanFilterParameters.InitialMean; - var initialCovariance = kalmanFilterParameters.InitialCovariance; + if (parameters.NumObservations != numObservations) + throw new ArgumentException($"The number of observation dimensions in the parameters ({parameters.NumObservations}) does not match the observations ({numObservations})."); - var identityStates = eye(numStates, dtype: scalarType, device: device); + var identityStates = eye(parameters.NumStates, dtype: parameters.ScalarType, device: parameters.Device); // Precompute constant observation terms reused across EM iterations var observationT = observation.mT; var autoCorrelationObservations = observationT.matmul(observation); - using (var _ = no_grad()) + using var g = no_grad(); + + for (int iteration = 0; iteration < maxIterations; iteration++) { - for (int iteration = 0; iteration < maxIterations; iteration++) + // Filter observations + var filteredState = Filter( + observation: observation, + timeBins: timeBins, + numStates: parameters.NumStates, + numObservations: numObservations, + transitionMatrix: parameters.TransitionMatrix, + measurementFunction: parameters.MeasurementFunction, + processNoiseCovariance: parameters.ProcessNoiseCovariance, + measurementNoiseCovariance: parameters.MeasurementNoiseCovariance, + initialMean: parameters.InitialMean, + initialCovariance: parameters.InitialCovariance, + scalarType: parameters.ScalarType, + device: parameters.Device); + + // Compute log likelihood (avoid creating intermediate tensors) + var llSumDouble = filteredState.LogLikelihood.sum() + .to_type(ScalarType.Float64).item(); + var filteredLogLikelihoodSum = logLikelihoodConst + 0.5 * llSumDouble; + + logLikelihood[iteration] = filteredLogLikelihoodSum; + + // Check for convergence + if (filteredLogLikelihoodSum <= previousLogLikelihood) { - // Filter observations - var filteredState = Filter( - observation: observation, - timeBins: timeBins, - numStates: numStates, - numObservations: numObservations, - transitionMatrix: transitionMatrix, - measurementFunction: measurementFunction, - processNoiseCovariance: processNoiseCovariance, - measurementNoiseCovariance: measurementNoiseCovariance, - initialMean: initialMean, - initialCovariance: initialCovariance, - scalarType: scalarType, - device: device); - - // Compute log likelihood (avoid creating intermediate tensors) - var llSumDouble = filteredState.LogLikelihood.sum() - .to_type(ScalarType.Float64).item(); - var filteredLogLikelihoodSum = logLikelihoodConst + 0.5 * llSumDouble; - - logLikelihood[iteration] = filteredLogLikelihoodSum; - - // Check for convergence - if (filteredLogLikelihoodSum <= previousLogLikelihood) - { - Console.WriteLine($"Warning: Log likelihood decreased! New: {filteredLogLikelihoodSum}, Previous: {previousLogLikelihood}"); - break; - } - - if (filteredLogLikelihoodSum - previousLogLikelihood < tolerance) - break; - - previousLogLikelihood = filteredLogLikelihoodSum; - - // Smooth the filtered results - var smoothedState = Smooth( - filteredState: filteredState, - timeBins: timeBins, - numStates: numStates, - transitionMatrix: transitionMatrix, - measurementFunction: measurementFunction, - initialMean: initialMean, - initialCovariance: initialCovariance, - identityStates: identityStates, - scalarType: scalarType, - device: device); - - // Sufficient statistics - var S00 = smoothedState.S00.sum([0]); - var S11 = smoothedState.S11.sum([0]); - var S10 = smoothedState.S10.sum([0]); - - // Replace einsum with faster matmul - var crossCorrelationObservations = observationT.matmul(smoothedState.SmoothedMean); - - // Update parameters - if (parametersToEstimate.TransitionMatrix) - transitionMatrix = InverseCholesky(S10, S00); - - if (parametersToEstimate.MeasurementFunction) - measurementFunction = InverseCholesky(crossCorrelationObservations, S11); - - if (parametersToEstimate.ProcessNoiseCovariance) - processNoiseCovariance = WrappedTensorDisposeScope(() => - EnsureSymmetric((S11 - transitionMatrix.matmul(S10.mT)) / timeBins)); - - var explainedObservationCovariance = measurementFunction.matmul(crossCorrelationObservations.mT); - - if (parametersToEstimate.MeasurementNoiseCovariance) - measurementNoiseCovariance = WrappedTensorDisposeScope(() => - EnsureSymmetric((autoCorrelationObservations - explainedObservationCovariance - explainedObservationCovariance.mT - + measurementFunction.matmul(S11).matmul(measurementFunction.mT)) / timeBins)); - - if (parametersToEstimate.InitialMean) - initialMean = smoothedState.SmoothedInitialMean; - - if (parametersToEstimate.InitialCovariance) - initialCovariance = smoothedState.SmoothedInitialCovariance; + Console.WriteLine($"Warning: Log likelihood decreased! New: {filteredLogLikelihoodSum}, Previous: {previousLogLikelihood}"); + break; } - } - var updatedParameters = new KalmanFilterParameters( - numStates: numStates, - numObservations: numObservations, - transitionMatrix: parametersToEstimate.TransitionMatrix ? transitionMatrix : null, - measurementFunction: parametersToEstimate.MeasurementFunction ? measurementFunction : null, - processNoiseCovariance: parametersToEstimate.ProcessNoiseCovariance ? processNoiseCovariance : null, - measurementNoiseCovariance: parametersToEstimate.MeasurementNoiseCovariance ? measurementNoiseCovariance : null, - initialMean: parametersToEstimate.InitialMean ? initialMean : null, - initialCovariance: parametersToEstimate.InitialCovariance ? initialCovariance : null - ); + if (filteredLogLikelihoodSum - previousLogLikelihood < tolerance) + break; + + previousLogLikelihood = filteredLogLikelihoodSum; + + // Smooth the filtered results + var smoothedState = Smooth( + filteredState: filteredState, + timeBins: timeBins, + numStates: parameters.NumStates, + transitionMatrix: parameters.TransitionMatrix, + measurementFunction: parameters.MeasurementFunction, + initialMean: parameters.InitialMean, + initialCovariance: parameters.InitialCovariance, + identityStates: identityStates, + scalarType: parameters.ScalarType, + device: parameters.Device); + + // Sufficient statistics + var S00 = smoothedState.S00.sum([0]); + var S11 = smoothedState.S11.sum([0]); + var S10 = smoothedState.S10.sum([0]); + + // Replace einsum with faster matmul + var crossCorrelationObservations = observationT.matmul(smoothedState.SmoothedMean); + + // Update parameters + if (parametersToEstimate.TransitionMatrix) + parameters.TransitionMatrix = InverseCholesky(S10, S00); + + if (parametersToEstimate.MeasurementFunction) + parameters.MeasurementFunction = InverseCholesky(crossCorrelationObservations, S11); + + if (parametersToEstimate.ProcessNoiseCovariance) + parameters.ProcessNoiseCovariance = WrappedTensorDisposeScope(() => + EnsureSymmetric((S11 - parameters.TransitionMatrix.matmul(S10.mT)) / timeBins)); + + var explainedObservationCovariance = parameters.MeasurementFunction.matmul(crossCorrelationObservations.mT); + + if (parametersToEstimate.MeasurementNoiseCovariance) + parameters.MeasurementNoiseCovariance = WrappedTensorDisposeScope(() => + EnsureSymmetric((autoCorrelationObservations - explainedObservationCovariance - explainedObservationCovariance.mT + + parameters.MeasurementFunction.matmul(S11).matmul(parameters.MeasurementFunction.mT)) / timeBins)); + + if (parametersToEstimate.InitialMean) + parameters.InitialMean = smoothedState.SmoothedInitialMean; + + if (parametersToEstimate.InitialCovariance) + parameters.InitialCovariance = smoothedState.SmoothedInitialCovariance; + } - return new ExpectationMaximizationResult(logLikelihood, updatedParameters); + return new ExpectationMaximizationResult(logLikelihood, parameters); } public static StochasticSubspaceIdentificationResult StochasticSubspaceIdentification( From 92a21754b472ab74e38688e651ec4c835769f592 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 29 Oct 2025 15:40:13 +0000 Subject: [PATCH 83/92] Refactored `KalmanFilterParameters` to manage operations on tensors, including moving devices, setting scalar types, and setting gradient tracking --- .../KalmanFilterParameters.cs | 151 ++++++++++++++---- 1 file changed, 118 insertions(+), 33 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs index 3acda77b..c6049332 100644 --- a/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs @@ -17,7 +17,10 @@ namespace Bonsai.ML.Lds.Torch; /// /// /// -/// +/// +/// +/// +/// public struct KalmanFilterParameters( int numStates, int numObservations, @@ -27,7 +30,10 @@ public struct KalmanFilterParameters( Tensor measurementNoiseCovariance = null, Tensor initialMean = null, Tensor initialCovariance = null, - bool isValidated = false) + Device device = null, + ScalarType? scalarType = null, + bool requiresGrad = false, + bool validated = false) { /// /// The number of states in the system. @@ -69,13 +75,28 @@ public struct KalmanFilterParameters( /// public Tensor InitialCovariance = initialCovariance; + /// + /// The device to use for tensor operations. + /// + public Device Device = device ?? CPU; + + /// + /// The data type of the tensors. + /// + public ScalarType ScalarType = scalarType ?? ScalarType.Float32; + + /// + /// Indicates whether the tensors require gradient computation. + /// + public bool RequiresGrad = requiresGrad; + /// /// Indicates whether the parameters have been validated. /// /// /// This field is used to avoid redundant validation checks. /// - public bool IsValidated = isValidated; + public bool Validated = validated; /// /// Initializes the Kalman filter parameters. @@ -90,6 +111,7 @@ public struct KalmanFilterParameters( /// /// /// + /// /// /// public static KalmanFilterParameters Initialize( @@ -102,12 +124,14 @@ public static KalmanFilterParameters Initialize( Tensor initialMean = null, Tensor initialCovariance = null, Device device = null, - ScalarType scalarType = ScalarType.Float32 + ScalarType? scalarType = null, + bool requiresGrad = false ) { var trueNumStates = numStates ?? -1; var trueNumObservations = numObservations ?? -1; device ??= CPU; + scalarType ??= ScalarType.Float32; if (numStates is null) { @@ -125,23 +149,20 @@ public static KalmanFilterParameters Initialize( if (trueNumObservations <= 0) throw new ArgumentOutOfRangeException(nameof(numObservations), "Number of observations must be greater than zero."); - transitionMatrix = transitionMatrix?.clone().to_type(scalarType).to(device).requires_grad_(false) ?? eye(trueNumStates, dtype: scalarType, device: device).requires_grad_(false); - - measurementFunction = measurementFunction?.clone().to_type(scalarType).to(device).requires_grad_(false) ?? eye(trueNumObservations, trueNumStates, dtype: scalarType, device: device).requires_grad_(false); - - initialMean = initialMean?.clone().to_type(scalarType).to(device).requires_grad_(false) ?? zeros(trueNumStates, dtype: scalarType, device: device).requires_grad_(false); - - initialCovariance = initialCovariance?.clone().to_type(scalarType).to(device).requires_grad_(false) ?? eye(trueNumStates, dtype: scalarType, device: device).requires_grad_(false); + transitionMatrix = transitionMatrix?.clone() ?? eye(trueNumStates); + measurementFunction = measurementFunction?.clone() ?? eye(trueNumObservations, trueNumStates); + initialMean = initialMean?.clone() ?? zeros(trueNumStates); + initialCovariance = initialCovariance?.clone() ?? eye(trueNumStates); processNoiseCovariance = processNoiseCovariance?.NumberOfElements == 1 - ? CreateCovarianceMatrix(processNoiseCovariance, scalarType, device, trueNumStates, "Process noise variance") - : processNoiseCovariance?.clone().to_type(scalarType).to(device).requires_grad_(false) - ?? CreateCovarianceMatrix(tensor(1.0), scalarType, device, trueNumStates, "Process noise variance"); + ? CreateCovarianceMatrixFromScalar(processNoiseCovariance, trueNumStates, "Process noise variance") + : processNoiseCovariance?.clone() + ?? CreateCovarianceMatrixFromScalar(1.0, trueNumStates, "Process noise variance"); measurementNoiseCovariance = measurementNoiseCovariance?.NumberOfElements == 1 - ? CreateCovarianceMatrix(measurementNoiseCovariance, scalarType, device, trueNumObservations, "Measurement noise variance") - : measurementNoiseCovariance?.clone().to_type(scalarType).to(device).requires_grad_(false) - ?? CreateCovarianceMatrix(tensor(1.0), scalarType, device, trueNumObservations, "Measurement noise variance"); + ? CreateCovarianceMatrixFromScalar(measurementNoiseCovariance, trueNumObservations, "Measurement noise variance") + : measurementNoiseCovariance?.clone() + ?? CreateCovarianceMatrixFromScalar(1.0, trueNumObservations, "Measurement noise variance"); var parameters = new KalmanFilterParameters( trueNumStates, @@ -151,10 +172,16 @@ public static KalmanFilterParameters Initialize( processNoiseCovariance, measurementNoiseCovariance, initialMean, - initialCovariance + initialCovariance, + device, + scalarType, + requiresGrad ); parameters.Validate(); + parameters.ToScalarType(scalarType); + parameters.ToDevice(device); + parameters.SetGrad(requiresGrad); return parameters; } @@ -164,7 +191,7 @@ public static KalmanFilterParameters Initialize( /// public void Validate() { - if (IsValidated) + if (Validated) return; ValidateNumStates(TransitionMatrix, MeasurementFunction, InitialMean, InitialCovariance, ProcessNoiseCovariance, out NumStates); @@ -175,16 +202,7 @@ public void Validate() ValidateMatrix(MeasurementNoiseCovariance, "Measurement noise covariance", isSquare: true, expectedDimension1: NumObservations); ValidateVector(InitialMean, "Initial mean", NumStates); ValidateMatrix(InitialCovariance, "Initial covariance", isSquare: true, expectedDimension1: NumStates); - IsValidated = true; - } - - /// - /// Validates the specified Kalman filter parameters. - /// - /// - public static void Validate(KalmanFilterParameters parameters) - { - parameters.Validate(); + Validated = true; } /// @@ -200,9 +218,64 @@ public static void Validate(KalmanFilterParameters parameters) MeasurementNoiseCovariance?.clone(), InitialMean?.clone(), InitialCovariance?.clone(), - IsValidated + Device, + ScalarType, + RequiresGrad, + Validated ); + /// + /// Converts the tensors in the Kalman filter parameters to the specified scalar type. + /// + /// + public void ToScalarType(ScalarType? scalarType) + { + if (scalarType is not null) + { + TransitionMatrix = TransitionMatrix?.to_type(scalarType.Value); + MeasurementFunction = MeasurementFunction?.to_type(scalarType.Value); + ProcessNoiseCovariance = ProcessNoiseCovariance?.to_type(scalarType.Value); + MeasurementNoiseCovariance = MeasurementNoiseCovariance?.to_type(scalarType.Value); + InitialMean = InitialMean?.to_type(scalarType.Value); + InitialCovariance = InitialCovariance?.to_type(scalarType.Value); + } + } + + /// + /// Moves the tensors in the Kalman filter parameters to the specified device. + /// + /// + public void ToDevice(Device? device) + { + if (device is not null) + { + TransitionMatrix = TransitionMatrix?.to(device); + MeasurementFunction = MeasurementFunction?.to(device); + ProcessNoiseCovariance = ProcessNoiseCovariance?.to(device); + MeasurementNoiseCovariance = MeasurementNoiseCovariance?.to(device); + InitialMean = InitialMean?.to(device); + InitialCovariance = InitialCovariance?.to(device); + } + } + + /// + /// Sets the requires_grad flag for all tensors in the Kalman filter parameters. + /// + /// + public void SetGrad(bool requiresGrad) + { + if (RequiresGrad == requiresGrad) + return; + + TransitionMatrix = TransitionMatrix?.requires_grad_(requiresGrad); + MeasurementFunction = MeasurementFunction?.requires_grad_(requiresGrad); + ProcessNoiseCovariance = ProcessNoiseCovariance?.requires_grad_(requiresGrad); + MeasurementNoiseCovariance = MeasurementNoiseCovariance?.requires_grad_(requiresGrad); + InitialMean = InitialMean?.requires_grad_(requiresGrad); + InitialCovariance = InitialCovariance?.requires_grad_(requiresGrad); + RequiresGrad = requiresGrad; + } + private static void ValidateNumStates(Tensor transitionMatrix, Tensor measurementFunction, Tensor initialMean, Tensor initialCovariance, Tensor processNoiseCovariance, out int numStates) { if (transitionMatrix is not null) @@ -290,10 +363,22 @@ private static void ValidateScalar(Tensor scalar, string name) throw new ArgumentException($"{name} must be a scalar."); } - private static Tensor CreateCovarianceMatrix(Tensor variance, ScalarType scalarType, Device device, int dimension, string name) + private static Tensor CreateCovarianceMatrixFromScalar(Tensor variance, int dimension, string name) { ValidateScalar(variance, name); - var scalar = variance.clone().squeeze().to_type(scalarType).to(device); - return (scalar * eye(dimension, dtype: scalarType, device: device)).requires_grad_(false); + var scalar = variance.clone().squeeze(); + return scalar * eye(dimension); } + + /// + public override readonly string ToString() => + $"KalmanFilterParameters(NumStates={NumStates}, NumObservations={NumObservations}, TransitionMatrix={TransitionMatrix}, MeasurementFunction={MeasurementFunction}, ProcessNoiseCovariance={ProcessNoiseCovariance}, MeasurementNoiseCovariance={MeasurementNoiseCovariance}, InitialMean={InitialMean}, InitialCovariance={InitialCovariance})"; + + /// + /// Returns a string representation of the Kalman filter parameters with the specified tensor string style. + /// + /// + /// + public readonly string ToString(TorchSharp.TensorStringStyle tensorStringStyle) => + $"KalmanFilterParameters(NumStates={NumStates}, NumObservations={NumObservations}, TransitionMatrix={TransitionMatrix.ToString(tensorStringStyle)}, MeasurementFunction={MeasurementFunction.ToString(tensorStringStyle)}, ProcessNoiseCovariance={ProcessNoiseCovariance.ToString(tensorStringStyle)}, MeasurementNoiseCovariance={MeasurementNoiseCovariance.ToString(tensorStringStyle)}, InitialMean={InitialMean.ToString(tensorStringStyle)}, InitialCovariance={InitialCovariance.ToString(tensorStringStyle)})"; } \ No newline at end of file From 6898a64497f3454c099264556fc56a9ff74f033f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 29 Oct 2025 19:15:18 +0000 Subject: [PATCH 84/92] Refactored `CreateKalmanFilter` operator to correctly pass in `Type` property during model creation --- src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs index 5e20376d..89eabfd6 100644 --- a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs +++ b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs @@ -217,7 +217,7 @@ public IObservable Process() initialMean: InitialMean, initialCovariance: InitialCovariance, device: Device, - scalarType: _scalarType + scalarType: Type )); } @@ -231,7 +231,7 @@ public IObservable Process(IObservable sou return Observable.Return(new KalmanFilter( parameters: parameters, device: Device, - scalarType: _scalarType + scalarType: Type )); }); } From 40c1f52bd284bbae33a6795dca9c9faee2ada406 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 29 Oct 2025 19:15:43 +0000 Subject: [PATCH 85/92] Updated to allow specifying `Device` property when loading parameters --- .../CreateKalmanFilterParameters.cs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs index 4fed8245..30b29b9d 100644 --- a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs @@ -31,6 +31,13 @@ public ScalarType Type } private ScalarType _scalarType = ScalarType.Float32; + /// + /// The device on which to create the tensor. + /// + [XmlIgnore] + [Description("The device on which to create the tensor.")] + public Device Device { get; set; } + /// /// The number of states in the Kalman filter model. /// @@ -209,7 +216,8 @@ public IObservable Process() measurementNoiseCovariance: _measurementNoiseCovariance, initialMean: _initialMean, initialCovariance: _initialCovariance, - scalarType: _scalarType + scalarType: _scalarType, + device: Device )); } @@ -229,7 +237,8 @@ public IObservable Process(IObservable source) measurementNoiseCovariance: _measurementNoiseCovariance, initialMean: _initialMean, initialCovariance: _initialCovariance, - scalarType: _scalarType + scalarType: _scalarType, + device: Device ); }); } From 7407cdfec12c4d0a00b045a87d4194a239d6a6c9 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 29 Oct 2025 19:16:09 +0000 Subject: [PATCH 86/92] Removed device and scalartype overrides in `Initialize` method --- src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs index c6049332..e6b02ee5 100644 --- a/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs @@ -130,8 +130,6 @@ public static KalmanFilterParameters Initialize( { var trueNumStates = numStates ?? -1; var trueNumObservations = numObservations ?? -1; - device ??= CPU; - scalarType ??= ScalarType.Float32; if (numStates is null) { From a12b8d58d25f25bd3311cda100b492bbb5c97d75 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Wed, 29 Oct 2025 19:17:45 +0000 Subject: [PATCH 87/92] Refactored `LoadKalmanFilterParameters` operator to allow tensors to be loaded without explicitly setting type and fixed path support --- .../LoadKalmanFilterParameters.cs | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs index b17b1691..9d1ea40d 100644 --- a/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs @@ -29,23 +29,26 @@ public class LoadKalmanFilterParameters /// Gets or sets the data type of the tensors. /// [Description("Gets or sets the data type of the tensors.")] - public ScalarType Type { get; set; } = ScalarType.Float32; + public ScalarType? Type { get; set; } = null; /// /// Gets or sets the device to use for tensor operations. /// [XmlIgnore] [Description("Gets or sets the device to use for tensor operations.")] - public Device Device { get; set; } + public Device Device { get; set; } = null; - private Tensor LoadTensorFromFile(string filePath) + private static Tensor LoadTensorFromFile(string basePath, string filePath) { if (filePath == null) return null; + + filePath = System.IO.Path.Combine(basePath, filePath); + if (!File.Exists(filePath)) { throw new FileNotFoundException($"The specified file was not found: {filePath}"); } - return Tensor.Load(filePath)?.to(Device).to_type(Type); + return Tensor.Load(filePath); } /// @@ -53,8 +56,6 @@ private Tensor LoadTensorFromFile(string filePath) /// public IObservable Process() { - Device ??= CPU; - if (string.IsNullOrEmpty(Path)) { throw new InvalidOperationException("The save path is not specified."); @@ -65,14 +66,14 @@ public IObservable Process() throw new InvalidOperationException("The save path does not exist."); } - var transitionMatrix = LoadTensorFromFile("TransitionMatrix.bin"); - var measurementFunction = LoadTensorFromFile("MeasurementFunction.bin"); - var processNoiseCovariance = LoadTensorFromFile("ProcessNoiseCovariance.bin"); - var measurementNoiseCovariance = LoadTensorFromFile("MeasurementNoiseCovariance.bin"); - var initialMean = LoadTensorFromFile("InitialMean.bin"); - var initialCovariance = LoadTensorFromFile("InitialCovariance.bin"); + var transitionMatrix = LoadTensorFromFile(Path, "TransitionMatrix.bin"); + var measurementFunction = LoadTensorFromFile(Path, "MeasurementFunction.bin"); + var processNoiseCovariance = LoadTensorFromFile(Path, "ProcessNoiseCovariance.bin"); + var measurementNoiseCovariance = LoadTensorFromFile(Path, "MeasurementNoiseCovariance.bin"); + var initialMean = LoadTensorFromFile(Path, "InitialMean.bin"); + var initialCovariance = LoadTensorFromFile(Path, "InitialCovariance.bin"); - return Observable.Return(KalmanFilterParameters.Initialize( + var parameters = KalmanFilterParameters.Initialize( transitionMatrix: transitionMatrix, measurementFunction: measurementFunction, processNoiseCovariance: processNoiseCovariance, @@ -80,6 +81,8 @@ public IObservable Process() initialMean: initialMean, initialCovariance: initialCovariance, device: Device, - scalarType: Type)); + scalarType: Type); + + return Observable.Return(parameters); } } \ No newline at end of file From 36c632863dae904ff96214dd7ae26edc0dd605f3 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 31 Oct 2025 18:41:12 +0000 Subject: [PATCH 88/92] Updated filtering step to support missing nan values --- src/Bonsai.ML.Lds.Torch/FilteredState.cs | 4 ++-- src/Bonsai.ML.Lds.Torch/KalmanFilter.cs | 24 ++++++++++++++++++++---- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch/FilteredState.cs b/src/Bonsai.ML.Lds.Torch/FilteredState.cs index b2a88e44..cb128b81 100644 --- a/src/Bonsai.ML.Lds.Torch/FilteredState.cs +++ b/src/Bonsai.ML.Lds.Torch/FilteredState.cs @@ -36,8 +36,8 @@ public struct FilteredState( public Tensor UpdatedCovariance = updatedCovariance; /// - public readonly Tensor Mean => UpdatedMean; + public readonly Tensor Mean => UpdatedMean.isnan().any().item() ? PredictedMean : UpdatedMean; /// - public readonly Tensor Covariance => UpdatedCovariance; + public readonly Tensor Covariance => UpdatedCovariance.isnan().any().item() ? PredictedCovariance : UpdatedCovariance; } \ No newline at end of file diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs index edac67e1..c12a7646 100644 --- a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs @@ -184,8 +184,16 @@ public FilteredState Filter(Tensor observation) updatedMean[time] = update.UpdatedMean; updatedCovariance[time] = update.UpdatedCovariance; - _mean.set_(update.UpdatedMean); - _covariance.set_(update.UpdatedCovariance); + if (!update.UpdatedMean.isnan().any().item()) + { + _mean.set_(update.UpdatedMean); + _covariance.set_(update.UpdatedCovariance); + } + else + { + _mean.set_(prediction.Mean); + _covariance.set_(prediction.Covariance); + } } return new FilteredState( @@ -273,8 +281,16 @@ private static FilteredStateWithAuxiliaryVariables Filter( innovationCovariance[time] = update.InnovationCovariance; kalmanGain[time] = update.KalmanGain; - mean = update.UpdatedMean; - covariance = update.UpdatedCovariance; + if (!update.UpdatedMean.isnan().any().item()) + { + mean = update.UpdatedMean; + covariance = update.UpdatedCovariance; + } + else + { + mean = prediction.Mean; + covariance = prediction.Covariance; + } } return new FilteredStateWithAuxiliaryVariables( From de2278aac93a1c9167d802d8ecda9349e74e5bad Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 10 Nov 2025 14:48:24 +0000 Subject: [PATCH 89/92] Updated neural latents test to use new load/save tensor method --- .../NeuralLatentsTest.bonsai | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai index 84da3814..af3f1d26 100644 --- a/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai +++ b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.bonsai @@ -14,6 +14,7 @@ transformed_binned_spikes.pt + true @@ -34,6 +35,7 @@ python_V0_0.pt + true @@ -46,6 +48,7 @@ python_m0_0.pt + true @@ -58,6 +61,7 @@ python_Z0.pt + true @@ -70,6 +74,7 @@ python_R0.pt + true @@ -82,6 +87,7 @@ python_B0.pt + true @@ -94,6 +100,7 @@ python_Q0.pt + true @@ -161,6 +168,8 @@ Float64 + + @@ -343,6 +352,7 @@ bonsai_means.pt + true @@ -354,6 +364,7 @@ bonsai_covs.pt + true From 50cc8c85cb75882bf3c70c1638fccac1ae082a36 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 13 Nov 2025 10:22:32 +0000 Subject: [PATCH 90/92] Updated to use `TensorOperatorConverter` class --- src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs | 76 +++++++------------ .../CreateKalmanFilterParameters.cs | 74 +++++++----------- .../CreateLinearDynamicalSystemState.cs | 43 ++++------- .../ExpectationMaximization.cs | 7 +- .../StochasticSubspaceIdentification.cs | 7 +- 5 files changed, 81 insertions(+), 126 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs index 89eabfd6..fa503a5c 100644 --- a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs +++ b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs @@ -14,21 +14,13 @@ namespace Bonsai.ML.Lds.Torch; [ResetCombinator] [Description("Creates a Kalman filter model.")] [WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] public class CreateKalmanFilter : IScalarTypeProvider { - private ScalarType _scalarType = ScalarType.Float32; /// [Description("The data type of the tensor elements.")] [TypeConverter(typeof(ScalarTypeConverter))] - public ScalarType Type - { - get => _scalarType; - set - { - _scalarType = value; - ConvertTensorsScalarType(value); - } - } + public ScalarType Type { get; set; } = ScalarType.Float32; /// /// The device on which to create the tensor. @@ -47,8 +39,6 @@ public ScalarType Type /// public int? NumObservations { get; set; } = null; - // Tensor properties with XML serialization support - private Tensor _transitionMatrix; /// /// The state transition matrix. /// @@ -57,7 +47,7 @@ public ScalarType Type public Tensor TransitionMatrix { get => _transitionMatrix; - set => _transitionMatrix = value?.to_type(Type); + set => _transitionMatrix = value; } /// @@ -68,11 +58,10 @@ public Tensor TransitionMatrix [EditorBrowsable(EditorBrowsableState.Never)] public string TransitionMatrixXml { - get => TensorConverter.ConvertToString(TransitionMatrix, _scalarType); - set => TransitionMatrix = TensorConverter.ConvertFromString(value, _scalarType); + get => TensorConverter.ConvertToString(_transitionMatrix, Type); + set => _transitionMatrix = TensorConverter.ConvertFromString(value, Type); } - private Tensor _measurementFunction; /// /// The measurement function. /// @@ -81,7 +70,7 @@ public string TransitionMatrixXml public Tensor MeasurementFunction { get => _measurementFunction; - set => _measurementFunction = value?.to_type(Type); + set => _measurementFunction = value; } /// @@ -92,11 +81,10 @@ public Tensor MeasurementFunction [EditorBrowsable(EditorBrowsableState.Never)] public string MeasurementFunctionXml { - get => TensorConverter.ConvertToString(MeasurementFunction, _scalarType); - set => MeasurementFunction = TensorConverter.ConvertFromString(value, _scalarType); + get => TensorConverter.ConvertToString(_measurementFunction, Type); + set => _measurementFunction = TensorConverter.ConvertFromString(value, Type); } - private Tensor _processNoiseVariance; /// /// The process noise variance. /// @@ -105,7 +93,7 @@ public string MeasurementFunctionXml public Tensor ProcessNoiseVariance { get => _processNoiseVariance; - set => _processNoiseVariance = value?.to_type(Type); + set => _processNoiseVariance = value; } /// @@ -116,11 +104,10 @@ public Tensor ProcessNoiseVariance [EditorBrowsable(EditorBrowsableState.Never)] public string ProcessNoiseVarianceXml { - get => TensorConverter.ConvertToString(ProcessNoiseVariance, _scalarType); - set => ProcessNoiseVariance = TensorConverter.ConvertFromString(value, _scalarType); + get => TensorConverter.ConvertToString(_processNoiseVariance, Type); + set => _processNoiseVariance = TensorConverter.ConvertFromString(value, Type); } - private Tensor _measurementNoiseVariance; /// /// The measurement noise variance. /// @@ -129,7 +116,7 @@ public string ProcessNoiseVarianceXml public Tensor MeasurementNoiseVariance { get => _measurementNoiseVariance; - set => _measurementNoiseVariance = value?.to_type(Type); + set => _measurementNoiseVariance = value; } /// @@ -140,11 +127,10 @@ public Tensor MeasurementNoiseVariance [EditorBrowsable(EditorBrowsableState.Never)] public string MeasurementNoiseVarianceXml { - get => TensorConverter.ConvertToString(MeasurementNoiseVariance, _scalarType); - set => MeasurementNoiseVariance = TensorConverter.ConvertFromString(value, _scalarType); + get => TensorConverter.ConvertToString(_measurementNoiseVariance, Type); + set => _measurementNoiseVariance = TensorConverter.ConvertFromString(value, Type); } - private Tensor _initialMean; /// /// The initial mean. /// @@ -153,7 +139,7 @@ public string MeasurementNoiseVarianceXml public Tensor InitialMean { get => _initialMean; - set => _initialMean = value?.to_type(Type); + set => _initialMean = value; } /// @@ -164,11 +150,10 @@ public Tensor InitialMean [EditorBrowsable(EditorBrowsableState.Never)] public string InitialMeanXml { - get => TensorConverter.ConvertToString(InitialMean, _scalarType); - set => InitialMean = TensorConverter.ConvertFromString(value, _scalarType); + get => TensorConverter.ConvertToString(_initialMean, Type); + set => _initialMean = TensorConverter.ConvertFromString(value, Type); } - private Tensor _initialCovariance; /// /// The initial covariance. /// @@ -177,7 +162,7 @@ public string InitialMeanXml public Tensor InitialCovariance { get => _initialCovariance; - set => _initialCovariance = value?.to_type(Type); + set => _initialCovariance = value; } /// @@ -188,19 +173,16 @@ public Tensor InitialCovariance [EditorBrowsable(EditorBrowsableState.Never)] public string InitialCovarianceXml { - get => TensorConverter.ConvertToString(InitialCovariance, _scalarType); - set => InitialCovariance = TensorConverter.ConvertFromString(value, _scalarType); + get => TensorConverter.ConvertToString(_initialCovariance, Type); + set => _initialCovariance = TensorConverter.ConvertFromString(value, Type); } - private void ConvertTensorsScalarType(ScalarType scalarType) - { - _transitionMatrix = _transitionMatrix?.to_type(scalarType); - _measurementFunction = _measurementFunction?.to_type(scalarType); - _processNoiseVariance = _processNoiseVariance?.to_type(scalarType); - _measurementNoiseVariance = _measurementNoiseVariance?.to_type(scalarType); - _initialMean = _initialMean?.to_type(scalarType); - _initialCovariance = _initialCovariance?.to_type(scalarType); - } + private Tensor _transitionMatrix; + private Tensor _measurementFunction; + private Tensor _processNoiseVariance; + private Tensor _measurementNoiseVariance; + private Tensor _initialMean; + private Tensor _initialCovariance; /// /// Creates a Kalman filter model using the properties of this class. @@ -226,13 +208,13 @@ public IObservable Process() /// public IObservable Process(IObservable source) { - return source.SelectMany(parameters => + return source.Select(parameters => { - return Observable.Return(new KalmanFilter( + return new KalmanFilter( parameters: parameters, device: Device, scalarType: Type - )); + ); }); } } diff --git a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs index 30b29b9d..be18195b 100644 --- a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs @@ -15,21 +15,13 @@ namespace Bonsai.ML.Lds.Torch; [ResetCombinator] [Description("Initializes the parameters for a new Kalman filter model.")] [WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] public class CreateKalmanFilterParameters : IScalarTypeProvider { /// [Description("The data type of the tensor elements.")] [TypeConverter(typeof(ScalarTypeConverter))] - public ScalarType Type - { - get => _scalarType; - set - { - _scalarType = value; - ConvertTensorsScalarType(value); - } - } - private ScalarType _scalarType = ScalarType.Float32; + public ScalarType Type { get; set; } = ScalarType.Float32; /// /// The device on which to create the tensor. @@ -48,7 +40,6 @@ public ScalarType Type /// public int? NumObservations { get; set; } = null; - private Tensor _transitionMatrix = null; /// /// The state transition matrix. /// @@ -57,7 +48,7 @@ public ScalarType Type public Tensor TransitionMatrix { get => _transitionMatrix; - set => _transitionMatrix = value?.to_type(Type); + set => _transitionMatrix = value; } /// @@ -68,11 +59,10 @@ public Tensor TransitionMatrix [EditorBrowsable(EditorBrowsableState.Never)] public string TransitionMatrixXml { - get => TensorConverter.ConvertToString(TransitionMatrix, _scalarType); - set => TransitionMatrix = TensorConverter.ConvertFromString(value, _scalarType); + get => TensorConverter.ConvertToString(_transitionMatrix, Type); + set => _transitionMatrix = TensorConverter.ConvertFromString(value, Type); } - private Tensor _measurementFunction = null; /// /// The measurement function. /// @@ -81,7 +71,7 @@ public string TransitionMatrixXml public Tensor MeasurementFunction { get => _measurementFunction; - set => _measurementFunction = value?.to_type(Type); + set => _measurementFunction = value; } /// @@ -92,11 +82,10 @@ public Tensor MeasurementFunction [EditorBrowsable(EditorBrowsableState.Never)] public string MeasurementFunctionXml { - get => TensorConverter.ConvertToString(MeasurementFunction, _scalarType); - set => MeasurementFunction = TensorConverter.ConvertFromString(value, _scalarType); + get => TensorConverter.ConvertToString(_measurementFunction, Type); + set => _measurementFunction = TensorConverter.ConvertFromString(value, Type); } - private Tensor _processNoiseCovariance = null; /// /// The process noise variance. /// @@ -105,7 +94,7 @@ public string MeasurementFunctionXml public Tensor ProcessNoiseCovariance { get => _processNoiseCovariance; - set => _processNoiseCovariance = value?.to_type(Type); + set => _processNoiseCovariance = value; } /// @@ -116,11 +105,10 @@ public Tensor ProcessNoiseCovariance [EditorBrowsable(EditorBrowsableState.Never)] public string ProcessNoiseCovarianceXml { - get => TensorConverter.ConvertToString(ProcessNoiseCovariance, _scalarType); - set => ProcessNoiseCovariance = TensorConverter.ConvertFromString(value, _scalarType); + get => TensorConverter.ConvertToString(_processNoiseCovariance, Type); + set => _processNoiseCovariance = TensorConverter.ConvertFromString(value, Type); } - private Tensor _measurementNoiseCovariance = null; /// /// The measurement noise covariance matrix. /// @@ -129,7 +117,7 @@ public string ProcessNoiseCovarianceXml public Tensor MeasurementNoiseCovariance { get => _measurementNoiseCovariance; - set => _measurementNoiseCovariance = value?.to_type(Type); + set => _measurementNoiseCovariance = value; } /// @@ -140,11 +128,10 @@ public Tensor MeasurementNoiseCovariance [EditorBrowsable(EditorBrowsableState.Never)] public string MeasurementNoiseCovarianceXml { - get => TensorConverter.ConvertToString(MeasurementNoiseCovariance, _scalarType); - set => MeasurementNoiseCovariance = TensorConverter.ConvertFromString(value, _scalarType); + get => TensorConverter.ConvertToString(_measurementNoiseCovariance, Type); + set => _measurementNoiseCovariance = TensorConverter.ConvertFromString(value, Type); } - private Tensor _initialMean = null; /// /// The initial mean. /// @@ -153,7 +140,7 @@ public string MeasurementNoiseCovarianceXml public Tensor InitialMean { get => _initialMean; - set => _initialMean = value?.to_type(Type); + set => _initialMean = value; } /// @@ -164,11 +151,10 @@ public Tensor InitialMean [EditorBrowsable(EditorBrowsableState.Never)] public string InitialMeanXml { - get => TensorConverter.ConvertToString(InitialMean, _scalarType); - set => InitialMean = TensorConverter.ConvertFromString(value, _scalarType); + get => TensorConverter.ConvertToString(_initialMean, Type); + set => _initialMean = TensorConverter.ConvertFromString(value, Type); } - private Tensor _initialCovariance = null; /// /// The initial covariance. /// @@ -177,7 +163,7 @@ public string InitialMeanXml public Tensor InitialCovariance { get => _initialCovariance; - set => _initialCovariance = value?.to_type(Type); + set => _initialCovariance = value; } /// @@ -188,19 +174,17 @@ public Tensor InitialCovariance [EditorBrowsable(EditorBrowsableState.Never)] public string InitialCovarianceXml { - get => TensorConverter.ConvertToString(InitialCovariance, _scalarType); - set => InitialCovariance = TensorConverter.ConvertFromString(value, _scalarType); + get => TensorConverter.ConvertToString(_initialCovariance, Type); + set => _initialCovariance = TensorConverter.ConvertFromString(value, Type); } - private void ConvertTensorsScalarType(ScalarType scalarType) - { - _transitionMatrix = _transitionMatrix?.to_type(scalarType); - _measurementFunction = _measurementFunction?.to_type(scalarType); - _processNoiseCovariance = _processNoiseCovariance?.to_type(scalarType); - _measurementNoiseCovariance = _measurementNoiseCovariance?.to_type(scalarType); - _initialMean = _initialMean?.to_type(scalarType); - _initialCovariance = _initialCovariance?.to_type(scalarType); - } + private Tensor _transitionMatrix = null; + private Tensor _measurementFunction = null; + private Tensor _processNoiseCovariance = null; + private Tensor _measurementNoiseCovariance = null; + private Tensor _initialMean = null; + private Tensor _initialCovariance = null; + /// /// Creates parameters for a Kalman filter model using the properties of this class. @@ -216,7 +200,7 @@ public IObservable Process() measurementNoiseCovariance: _measurementNoiseCovariance, initialMean: _initialMean, initialCovariance: _initialCovariance, - scalarType: _scalarType, + scalarType: Type, device: Device )); } @@ -237,7 +221,7 @@ public IObservable Process(IObservable source) measurementNoiseCovariance: _measurementNoiseCovariance, initialMean: _initialMean, initialCovariance: _initialCovariance, - scalarType: _scalarType, + scalarType: Type, device: Device ); }); diff --git a/src/Bonsai.ML.Lds.Torch/CreateLinearDynamicalSystemState.cs b/src/Bonsai.ML.Lds.Torch/CreateLinearDynamicalSystemState.cs index ac44e2c1..23983f9b 100644 --- a/src/Bonsai.ML.Lds.Torch/CreateLinearDynamicalSystemState.cs +++ b/src/Bonsai.ML.Lds.Torch/CreateLinearDynamicalSystemState.cs @@ -14,21 +14,13 @@ namespace Bonsai.ML.Lds.Torch; [ResetCombinator] [Description("Creates a new state for a linear gaussian dynamical system.")] [WorkflowElementCategory(ElementCategory.Source)] +[TypeConverter(typeof(TensorOperatorConverter))] public class CreateLinearDynamicalSystemState : IScalarTypeProvider { - private ScalarType _scalarType = ScalarType.Float32; /// [Description("The data type of the tensor elements.")] [TypeConverter(typeof(ScalarTypeConverter))] - public ScalarType Type - { - get => _scalarType; - set - { - _scalarType = value; - ConvertTensorsScalarType(value); - } - } + public ScalarType Type { get; set; } = ScalarType.Float32; /// /// The device on which to create the tensor. @@ -37,13 +29,6 @@ public ScalarType Type [XmlIgnore] public Device Device { get; set; } - private void ConvertTensorsScalarType(ScalarType scalarType) - { - _mean = _mean?.to_type(scalarType); - _covariance = _covariance?.to_type(scalarType); - } - - private Tensor _mean = null; /// /// The mean of the state. /// @@ -52,7 +37,7 @@ private void ConvertTensorsScalarType(ScalarType scalarType) public Tensor Mean { get => _mean; - set => _mean = value?.to_type(Type); + set => _mean = value; } /// @@ -63,11 +48,10 @@ public Tensor Mean [EditorBrowsable(EditorBrowsableState.Never)] public string MeanXml { - get => TensorConverter.ConvertToString(Mean, _scalarType); - set => Mean = TensorConverter.ConvertFromString(value, _scalarType); + get => TensorConverter.ConvertToString(_mean, Type); + set => _mean = TensorConverter.ConvertFromString(value, Type); } - private Tensor _covariance = null; /// /// The covariance of the state. /// @@ -76,7 +60,7 @@ public string MeanXml public Tensor Covariance { get => _covariance; - set => _covariance = value?.to_type(Type); + set => _covariance = value; } /// @@ -87,10 +71,13 @@ public Tensor Covariance [EditorBrowsable(EditorBrowsableState.Never)] public string CovarianceXml { - get => TensorConverter.ConvertToString(Covariance, _scalarType); - set => Covariance = TensorConverter.ConvertFromString(value, _scalarType); + get => TensorConverter.ConvertToString(_covariance, Type); + set => _covariance = TensorConverter.ConvertFromString(value, Type); } + private Tensor _mean = null; + private Tensor _covariance = null; + /// /// Creates an observable sequence and emits the state for a linear gaussian dynamical system. /// @@ -100,8 +87,8 @@ public IObservable Process() return Observable.Defer(() => { var device = Device ?? CPU; - var mean = Mean?.to(device) ?? throw new InvalidOperationException("The mean of the state must be specified."); - var covariance = Covariance?.to(device) ?? throw new InvalidOperationException("The covariance of the state must be specified."); + var mean = _mean?.to(device) ?? throw new InvalidOperationException("The mean of the state must be specified."); + var covariance = _covariance?.to(device) ?? throw new InvalidOperationException("The covariance of the state must be specified."); return Observable.Return(new LinearDynamicalSystemState(mean, covariance)); }); } @@ -118,8 +105,8 @@ public IObservable Process(IObservable source) return source.Select(_ => { var device = Device ?? CPU; - var mean = Mean?.to(device) ?? throw new InvalidOperationException("The mean of the state must be specified."); - var covariance = Covariance?.to(device) ?? throw new InvalidOperationException("The covariance of the state must be specified."); + var mean = _mean?.to(device) ?? throw new InvalidOperationException("The mean of the state must be specified."); + var covariance = _covariance?.to(device) ?? throw new InvalidOperationException("The covariance of the state must be specified."); return new LinearDynamicalSystemState(mean, covariance); }); } diff --git a/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs index 53032b01..1cc57235 100644 --- a/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs +++ b/src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs @@ -32,7 +32,6 @@ public class ExpectationMaximization [XmlIgnore] public KalmanFilterParameters? ModelParameters { get; set; } = null; - private int _maxIterations = 10; /// /// The maximum number of EM iterations to perform. /// @@ -43,7 +42,6 @@ public int MaxIterations set => _maxIterations = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(MaxIterations), "Must be greater than zero."); } - private double _tolerance = 1e-4; /// /// The convergence tolerance for the EM algorithm. /// @@ -54,7 +52,6 @@ public double Tolerance set => _tolerance = value >= 0 ? value : throw new ArgumentOutOfRangeException(nameof(Tolerance), "Must be greater than or equal to zero."); } - private bool _verbose = true; /// /// If true, prints progress messages to the console. /// @@ -101,6 +98,10 @@ public bool Verbose [Description("If true, the initial covariance will be estimated during the EM algorithm.")] public bool EstimateInitialCovariance { get; set; } = true; + private int _maxIterations = 10; + private double _tolerance = 1e-4; + private bool _verbose = true; + /// /// Processes an observable sequence of input tensors, applying the Expectation-Maximization algorithm to learn the parameters of a Kalman filter model. /// diff --git a/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs b/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs index 01d0c8e8..6fb24d3e 100644 --- a/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs +++ b/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs @@ -16,7 +16,6 @@ namespace Bonsai.ML.Lds.Torch; [WorkflowElementCategory(ElementCategory.Combinator)] public class StochasticSubspaceIdentification { - private int? _targetNumStates = 2; /// /// The target number of states in the Kalman filter model. /// @@ -27,7 +26,6 @@ public int? TargetNumStates set => _targetNumStates = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(value), "Number of states must be greater than zero."); } - private int _maxLag = 20; /// /// The maximum lag to consider for the subspace identification. /// @@ -38,7 +36,6 @@ public int MaxLag set => _maxLag = value > 0 ? value : throw new ArgumentOutOfRangeException(nameof(MaxLag), "Must be greater than zero."); } - private double _threshold = 1e-4; /// /// The threshold for the singular values to determine the effective number of states. /// @@ -85,6 +82,10 @@ public double Threshold [Description("If true, the initial covariance will be estimated during the EM algorithm.")] public bool EstimateInitialCovariance { get; set; } = true; + private int? _targetNumStates = 2; + private int _maxLag = 20; + private double _threshold = 1e-4; + /// /// Processes an observable sequence of input tensors, applying the Expectation-Maximization algorithm to learn the parameters of a Kalman filter model. /// From e326a85bec38aa5c2e24139e676808d562479267 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Mon, 12 Jan 2026 18:45:13 +0000 Subject: [PATCH 91/92] Added default user-agent in request header when downloading test data from zenodo --- tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs index d997e3fb..c641fe93 100644 --- a/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs +++ b/tests/Bonsai.ML.Lds.Torch.Tests/NeuralLatentsTest.cs @@ -10,6 +10,7 @@ using Bonsai.ML.Tests.Utilities; using static TorchSharp.torch; using TorchSharp; +using System.Text; namespace Bonsai.ML.Lds.Torch.Tests; @@ -30,6 +31,7 @@ private static void DownloadData(string basePath) byte[] responseBytes; using (var httpClient = new HttpClient()) { + httpClient.DefaultRequestHeaders.Add("User-Agent", "Bonsai.ML.Tests"); responseBytes = httpClient.GetByteArrayAsync(zipFileUrl).Result; Console.WriteLine("File downloaded successfully."); } From d306bc266c4ea4118e2e9ddf4be6c29c460af1cb Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Tue, 13 Jan 2026 12:03:09 +0000 Subject: [PATCH 92/92] Added support in KF for estimating state and observation offset parameters --- src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs | 69 +- .../CreateKalmanFilterParameters.cs | 69 +- .../CreateLinearDynamicalSystemState.cs | 8 +- .../ExpectationMaximization.cs | 28 +- src/Bonsai.ML.Lds.Torch/FilteredState.cs | 54 +- src/Bonsai.ML.Lds.Torch/KalmanFilter.cs | 659 +++++++++--------- .../KalmanFilterParameters.cs | 290 ++++---- .../LinearDynamicalSystemState.cs | 10 +- .../LoadKalmanFilterParameters.cs | 18 +- src/Bonsai.ML.Lds.Torch/Orthogonalize.cs | 29 +- .../ParametersToEstimate.cs | 18 +- .../SaveKalmanFilterParameters.cs | 30 +- src/Bonsai.ML.Lds.Torch/Smooth.cs | 7 +- .../StochasticSubspaceIdentification.cs | 50 +- .../ReceptiveFieldSimpleCellTest.cs | 8 +- 15 files changed, 741 insertions(+), 606 deletions(-) diff --git a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs index fa503a5c..4b9dda07 100644 --- a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs +++ b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs @@ -17,6 +17,15 @@ namespace Bonsai.ML.Lds.Torch; [TypeConverter(typeof(TensorOperatorConverter))] public class CreateKalmanFilter : IScalarTypeProvider { + private Tensor _transitionMatrix; + private Tensor _measurementFunction; + private Tensor _processNoiseVariance; + private Tensor _measurementNoiseVariance; + private Tensor _initialMean; + private Tensor _initialCovariance; + private Tensor _stateOffset; + private Tensor _observationOffset; + /// [Description("The data type of the tensor elements.")] [TypeConverter(typeof(ScalarTypeConverter))] @@ -38,7 +47,7 @@ public class CreateKalmanFilter : IScalarTypeProvider /// The number of observations in the Kalman filter model. /// public int? NumObservations { get; set; } = null; - + /// /// The state transition matrix. /// @@ -177,12 +186,52 @@ public string InitialCovarianceXml set => _initialCovariance = TensorConverter.ConvertFromString(value, Type); } - private Tensor _transitionMatrix; - private Tensor _measurementFunction; - private Tensor _processNoiseVariance; - private Tensor _measurementNoiseVariance; - private Tensor _initialMean; - private Tensor _initialCovariance; + /// + /// The state offset. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor StateOffset + { + get => _stateOffset; + set => _stateOffset = value; + } + + /// + /// The XML string representation of the state offset for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(StateOffset))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string StateOffsetXml + { + get => TensorConverter.ConvertToString(_stateOffset, Type); + set => _stateOffset = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// The observation offset. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor ObservationOffset + { + get => _observationOffset; + set => _observationOffset = value; + } + + /// + /// The XML string representation of the observation offset for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(ObservationOffset))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ObservationOffsetXml + { + get => TensorConverter.ConvertToString(_observationOffset, Type); + set => _observationOffset = TensorConverter.ConvertFromString(value, Type); + } + /// /// Creates a Kalman filter model using the properties of this class. @@ -198,6 +247,8 @@ public IObservable Process() measurementNoiseVariance: MeasurementNoiseVariance, initialMean: InitialMean, initialCovariance: InitialCovariance, + stateOffset: StateOffset, + observationOffset: ObservationOffset, device: Device, scalarType: Type )); @@ -211,9 +262,7 @@ public IObservable Process(IObservable sou return source.Select(parameters => { return new KalmanFilter( - parameters: parameters, - device: Device, - scalarType: Type + parameters: parameters ); }); } diff --git a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs index be18195b..b3aa2a09 100644 --- a/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs @@ -18,6 +18,15 @@ namespace Bonsai.ML.Lds.Torch; [TypeConverter(typeof(TensorOperatorConverter))] public class CreateKalmanFilterParameters : IScalarTypeProvider { + private Tensor _transitionMatrix = null; + private Tensor _measurementFunction = null; + private Tensor _processNoiseCovariance = null; + private Tensor _measurementNoiseCovariance = null; + private Tensor _initialMean = null; + private Tensor _initialCovariance = null; + private Tensor _stateOffset = null; + private Tensor _observationOffset = null; + /// [Description("The data type of the tensor elements.")] [TypeConverter(typeof(ScalarTypeConverter))] @@ -178,20 +187,58 @@ public string InitialCovarianceXml set => _initialCovariance = TensorConverter.ConvertFromString(value, Type); } - private Tensor _transitionMatrix = null; - private Tensor _measurementFunction = null; - private Tensor _processNoiseCovariance = null; - private Tensor _measurementNoiseCovariance = null; - private Tensor _initialMean = null; - private Tensor _initialCovariance = null; + /// + /// The state offset. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor StateOffset + { + get => _stateOffset; + set => _stateOffset = value; + } + /// + /// The XML string representation of the state offset for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(StateOffset))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string StateOffsetXml + { + get => TensorConverter.ConvertToString(_stateOffset, Type); + set => _stateOffset = TensorConverter.ConvertFromString(value, Type); + } + + /// + /// The observation offset. + /// + [XmlIgnore] + [TypeConverter(typeof(TensorConverter))] + public Tensor ObservationOffset + { + get => _observationOffset; + set => _observationOffset = value; + } + + /// + /// The XML string representation of the observation offset for serialization. + /// + [Browsable(false)] + [XmlElement(nameof(ObservationOffset))] + [EditorBrowsable(EditorBrowsableState.Never)] + public string ObservationOffsetXml + { + get => TensorConverter.ConvertToString(_observationOffset, Type); + set => _observationOffset = TensorConverter.ConvertFromString(value, Type); + } /// /// Creates parameters for a Kalman filter model using the properties of this class. /// public IObservable Process() { - return Observable.Return(KalmanFilterParameters.Initialize( + return Observable.Return(new KalmanFilterParameters( numStates: NumStates, numObservations: NumObservations, transitionMatrix: _transitionMatrix, @@ -200,6 +247,8 @@ public IObservable Process() measurementNoiseCovariance: _measurementNoiseCovariance, initialMean: _initialMean, initialCovariance: _initialCovariance, + stateOffset: _stateOffset, + observationOffset: _observationOffset, scalarType: Type, device: Device )); @@ -212,7 +261,7 @@ public IObservable Process(IObservable source) { return source.Select(_ => { - return KalmanFilterParameters.Initialize( + return new KalmanFilterParameters( numStates: NumStates, numObservations: NumObservations, transitionMatrix: _transitionMatrix, @@ -221,9 +270,11 @@ public IObservable Process(IObservable source) measurementNoiseCovariance: _measurementNoiseCovariance, initialMean: _initialMean, initialCovariance: _initialCovariance, + stateOffset: _stateOffset, + observationOffset: _observationOffset, scalarType: Type, device: Device ); }); } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Lds.Torch/CreateLinearDynamicalSystemState.cs b/src/Bonsai.ML.Lds.Torch/CreateLinearDynamicalSystemState.cs index 23983f9b..a1760724 100644 --- a/src/Bonsai.ML.Lds.Torch/CreateLinearDynamicalSystemState.cs +++ b/src/Bonsai.ML.Lds.Torch/CreateLinearDynamicalSystemState.cs @@ -17,6 +17,9 @@ namespace Bonsai.ML.Lds.Torch; [TypeConverter(typeof(TensorOperatorConverter))] public class CreateLinearDynamicalSystemState : IScalarTypeProvider { + private Tensor _mean = null; + private Tensor _covariance = null; + /// [Description("The data type of the tensor elements.")] [TypeConverter(typeof(ScalarTypeConverter))] @@ -75,9 +78,6 @@ public string CovarianceXml set => _covariance = TensorConverter.ConvertFromString(value, Type); } - private Tensor _mean = null; - private Tensor _covariance = null; - /// /// Creates an observable sequence and emits the state for a linear gaussian dynamical system. /// @@ -123,4 +123,4 @@ public IObservable Process(IObservable /// The number of states in the Kalman filter model. /// @@ -98,9 +102,17 @@ public bool Verbose [Description("If true, the initial covariance will be estimated during the EM algorithm.")] public bool EstimateInitialCovariance { get; set; } = true; - private int _maxIterations = 10; - private double _tolerance = 1e-4; - private bool _verbose = true; + /// + /// If true, the state offset will be estimated during the EM algorithm. + /// + [Description("If true, the state offset will be estimated during the EM algorithm.")] + public bool EstimateStateOffset { get; set; } = false; + + /// + /// If true, the observation offset will be estimated during the EM algorithm. + /// + [Description("If true, the observation offset will be estimated during the EM algorithm.")] + public bool EstimateObservationOffset { get; set; } = false; /// /// Processes an observable sequence of input tensors, applying the Expectation-Maximization algorithm to learn the parameters of a Kalman filter model. @@ -124,16 +136,16 @@ public IObservable Process(IObservable so processNoiseCovariance: EstimateProcessNoiseCovariance, measurementNoiseCovariance: EstimateMeasurementNoiseCovariance, initialMean: EstimateInitialMean, - initialCovariance: EstimateInitialCovariance); + initialCovariance: EstimateInitialCovariance, + stateOffset: EstimateStateOffset, + observationOffset: EstimateObservationOffset); - var parameters = ModelParameters?.Copy() ?? KalmanFilterParameters.Initialize( + var parameters = ModelParameters?.Copy() ?? new KalmanFilterParameters( numStates: NumStates, numObservations: numObservations, scalarType: input.dtype, device: input.device); - parameters.Validate(); - for (int i = 0; i < MaxIterations; i++) { // Check for cancellation before each iteration @@ -200,4 +212,4 @@ public IObservable Process(IObservable so cancellationToken); })).Concat(); } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Lds.Torch/FilteredState.cs b/src/Bonsai.ML.Lds.Torch/FilteredState.cs index cb128b81..9b1d53a1 100644 --- a/src/Bonsai.ML.Lds.Torch/FilteredState.cs +++ b/src/Bonsai.ML.Lds.Torch/FilteredState.cs @@ -5,39 +5,53 @@ namespace Bonsai.ML.Lds.Torch; /// /// Represents the state of a Kalman filter. /// -/// -/// -/// -/// -public struct FilteredState( - Tensor predictedMean, - Tensor predictedCovariance, - Tensor updatedMean, - Tensor updatedCovariance) : ILinearDynamicalSystemState +/// +/// +/// +/// +/// +/// +public readonly struct FilteredState( + LinearDynamicalSystemState predictedState, + LinearDynamicalSystemState updatedState, + Tensor innovation = null, + Tensor innovationCovariance = null, + Tensor kalmanGain = null, + Tensor logLikelihood = null) : ILinearDynamicalSystemState { /// - /// The predicted mean after the prediction step. + /// The predicted state following the prediction step. /// - public Tensor PredictedMean = predictedMean; + public readonly LinearDynamicalSystemState PredictedState => predictedState; /// - /// The predicted covariance after the prediction step. + /// The updated state following the update step. /// - public Tensor PredictedCovariance = predictedCovariance; + public readonly LinearDynamicalSystemState UpdatedState => updatedState; /// - /// The updated mean after the update step. + /// The innovation (residual) between the observation and the prediction. /// - public Tensor UpdatedMean = updatedMean; + public readonly Tensor Innovation => innovation; /// - /// The updated covariance after the update step. + /// The innovation (residual) covariance. /// - public Tensor UpdatedCovariance = updatedCovariance; + public readonly Tensor InnovationCovariance => innovationCovariance; + + /// + /// The Kalman gain. + /// + public readonly Tensor KalmanGain => kalmanGain; + + /// + /// The log likelihood of the observation given the updated state. + /// + public readonly Tensor LogLikelihood => logLikelihood; /// - public readonly Tensor Mean => UpdatedMean.isnan().any().item() ? PredictedMean : UpdatedMean; + public readonly Tensor Mean => updatedState.Mean; /// - public readonly Tensor Covariance => UpdatedCovariance.isnan().any().item() ? PredictedCovariance : UpdatedCovariance; -} \ No newline at end of file + public readonly Tensor Covariance => updatedState.Covariance; +} diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs index c12a7646..f6b9ba49 100644 --- a/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilter.cs @@ -1,4 +1,5 @@ using System; +using TorchSharp.Modules; using static TorchSharp.torch; namespace Bonsai.ML.Lds.Torch; @@ -6,30 +7,19 @@ namespace Bonsai.ML.Lds.Torch; // disable missing XML comment warnings # pragma warning disable CS1591 -public class KalmanFilter : nn.Module +public class KalmanFilter : nn.Module { - private readonly Tensor _mean; - private readonly Tensor _covariance; - private readonly Device _device; - private readonly ScalarType _scalarType; + private LinearDynamicalSystemState _state; public readonly KalmanFilterParameters Parameters; public KalmanFilter( - KalmanFilterParameters parameters, - Device device = null, - ScalarType? scalarType = null) : base("KalmanFilter") + KalmanFilterParameters parameters) : base("KalmanFilter") { - Parameters = parameters.Copy(); - - Parameters.Validate(); - Parameters.ToScalarType(scalarType); - Parameters.ToDevice(device); + Parameters = parameters; - _device = Parameters.Device; - _scalarType = Parameters.ScalarType; - - _mean = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); - _covariance = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); + _state = new LinearDynamicalSystemState( + Parameters.InitialMean, + Parameters.InitialCovariance); RegisterComponents(); } @@ -43,10 +33,13 @@ public KalmanFilter( Tensor initialCovariance = null, Tensor processNoiseVariance = null, Tensor measurementNoiseVariance = null, + Tensor stateOffset = null, + Tensor observationOffset = null, Device device = null, - ScalarType? scalarType = null) : base("KalmanFilter") + ScalarType? scalarType = null, + bool requiresGrad = false) : base("KalmanFilter") { - Parameters = KalmanFilterParameters.Initialize( + Parameters = new KalmanFilterParameters( numStates, numObservations, transitionMatrix, @@ -55,99 +48,59 @@ public KalmanFilter( measurementNoiseVariance, initialMean, initialCovariance, + stateOffset, + observationOffset, device, - scalarType + scalarType, + requiresGrad ); - _device = Parameters.Device; - _scalarType = Parameters.ScalarType; - - _mean = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); - _covariance = empty(0, dtype: _scalarType, device: _device).requires_grad_(false); + _state = new LinearDynamicalSystemState( + Parameters.InitialMean, + Parameters.InitialCovariance); RegisterComponents(); } - private LinearDynamicalSystemState FilterPredict( - Tensor mean, - Tensor covariance) => - new(Parameters.TransitionMatrix.matmul(mean), - Parameters.TransitionMatrix.matmul(covariance) - .matmul(Parameters.TransitionMatrix.mT) + Parameters.ProcessNoiseCovariance); - private static LinearDynamicalSystemState FilterPredict( - Tensor mean, - Tensor covariance, - Tensor transitionMatrix, - Tensor processNoiseCovariance) => - new(transitionMatrix.matmul(mean), - transitionMatrix.matmul(covariance) - .matmul(transitionMatrix.mT) + processNoiseCovariance); - - private readonly struct UpdatedState( - Tensor updatedMean, - Tensor updatedCovariance, - Tensor innovation, - Tensor innovationCovariance, - Tensor kalmanGain) - { - public readonly Tensor UpdatedMean = updatedMean; - public readonly Tensor UpdatedCovariance = updatedCovariance; - public readonly Tensor Innovation = innovation; - public readonly Tensor InnovationCovariance = innovationCovariance; - public readonly Tensor KalmanGain = kalmanGain; - } - - private UpdatedState FilterUpdate( - Tensor predictedMean, - Tensor predictedCovariance, - Tensor observation) - { - // Innovation step - var innovation = observation - Parameters.MeasurementFunction.matmul(predictedMean); - var innovationCovariance = WrappedTensorDisposeScope(() => EnsureSymmetric( - Parameters.MeasurementFunction.matmul(predictedCovariance) - .matmul(Parameters.MeasurementFunction.mT) + Parameters.MeasurementNoiseCovariance)); + LinearDynamicalSystemState state, + KalmanFilterParameters parameters) => + new(parameters.TransitionMatrix.matmul(state.Mean) + (parameters.OffsetsProvided ? parameters.StateOffset : 0), + parameters.TransitionMatrix.matmul(state.Covariance) + .matmul(parameters.TransitionMatrix.mT) + parameters.ProcessNoiseCovariance); - // Kalman gain - var kalmanGain = WrappedTensorDisposeScope(() => InverseCholesky( - predictedCovariance.matmul(Parameters.MeasurementFunction.mT), - innovationCovariance)); - - // Update step - var updatedMean = predictedMean + kalmanGain.matmul(innovation); - var updatedCovariance = WrappedTensorDisposeScope(() => predictedCovariance - - kalmanGain.matmul(Parameters.MeasurementFunction).matmul(predictedCovariance)); - - return new UpdatedState(updatedMean, updatedCovariance, innovation, innovationCovariance, kalmanGain); - } - - private static UpdatedState FilterUpdate( - Tensor predictedMean, - Tensor predictedCovariance, + private static FilteredState FilterUpdate( Tensor observation, - Tensor measurementFunction, - Tensor measurementNoiseCovariance) + LinearDynamicalSystemState state, + KalmanFilterParameters parameters) { + if (observation is null) + return new FilteredState( + predictedState: state, + updatedState: state + ); + // Innovation step - var innovation = observation - measurementFunction.matmul(predictedMean); + var innovation = observation - (parameters.MeasurementFunction.matmul(state.Mean) + (parameters.OffsetsProvided ? parameters.ObservationOffset : 0)); var innovationCovariance = WrappedTensorDisposeScope(() => EnsureSymmetric( - measurementFunction.matmul(predictedCovariance) - .matmul(measurementFunction.mT) + measurementNoiseCovariance)); + parameters.MeasurementFunction.matmul(state.Covariance) + .matmul(parameters.MeasurementFunction.mT) + parameters.MeasurementNoiseCovariance)); // Kalman gain var kalmanGain = WrappedTensorDisposeScope(() => InverseCholesky( - predictedCovariance.matmul(measurementFunction.mT), + state.Covariance.matmul(parameters.MeasurementFunction.mT), innovationCovariance)); // Update step - var updatedMean = predictedMean + kalmanGain.matmul(innovation); - var updatedCovariance = WrappedTensorDisposeScope(() => predictedCovariance - - kalmanGain.matmul(measurementFunction).matmul(predictedCovariance)); + var updatedMean = state.Mean + kalmanGain.matmul(innovation); + var updatedCovariance = WrappedTensorDisposeScope(() => state.Covariance + - kalmanGain.matmul(parameters.MeasurementFunction).matmul(state.Covariance)); + + var updatedState = new LinearDynamicalSystemState(updatedMean, updatedCovariance); - return new UpdatedState( - updatedMean: updatedMean, - updatedCovariance: updatedCovariance, + return new FilteredState( + predictedState: state, + updatedState: updatedState, innovation: innovation, innovationCovariance: innovationCovariance, kalmanGain: kalmanGain @@ -161,110 +114,85 @@ public FilteredState Filter(Tensor observation) var obs = observation.atleast_2d(); var timeBins = obs.size(0); - var predictedMean = empty([timeBins, Parameters.NumStates], dtype: _scalarType, device: _device); - var predictedCovariance = empty([timeBins, Parameters.NumStates, Parameters.NumStates], dtype: _scalarType, device: _device); - var updatedMean = empty([timeBins, Parameters.NumStates], dtype: _scalarType, device: _device); - var updatedCovariance = empty([timeBins, Parameters.NumStates, Parameters.NumStates], dtype: _scalarType, device: _device); + // We get the scalar type and device from the parameters in the event that the observations are null (e.g. during forecasting) + var scalarType = Parameters.ScalarType; + var device = Parameters.Device; - if (_mean.NumberOfElements == 0) - _mean.set_(Parameters.InitialMean.clone()); - if (_covariance.NumberOfElements == 0) - _covariance.set_(Parameters.InitialCovariance.clone()); + var predictedState = new LinearDynamicalSystemState( + empty([timeBins, Parameters.NumStates], dtype: scalarType, device: device), + empty([timeBins, Parameters.NumStates, Parameters.NumStates], dtype: scalarType, device: device)); + + var updatedState = new LinearDynamicalSystemState( + empty([timeBins, Parameters.NumStates], dtype: scalarType, device: device), + empty([timeBins, Parameters.NumStates, Parameters.NumStates], dtype: scalarType, device: device)); for (long time = 0; time < timeBins; time++) { // Predict - var prediction = FilterPredict(_mean, _covariance); + var prediction = FilterPredict(_state, Parameters); // Update - var update = FilterUpdate(prediction.Mean, prediction.Covariance, obs[time]); + var update = FilterUpdate(obs[time], prediction, Parameters); - predictedMean[time] = prediction.Mean; - predictedCovariance[time] = prediction.Covariance; - updatedMean[time] = update.UpdatedMean; - updatedCovariance[time] = update.UpdatedCovariance; + predictedState.Mean[time] = prediction.Mean; + predictedState.Covariance[time] = prediction.Covariance; + updatedState.Mean[time] = update.UpdatedState.Mean; + updatedState.Covariance[time] = update.UpdatedState.Covariance; - if (!update.UpdatedMean.isnan().any().item()) - { - _mean.set_(update.UpdatedMean); - _covariance.set_(update.UpdatedCovariance); - } - else - { - _mean.set_(prediction.Mean); - _covariance.set_(prediction.Covariance); - } + _state = update.UpdatedState; } return new FilteredState( - predictedMean: predictedMean, - predictedCovariance: predictedCovariance, - updatedMean: updatedMean, - updatedCovariance: updatedCovariance); + predictedState: predictedState, + updatedState: updatedState + ); } - private readonly struct FilteredStateWithAuxiliaryVariables( - Tensor predictedMean, - Tensor predictedCovariance, - Tensor updatedMean, - Tensor updatedCovariance, - Tensor innovation, - Tensor innovationCovariance, - Tensor logLikelihood, - Tensor kalmanGain) + public override LinearDynamicalSystemState forward(Tensor input) { - public readonly Tensor PredictedMean = predictedMean; - public readonly Tensor PredictedCovariance = predictedCovariance; - public readonly Tensor UpdatedMean = updatedMean; - public readonly Tensor UpdatedCovariance = updatedCovariance; - public readonly Tensor Innovation = innovation; - public readonly Tensor InnovationCovariance = innovationCovariance; - public readonly Tensor LogLikelihood = logLikelihood; - public readonly Tensor KalmanGain = kalmanGain; + var filteredState = Filter(input); + return filteredState.UpdatedState; } - private static FilteredStateWithAuxiliaryVariables Filter( - Tensor observation, + private static FilteredState Filter( long timeBins, - int numStates, - int numObservations, - Tensor transitionMatrix, - Tensor measurementFunction, - Tensor processNoiseCovariance, - Tensor measurementNoiseCovariance, - Tensor initialMean, - Tensor initialCovariance, - ScalarType scalarType, - Device device) + Tensor observation, + KalmanFilterParameters parameters) { - var logLikelihood = empty(timeBins, dtype: scalarType, device: device); - var predictedMean = empty([timeBins, numStates], dtype: scalarType, device: device); - var predictedCovariance = empty([timeBins, numStates, numStates], dtype: scalarType, device: device); - var updatedMean = empty([timeBins, numStates], dtype: scalarType, device: device); - var updatedCovariance = empty([timeBins, numStates, numStates], dtype: scalarType, device: device); - var innovation = empty([timeBins, numObservations], dtype: scalarType, device: device); - var innovationCovariance = empty([timeBins, numObservations, numObservations], dtype: scalarType, device: device); - var kalmanGain = empty([timeBins, numStates, numObservations], dtype: scalarType, device: device); - - var mean = initialMean; - var covariance = initialCovariance; + using var g = no_grad(); + + var timeBinsObservations = observation.size(0); + timeBins = Math.Min(timeBins, timeBinsObservations); + + var filteredState = new FilteredState( + predictedState: new LinearDynamicalSystemState( + empty([timeBins, parameters.NumStates], dtype: parameters.ScalarType, device: parameters.Device), + empty([timeBins, parameters.NumStates, parameters.NumStates], dtype: parameters.ScalarType, device: parameters.Device)), + updatedState: new LinearDynamicalSystemState( + empty([timeBins, parameters.NumStates], dtype: parameters.ScalarType, device: parameters.Device), + empty([timeBins, parameters.NumStates, parameters.NumStates], dtype: parameters.ScalarType, device: parameters.Device)), + innovation: empty([timeBins, parameters.NumObservations], dtype: parameters.ScalarType, device: parameters.Device), + innovationCovariance: empty([timeBins, parameters.NumObservations, parameters.NumObservations], dtype: parameters.ScalarType, device: parameters.Device), + kalmanGain: empty([timeBins, parameters.NumStates, parameters.NumObservations], dtype: parameters.ScalarType, device: parameters.Device), + logLikelihood: empty([timeBins], dtype: parameters.ScalarType, device: parameters.Device) + ); + + var state = new LinearDynamicalSystemState( + mean: parameters.InitialMean, + covariance: parameters.InitialCovariance); for (long time = 0; time < timeBins; time++) { // Predict var prediction = FilterPredict( - mean: mean, - covariance: covariance, - transitionMatrix: transitionMatrix, - processNoiseCovariance: processNoiseCovariance); + state: state, + parameters: parameters); // Update var update = FilterUpdate( - predictedMean: prediction.Mean, - predictedCovariance: prediction.Covariance, observation: observation[time], - measurementFunction: measurementFunction, - measurementNoiseCovariance: measurementNoiseCovariance); + state: prediction, + parameters: parameters); // Log Likelihood var logLikelihoodData = -(slogdet(update.InnovationCovariance).logabsdet @@ -272,47 +200,29 @@ private static FilteredStateWithAuxiliaryVariables Filter( .matmul(update.Innovation)).squeeze(); // Detach and assign - logLikelihood[time] = logLikelihoodData; - predictedMean[time] = prediction.Mean; - predictedCovariance[time] = prediction.Covariance; - updatedMean[time] = update.UpdatedMean; - updatedCovariance[time] = update.UpdatedCovariance; - innovation[time] = update.Innovation; - innovationCovariance[time] = update.InnovationCovariance; - kalmanGain[time] = update.KalmanGain; - - if (!update.UpdatedMean.isnan().any().item()) - { - mean = update.UpdatedMean; - covariance = update.UpdatedCovariance; - } - else - { - mean = prediction.Mean; - covariance = prediction.Covariance; - } + filteredState.LogLikelihood[time] = logLikelihoodData; + filteredState.PredictedState.Mean[time] = prediction.Mean; + filteredState.PredictedState.Covariance[time] = prediction.Covariance; + filteredState.UpdatedState.Mean[time] = update.UpdatedState.Mean; + filteredState.UpdatedState.Covariance[time] = update.UpdatedState.Covariance; + filteredState.Innovation[time] = update.Innovation; + filteredState.InnovationCovariance[time] = update.InnovationCovariance; + filteredState.KalmanGain[time] = update.KalmanGain; + + state = update.UpdatedState; } - return new FilteredStateWithAuxiliaryVariables( - predictedMean: predictedMean, - predictedCovariance: predictedCovariance, - updatedMean: updatedMean, - updatedCovariance: updatedCovariance, - innovation: innovation, - innovationCovariance: innovationCovariance, - logLikelihood: logLikelihood, - kalmanGain: kalmanGain - ); + return filteredState; } public LinearDynamicalSystemState Smooth(FilteredState filteredState) { using var g = no_grad(); - var predictedMean = filteredState.PredictedMean; - var predictedCovariance = filteredState.PredictedCovariance; - var updatedMean = filteredState.UpdatedMean; - var updatedCovariance = filteredState.UpdatedCovariance; + var predictedMean = filteredState.PredictedState.Mean; + var predictedCovariance = filteredState.PredictedState.Covariance; + var updatedMean = filteredState.UpdatedState.Mean; + var updatedCovariance = filteredState.UpdatedState.Covariance; var timeBins = predictedMean.size(0); var smoothedMean = empty_like(updatedMean); @@ -322,7 +232,7 @@ public LinearDynamicalSystemState Smooth(FilteredState filteredState) smoothedMean[-1] = updatedMean[-1]; smoothedCovariance[-1] = updatedCovariance[-1]; - var smoothingGain = empty([Parameters.NumStates, Parameters.NumStates], dtype: _scalarType, device: _device); + var smoothingGain = empty([Parameters.NumStates, Parameters.NumStates], dtype: Parameters.ScalarType, device: Parameters.Device); // Backward pass for (long time = timeBins - 2; time >= 0; time--) @@ -351,65 +261,91 @@ public LinearDynamicalSystemState Smooth(FilteredState filteredState) ); } + private readonly struct SummaryStatisticsEM( + Tensor sxx11, + Tensor sxx10, + Tensor sxx00, + Tensor tx1 = null, + Tensor tx0 = null, + Tensor ty1 = null, + Tensor tyx11 = null, + Tensor tyy11 = null + ) + { + public readonly Tensor Sxx11 => sxx11; + public readonly Tensor Sxx10 => sxx10; + public readonly Tensor Sxx00 => sxx00; + public readonly Tensor Tx1 => tx1; + public readonly Tensor Tx0 => tx0; + public readonly Tensor Ty1 => ty1; + public readonly Tensor Tyx11 => tyx11; + public readonly Tensor Tyy11 => tyy11; + } + private readonly struct SmoothedStateWithAuxiliaryVariables( - Tensor smoothedMean, - Tensor smoothedCovariance, + LinearDynamicalSystemState smoothedState, Tensor smoothedInitialMean, - Tensor smoothedInitialCovariance, - Tensor S00, - Tensor S10, - Tensor S11) + Tensor smoothedInitialCovariance) { - public readonly Tensor SmoothedMean = smoothedMean; - public readonly Tensor SmoothedCovariance = smoothedCovariance; - public readonly Tensor SmoothedInitialMean = smoothedInitialMean; - public readonly Tensor SmoothedInitialCovariance = smoothedInitialCovariance; - public readonly Tensor S00 = S00; - public readonly Tensor S10 = S10; - public readonly Tensor S11 = S11; + public readonly LinearDynamicalSystemState SmoothedState => smoothedState; + public readonly Tensor SmoothedInitialMean => smoothedInitialMean; + public readonly Tensor SmoothedInitialCovariance => smoothedInitialCovariance; } - private static SmoothedStateWithAuxiliaryVariables Smooth( - FilteredStateWithAuxiliaryVariables filteredState, - long timeBins, - int numStates, - Tensor transitionMatrix, - Tensor measurementFunction, - Tensor initialMean, - Tensor initialCovariance, - Tensor identityStates, - ScalarType scalarType, - Device device + private static (SmoothedStateWithAuxiliaryVariables state, SummaryStatisticsEM statistics) Smooth( + FilteredState filteredState, + Tensor observations, + KalmanFilterParameters parameters ) { + var timeBins = filteredState.PredictedState.Mean.size(0); + if (timeBins < 2) throw new ArgumentException("Smoothing requires at least two time bins."); - var predictedMean = filteredState.PredictedMean; - var predictedCovariance = filteredState.PredictedCovariance; - var updatedMean = filteredState.UpdatedMean; - var updatedCovariance = filteredState.UpdatedCovariance; + var predictedMean = filteredState.PredictedState.Mean; + var predictedCovariance = filteredState.PredictedState.Covariance; + var updatedMean = filteredState.UpdatedState.Mean; + var updatedCovariance = filteredState.UpdatedState.Covariance; var kalmanGain = filteredState.KalmanGain; - var smoothedMean = empty_like(updatedMean); - var smoothedCovariance = empty_like(updatedCovariance); + var smoothedState = new LinearDynamicalSystemState( + mean: empty_like(updatedMean), + covariance: empty_like(updatedCovariance) + ); + + var sxx00 = zeros_like(smoothedState.Covariance, dtype: parameters.ScalarType, device: parameters.Device); + var sxx10 = zeros_like(smoothedState.Covariance, dtype: parameters.ScalarType, device: parameters.Device); + var sxx11 = zeros_like(smoothedState.Covariance, dtype: parameters.ScalarType, device: parameters.Device); + + var tx1 = zeros([parameters.NumStates], dtype: parameters.ScalarType, device: parameters.Device); + var tx0 = zeros([parameters.NumStates], dtype: parameters.ScalarType, device: parameters.Device); + var ty1 = zeros([parameters.NumObservations], dtype: parameters.ScalarType, device: parameters.Device); + var tyx11 = zeros([parameters.NumObservations, parameters.NumStates], dtype: parameters.ScalarType, device: parameters.Device); + var tyy11 = zeros([parameters.NumObservations, parameters.NumObservations], dtype: parameters.ScalarType, device: parameters.Device); - var S00 = zeros_like(smoothedCovariance, dtype: scalarType, device: device); - var S10 = zeros_like(smoothedCovariance, dtype: scalarType, device: device); - var S11 = zeros_like(smoothedCovariance, dtype: scalarType, device: device); + var identityStates = eye(parameters.NumStates, dtype: parameters.ScalarType, device: parameters.Device); // Fix the last time point - smoothedMean[-1] = updatedMean[-1]; - smoothedCovariance[-1] = updatedCovariance[-1]; + smoothedState.Mean[-1] = updatedMean[-1]; + smoothedState.Covariance[-1] = updatedCovariance[-1]; var smoothedLagOneCovariance = WrappedTensorDisposeScope(() => (identityStates - kalmanGain[-1] - .matmul(measurementFunction)) - .matmul(transitionMatrix) + .matmul(parameters.MeasurementFunction)) + .matmul(parameters.TransitionMatrix) .matmul(updatedCovariance[-2])); - S11[-1] = outer(updatedMean[-1], updatedMean[-1]) + updatedCovariance[-1]; + sxx11[-1] = outer(smoothedState.Mean[-1], smoothedState.Mean[-1]) + smoothedState.Covariance[-1]; - var smoothingGain = empty([numStates, numStates], dtype: scalarType, device: device); + if (parameters.OffsetsProvided) + { + tx1 += smoothedState.Mean[-1]; + ty1 += observations[-1]; + tyx11 += outer(observations[-1], smoothedState.Mean[-1]); + tyy11 += outer(observations[-1], observations[-1]); + } + + var smoothingGain = empty([parameters.NumStates, parameters.NumStates], dtype: parameters.ScalarType, device: parameters.Device); var smoothingGainNext = null as Tensor; // Backward pass @@ -417,74 +353,99 @@ Device device { // Smoothing gain smoothingGain = smoothingGainNext ?? WrappedTensorDisposeScope(() => updatedCovariance[time].matmul( - InverseCholesky(transitionMatrix.mT, predictedCovariance[time + 1]) + InverseCholesky(parameters.TransitionMatrix.mT, predictedCovariance[time + 1]) )); // Smoothed mean - smoothedMean[time] = WrappedTensorDisposeScope(() => updatedMean[time] + smoothedState.Mean[time] = WrappedTensorDisposeScope(() => updatedMean[time] + smoothingGain.matmul( - (smoothedMean[time + 1] - predictedMean[time + 1]).unsqueeze(-1) + (smoothedState.Mean[time + 1] - predictedMean[time + 1]).unsqueeze(-1) ).squeeze(-1)); // Smoothed covariance - smoothedCovariance[time] = WrappedTensorDisposeScope(() => updatedCovariance[time] + smoothingGain - .matmul(smoothedCovariance[time + 1] - predictedCovariance[time + 1]) + smoothedState.Covariance[time] = WrappedTensorDisposeScope(() => updatedCovariance[time] + smoothingGain + .matmul(smoothedState.Covariance[time + 1] - predictedCovariance[time + 1]) .matmul(smoothingGain.mT) ); - var expectationUpdate = outer(smoothedMean[time], smoothedMean[time]) + smoothedCovariance[time]; - S11[time] = expectationUpdate; - S00[time + 1] = expectationUpdate; - S10[time + 1] = outer(smoothedMean[time + 1], smoothedMean[time]) + smoothedLagOneCovariance; + var expectationUpdate = outer(smoothedState.Mean[time], smoothedState.Mean[time]) + smoothedState.Covariance[time]; + sxx11[time] = expectationUpdate; + sxx00[time + 1] = expectationUpdate; + sxx10[time + 1] = outer(smoothedState.Mean[time + 1], smoothedState.Mean[time]) + smoothedLagOneCovariance; + + if (parameters.OffsetsProvided) + { + tx1 += smoothedState.Mean[time]; + tx0 += smoothedState.Mean[time + 1]; + ty1 += observations[time]; + tyx11 += outer(observations[time], smoothedState.Mean[time]); + tyy11 += outer(observations[time], observations[time]); + } // Compute next smoothing gain for lag one covariance if (time > 0) { smoothingGainNext = WrappedTensorDisposeScope(() => updatedCovariance[time - 1] - .matmul(InverseCholesky(transitionMatrix.mT, predictedCovariance[time]))); + .matmul(InverseCholesky(parameters.TransitionMatrix.mT, predictedCovariance[time]))); // Smoothed lag one covariance smoothedLagOneCovariance = WrappedTensorDisposeScope(() => updatedCovariance[time] .matmul(smoothingGainNext.mT) + smoothingGain.matmul(smoothedLagOneCovariance - - transitionMatrix.matmul(updatedCovariance[time])) + - parameters.TransitionMatrix.matmul(updatedCovariance[time])) .matmul(smoothingGainNext.mT)); } } - var smoothingGain0 = WrappedTensorDisposeScope(() => initialCovariance.matmul( - InverseCholesky(transitionMatrix.mT, predictedCovariance[0]) + var smoothingGain0 = WrappedTensorDisposeScope(() => parameters.InitialCovariance.matmul( + InverseCholesky(parameters.TransitionMatrix.mT, predictedCovariance[0]) )); // Smoothed initial mean - var smoothedInitialMean = WrappedTensorDisposeScope(() => initialMean + smoothingGain0.matmul( - (smoothedMean[0] - predictedMean[0]).unsqueeze(-1) + var smoothedInitialMean = WrappedTensorDisposeScope(() => parameters.InitialMean + smoothingGain0.matmul( + (smoothedState.Mean[0] - predictedMean[0]).unsqueeze(-1) ).squeeze(-1)); // Smoothed initial covariance - var smoothedInitialCovariance = WrappedTensorDisposeScope(() => initialCovariance + smoothingGain0 - .matmul(smoothedCovariance[0] - predictedCovariance[0]) + var smoothedInitialCovariance = WrappedTensorDisposeScope(() => parameters.InitialCovariance + smoothingGain0 + .matmul(smoothedState.Covariance[0] - predictedCovariance[0]) .matmul(smoothingGain0.mT)); // Smoothed lag one covariance at time 0 smoothedLagOneCovariance = WrappedTensorDisposeScope(() => updatedCovariance[0] .matmul(smoothingGain0.mT) + smoothingGain.matmul(smoothedLagOneCovariance - - transitionMatrix.matmul(updatedCovariance[0])) + - parameters.TransitionMatrix.matmul(updatedCovariance[0])) .matmul(smoothingGain0.mT)); - S10[0] = outer(smoothedMean[0], smoothedInitialMean) + smoothedLagOneCovariance; - S00[0] = outer(smoothedInitialMean, smoothedInitialMean) + smoothedInitialCovariance; + sxx10[0] = outer(smoothedState.Mean[0], smoothedInitialMean) + smoothedLagOneCovariance; + sxx00[0] = outer(smoothedInitialMean, smoothedInitialMean) + smoothedInitialCovariance; + + if (parameters.OffsetsProvided) + tx0 += smoothedInitialMean; - return new SmoothedStateWithAuxiliaryVariables( - smoothedMean: smoothedMean, - smoothedCovariance: smoothedCovariance, + var state = new SmoothedStateWithAuxiliaryVariables( + smoothedState: smoothedState, smoothedInitialMean: smoothedInitialMean, - smoothedInitialCovariance: smoothedInitialCovariance, - S00: S00, - S10: S10, - S11: S11 + smoothedInitialCovariance: smoothedInitialCovariance ); + + var stats = parameters.OffsetsProvided ? new SummaryStatisticsEM( + sxx11: sxx11, + sxx10: sxx10, + sxx00: sxx00, + tx1: tx1, + tx0: tx0, + ty1: ty1, + tyx11: tyx11, + tyy11: tyy11 + ) : new SummaryStatisticsEM( + sxx11: sxx11, + sxx10: sxx10, + sxx00: sxx00 + ); + + return (state, stats); } public static ExpectationMaximizationResult ExpectationMaximization( @@ -494,9 +455,6 @@ public static ExpectationMaximizationResult ExpectationMaximization( double tolerance = 1e-4, ParametersToEstimate parametersToEstimate = new()) { - parameters = parameters.Copy(); - parameters.Validate(); - var timeBins = observation.size(0); var numObservations = (int)observation.size(1); var logLikelihood = empty(maxIterations, dtype: ScalarType.Float32, device: parameters.Device); @@ -513,23 +471,14 @@ public static ExpectationMaximizationResult ExpectationMaximization( var autoCorrelationObservations = observationT.matmul(observation); using var g = no_grad(); - + for (int iteration = 0; iteration < maxIterations; iteration++) { // Filter observations var filteredState = Filter( - observation: observation, timeBins: timeBins, - numStates: parameters.NumStates, - numObservations: numObservations, - transitionMatrix: parameters.TransitionMatrix, - measurementFunction: parameters.MeasurementFunction, - processNoiseCovariance: parameters.ProcessNoiseCovariance, - measurementNoiseCovariance: parameters.MeasurementNoiseCovariance, - initialMean: parameters.InitialMean, - initialCovariance: parameters.InitialCovariance, - scalarType: parameters.ScalarType, - device: parameters.Device); + observation: observation, + parameters: parameters); // Compute log likelihood (avoid creating intermediate tensors) var llSumDouble = filteredState.LogLikelihood.sum() @@ -551,49 +500,127 @@ public static ExpectationMaximizationResult ExpectationMaximization( previousLogLikelihood = filteredLogLikelihoodSum; // Smooth the filtered results - var smoothedState = Smooth( + var (state, statistics) = Smooth( filteredState: filteredState, - timeBins: timeBins, - numStates: parameters.NumStates, - transitionMatrix: parameters.TransitionMatrix, - measurementFunction: parameters.MeasurementFunction, - initialMean: parameters.InitialMean, - initialCovariance: parameters.InitialCovariance, - identityStates: identityStates, - scalarType: parameters.ScalarType, - device: parameters.Device); + observations: observation, + parameters: parameters); // Sufficient statistics - var S00 = smoothedState.S00.sum([0]); - var S11 = smoothedState.S11.sum([0]); - var S10 = smoothedState.S10.sum([0]); + var sxx00 = statistics.Sxx00.sum([0]); + var sxx11 = statistics.Sxx11.sum([0]); + var sxx10 = statistics.Sxx10.sum([0]); // Replace einsum with faster matmul - var crossCorrelationObservations = observationT.matmul(smoothedState.SmoothedMean); + var crossCorrelationObservations = observationT.matmul(state.SmoothedState.Mean); // Update parameters if (parametersToEstimate.TransitionMatrix) - parameters.TransitionMatrix = InverseCholesky(S10, S00); + { + if (parameters.OffsetsProvided) + { + parameters.TransitionMatrix.set_( + InverseCholesky( + sxx10 - outer(statistics.Tx1, statistics.Tx0) / timeBins, + sxx00 - outer(statistics.Tx0, statistics.Tx0) / timeBins + ) + ); + } + else + { + parameters.TransitionMatrix.set_( + InverseCholesky( + sxx10, + sxx00 + ) + ); + } + } + + if (parametersToEstimate.StateOffset) + { + if (parameters.OffsetsProvided) + { + parameters.StateOffset.set_( + (statistics.Tx1 - parameters.TransitionMatrix.matmul(statistics.Tx0)) / timeBins + ); + } + } if (parametersToEstimate.MeasurementFunction) - parameters.MeasurementFunction = InverseCholesky(crossCorrelationObservations, S11); + { + if (parameters.OffsetsProvided) + { + parameters.MeasurementFunction.set_( + InverseCholesky( + statistics.Tyx11 - outer(statistics.Ty1, statistics.Tx1) / timeBins, + sxx11 - outer(statistics.Tx0, statistics.Tx0) / timeBins + ) + ); + } + else + { + parameters.MeasurementFunction.set_( + InverseCholesky( + crossCorrelationObservations, + sxx11 + ) + ); + } + } - if (parametersToEstimate.ProcessNoiseCovariance) - parameters.ProcessNoiseCovariance = WrappedTensorDisposeScope(() => - EnsureSymmetric((S11 - parameters.TransitionMatrix.matmul(S10.mT)) / timeBins)); + if (parametersToEstimate.ObservationOffset) + { + if (parameters.OffsetsProvided) + { + parameters.ObservationOffset.set_( + (statistics.Ty1 - parameters.MeasurementFunction.matmul(statistics.Tx1)) / timeBins + ); + } + } - var explainedObservationCovariance = parameters.MeasurementFunction.matmul(crossCorrelationObservations.mT); + if (parametersToEstimate.ProcessNoiseCovariance) + { + if (parameters.OffsetsProvided) + { + parameters.ProcessNoiseCovariance.set_(WrappedTensorDisposeScope(() => + EnsureSymmetric( + (sxx11 - outer(statistics.Tx1, parameters.StateOffset) - outer(parameters.StateOffset, statistics.Tx1) + timeBins * outer(parameters.StateOffset, parameters.StateOffset) - parameters.TransitionMatrix.matmul(sxx10.mT) - sxx10.matmul(parameters.TransitionMatrix.mT) + linalg.multi_dot([parameters.TransitionMatrix, sxx00, parameters.TransitionMatrix.mT]) + parameters.TransitionMatrix.matmul(outer(statistics.Tx0, parameters.StateOffset)) + outer(parameters.StateOffset, statistics.Tx0).matmul(parameters.TransitionMatrix.mT)) / timeBins + ) + )); + } + else + { + parameters.ProcessNoiseCovariance.set_(WrappedTensorDisposeScope(() => + EnsureSymmetric((sxx11 - parameters.TransitionMatrix.matmul(sxx10.mT)) / timeBins))); + } + } if (parametersToEstimate.MeasurementNoiseCovariance) - parameters.MeasurementNoiseCovariance = WrappedTensorDisposeScope(() => - EnsureSymmetric((autoCorrelationObservations - explainedObservationCovariance - explainedObservationCovariance.mT - + parameters.MeasurementFunction.matmul(S11).matmul(parameters.MeasurementFunction.mT)) / timeBins)); + { + if (parameters.OffsetsProvided) + { + parameters.MeasurementNoiseCovariance.set_(WrappedTensorDisposeScope(() => + EnsureSymmetric( + (statistics.Tyy11 - outer(statistics.Ty1, parameters.ObservationOffset) - outer(parameters.ObservationOffset, statistics.Ty1) + timeBins * outer(parameters.ObservationOffset, parameters.ObservationOffset) - parameters.MeasurementFunction.matmul(statistics.Tyx11.mT) - statistics.Tyx11.matmul(parameters.MeasurementFunction.mT) + parameters.MeasurementFunction.matmul(outer(statistics.Tx1, parameters.ObservationOffset)) + linalg.multi_dot([parameters.MeasurementFunction, sxx11, parameters.MeasurementFunction.mT]) + outer(parameters.ObservationOffset, statistics.Tx1).matmul(parameters.MeasurementFunction.mT)) / timeBins + ) + )); + } + else + { + var explainedObservationCovariance = parameters.MeasurementFunction.matmul(crossCorrelationObservations.mT); + + if (parametersToEstimate.MeasurementNoiseCovariance) + parameters.MeasurementNoiseCovariance.set_(WrappedTensorDisposeScope(() => + EnsureSymmetric((autoCorrelationObservations - explainedObservationCovariance - explainedObservationCovariance.mT + + parameters.MeasurementFunction.matmul(sxx11).matmul(parameters.MeasurementFunction.mT)) / timeBins))); + } + } if (parametersToEstimate.InitialMean) - parameters.InitialMean = smoothedState.SmoothedInitialMean; + parameters.InitialMean.set_(state.SmoothedInitialMean); if (parametersToEstimate.InitialCovariance) - parameters.InitialCovariance = smoothedState.SmoothedInitialCovariance; + parameters.InitialCovariance.set_(state.SmoothedInitialCovariance); } return new ExpectationMaximizationResult(logLikelihood, parameters); @@ -691,19 +718,19 @@ public static StochasticSubspaceIdentificationResult StochasticSubspaceIdentific ); } - public LinearDynamicalSystemState OrthogonalizeMeanAndCovariance(Tensor mean, Tensor covariance) + public LinearDynamicalSystemState OrthogonalizeMeanAndCovariance(LinearDynamicalSystemState state) { var (_, S, Vt) = linalg.svd(Parameters.MeasurementFunction); var SVt = diag(S).matmul(Vt); Tensor orthogonalizedMean = null; - if (mean is not null) - orthogonalizedMean = matmul(mean, SVt.mT); + if (state.Mean is not null) + orthogonalizedMean = matmul(state.Mean, SVt.mT); Tensor orthogonalizedCovariance = null; - if (covariance is not null) + if (state.Covariance is not null) { - var auxilary = matmul(SVt, covariance); + var auxilary = matmul(SVt, state.Covariance); orthogonalizedCovariance = matmul(auxilary, SVt.mT); } @@ -727,6 +754,10 @@ public void UpdateParameters(KalmanFilterParameters updatedParameters) Parameters.InitialMean.set_(updatedParameters.InitialMean); if (updatedParameters.InitialCovariance is not null) Parameters.InitialCovariance.set_(updatedParameters.InitialCovariance); + if (updatedParameters.StateOffset is not null) + Parameters.StateOffset.set_(updatedParameters.StateOffset); + if (updatedParameters.ObservationOffset is not null) + Parameters.ObservationOffset.set_(updatedParameters.ObservationOffset); } private static Tensor EnsureSymmetric(Tensor M) => 0.5 * (M + M.mT); diff --git a/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs index e6b02ee5..521d9336 100644 --- a/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/KalmanFilterParameters.cs @@ -1,4 +1,6 @@ using System; +using System.Text; +using TorchSharp; using static TorchSharp.torch; namespace Bonsai.ML.Lds.Torch; @@ -6,100 +8,79 @@ namespace Bonsai.ML.Lds.Torch; /// /// Represents the parameters of a Kalman filter model. /// -/// -/// Initializes a new instance of the struct with the specified parameters. -/// -/// -/// -/// -/// -/// -/// -/// -/// -/// -/// -/// -/// -public struct KalmanFilterParameters( - int numStates, - int numObservations, - Tensor transitionMatrix = null, - Tensor measurementFunction = null, - Tensor processNoiseCovariance = null, - Tensor measurementNoiseCovariance = null, - Tensor initialMean = null, - Tensor initialCovariance = null, - Device device = null, - ScalarType? scalarType = null, - bool requiresGrad = false, - bool validated = false) +public class KalmanFilterParameters : nn.Module { + private readonly static StringBuilder _sb = new(); + private readonly ScalarType _scalarType; + private readonly Device _device; + /// /// The number of states in the system. /// - public int NumStates = numStates; + public int NumStates { get; private set; } /// /// The number of observations in the system. /// - public int NumObservations = numObservations; + public int NumObservations { get; private set; } /// /// The state transition matrix. /// - public Tensor TransitionMatrix = transitionMatrix; + public Tensor TransitionMatrix { get; private set; } /// /// The measurement function. /// - public Tensor MeasurementFunction = measurementFunction; + public Tensor MeasurementFunction { get; private set; } /// /// The process noise covariance. /// - public Tensor ProcessNoiseCovariance = processNoiseCovariance; + public Tensor ProcessNoiseCovariance { get; private set; } /// /// The measurement noise covariance. /// - public Tensor MeasurementNoiseCovariance = measurementNoiseCovariance; + public Tensor MeasurementNoiseCovariance { get; private set; } /// /// The initial mean. /// - public Tensor InitialMean = initialMean; + public Tensor InitialMean { get; private set; } /// /// The initial covariance. /// - public Tensor InitialCovariance = initialCovariance; + public Tensor InitialCovariance { get; private set; } /// - /// The device to use for tensor operations. + /// The optional state offset. /// - public Device Device = device ?? CPU; + public Tensor StateOffset { get; private set; } /// - /// The data type of the tensors. + /// The optional observation offset. /// - public ScalarType ScalarType = scalarType ?? ScalarType.Float32; + public Tensor ObservationOffset { get; private set; } /// - /// Indicates whether the tensors require gradient computation. + /// Indicates whether any offsets have been provided. /// - public bool RequiresGrad = requiresGrad; + public bool OffsetsProvided => StateOffset is not null || ObservationOffset is not null; /// - /// Indicates whether the parameters have been validated. + /// The data type of the tensors. /// - /// - /// This field is used to avoid redundant validation checks. - /// - public bool Validated = validated; + public ScalarType ScalarType => _scalarType; /// - /// Initializes the Kalman filter parameters. + /// The device on which the tensors are allocated. + /// + public Device Device => _device; + + /// + /// Initializes a new instance of the class with the specified parameters. /// /// /// @@ -109,12 +90,12 @@ public struct KalmanFilterParameters( /// /// /// + /// + /// /// /// /// - /// - /// - public static KalmanFilterParameters Initialize( + public KalmanFilterParameters( int? numStates = null, int? numObservations = null, Tensor transitionMatrix = null, @@ -123,91 +104,94 @@ public static KalmanFilterParameters Initialize( Tensor measurementNoiseCovariance = null, Tensor initialMean = null, Tensor initialCovariance = null, + Tensor stateOffset = null, + Tensor observationOffset = null, Device device = null, ScalarType? scalarType = null, - bool requiresGrad = false - ) + bool requiresGrad = false) : base("KalmanFilterParameters") { - var trueNumStates = numStates ?? -1; - var trueNumObservations = numObservations ?? -1; + numStates ??= InferNumStates(transitionMatrix, measurementFunction, initialMean, initialCovariance, processNoiseCovariance, stateOffset); + numObservations ??= InferNumObservations(measurementFunction, measurementNoiseCovariance, observationOffset); if (numStates is null) - { - ValidateNumStates(transitionMatrix, measurementFunction, initialMean, initialCovariance, processNoiseCovariance, out trueNumStates); - } - - if (trueNumStates <= 0) - throw new ArgumentOutOfRangeException(nameof(trueNumStates), "Number of states must be greater than zero."); - + throw new ArgumentOutOfRangeException(nameof(numStates), "Number of states must be specified or inferred from the parameters."); if (numObservations is null) - { - ValidateNumObservations(measurementFunction, measurementNoiseCovariance, out trueNumObservations); - } - - if (trueNumObservations <= 0) - throw new ArgumentOutOfRangeException(nameof(numObservations), "Number of observations must be greater than zero."); + throw new ArgumentOutOfRangeException(nameof(numObservations), "Number of observations must be specified or inferred from the parameters."); - transitionMatrix = transitionMatrix?.clone() ?? eye(trueNumStates); - measurementFunction = measurementFunction?.clone() ?? eye(trueNumObservations, trueNumStates); - initialMean = initialMean?.clone() ?? zeros(trueNumStates); - initialCovariance = initialCovariance?.clone() ?? eye(trueNumStates); + transitionMatrix = transitionMatrix?.clone() ?? eye(numStates.Value); + measurementFunction = measurementFunction?.clone() ?? eye(numObservations.Value, numStates.Value); + initialMean = initialMean?.clone() ?? zeros(numStates.Value); + initialCovariance = initialCovariance?.clone() ?? eye(numStates.Value); processNoiseCovariance = processNoiseCovariance?.NumberOfElements == 1 - ? CreateCovarianceMatrixFromScalar(processNoiseCovariance, trueNumStates, "Process noise variance") + ? CreateCovarianceMatrixFromScalar(processNoiseCovariance, numStates.Value, "Process noise variance") : processNoiseCovariance?.clone() - ?? CreateCovarianceMatrixFromScalar(1.0, trueNumStates, "Process noise variance"); + ?? CreateCovarianceMatrixFromScalar(1.0, numStates.Value, "Process noise variance"); measurementNoiseCovariance = measurementNoiseCovariance?.NumberOfElements == 1 - ? CreateCovarianceMatrixFromScalar(measurementNoiseCovariance, trueNumObservations, "Measurement noise variance") + ? CreateCovarianceMatrixFromScalar(measurementNoiseCovariance, numObservations.Value, "Measurement noise variance") : measurementNoiseCovariance?.clone() - ?? CreateCovarianceMatrixFromScalar(1.0, trueNumObservations, "Measurement noise variance"); - - var parameters = new KalmanFilterParameters( - trueNumStates, - trueNumObservations, - transitionMatrix, - measurementFunction, - processNoiseCovariance, - measurementNoiseCovariance, - initialMean, - initialCovariance, - device, - scalarType, - requiresGrad - ); - - parameters.Validate(); - parameters.ToScalarType(scalarType); - parameters.ToDevice(device); - parameters.SetGrad(requiresGrad); - - return parameters; + ?? CreateCovarianceMatrixFromScalar(1.0, numObservations.Value, "Measurement noise variance"); + + NumStates = numStates.Value; + NumObservations = numObservations.Value; + TransitionMatrix = transitionMatrix; + MeasurementFunction = measurementFunction; + ProcessNoiseCovariance = processNoiseCovariance; + MeasurementNoiseCovariance = measurementNoiseCovariance; + InitialMean = initialMean; + InitialCovariance = initialCovariance; + StateOffset = stateOffset; + ObservationOffset = observationOffset; + + Validate(); + + if (device is not null) + this.to(device); + if (scalarType is not null) + this.to(scalarType.Value); + + _device = TransitionMatrix.device; + _scalarType = TransitionMatrix.dtype; + + SetGrad(requiresGrad); } /// /// Validates the Kalman filter parameters. /// - public void Validate() + private void Validate() { - if (Validated) - return; - - ValidateNumStates(TransitionMatrix, MeasurementFunction, InitialMean, InitialCovariance, ProcessNoiseCovariance, out NumStates); - ValidateNumObservations(MeasurementFunction, MeasurementNoiseCovariance, out NumObservations); - ValidateMatrix(TransitionMatrix, "Transition matrix", isSquare: true, expectedDimension1: NumStates); - ValidateMatrix(MeasurementFunction, "Measurement function", expectedDimension1: NumObservations, expectedDimension2: NumStates); - ValidateMatrix(ProcessNoiseCovariance, "Process noise covariance", isSquare: true, expectedDimension1: NumStates); - ValidateMatrix(MeasurementNoiseCovariance, "Measurement noise covariance", isSquare: true, expectedDimension1: NumObservations); - ValidateVector(InitialMean, "Initial mean", NumStates); - ValidateMatrix(InitialCovariance, "Initial covariance", isSquare: true, expectedDimension1: NumStates); - Validated = true; + int numStates = InferNumStates(TransitionMatrix, MeasurementFunction, InitialMean, InitialCovariance, ProcessNoiseCovariance, StateOffset); + + int numObservations = InferNumObservations(MeasurementFunction, MeasurementNoiseCovariance, ObservationOffset); + + ValidateMatrix(TransitionMatrix, "Transition matrix", isSquare: true, expectedDimension1: numStates); + + ValidateMatrix(MeasurementFunction, "Measurement function", expectedDimension1: numObservations, expectedDimension2: numStates); + + ValidateMatrix(ProcessNoiseCovariance, "Process noise covariance", isSquare: true, expectedDimension1: numStates); + + ValidateMatrix(MeasurementNoiseCovariance, "Measurement noise covariance", isSquare: true, expectedDimension1: numObservations); + + ValidateVector(InitialMean, "Initial mean", numStates); + ValidateMatrix(InitialCovariance, "Initial covariance", isSquare: true, expectedDimension1: numStates); + + if (StateOffset is not null) + ValidateVector(StateOffset, "State offset", numStates); + + if (ObservationOffset is not null) + ValidateVector(ObservationOffset, "Observation offset", numObservations); + + NumStates = numStates; + NumObservations = numObservations; } /// /// Creates a copy of the current Kalman filter parameters. /// /// - public readonly KalmanFilterParameters Copy() => new( + public KalmanFilterParameters Copy() => new( NumStates, NumObservations, TransitionMatrix?.clone(), @@ -216,90 +200,57 @@ public void Validate() MeasurementNoiseCovariance?.clone(), InitialMean?.clone(), InitialCovariance?.clone(), - Device, - ScalarType, - RequiresGrad, - Validated + StateOffset?.clone(), + ObservationOffset?.clone() ); - /// - /// Converts the tensors in the Kalman filter parameters to the specified scalar type. - /// - /// - public void ToScalarType(ScalarType? scalarType) - { - if (scalarType is not null) - { - TransitionMatrix = TransitionMatrix?.to_type(scalarType.Value); - MeasurementFunction = MeasurementFunction?.to_type(scalarType.Value); - ProcessNoiseCovariance = ProcessNoiseCovariance?.to_type(scalarType.Value); - MeasurementNoiseCovariance = MeasurementNoiseCovariance?.to_type(scalarType.Value); - InitialMean = InitialMean?.to_type(scalarType.Value); - InitialCovariance = InitialCovariance?.to_type(scalarType.Value); - } - } - - /// - /// Moves the tensors in the Kalman filter parameters to the specified device. - /// - /// - public void ToDevice(Device? device) - { - if (device is not null) - { - TransitionMatrix = TransitionMatrix?.to(device); - MeasurementFunction = MeasurementFunction?.to(device); - ProcessNoiseCovariance = ProcessNoiseCovariance?.to(device); - MeasurementNoiseCovariance = MeasurementNoiseCovariance?.to(device); - InitialMean = InitialMean?.to(device); - InitialCovariance = InitialCovariance?.to(device); - } - } - /// /// Sets the requires_grad flag for all tensors in the Kalman filter parameters. /// /// public void SetGrad(bool requiresGrad) { - if (RequiresGrad == requiresGrad) - return; - TransitionMatrix = TransitionMatrix?.requires_grad_(requiresGrad); MeasurementFunction = MeasurementFunction?.requires_grad_(requiresGrad); ProcessNoiseCovariance = ProcessNoiseCovariance?.requires_grad_(requiresGrad); MeasurementNoiseCovariance = MeasurementNoiseCovariance?.requires_grad_(requiresGrad); InitialMean = InitialMean?.requires_grad_(requiresGrad); InitialCovariance = InitialCovariance?.requires_grad_(requiresGrad); - RequiresGrad = requiresGrad; + StateOffset = StateOffset?.requires_grad_(requiresGrad); + ObservationOffset = ObservationOffset?.requires_grad_(requiresGrad); } - private static void ValidateNumStates(Tensor transitionMatrix, Tensor measurementFunction, Tensor initialMean, Tensor initialCovariance, Tensor processNoiseCovariance, out int numStates) + private static int InferNumStates(Tensor transitionMatrix, Tensor measurementFunction, Tensor initialMean, Tensor initialCovariance, Tensor processNoiseCovariance, Tensor stateOffset) { if (transitionMatrix is not null) { ValidateMatrix(transitionMatrix, "Transition matrix", isSquare: true); - numStates = (int)transitionMatrix.size(0); + return (int)transitionMatrix.size(0); } else if (measurementFunction is not null) { ValidateMatrix(measurementFunction, "Measurement function"); - numStates = (int)measurementFunction.size(1); + return (int)measurementFunction.size(1); } else if (initialMean is not null) { ValidateVector(initialMean, "Initial mean"); - numStates = (int)initialMean.size(0); + return (int)initialMean.size(0); } else if (initialCovariance is not null) { ValidateMatrix(initialCovariance, "Initial covariance", isSquare: true); - numStates = (int)initialCovariance.size(0); + return (int)initialCovariance.size(0); } else if (processNoiseCovariance is not null) { ValidateMatrix(processNoiseCovariance, "Process noise covariance", isSquare: true); - numStates = (int)processNoiseCovariance.size(0); + return (int)processNoiseCovariance.size(0); + } + else if (stateOffset is not null) + { + ValidateVector(stateOffset, "State offset"); + return (int)stateOffset.size(0); } else { @@ -307,17 +258,22 @@ private static void ValidateNumStates(Tensor transitionMatrix, Tensor measuremen } } - private static void ValidateNumObservations(Tensor measurementFunction, Tensor measurementNoiseCovariance, out int numObservations) + private static int InferNumObservations(Tensor measurementFunction, Tensor measurementNoiseCovariance, Tensor observationOffset) { if (measurementFunction is not null) { ValidateMatrix(measurementFunction, "Measurement function"); - numObservations = (int)measurementFunction.size(0); + return (int)measurementFunction.size(0); } else if (measurementNoiseCovariance is not null) { ValidateMatrix(measurementNoiseCovariance, "Measurement noise covariance", isSquare: true); - numObservations = (int)measurementNoiseCovariance.size(0); + return (int)measurementNoiseCovariance.size(0); + } + else if (observationOffset is not null) + { + ValidateVector(observationOffset, "Observation offset"); + return (int)observationOffset.size(0); } else { @@ -369,14 +325,14 @@ private static Tensor CreateCovarianceMatrixFromScalar(Tensor variance, int dime } /// - public override readonly string ToString() => - $"KalmanFilterParameters(NumStates={NumStates}, NumObservations={NumObservations}, TransitionMatrix={TransitionMatrix}, MeasurementFunction={MeasurementFunction}, ProcessNoiseCovariance={ProcessNoiseCovariance}, MeasurementNoiseCovariance={MeasurementNoiseCovariance}, InitialMean={InitialMean}, InitialCovariance={InitialCovariance})"; + public override string ToString() => _sb.Length == 0 ? _sb.Append( + $"KalmanFilterParameters(NumStates={NumStates}, NumObservations={NumObservations}, TransitionMatrix={TransitionMatrix}, MeasurementFunction={MeasurementFunction}, ProcessNoiseCovariance={ProcessNoiseCovariance}, MeasurementNoiseCovariance={MeasurementNoiseCovariance}, InitialMean={InitialMean}, InitialCovariance={InitialCovariance})" + (StateOffset is not null ? $", StateOffset={StateOffset}" : "") + (ObservationOffset is not null ? $", ObservationOffset={ObservationOffset}" : "")).ToString() : _sb.ToString(); /// /// Returns a string representation of the Kalman filter parameters with the specified tensor string style. /// /// /// - public readonly string ToString(TorchSharp.TensorStringStyle tensorStringStyle) => - $"KalmanFilterParameters(NumStates={NumStates}, NumObservations={NumObservations}, TransitionMatrix={TransitionMatrix.ToString(tensorStringStyle)}, MeasurementFunction={MeasurementFunction.ToString(tensorStringStyle)}, ProcessNoiseCovariance={ProcessNoiseCovariance.ToString(tensorStringStyle)}, MeasurementNoiseCovariance={MeasurementNoiseCovariance.ToString(tensorStringStyle)}, InitialMean={InitialMean.ToString(tensorStringStyle)}, InitialCovariance={InitialCovariance.ToString(tensorStringStyle)})"; -} \ No newline at end of file + public string ToString(TensorStringStyle tensorStringStyle) => _sb.Length == 0 ? _sb.Append( + $"KalmanFilterParameters(NumStates={NumStates}, NumObservations={NumObservations}, TransitionMatrix={TransitionMatrix.ToString(tensorStringStyle)}, MeasurementFunction={MeasurementFunction.ToString(tensorStringStyle)}, ProcessNoiseCovariance={ProcessNoiseCovariance.ToString(tensorStringStyle)}, MeasurementNoiseCovariance={MeasurementNoiseCovariance.ToString(tensorStringStyle)}, InitialMean={InitialMean.ToString(tensorStringStyle)}, InitialCovariance={InitialCovariance.ToString(tensorStringStyle)})" + (StateOffset is not null ? $", StateOffset={StateOffset.ToString(tensorStringStyle)}" : "") + (ObservationOffset is not null ? $", ObservationOffset={ObservationOffset.ToString(tensorStringStyle)}" : "")).ToString() : _sb.ToString(); +} diff --git a/src/Bonsai.ML.Lds.Torch/LinearDynamicalSystemState.cs b/src/Bonsai.ML.Lds.Torch/LinearDynamicalSystemState.cs index f89c08ef..60323c6c 100644 --- a/src/Bonsai.ML.Lds.Torch/LinearDynamicalSystemState.cs +++ b/src/Bonsai.ML.Lds.Torch/LinearDynamicalSystemState.cs @@ -7,11 +7,11 @@ namespace Bonsai.ML.Lds.Torch; /// /// /// -public class LinearDynamicalSystemState(Tensor mean, Tensor covariance) : ILinearDynamicalSystemState +public readonly struct LinearDynamicalSystemState(Tensor mean, Tensor covariance) : ILinearDynamicalSystemState { /// - public Tensor Mean => mean; - + public readonly Tensor Mean => mean; + /// - public Tensor Covariance => covariance; -} \ No newline at end of file + public readonly Tensor Covariance => covariance; +} diff --git a/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs index 9d1ea40d..bb6435b6 100644 --- a/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/LoadKalmanFilterParameters.cs @@ -1,11 +1,8 @@ using System; using System.ComponentModel; -using System.Linq; using System.Reactive.Linq; using System.Xml.Serialization; -using Bonsai.ML.Torch; using System.IO; -using TorchSharp; using static TorchSharp.torch; namespace Bonsai.ML.Lds.Torch; @@ -44,11 +41,12 @@ private static Tensor LoadTensorFromFile(string basePath, string filePath) filePath = System.IO.Path.Combine(basePath, filePath); - if (!File.Exists(filePath)) + if (File.Exists(filePath)) { - throw new FileNotFoundException($"The specified file was not found: {filePath}"); + return Tensor.Load(filePath); } - return Tensor.Load(filePath); + + return null; } /// @@ -72,17 +70,21 @@ public IObservable Process() var measurementNoiseCovariance = LoadTensorFromFile(Path, "MeasurementNoiseCovariance.bin"); var initialMean = LoadTensorFromFile(Path, "InitialMean.bin"); var initialCovariance = LoadTensorFromFile(Path, "InitialCovariance.bin"); + var stateOffset = LoadTensorFromFile(Path, "StateOffset.bin"); + var observationOffset = LoadTensorFromFile(Path, "ObservationOffset.bin"); - var parameters = KalmanFilterParameters.Initialize( + var parameters = new KalmanFilterParameters( transitionMatrix: transitionMatrix, measurementFunction: measurementFunction, processNoiseCovariance: processNoiseCovariance, measurementNoiseCovariance: measurementNoiseCovariance, initialMean: initialMean, initialCovariance: initialCovariance, + stateOffset: stateOffset, + observationOffset: observationOffset, device: Device, scalarType: Type); return Observable.Return(parameters); } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Lds.Torch/Orthogonalize.cs b/src/Bonsai.ML.Lds.Torch/Orthogonalize.cs index 3afa276e..061832ff 100644 --- a/src/Bonsai.ML.Lds.Torch/Orthogonalize.cs +++ b/src/Bonsai.ML.Lds.Torch/Orthogonalize.cs @@ -27,12 +27,9 @@ public class Orthogonalize /// /// /// - public IObservable Process(IObservable source) + public IObservable Process(IObservable source) { - return source.Select(input => - { - return Model.OrthogonalizeMeanAndCovariance(input.Mean, input.Covariance); - }); + return source.Select(Model.OrthogonalizeMeanAndCovariance); } /// @@ -44,20 +41,7 @@ public IObservable Process(IObservable { - return Model.OrthogonalizeMeanAndCovariance(input.Mean, input.Covariance); - }); - } - - /// - /// Processes an observable sequence of filtered results, orthogonalizing the mean and covariance estimates. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(input => - { - return Model.OrthogonalizeMeanAndCovariance(input.Mean, input.Covariance); + return Model.OrthogonalizeMeanAndCovariance(input.UpdatedState); }); } @@ -70,9 +54,8 @@ public IObservable Process(IObservable { - var mean = input.Item1; - var covariance = input.Item2; - return Model.OrthogonalizeMeanAndCovariance(mean, covariance); + var state = new LinearDynamicalSystemState(input.Item1, input.Item2); + return Model.OrthogonalizeMeanAndCovariance(state); }); } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Lds.Torch/ParametersToEstimate.cs b/src/Bonsai.ML.Lds.Torch/ParametersToEstimate.cs index 039c4018..2127ef6e 100644 --- a/src/Bonsai.ML.Lds.Torch/ParametersToEstimate.cs +++ b/src/Bonsai.ML.Lds.Torch/ParametersToEstimate.cs @@ -12,13 +12,17 @@ namespace Bonsai.ML.Lds.Torch; /// /// /// +/// +/// public struct ParametersToEstimate( bool transitionMatrix = true, bool measurementFunction = true, bool processNoiseCovariance = true, bool measurementNoiseCovariance = true, bool initialMean = true, - bool initialCovariance = true) + bool initialCovariance = true, + bool stateOffset = false, + bool observationOffset = false) { /// /// The state transition matrix. @@ -49,4 +53,14 @@ public struct ParametersToEstimate( /// The initial covariance. /// public bool InitialCovariance = initialCovariance; -} \ No newline at end of file + + /// + /// The state offset. + /// + public bool StateOffset = stateOffset; + + /// + /// The observation offset. + /// + public bool ObservationOffset = observationOffset; +} diff --git a/src/Bonsai.ML.Lds.Torch/SaveKalmanFilterParameters.cs b/src/Bonsai.ML.Lds.Torch/SaveKalmanFilterParameters.cs index e282afa1..83dff315 100644 --- a/src/Bonsai.ML.Lds.Torch/SaveKalmanFilterParameters.cs +++ b/src/Bonsai.ML.Lds.Torch/SaveKalmanFilterParameters.cs @@ -1,12 +1,8 @@ using System; using System.ComponentModel; -using System.Linq; using System.Reactive.Linq; -using System.Xml.Serialization; using System.IO; -using Bonsai.ML.Torch; using TorchSharp; -using static TorchSharp.torch; namespace Bonsai.ML.Lds.Torch; @@ -60,6 +56,8 @@ private void SaveKalmanFilterParametersToDisk(KalmanFilterParameters parameters) var measurementNoiseCovariancePath = System.IO.Path.Combine(path, "MeasurementNoiseCovariance.bin"); var initialMeanPath = System.IO.Path.Combine(path, "InitialMean.bin"); var initialCovariancePath = System.IO.Path.Combine(path, "InitialCovariance.bin"); + var stateOffsetPath = System.IO.Path.Combine(path, "StateOffset.bin"); + var observationOffsetPath = System.IO.Path.Combine(path, "ObservationOffset.bin"); if (Directory.Exists(path)) { @@ -69,7 +67,9 @@ private void SaveKalmanFilterParametersToDisk(KalmanFilterParameters parameters) File.Exists(processNoiseCovariancePath) || File.Exists(measurementNoiseCovariancePath) || File.Exists(initialMeanPath) || - File.Exists(initialCovariancePath)) + File.Exists(initialCovariancePath) || + File.Exists(stateOffsetPath) || + File.Exists(observationOffsetPath)) ) { throw new InvalidOperationException("The save path already exists."); @@ -88,17 +88,23 @@ private void SaveKalmanFilterParametersToDisk(KalmanFilterParameters parameters) File.Delete(initialMeanPath); if (File.Exists(initialCovariancePath)) File.Delete(initialCovariancePath); + if (File.Exists(stateOffsetPath)) + File.Delete(stateOffsetPath); + if (File.Exists(observationOffsetPath)) + File.Delete(observationOffsetPath); } } Directory.CreateDirectory(path); - parameters.TransitionMatrix.Save(transitionMatrixPath); - parameters.MeasurementFunction.Save(measurementFunctionPath); - parameters.ProcessNoiseCovariance.Save(processNoiseCovariancePath); - parameters.MeasurementNoiseCovariance.Save(measurementNoiseCovariancePath); - parameters.InitialMean.Save(initialMeanPath); - parameters.InitialCovariance.Save(initialCovariancePath); + parameters.TransitionMatrix?.Save(transitionMatrixPath); + parameters.MeasurementFunction?.Save(measurementFunctionPath); + parameters.ProcessNoiseCovariance?.Save(processNoiseCovariancePath); + parameters.MeasurementNoiseCovariance?.Save(measurementNoiseCovariancePath); + parameters.InitialMean?.Save(initialMeanPath); + parameters.InitialCovariance?.Save(initialCovariancePath); + parameters.StateOffset?.Save(stateOffsetPath); + parameters.ObservationOffset?.Save(observationOffsetPath); } /// @@ -137,4 +143,4 @@ public enum SuffixType /// Guid } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Lds.Torch/Smooth.cs b/src/Bonsai.ML.Lds.Torch/Smooth.cs index a2780d0c..69985696 100644 --- a/src/Bonsai.ML.Lds.Torch/Smooth.cs +++ b/src/Bonsai.ML.Lds.Torch/Smooth.cs @@ -41,8 +41,11 @@ public IObservable Process(IObservable { - var filteredState = new FilteredState(input.Item1, input.Item2, input.Item3, input.Item4); + var filteredState = new FilteredState( + predictedState: new LinearDynamicalSystemState(input.Item1, input.Item2), + updatedState: new LinearDynamicalSystemState(input.Item3, input.Item4) + ); return Model.Smooth(filteredState); }); } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs b/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs index 6fb24d3e..be041086 100644 --- a/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs +++ b/src/Bonsai.ML.Lds.Torch/StochasticSubspaceIdentification.cs @@ -16,6 +16,10 @@ namespace Bonsai.ML.Lds.Torch; [WorkflowElementCategory(ElementCategory.Combinator)] public class StochasticSubspaceIdentification { + private int? _targetNumStates = 2; + private int _maxLag = 20; + private double _threshold = 1e-4; + /// /// The target number of states in the Kalman filter model. /// @@ -47,47 +51,55 @@ public double Threshold } /// - /// If true, the transition matrix will be estimated during the EM algorithm. + /// If true, the transition matrix will be estimated during the SSI algorithm. /// - [Description("If true, the transition matrix will be estimated during the EM algorithm.")] + [Description("If true, the transition matrix will be estimated during the SSI algorithm.")] public bool EstimateTransitionMatrix { get; set; } = true; /// - /// If true, the measurement function will be estimated during the EM algorithm. + /// If true, the measurement function will be estimated during the SSI algorithm. /// - [Description("If true, the measurement function will be estimated during the EM algorithm.")] + [Description("If true, the measurement function will be estimated during the SSI algorithm.")] public bool EstimateMeasurementFunction { get; set; } = true; /// - /// If true, the process noise covariance will be estimated during the EM algorithm. + /// If true, the process noise covariance will be estimated during the SSI algorithm. /// - [Description("If true, the process noise covariance will be estimated during the EM algorithm.")] + [Description("If true, the process noise covariance will be estimated during the SSI algorithm.")] public bool EstimateProcessNoiseCovariance { get; set; } = true; /// - /// If true, the measurement noise covariance will be estimated during the EM algorithm. + /// If true, the measurement noise covariance will be estimated during the SSI algorithm. /// - [Description("If true, the measurement noise covariance will be estimated during the EM algorithm.")] + [Description("If true, the measurement noise covariance will be estimated during the SSI algorithm.")] public bool EstimateMeasurementNoiseCovariance { get; set; } = true; /// - /// If true, the initial mean will be estimated during the EM algorithm. + /// If true, the initial mean will be estimated during the SSI algorithm. /// - [Description("If true, the initial mean will be estimated during the EM algorithm.")] + [Description("If true, the initial mean will be estimated during the SSI algorithm.")] public bool EstimateInitialMean { get; set; } = true; /// - /// If true, the initial covariance will be estimated during the EM algorithm. + /// If true, the initial covariance will be estimated during the SSI algorithm. /// - [Description("If true, the initial covariance will be estimated during the EM algorithm.")] + [Description("If true, the initial covariance will be estimated during the SSI algorithm.")] public bool EstimateInitialCovariance { get; set; } = true; - private int? _targetNumStates = 2; - private int _maxLag = 20; - private double _threshold = 1e-4; + /// + /// If true, the state offset will be estimated during the SSI algorithm. + /// + [Description("If true, the state offset will be estimated during the SSI algorithm.")] + public bool EstimateStateOffset { get; set; } = false; + + /// + /// If true, the observation offset will be estimated during the SSI algorithm. + /// + [Description("If true, the observation offset will be estimated during the SSI algorithm.")] + public bool EstimateObservationOffset { get; set; } = false; /// - /// Processes an observable sequence of input tensors, applying the Expectation-Maximization algorithm to learn the parameters of a Kalman filter model. + /// Processes an observable sequence of input tensors, applying the Stochastic Subspace Identification algorithm to learn the parameters of a Kalman filter model. /// /// /// @@ -103,7 +115,9 @@ public IObservable Process(IObservable Process(IObservable tolerance || Math.Abs(originalOutput.X[i,j] - pythonOutput.X[i,j]) > tolerance) + if (Math.Abs(bonsaiOutput.X[i, j] - pythonOutput.X[i, j]) > tolerance || Math.Abs(originalOutput.X[i, j] - pythonOutput.X[i, j]) > tolerance) { - Console.WriteLine($"Discrepency found comparing X at index ({i},{j}) with tolerance {tolerance}: bonsaiOutput = {bonsaiOutput.X[i,j]}, pythonOutput = {pythonOutput.X[i,j]}, originalOutput = {originalOutput.X[i,j]}."); + Console.WriteLine($"Discrepency found comparing X at index ({i},{j}) with tolerance {tolerance}: bonsaiOutput = {bonsaiOutput.X[i, j]}, pythonOutput = {pythonOutput.X[i, j]}, originalOutput = {originalOutput.X[i, j]}."); return false; } } @@ -104,9 +104,9 @@ private static bool CompareJSONData(string basePath, double tolerance = 1e-9) { for (int j = 0; j < bonsaiOutput.P.GetLength(1); j++) { - if (Math.Abs(bonsaiOutput.P[i,j] - pythonOutput.P[i,j]) > tolerance || Math.Abs(originalOutput.P[i,j] - pythonOutput.P[i,j]) > tolerance) + if (Math.Abs(bonsaiOutput.P[i, j] - pythonOutput.P[i, j]) > tolerance || Math.Abs(originalOutput.P[i, j] - pythonOutput.P[i, j]) > tolerance) { - Console.WriteLine($"Discrepency found comparing P at index ({i},{j}) with tolerance {tolerance}: bonsaiOutput = {bonsaiOutput.P[i,j]}, pythonOutput = {pythonOutput.P[i,j]}, originalOutput = {originalOutput.P[i,j]}."); + Console.WriteLine($"Discrepency found comparing P at index ({i},{j}) with tolerance {tolerance}: bonsaiOutput = {bonsaiOutput.P[i, j]}, pythonOutput = {pythonOutput.P[i, j]}, originalOutput = {originalOutput.P[i, j]}."); return false; } }