@@ -23,7 +23,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
2323 }
2424
2525 if (ubatch->embd ) {
26- const int64_t n_embd = embd->ne [0 ];
26+ GGML_ASSERT (n_embd == embd->ne [0 ]);
27+
2728 const int64_t n_tokens = ubatch->n_tokens ;
2829
2930 ggml_backend_tensor_set (embd, ubatch->embd , 0 , n_tokens*n_embd*ggml_element_size (embd));
@@ -33,8 +34,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
3334bool llm_graph_input_embd::can_reuse (const llm_graph_params & params) {
3435 bool res = true ;
3536
36- res &= (!tokens && ! params.ubatch .token ) || (tokens && tokens->ne [0 ] == params.ubatch .n_tokens );
37- res &= (!embd && ! params.ubatch .embd ) || (embd && embd->ne [1 ] == params.ubatch .n_tokens );
37+ res &= (!params.ubatch .token ) || (tokens && tokens->ne [0 ] == params.ubatch .n_tokens );
38+ res &= (!params.ubatch .embd ) || (embd && embd->ne [1 ] == params.ubatch .n_tokens );
3839
3940 return res;
4041}
@@ -634,7 +635,8 @@ int64_t llm_graph_result::get_max_nodes() const {
634635}
635636
636637void llm_graph_result::reset () {
637- t_tokens = nullptr ;
638+ t_inp_tokens = nullptr ;
639+ t_inp_embd = nullptr ;
638640 t_logits = nullptr ;
639641 t_embd = nullptr ;
640642 t_embd_pooled = nullptr ;
@@ -1338,17 +1340,29 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
13381340
13391341// input embeddings with optional lora
13401342ggml_tensor * llm_graph_context::build_inp_embd (ggml_tensor * tok_embd) const {
1341- const int64_t n_embd = hparams.n_embd_inp ();
1343+ const int64_t n_embd_inp = hparams.n_embd_inp ();
1344+ const int64_t n_embd = hparams.n_embd ;
1345+
1346+ assert (n_embd_inp >= n_embd);
1347+
1348+ auto inp = std::make_unique<llm_graph_input_embd>(n_embd_inp);
1349+
1350+ inp->tokens = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, ubatch.n_tokens );
1351+ cb (inp->tokens , " inp_tokens" , -1 );
1352+ ggml_set_input (inp->tokens );
1353+ res->t_inp_tokens = inp->tokens ;
13421354
1343- auto inp = std::make_unique<llm_graph_input_embd>();
1355+ inp->embd = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens );
1356+ cb (inp->embd , " inp_embd" , -1 );
1357+ ggml_set_input (inp->embd );
13441358
1345- ggml_tensor * cur = nullptr ;
1359+ // select one of the 2 inputs, based on the batch contents
1360+ // ref: https://github.com/ggml-org/llama.cpp/pull/18550
1361+ std::array<ggml_tensor *, 2 > inps;
13461362
1347- if (ubatch.token ) {
1348- inp->tokens = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, ubatch.n_tokens );
1349- // cb(inp->tokens, "inp_tokens", -1);
1350- ggml_set_input (inp->tokens );
1351- res->t_tokens = inp->tokens ;
1363+ // token embeddings path (ubatch.token != nullptr)
1364+ {
1365+ auto & cur = inps[0 ];
13521366
13531367 cur = ggml_get_rows (ctx0, tok_embd, inp->tokens );
13541368
@@ -1369,19 +1383,36 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
13691383
13701384 cur = ggml_add (ctx0, cur, inpL_delta);
13711385 }
1372- } else {
1373- inp->embd = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens );
1374- ggml_set_input (inp->embd );
1386+
1387+ if (n_embd_inp != n_embd) {
1388+ cur = ggml_pad (ctx0, cur, hparams.n_embd_inp () - n_embd, 0 , 0 , 0 );
1389+ }
1390+ }
1391+
1392+ // vector embeddings path (ubatch.embd != nullptr)
1393+ {
1394+ auto & cur = inps[1 ];
13751395
13761396 cur = inp->embd ;
13771397 }
13781398
1399+ assert (ggml_are_same_shape (inps[0 ], inps[1 ]));
1400+ assert (ggml_are_same_stride (inps[0 ], inps[1 ]));
1401+
1402+ ggml_tensor * cur = ggml_build_forward_select (gf, inps.data (), inps.size (), ubatch.token ? 0 : 1 );
1403+
1404+ if (n_embd_inp != n_embd) {
1405+ cur = ggml_view_2d (ctx0, cur, n_embd, n_tokens, cur->nb [1 ], 0 );
1406+ }
1407+
1408+ res->t_inp_embd = cur;
1409+
13791410 // For Granite architecture
13801411 if (hparams.f_embedding_scale != 0 .0f ) {
13811412 cur = ggml_scale (ctx0, cur, hparams.f_embedding_scale );
13821413 }
13831414
1384- cb (cur, " inp_embd " , -1 );
1415+ cb (cur, " embd " , -1 );
13851416
13861417 res->add_input (std::move (inp));
13871418
@@ -1480,7 +1511,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
14801511 // }
14811512
14821513 const auto n_embd = !cross->v_embd .empty () ? cross->n_embd : hparams.n_embd_inp ();
1483- const auto n_enc = !cross->v_embd .empty () ? cross->n_enc : hparams.n_ctx_train ;
1514+ const auto n_enc = !cross->v_embd .empty () ? cross->n_enc : hparams.n_ctx_train ;
14841515
14851516 cur = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_embd, n_enc);
14861517 ggml_set_input (cur);
0 commit comments