-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Expand file tree
/
Copy pathPhi2RotaryEmbedding.cs
More file actions
49 lines (42 loc) · 1.72 KB
/
Phi2RotaryEmbedding.cs
File metadata and controls
49 lines (42 loc) · 1.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using TorchSharp;
using static TorchSharp.torch;
namespace Microsoft.ML.GenAI.Phi.Module;
internal class Phi2RotaryEmbedding : nn.Module<
Tensor, // input
int, // seq_len
(
Tensor, // cos
Tensor // sin
)>
{
private readonly double _base;
private readonly int _maxPositionEmbeddings;
private readonly int _dim;
public Phi2RotaryEmbedding(double baseValue, int maxPositionEmbeddings, int dim)
: base(nameof(Phi2RotaryEmbedding))
{
_base = baseValue;
_maxPositionEmbeddings = maxPositionEmbeddings;
_dim = dim;
var thetaNumerator = torch.arange(0, _dim, 2, dtype: ScalarType.Int64).to(torch.float32);
this.register_buffer("inv_freq", torch.pow(baseValue, -1.0f * (thetaNumerator / dim)), persistent: false);
}
public int Dim => _dim;
#pragma warning disable MSML_GeneralName // This name should be PascalCased
public override (Tensor, Tensor) forward(Tensor x, int seqLen)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
// TODO
// can be calculated once and cached
var invFreq = this.get_buffer("inv_freq")!.to(x.device);
var t = torch.arange(seqLen, dtype: invFreq.dtype, device: invFreq.device);
var freqs = torch.outer(t, invFreq).to(torch.float32);
var emb = torch.cat([freqs, freqs], dim: -1);
var cos = torch.cos(emb);
var sin = torch.sin(emb);
return (cos[..seqLen].to_type(x.dtype), sin[..seqLen].to_type(x.dtype));
}
}