33import io .graversen .replicate .common .BasePredictionMapper ;
44import io .graversen .replicate .common .PredictionTypes ;
55import io .graversen .replicate .common .ReplicateModel ;
6- import io .graversen .replicate .models .MetaMetaLlama370BInstructInput ;
7- import io .graversen .replicate .models .MetaMetaLlama370BInstructPredictionrequest ;
8- import io .graversen .replicate .models .MetaMetaLlama38BInstructInput ;
9- import io .graversen .replicate .models .MetaMetaLlama38BInstructPredictionrequest ;
6+ import io .graversen .replicate .models .*;
107import io .graversen .replicate .service .CreateTextPrediction ;
118import lombok .NonNull ;
129import org .springframework .stereotype .Component ;
@@ -19,6 +16,7 @@ public class Llama3PredictionMapper extends BasePredictionMapper<CreateTextPredi
1916 @ Override
2017 protected Set <ReplicateModel > supportedModels () {
2118 return Set .of (
19+ Llama3Models .LLAMA_3_1_405B_INSTRUCT ,
2220 Llama3Models .LLAMA_3_70B_INSTRUCT ,
2321 Llama3Models .LLAMA_3_8B_INSTRUCT
2422 );
@@ -44,6 +42,15 @@ public Object apply(@NonNull ReplicateModel model, @NonNull CreateTextPrediction
4442 null ,
4543 null
4644 );
45+ } else if (model .equals (Llama3Models .LLAMA_3_1_405B_INSTRUCT )) {
46+ return new MetaMetaLlama31405BInstructPredictionrequest (
47+ null ,
48+ mapLlama31405bInput ().apply (createPrediction ),
49+ null ,
50+ null ,
51+ null ,
52+ null
53+ );
4754 } else {
4855 throw new IllegalArgumentException ("Unsupported Replicate Model: " + model );
4956 }
@@ -54,6 +61,27 @@ public boolean supportsType(@NonNull PredictionTypes type) {
5461 return PredictionTypes .TEXT .equals (type );
5562 }
5663
64+ Function <CreateTextPrediction , MetaMetaLlama31405BInstructInput > mapLlama31405bInput () {
65+ return createTextPrediction -> {
66+ final var conversation = createTextPrediction .getConversation ();
67+ final var truncatedConversation = Llama3Tokenizer .fitToContextWindow (conversation , Llama3Tokenizer .DEFAULT_CONTEXT_WINDOW_SIZE );
68+ final var textCompletion = Llama3Tokenizer .generateTextCompletion (truncatedConversation );
69+
70+ return new MetaMetaLlama31405BInstructInput (
71+ createTextPrediction .getTopK (),
72+ createTextPrediction .getTopP (),
73+ textCompletion .getText (),
74+ createTextPrediction .getMaxTokens (),
75+ createTextPrediction .getMinTokens (),
76+ createTextPrediction .getTemperature (),
77+ truncatedConversation .getSystemMessage (),
78+ null ,
79+ null ,
80+ null
81+ );
82+ };
83+ }
84+
5785 Function <CreateTextPrediction , MetaMetaLlama370BInstructInput > mapLlama370bInput () {
5886 return createTextPrediction -> {
5987 final var conversation = createTextPrediction .getConversation ();
0 commit comments