Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions src/Bonsai.ML.PointProcessDecoder.Design/PosteriorImageOverlay.cs
Comment thread
glopesdev marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
using Bonsai;
using Bonsai.Design;
using Bonsai.Expressions;
using System;
using OxyPlot.Series;
using OxyPlot.Axes;
using OxyPlot;
using Bonsai.Vision.Design;
using Bonsai.ML.Torch;
using OpenCV.Net;
using static TorchSharp.torch;
using PointProcessDecoder.Core;
using System.Drawing.Imaging;
using System.Linq;
using TorchSharp;

[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.PosteriorImageOverlay),
Target = typeof(MashupSource<ImageMashupVisualizer, Bonsai.ML.PointProcessDecoder.Design.PosteriorVisualizer>))]

namespace Bonsai.ML.PointProcessDecoder.Design
{
/// <summary>
/// Class that overlays the true posterior distribution on the input image.
/// </summary>
public class PosteriorImageOverlay : DialogTypeVisualizer
{

private ImageMashupVisualizer imageVisualizer;
private int[] _stateSpaceMin;
private int[] _stateSpaceMax;
private int _height;
private int _width;
private PointProcessModel _model;
private string _modelName;
private static Func<object, Tensor> _extractPosterior;

/// <inheritdoc/>
public override void Load(IServiceProvider provider)
{
imageVisualizer = (ImageMashupVisualizer)provider.GetService(typeof(MashupVisualizer));
}

/// <inheritdoc/>
public override void Show(object value)
{
if (value is not DecoderDataFrame && value is not ClassifierDataFrame)
{
return;
}

_modelName ??= value switch
{
DecoderDataFrame decoderDataFrame => decoderDataFrame.Name,
ClassifierDataFrame classifierDataFrame => classifierDataFrame.Name,
_ => throw new InvalidOperationException("The input value is invalid.")
};

_extractPosterior ??= value switch
{
DecoderDataFrame _ => input => ((DecoderDataFrame)input).DecoderData.Posterior,
ClassifierDataFrame _ => input => ((ClassifierDataFrame)input).ClassifierData.DecoderData.Posterior,
_ => throw new InvalidOperationException("The node is invalid.")
};

if (_model is null)
{
_model = PointProcessModelManager.GetModel(_modelName);

_stateSpaceMin = [.. _model.StateSpace.Points
.min(dim: 0)
.values
.to_type(ScalarType.Int32)
.data<int>()
];

_stateSpaceMax = [.. _model.StateSpace.Points
.max(dim: 0)
.values
.to_type(ScalarType.Int32)
.data<int>()
];

_width = _stateSpaceMax[0] - _stateSpaceMin[0];
_height = _stateSpaceMax[1] - _stateSpaceMin[1];
}

var image = imageVisualizer.VisualizerImage;

var posterior = _extractPosterior(value)[-1].T.unsqueeze(0);

var posteriorScaled = torchvision.transforms.functional.resize(posterior, _height, _width);
posteriorScaled -= posteriorScaled.min();
posteriorScaled /= posteriorScaled.max();
posteriorScaled *= 255.0;

var fullPosterior = zeros([1, image.Height, image.Width], dtype: ScalarType.Byte, device: posterior.device);

fullPosterior[0, torch.TensorIndex.Slice(_stateSpaceMin[1], _stateSpaceMax[1]), torch.TensorIndex.Slice(_stateSpaceMin[0], _stateSpaceMax[0])] = posteriorScaled.to_type(ScalarType.Byte);

var posteriorImage = OpenCVHelper.ToImage(fullPosterior.cpu().permute(1, 2, 0), CPU);

var posteriorOverlay = new IplImage(posteriorImage.Size, posteriorImage.Depth, 3);

CV.CvtColor(posteriorImage, posteriorOverlay, ColorConversion.Gray2Rgb);

CV.LUT(posteriorOverlay, posteriorOverlay, ColormapExtensions.HotLut);

CV.AddWeighted(image, 0.8, posteriorOverlay, 0.5, 0, image);
}

/// <inheritdoc/>
public override void Unload()
{
imageVisualizer = null;
_model = null;
_stateSpaceMin = null;
_stateSpaceMax = null;
_modelName = null;
_extractPosterior = null;
}
}

internal static class ColormapExtensions
{
public static Mat HotLut => _hotLut ??= EnsureHotLut();
private static Mat _hotLut;
private static Mat EnsureHotLut()
{
_hotLut = new Mat(1, 256, Depth.U8, 3);
for (int i = 0; i < 256; i++)
{
double t = i / 255.0;
double r, g, b;
if (t < 1.0 / 3.0)
{
r = 3 * t;
g = 0;
b = 0;
}
else if (t < 2.0 / 3.0)
{
r = 1;
g = 3 * t - 1;
b = 0;
}
else
{
r = 1;
g = 1;
b = 3 * t - 2;
}
// Clamp and convert to bytes (BGR order)
byte R = (byte)Math.Round(r * 255);
byte G = (byte)Math.Round(g * 255);
byte B = (byte)Math.Round(b * 255);
_hotLut[i] = new OpenCV.Net.Scalar(B, G, R);
}
return _hotLut;
}
}
}
13 changes: 9 additions & 4 deletions src/Bonsai.ML.PointProcessDecoder.Design/PosteriorVisualizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
Target = typeof(Bonsai.ML.PointProcessDecoder.Decode))]
[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.PosteriorVisualizer),
Target = typeof(Bonsai.ML.PointProcessDecoder.GetDecoderData))]
[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.PosteriorVisualizer),
[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.PosteriorVisualizer),
Comment thread
glopesdev marked this conversation as resolved.
Target = typeof(Bonsai.ML.PointProcessDecoder.GetClassifierData))]
[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.PosteriorVisualizer),
Target = typeof(Bonsai.ML.PointProcessDecoder.DecoderDataFrame))]
[assembly: TypeVisualizer(typeof(Bonsai.ML.PointProcessDecoder.Design.PosteriorVisualizer),
Target = typeof(Bonsai.ML.PointProcessDecoder.ClassifierDataFrame))]

