-
Notifications
You must be signed in to change notification settings - Fork 631
Expand file tree
/
Copy pathgemma.h
More file actions
259 lines (227 loc) · 10.6 KB
/
gemma.h
File metadata and controls
259 lines (227 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
#include <functional>
#include <random>
#include <string>
#include <vector>
// IWYU pragma: begin_exports
#include "compression/io.h" // Path
#include "gemma/activations.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/kv_cache.h"
#include "gemma/tokenizer.h"
#include "gemma/weights.h"
#include "paligemma/image.h"
#include "util/allocator.h" // RowVectorBatch
#include "util/basics.h" // TokenAndProb
#include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/timer.h"
// IWYU pragma: end_exports
#include "hwy/aligned_allocator.h" // Span
namespace gcpp {
using PromptTokens = hwy::Span<const int>;
// Batches of independent queries have their own prompt, previous token,
// position in the sequence, and KVCache.
using QueriesPromptTokens = hwy::Span<const PromptTokens>;
using QueriesToken = hwy::Span<const int>;
using QueriesPos = hwy::Span<const size_t>;
using KVCaches = hwy::Span<KVCache>;
// StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f. StreamFunc should return false to stop generation and
// true to continue generation.
using StreamFunc = std::function<bool(int, float)>;
// BatchStreamFunc is called with (query_idx, pos, token, probability).
// For prompt tokens, probability is 0.0f.
// StreamFunc should return false to stop generation and true to continue.
using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
// If not empty, AcceptFunc is called with token. It should return false for
// tokens you don't want to generate and true for tokens you want to generate.
using AcceptFunc = std::function<bool(int, float)>;
// If not empty, SampleFunc is called with the logits for the next token, which
// it may modify/overwrite, and its return value is the next generated token
// together with its probability.
using SampleFunc = std::function<TokenAndProb(float*, size_t)>;
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
// - index of query within containing batch (if any); zero otherwise.
// - position in the tokens sequence
// - name of the data, e.g. "tokens" for token IDs
// - layer index (or -1 for global outputs)
// - pointer to the data array
// - size of the data array
using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&,
int, const float*, size_t)>;
// If not empty, ActivationsObserverFunc is invoked after each layer with:
// - per-query position within the tokens sequence
// - layer index (or -1 for post-norm output)
// - activations
using ActivationsObserverFunc =
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;
// ImageTokens are represented as a RowVectorBatch, where each "batch" index
// corresponds to a token for an image patch as computed by the image encoder.
using ImageTokens = RowVectorBatch<float>;
// RuntimeConfig holds configuration for a single generation run.
struct RuntimeConfig {
// If not empty, batch_stream_token is called for each token in the batch,
// instead of stream_token.
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
if (batch_stream_token) {
return batch_stream_token(query_idx, pos, token, prob);
}
return stream_token(token, prob);
}
// Limit on the number of tokens generated.
size_t max_generated_tokens;
// These defaults are overridden by InferenceArgs::CopyTo(*this):
// Max tokens per batch during prefill.
size_t prefill_tbatch_size = 32;
// Max queries per batch (one token from each) during decode.
size_t decode_qbatch_size = 16;
// Sampling-related parameters.
float temperature; // Temperature for sampling.
size_t top_k = kTopK; // Top-k for sampling.
std::mt19937* gen; // Random number generator used for sampling.
int verbosity; // Controls verbosity of printed messages.
// Functions operating on the generated tokens.
StreamFunc stream_token;
BatchStreamFunc batch_stream_token;
AcceptFunc accept_token; // if empty, accepts all tokens.
SampleFunc sample_func; // if empty, uses SampleTopK.
// Observer callbacks for intermediate data.
LayersOutputFunc layers_output; // if not empty, called after each layer.
ActivationsObserverFunc activations_observer; // if set, called per-layer.
// If not empty, these point to the image tokens and are used in the
// PaliGemma prefix-LM style attention.
const ImageTokens *image_tokens = nullptr;
// Whether to use thread spinning to reduce barrier synchronization latency.
// Mutable so we can change kDefault to kTrue/kFalse during Generate, because
// RuntimeConfig is const there and is not passed to the Gemma ctor. This
// default decision is likely sufficient because it is based on whether
// threads are successfully pinned.
mutable Tristate use_spinning = Tristate::kDefault;
// End-of-sequence token.
int eos_id = EOS_ID;
};
struct TimingInfo {
void NotifyPrefill(size_t tokens, double start) {
prefill_duration = hwy::platform::Now() - start;
prefill_tokens = tokens;
time_to_first_token = 0.0;
tokens_generated = 0;
}
void NotifyGenerated(double prefill_start, double gen_start) {
++tokens_generated;
if (HWY_UNLIKELY(tokens_generated == 1)) {
time_to_first_token = hwy::platform::Now() - prefill_start;
if (verbosity >= 1) {
double prefill_tok_sec =
static_cast<double>(prefill_tokens) / prefill_duration;
fprintf(stderr,
"\n\n[ Timing info ] Prefill: %d ms for %zu prompt tokens "
"(%.2f tokens / sec); Time to first token: %d ms\n",
static_cast<int>(prefill_duration * 1000), prefill_tokens,
prefill_tok_sec, static_cast<int>(time_to_first_token * 1000));
}
}
if (verbosity >= 2 && tokens_generated % 128 == 0) {
double gen_tok_sec = static_cast<double>(tokens_generated) /
(hwy::platform::Now() - gen_start);
fprintf(stderr,
"\n\n[ Timing info ] %zu tokens generated "
"(avg speed %.2f tokens / sec)\n\n",
tokens_generated, gen_tok_sec);
}
}
void NotifyGenerateDone(double gen_start) {
generate_duration = hwy::platform::Now() - gen_start;
if (verbosity >= 1) {
double gen_tok_sec =
static_cast<double>(tokens_generated) / generate_duration;
fprintf(stderr,
"\n[ Timing info ] Generate: %d ms for %zu tokens (%.2f tokens / "
"sec)\n",
static_cast<int>(generate_duration * 1000), tokens_generated,
gen_tok_sec);
}
}
int verbosity = 0;
double prefill_duration = 0;
size_t prefill_tokens = 0;
double time_to_first_token = 0;
double generate_duration = 0;
size_t tokens_generated = 0;
};
class Gemma {
public:
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
NestedPools& pools);
// Allocates weights, caller is responsible for filling them.
Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, NestedPools& pools);
~Gemma();
const ModelConfig& GetModelConfig() const { return model_.Config(); }
ModelConfig& GetMutableModelConfig() { return model_.MutableConfig(); }
const ModelInfo& Info() const { return info_; }
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
const ModelWeightsStorage& Weights() const { return model_; }
ModelWeightsStorage& MutableWeights() { return model_; }
// `pos` is the position in the KV cache. Users are responsible for
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn.
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
size_t pos, KVCache& kv_cache, TimingInfo& timing_info) {
Generate(runtime_config, prompt, pos, /*prefix_end=*/0, kv_cache,
timing_info);
}
// For prefix-LM style attention, we can pass the end of the prefix.
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
size_t pos, size_t prefix_end, KVCache& kv_cache,
TimingInfo& timing_info);
// `queries_pos` are the positions in the KV cache. Users are responsible for
// incrementing them in `BatchStreamFunc`, or setting to zero for single-turn.
void GenerateBatch(const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos, const KVCaches& kv_caches,
TimingInfo& timing_info) {
GenerateBatch(runtime_config, queries_prompt, queries_pos,
/*queries_prefix_end=*/{}, kv_caches, timing_info);
}
// For prefix-LM style attention, we can pass the ends of the prefixes.
void GenerateBatch(const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const KVCaches& kv_caches, TimingInfo& timing_info);
// Generates the image tokens by running the image encoder ViT.
void GenerateImageTokens(const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens);
private:
NestedPools& pools_;
GemmaTokenizer tokenizer_;
// Type-erased so that this can be defined in the header.
ModelWeightsStorage model_;
ModelInfo info_;
};
// Adds BOS token and possibly 'turn' annotations, which depend on `training`
// and `pos`, the number of tokens decoded so far; returns the corresponding
// tokens. Asserts that tokenization is successful.
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const ModelInfo& info, size_t pos,
std::string& prompt);
void RangeChecks(const ModelConfig& weights_config,
size_t& max_generated_tokens, size_t prompt_size);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_