Skip to content

Commit 42b09ce

Browse files
committed
⚡ Llama3 input mapping support
1 parent dcf5ad3 commit 42b09ce

2 files changed

Lines changed: 106 additions & 0 deletions

File tree

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package io.graversen.replicate.configuration;
2+
3+
import org.springframework.context.annotation.ComponentScan;
4+
import org.springframework.context.annotation.Configuration;
5+
6+
@Configuration
7+
@ComponentScan("io.graversen.replicate.llama3")
8+
public class Llama3Configuration {
9+
10+
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package io.graversen.replicate.llama3;
2+
3+
import io.graversen.replicate.common.BasePredictionMapper;
4+
import io.graversen.replicate.common.PredictionTypes;
5+
import io.graversen.replicate.common.ReplicateModel;
6+
import io.graversen.replicate.models.MetaMetaLlama370BInstructInput;
7+
import io.graversen.replicate.models.MetaMetaLlama370BInstructPredictionrequest;
8+
import io.graversen.replicate.models.MetaMetaLlama38BInstructInput;
9+
import io.graversen.replicate.models.MetaMetaLlama38BInstructPredictionrequest;
10+
import io.graversen.replicate.service.CreateTextPrediction;
11+
import lombok.NonNull;
12+
import org.springframework.stereotype.Component;
13+
14+
import java.util.Set;
15+
import java.util.function.Function;
16+
17+
@Component
18+
public class Llama3PredictionMapper extends BasePredictionMapper<CreateTextPrediction, Object> {
19+
@Override
20+
protected Set<ReplicateModel> supportedModels() {
21+
return Set.of(
22+
Llama3Models.LLAMA_3_70B_INSTRUCT,
23+
Llama3Models.LLAMA_3_8B_INSTRUCT
24+
);
25+
}
26+
27+
@Override
28+
public Object apply(@NonNull ReplicateModel model, @NonNull CreateTextPrediction createPrediction) {
29+
if (model.equals(Llama3Models.LLAMA_3_70B_INSTRUCT)) {
30+
return new MetaMetaLlama370BInstructPredictionrequest(
31+
null,
32+
mapLlama370bInput().apply(createPrediction),
33+
null,
34+
null,
35+
null,
36+
null
37+
);
38+
} else if (model.equals(Llama3Models.LLAMA_3_8B_INSTRUCT)) {
39+
return new MetaMetaLlama38BInstructPredictionrequest(
40+
null,
41+
mapLlama38bInput().apply(createPrediction),
42+
null,
43+
null,
44+
null,
45+
null
46+
);
47+
} else {
48+
throw new IllegalArgumentException("Unsupported Replicate Model: " + model);
49+
}
50+
}
51+
52+
@Override
53+
public boolean supportsType(@NonNull PredictionTypes type) {
54+
return PredictionTypes.TEXT.equals(type);
55+
}
56+
57+
Function<CreateTextPrediction, MetaMetaLlama370BInstructInput> mapLlama370bInput() {
58+
return createTextPrediction -> {
59+
final var conversation = createTextPrediction.getConversation();
60+
final var truncatedConversation = Llama3Tokenizer.fitToContextWindow(conversation, Llama3Tokenizer.DEFAULT_CONTEXT_WINDOW_SIZE);
61+
final var textCompletion = Llama3Tokenizer.generateTextCompletion(truncatedConversation);
62+
63+
return new MetaMetaLlama370BInstructInput(
64+
createTextPrediction.getTopK(),
65+
createTextPrediction.getTopP(),
66+
textCompletion.getText(),
67+
createTextPrediction.getMaxTokens(),
68+
createTextPrediction.getMinTokens(),
69+
createTextPrediction.getTemperature(),
70+
createTextPrediction.getPromptTemplate(),
71+
null,
72+
null
73+
);
74+
};
75+
}
76+
77+
Function<CreateTextPrediction, MetaMetaLlama38BInstructInput> mapLlama38bInput() {
78+
return createTextPrediction -> {
79+
final var conversation = createTextPrediction.getConversation();
80+
final var truncatedConversation = Llama3Tokenizer.fitToContextWindow(conversation, Llama3Tokenizer.DEFAULT_CONTEXT_WINDOW_SIZE);
81+
final var textCompletion = Llama3Tokenizer.generateTextCompletion(truncatedConversation);
82+
83+
return new MetaMetaLlama38BInstructInput(
84+
createTextPrediction.getTopK(),
85+
createTextPrediction.getTopP(),
86+
textCompletion.getText(),
87+
createTextPrediction.getMaxTokens(),
88+
createTextPrediction.getMinTokens(),
89+
createTextPrediction.getTemperature(),
90+
createTextPrediction.getPromptTemplate(),
91+
null,
92+
null
93+
);
94+
};
95+
}
96+
}

0 commit comments

Comments
 (0)