namespace Bonsai.ML.PointProcessDecoder.Design
{
Expand Down Expand Up @@ -85,16 +89,17 @@ public override void Load(IServiceProvider provider)
}
}

if (node == null)
if (node is null)
{
throw new InvalidOperationException("The decode node is invalid.");
}

_convertInputData = node switch
{
Decode _ => input => (Tensor)input,
GetDecoderData _ => input => ((DecoderData)input).Posterior,
GetClassifierData _ => input => ((ClassifierData)input).DecoderData.Posterior,
GetDecoderData _ => input => ((DecoderDataFrame)input).DecoderData.Posterior,
GetClassifierData _ => input => ((ClassifierDataFrame)input).ClassifierData.DecoderData.Posterior,
DecoderDataFrame _ => input => ((DecoderDataFrame)input).DecoderData.Posterior,
_ => throw new InvalidOperationException("The node is invalid.")
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
<ProjectReference Include="..\Bonsai.ML.Torch\Bonsai.ML.Torch.csproj" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="PointProcessDecoder.Core" Version="0.4.0" />
<PackageReference Include="PointProcessDecoder.Core" Version="0.5.0" />
</ItemGroup>
<ItemGroup>
<InternalsVisibleTo Include="Bonsai.ML.PointProcessDecoder.Design" />
Expand Down
23 changes: 23 additions & 0 deletions src/Bonsai.ML.PointProcessDecoder/ClassifierDataFrame.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using PointProcessDecoder.Core.Decoder;

namespace Bonsai.ML.PointProcessDecoder;

/// <summary>
/// Represents a packaged data frame containing the output of a point process classifier model.
/// </summary>
/// <param name="classifierData"></param>
/// <param name="name"></param>
public readonly struct ClassifierDataFrame(
ClassifierData classifierData,
string name) : IPointProcessModelReference
{
/// <summary>
/// The packaged classifier data.
/// </summary>
public ClassifierData ClassifierData => classifierData;

/// <summary>
/// The name of the point process model.
/// </summary>
public string Name => name;
}
34 changes: 17 additions & 17 deletions src/Bonsai.ML.PointProcessDecoder/CreatePointProcessModel.cs
Comment thread
glopesdev marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System;
using System;
using System.ComponentModel;
using System.Xml.Serialization;
using System.Linq;
Expand Down Expand Up @@ -87,7 +87,7 @@ public ScalarType? ScalarType
/// </summary>
[Category("Covariate Parameters")]
[Description("The number of dimensions of the covariate.")]
public int Dimensions
public int CovariateDimensions
{
get
{
Expand Down Expand Up @@ -163,7 +163,7 @@ public long[] Steps
[Category("Covariate Parameters")]
[Description("The kernel bandwidth used to estimate the probability density over the covariate dimensions. Must be the same length as the covariate dimensions.")]
[TypeConverter(typeof(UnidimensionalArrayConverter))]
public double[] Bandwidth
public double[] CovariateBandwidth
{
get
{
Expand Down Expand Up @@ -195,7 +195,7 @@ public EncoderType EncoderType

private int? kernelLimit = null;
/// <summary>
/// Gets or sets the maximum number of kernels maintained in memory for each probability density estimation made by the encoder.
/// Gets or sets the maximum number of kernels maintained in memory for each probability density estimation made by the encoder.
/// </summary>
/// <remarks>
/// In the case of sorted spikes, there is an estimate for the full covariate distribution and an estimate for each unit. #
Expand All @@ -215,22 +215,22 @@ public int? KernelLimit
}
}

private int? nUnits = null;
private int? numUnits = null;
/// <summary>
/// Gets or sets the number of sorted spiking units.
/// Only used when the encoder type is set to <see cref="EncoderType.SortedSpikes"/>.
/// </summary>
[Category("Encoder Parameters")]
[Description("The number of sorted spiking units. Only used when the encoder type is set to SortedSpikeEncoder.")]
public int? NUnits
public int? NumUnits
{
get
{
return nUnits;
return numUnits;
}
set
{
nUnits = value;
numUnits = value;
}
}

Expand All @@ -253,22 +253,22 @@ public int? MarkDimensions
}
}

private int? markChannels = null;
private int? numChannels = null;
/// <summary>
/// Gets or sets the number of mark recording channels.
/// Only used when the encoder type is set to <see cref="EncoderType.ClusterlessMarks"/>.
/// </summary>
[Category("Encoder Parameters")]
[Description("The number of mark recording channels. Only used when the encoder type is set to ClusterlessMarkEncoder.")]
public int? MarkChannels
public int? NumChannels
{
get
{
return markChannels;
return numChannels;
}
set
{
markChannels = value;
numChannels = value;
}
}

Expand Down Expand Up @@ -352,7 +352,7 @@ public TransitionsType TransitionsType
private double? sigmaRandomWalk = null;
/// <summary>
/// Gets or sets the standard deviation of the random walk transitions model.
/// Only used when the transitions type is set to <see cref="TransitionsType.RandomWalk"/> or when the decoder type is set to <see cref="DecoderType.HybridStateSpaceClassifier"/>
/// Only used when the transitions type is set to <see cref="TransitionsType.RandomWalk"/> or when the decoder type is set to <see cref="DecoderType.HybridStateSpaceClassifier"/>
/// </summary>
[Category("Decoder Parameters")]
[Description("The standard deviation of the random walk transitions model. Only used when the transitions type is set to RandomWalk or when the decoder type is set to HybridStateSpaceClassifier.")]
Expand Down Expand Up @@ -433,12 +433,12 @@ public IObservable<PointProcessModel> Process()
minStateSpace: minCovariateRange,
maxStateSpace: maxCovariateRange,
stepsStateSpace: stepsCovariateRange,
observationBandwidth: covariateBandwidth,
covariateBandwidth: covariateBandwidth,
stateSpaceDimensions: covariateDimensions,
markDimensions: markDimensions,
markChannels: markChannels,
numChannels: numChannels,
markBandwidth: markBandwidth,
nUnits: nUnits,
numUnits: numUnits,
distanceThreshold: distanceThreshold,
sigmaRandomWalk: sigmaRandomWalk,
kernelLimit: kernelLimit,
Expand All @@ -449,4 +449,4 @@ public IObservable<PointProcessModel> Process()
.Concat(Observable.Never(resource.Model))
.Finally(resource.Dispose));
}
}
}
14 changes: 7 additions & 7 deletions src/Bonsai.ML.PointProcessDecoder/Decode.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System;
using System;
using System.ComponentModel;
using System.Reactive.Linq;
using System.Text;
Expand Down Expand Up @@ -39,23 +39,23 @@ public bool IgnoreNoSpikes
}

/// <summary>
/// Decodes the input neural data into a posterior state estimate using a point process model.
/// Decodes the observations into a posterior state estimate using a point process model.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tensor> Process(IObservable<Tensor> source)
{
var modelName = Name;
return source.Select(input =>
return source.Select(observations =>
{
var model = PointProcessModelManager.GetModel(modelName);
if (_updateIgnoreNoSpikes)
if (_updateIgnoreNoSpikes)
{
model.Likelihood.IgnoreNoSpikes = _ignoreNoSpikes;
_updateIgnoreNoSpikes = false;
}
return model.Decode(input);

return model.Decode(observations);
});
}
}
}
Loading
Loading