-
Notifications
You must be signed in to change notification settings - Fork 62
Expand file tree
/
Copy pathinference_context.hpp
More file actions
157 lines (131 loc) · 6.28 KB
/
inference_context.hpp
File metadata and controls
157 lines (131 loc) · 6.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#pragma once
#include "../cache_manager/opcache_manager.hpp"
#include <cassert>
struct InferenceContext {
infiniopHandle_t op_handle;
std::shared_ptr<MemoryPool> memory_pool;
CacheManager *cache_manager;
infinirtStream_t stream;
std::shared_ptr<Storage> workspace_storage;
size_t current_workspace_size = 0;
InferenceContext(infiniopHandle_t op_handle, std::shared_ptr<MemoryPool> memory_pool, CacheManager *cache_manager, infinirtStream_t stream);
void ensure_workspace(size_t required_size);
void add(std::shared_ptr<Tensor> c,
std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b);
void rmsnorm(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x,
std::shared_ptr<Tensor> w,
float epsilon);
void gemm(std::shared_ptr<Tensor> c,
std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b,
float alpha, float beta);
void rearrange(std::shared_ptr<Tensor> dst,
std::shared_ptr<Tensor> src);
void rope(std::shared_ptr<Tensor> q,
std::shared_ptr<Tensor> k,
std::shared_ptr<Tensor> pos,
std::shared_ptr<Tensor> sin,
std::shared_ptr<Tensor> cos,
infiniopRoPEAlgo_t algo);
void causalSoftmax(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x);
void logSoftmax(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x);
void topkrouter(std::shared_ptr<Tensor> values, // F32
std::shared_ptr<Tensor> indices, // I32
std::shared_ptr<Tensor> x,
std::shared_ptr<Tensor> correction_bias, // F32
float routed_scaling_factor,
size_t topk);
void swiglu(std::shared_ptr<Tensor> out,
std::shared_ptr<Tensor> up,
std::shared_ptr<Tensor> gate);
void randomSample(std::shared_ptr<Tensor> out,
std::shared_ptr<Tensor> prob,
float random_val, float top_p, uint32_t top_k, float temperature);
void linear(std::shared_ptr<Tensor> c,
std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b,
float alpha, float beta,
std::shared_ptr<Tensor> residual,
std::shared_ptr<Tensor> bias);
void dequant(std::shared_ptr<Tensor> weight,
std::shared_ptr<Tensor> in_w,
std::shared_ptr<Tensor> in_s,
std::shared_ptr<Tensor> in_z);
};
namespace {
thread_local InferenceContext *tls_inference_context = nullptr;
}
inline InferenceContext &getInferenceContext() {
assert(tls_inference_context != nullptr && "InferenceContext not set for this thread");
return *tls_inference_context;
}
inline void setInferenceContext(InferenceContext *ctx) {
tls_inference_context = ctx;
}
inline void add(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a, std::shared_ptr<Tensor> b) {
getInferenceContext().add(c, a, b);
}
inline void rmsnorm(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x,
std::shared_ptr<Tensor> w, float epsilon) {
getInferenceContext().rmsnorm(y, x, w, epsilon);
}
inline void gemm(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b, float alpha, float beta) {
getInferenceContext().gemm(c, a, b, alpha, beta);
}
inline void rearrange(std::shared_ptr<Tensor> dst, std::shared_ptr<Tensor> src) {
getInferenceContext().rearrange(dst, src);
}
inline void rope(std::shared_ptr<Tensor> q, std::shared_ptr<Tensor> k,
std::shared_ptr<Tensor> pos, std::shared_ptr<Tensor> sin,
std::shared_ptr<Tensor> cos) {
getInferenceContext().rope(q, k, pos, sin, cos, INFINIOP_ROPE_ALGO_GPT_J);
}
inline void rope_v2(std::shared_ptr<Tensor> q, std::shared_ptr<Tensor> k,
std::shared_ptr<Tensor> pos, std::shared_ptr<Tensor> sin,
std::shared_ptr<Tensor> cos) {
getInferenceContext().rope(q, k, pos, sin, cos, INFINIOP_ROPE_ALGO_GPT_NEOX);
}
inline void causalSoftmax(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x) {
getInferenceContext().causalSoftmax(y, x);
}
inline void logSoftmax(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x) {
getInferenceContext().logSoftmax(y, x);
}
inline void topkrouter(std::shared_ptr<Tensor> values, // F32
std::shared_ptr<Tensor> indices, // I32
std::shared_ptr<Tensor> x,
std::shared_ptr<Tensor> correction_bias, // F32
float routed_scaling_factor,
size_t topk) {
getInferenceContext().topkrouter(values, // F32
indices, // I32
x,
correction_bias, // F32
routed_scaling_factor,
topk);
}
inline void swiglu(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> up,
std::shared_ptr<Tensor> gate) {
getInferenceContext().swiglu(out, up, gate);
}
inline void randomSample(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> prob,
float random_val, float top_p, uint32_t top_k, float temperature) {
getInferenceContext().randomSample(out, prob, random_val, top_p, top_k, temperature);
}
inline void linear(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a,
std::shared_ptr<Tensor> b, float alpha, float beta,
std::shared_ptr<Tensor> residual, std::shared_ptr<Tensor> bias) {
getInferenceContext().linear(c, a, b, alpha, beta, residual, bias);
}
inline void dequant_linear(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> x,
std::shared_ptr<Tensor> w_w, std::shared_ptr<Tensor> w_s, std::shared_ptr<Tensor> w_z,
float alpha, float beta, std::shared_ptr<Tensor> residual, std::shared_ptr<Tensor> bias) {
auto w = Tensor::buffer(x->dtype(), {x->shape()[1], out->shape()[1]}, getInferenceContext().memory_pool);
getInferenceContext().dequant(w, w_w, w_s, w_z);
getInferenceContext().linear(out, x, w, alpha, beta, residual, bias);
}