Skip to content

Commit 3c82818

Browse files
committed
⚡ Support for model meta-llama-3.1-405b-instruct
1 parent e94771a commit 3c82818

2 files changed

Lines changed: 33 additions & 4 deletions

File tree

spring-boot-starter-replicate/src/main/java/io/graversen/replicate/llama3/Llama3Models.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
@UtilityClass
77
public class Llama3Models {
8+
public static final ReplicateModel LLAMA_3_1_405B_INSTRUCT = new ReplicateModel("meta", "meta-llama-3.1-405b-instruct");
89
public static final ReplicateModel LLAMA_3_70B_INSTRUCT = new ReplicateModel("meta", "meta-llama-3-70b-instruct");
910
public static final ReplicateModel LLAMA_3_8B_INSTRUCT = new ReplicateModel("meta", "meta-llama-3-8b-instruct");
1011
}

spring-boot-starter-replicate/src/main/java/io/graversen/replicate/llama3/Llama3PredictionMapper.java

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@
33
import io.graversen.replicate.common.BasePredictionMapper;
44
import io.graversen.replicate.common.PredictionTypes;
55
import io.graversen.replicate.common.ReplicateModel;
6-
import io.graversen.replicate.models.MetaMetaLlama370BInstructInput;
7-
import io.graversen.replicate.models.MetaMetaLlama370BInstructPredictionrequest;
8-
import io.graversen.replicate.models.MetaMetaLlama38BInstructInput;
9-
import io.graversen.replicate.models.MetaMetaLlama38BInstructPredictionrequest;
6+
import io.graversen.replicate.models.*;
107
import io.graversen.replicate.service.CreateTextPrediction;
118
import lombok.NonNull;
129
import org.springframework.stereotype.Component;
@@ -19,6 +16,7 @@ public class Llama3PredictionMapper extends BasePredictionMapper<CreateTextPredi
1916
@Override
2017
protected Set<ReplicateModel> supportedModels() {
2118
return Set.of(
19+
Llama3Models.LLAMA_3_1_405B_INSTRUCT,
2220
Llama3Models.LLAMA_3_70B_INSTRUCT,
2321
Llama3Models.LLAMA_3_8B_INSTRUCT
2422
);
@@ -44,6 +42,15 @@ public Object apply(@NonNull ReplicateModel model, @NonNull CreateTextPrediction
4442
null,
4543
null
4644
);
45+
} else if (model.equals(Llama3Models.LLAMA_3_1_405B_INSTRUCT)) {
46+
return new MetaMetaLlama31405BInstructPredictionrequest(
47+
null,
48+
mapLlama31405bInput().apply(createPrediction),
49+
null,
50+
null,
51+
null,
52+
null
53+
);
4754
} else {
4855
throw new IllegalArgumentException("Unsupported Replicate Model: " + model);
4956
}
@@ -54,6 +61,27 @@ public boolean supportsType(@NonNull PredictionTypes type) {
5461
return PredictionTypes.TEXT.equals(type);
5562
}
5663

64+
Function<CreateTextPrediction, MetaMetaLlama31405BInstructInput> mapLlama31405bInput() {
65+
return createTextPrediction -> {
66+
final var conversation = createTextPrediction.getConversation();
67+
final var truncatedConversation = Llama3Tokenizer.fitToContextWindow(conversation, Llama3Tokenizer.DEFAULT_CONTEXT_WINDOW_SIZE);
68+
final var textCompletion = Llama3Tokenizer.generateTextCompletion(truncatedConversation);
69+
70+
return new MetaMetaLlama31405BInstructInput(
71+
createTextPrediction.getTopK(),
72+
createTextPrediction.getTopP(),
73+
textCompletion.getText(),
74+
createTextPrediction.getMaxTokens(),
75+
createTextPrediction.getMinTokens(),
76+
createTextPrediction.getTemperature(),
77+
truncatedConversation.getSystemMessage(),
78+
null,
79+
null,
80+
null
81+
);
82+
};
83+
}
84+
5785
Function<CreateTextPrediction, MetaMetaLlama370BInstructInput> mapLlama370bInput() {
5886
return createTextPrediction -> {
5987
final var conversation = createTextPrediction.getConversation();

0 commit comments

Comments
 (0)