1+ package io .graversen .replicate .llama3 ;
2+
3+ import io .graversen .replicate .common .BasePredictionMapper ;
4+ import io .graversen .replicate .common .PredictionTypes ;
5+ import io .graversen .replicate .common .ReplicateModel ;
6+ import io .graversen .replicate .models .MetaMetaLlama370BInstructInput ;
7+ import io .graversen .replicate .models .MetaMetaLlama370BInstructPredictionrequest ;
8+ import io .graversen .replicate .models .MetaMetaLlama38BInstructInput ;
9+ import io .graversen .replicate .models .MetaMetaLlama38BInstructPredictionrequest ;
10+ import io .graversen .replicate .service .CreateTextPrediction ;
11+ import lombok .NonNull ;
12+ import org .springframework .stereotype .Component ;
13+
14+ import java .util .Set ;
15+ import java .util .function .Function ;
16+
17+ @ Component
18+ public class Llama3PredictionMapper extends BasePredictionMapper <CreateTextPrediction , Object > {
19+ @ Override
20+ protected Set <ReplicateModel > supportedModels () {
21+ return Set .of (
22+ Llama3Models .LLAMA_3_70B_INSTRUCT ,
23+ Llama3Models .LLAMA_3_8B_INSTRUCT
24+ );
25+ }
26+
27+ @ Override
28+ public Object apply (@ NonNull ReplicateModel model , @ NonNull CreateTextPrediction createPrediction ) {
29+ if (model .equals (Llama3Models .LLAMA_3_70B_INSTRUCT )) {
30+ return new MetaMetaLlama370BInstructPredictionrequest (
31+ null ,
32+ mapLlama370bInput ().apply (createPrediction ),
33+ null ,
34+ null ,
35+ null ,
36+ null
37+ );
38+ } else if (model .equals (Llama3Models .LLAMA_3_8B_INSTRUCT )) {
39+ return new MetaMetaLlama38BInstructPredictionrequest (
40+ null ,
41+ mapLlama38bInput ().apply (createPrediction ),
42+ null ,
43+ null ,
44+ null ,
45+ null
46+ );
47+ } else {
48+ throw new IllegalArgumentException ("Unsupported Replicate Model: " + model );
49+ }
50+ }
51+
52+ @ Override
53+ public boolean supportsType (@ NonNull PredictionTypes type ) {
54+ return PredictionTypes .TEXT .equals (type );
55+ }
56+
57+ Function <CreateTextPrediction , MetaMetaLlama370BInstructInput > mapLlama370bInput () {
58+ return createTextPrediction -> {
59+ final var conversation = createTextPrediction .getConversation ();
60+ final var truncatedConversation = Llama3Tokenizer .fitToContextWindow (conversation , Llama3Tokenizer .DEFAULT_CONTEXT_WINDOW_SIZE );
61+ final var textCompletion = Llama3Tokenizer .generateTextCompletion (truncatedConversation );
62+
63+ return new MetaMetaLlama370BInstructInput (
64+ createTextPrediction .getTopK (),
65+ createTextPrediction .getTopP (),
66+ textCompletion .getText (),
67+ createTextPrediction .getMaxTokens (),
68+ createTextPrediction .getMinTokens (),
69+ createTextPrediction .getTemperature (),
70+ createTextPrediction .getPromptTemplate (),
71+ null ,
72+ null
73+ );
74+ };
75+ }
76+
77+ Function <CreateTextPrediction , MetaMetaLlama38BInstructInput > mapLlama38bInput () {
78+ return createTextPrediction -> {
79+ final var conversation = createTextPrediction .getConversation ();
80+ final var truncatedConversation = Llama3Tokenizer .fitToContextWindow (conversation , Llama3Tokenizer .DEFAULT_CONTEXT_WINDOW_SIZE );
81+ final var textCompletion = Llama3Tokenizer .generateTextCompletion (truncatedConversation );
82+
83+ return new MetaMetaLlama38BInstructInput (
84+ createTextPrediction .getTopK (),
85+ createTextPrediction .getTopP (),
86+ textCompletion .getText (),
87+ createTextPrediction .getMaxTokens (),
88+ createTextPrediction .getMinTokens (),
89+ createTextPrediction .getTemperature (),
90+ createTextPrediction .getPromptTemplate (),
91+ null ,
92+ null
93+ );
94+ };
95+ }
96+ }
0 commit comments