-
Notifications
You must be signed in to change notification settings - Fork 631
Expand file tree
/
Copy pathforward-inl.h
More file actions
292 lines (259 loc) · 11.3 KB
/
forward-inl.h
File metadata and controls
292 lines (259 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Include guard for non-SIMD code.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
#include <stddef.h>
#include <stdint.h>
#include <cmath>
#include <vector>
#include "backprop/activations.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
#endif
#include "gemma/ops.h"
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
template <typename ArrayT>
void InputEmbedding(const ArrayT& weights, const std::vector<int>& prompt,
const float scaling, float* HWY_RESTRICT output,
size_t model_dim) {
HWY_ASSERT(!prompt.empty());
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
int token = prompt[pos];
Decompress(weights, token * model_dim, output + pos * model_dim, model_dim);
MulByConst(scaling, output + pos * model_dim, model_dim);
}
}
template<typename WT, typename XT, typename OutT>
void ApplyRMSNorm(const WT* HWY_RESTRICT weights, const XT* HWY_RESTRICT x,
size_t model_dim, size_t num_tokens,
OutT* HWY_RESTRICT output,
hwy::ThreadPool& pool) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = pos * model_dim;
RMSNorm(x + offset, weights, output + offset, model_dim);
}
}
static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs,
const std::vector<int>& prompt,
size_t context_size,
size_t vocab_size,
hwy::ThreadPool& pool) {
HWY_ASSERT(!prompt.empty());
float loss = 0.0f;
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
if (pos + 1 < context_size) {
continue; // next token is part of context, don't try to predict it
}
const int next_token = prompt[pos + 1];
loss += std::log(probs[pos * vocab_size + next_token]);
}
float scaling = -1.0 / std::log(2.0);
return loss * scaling;
}
template <typename TConfig, template<typename> typename LayerT>
void ApplyForwardLayer(const LayerT<TConfig>& weights,
ForwardLayer<float, TConfig>& activations,
size_t num_tokens,
float* HWY_RESTRICT output,
hwy::ThreadPool& pool) {
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads;
static const float kQueryScale =
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
HWY_ASSERT(num_tokens <= kSeqLen);
ApplyRMSNorm(weights.pre_attention_norm_scale.data(),
activations.input.data(), kModelDim, num_tokens,
activations.pre_att_rms_out.data(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec<(kHeads + 2) * kQKVDim, kModelDim>(
weights.qkv_einsum_w, 0,
activations.pre_att_rms_out.data() + pos * kModelDim, nullptr,
activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool);
}
const size_t num_tasks = kHeads * num_tokens;
for (size_t pos = 0; pos < num_tokens; ++pos) {
float* HWY_RESTRICT k =
activations.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
Rope(k, kQKVDim, pos);
}
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
float* HWY_RESTRICT q =
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
Rope(q, kQKVDim, pos);
MulByConst(kQueryScale, q, kQKVDim);
});
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
const float* HWY_RESTRICT q =
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
float* HWY_RESTRICT head_att =
activations.att.data() + (pos * kHeads + head) * kSeqLen;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const float* HWY_RESTRICT k2 =
activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
const float score = Dot(q, k2, kQKVDim);
head_att[pos2] = score;
}
});
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
float* HWY_RESTRICT head_att =
activations.att.data() + (pos * kHeads + head) * kSeqLen;
Softmax(head_att, pos + 1);
});
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
const float* HWY_RESTRICT head_att =
activations.att.data() + (pos * kHeads + head) * kSeqLen;
float* HWY_RESTRICT att_out =
activations.att_out.data() + (pos * kHeads + head) * kQKVDim;
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
float* HWY_RESTRICT v2 =
activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
}
});
hwy::ZeroBytes(activations.attention_out.data(),
num_tokens * kModelDim * sizeof(activations.attention_out[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < kHeads; ++head) {
MatVec<kModelDim, kQKVDim>(
weights.attn_vec_einsum_w, head * kModelDim * kQKVDim,
activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim,
nullptr, activations.att_post1.data() + pos * kModelDim, pool);
AddFrom(activations.att_post1.data() + pos * kModelDim,
activations.attention_out.data() + pos * kModelDim, kModelDim);
}
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(activations.input.data() + pos * kModelDim,
activations.attention_out.data() + pos * kModelDim, kModelDim);
}
ApplyRMSNorm(weights.pre_ffw_norm_scale.data(),
activations.attention_out.data(), kModelDim, num_tokens,
activations.bf_pre_ffw_rms_out.data(), pool);
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec<kFFHiddenDim * 2, kModelDim>(
weights.gating_einsum_w, 0,
activations.bf_pre_ffw_rms_out.data() + pos * kModelDim, nullptr,
activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t hidden_offset = pos * kFFHiddenDim * 2;
const float* HWY_RESTRICT out =
activations.ffw_hidden.data() + hidden_offset;
const float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
float* HWY_RESTRICT out_gated =
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim;
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
DF df;
for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) {
const auto y = hn::Load(df, out + i);
const auto x = hn::Load(df, out_mul + i);
hn::Store(hn::Mul(x, Gelu(df, y)), df, out_gated + i);
}
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec<kModelDim, kFFHiddenDim>(
weights.linear_w, 0,
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim,
nullptr, output + pos * kModelDim, pool);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(activations.attention_out.data() + pos * kModelDim,
output + pos * kModelDim, kModelDim);
}
}
template <typename TConfig, template<typename...> typename WeightsT,
template<typename> typename LayerT>
float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
size_t context_size,
const WeightsT<TConfig>& weights,
ForwardPass<float, TConfig>& forward,
hwy::ThreadPool& pool) {
static constexpr size_t kVocabSize = TConfig::kVocabSize;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kLayers = TConfig::kLayers;
const float kEmbScaling = EmbeddingScaling<TConfig>();
static_assert(!TConfig::kAbsolutePE);
static_assert(TConfig::kPostNorm == PostNormType::None);
static_assert(TConfig::kKVHeads == 1);
HWY_DASSERT(context_size > 0);
HWY_DASSERT(context_size < prompt.size());
const size_t num_tokens = prompt.size() - 1;
InputEmbedding(weights.embedder_input_embedding, prompt, kEmbScaling,
forward.layers[0].input.data(), kModelDim);
for (size_t layer = 0; layer < kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer];
// TODO(szabadka) Implement Griffin layer.
HWY_ASSERT(type == LayerAttentionType::kGemma);
float* HWY_RESTRICT output = layer + 1 < kLayers ?
forward.layers[layer + 1].input.data() :
forward.final_layer_output.data();
ApplyForwardLayer<TConfig, LayerT>(
*weights.GetLayer(layer), forward.layers[layer],
num_tokens, output, pool);
}
ApplyRMSNorm(weights.final_norm_scale.data(),
forward.final_layer_output.data(),
kModelDim, num_tokens, forward.final_norm_output.data(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec<kVocabSize, kModelDim>(
weights.embedder_input_embedding, 0,
forward.final_norm_output.data() + pos * kModelDim, nullptr,
forward.logits.data() + pos * kVocabSize, pool);
}
if constexpr (TConfig::kFinalCap > 0.0f) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
LogitsSoftCap(TConfig::kFinalCap,
forward.logits.data() + pos * kVocabSize, kVocabSize);
}
}
hwy::CopyBytes(forward.logits.data(), forward.probs.data(),
num_tokens * kVocabSize * sizeof(forward.logits[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) {
Softmax(forward.probs.data() + pos * kVocabSize, kVocabSize);
}
return CrossEntropyLoss(forward.probs.data(), prompt, context_size,
kVocabSize, pool);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // NOLINT