Skip to content

Commit dcf5ad3

Browse files
committed
✨ Synchronous prediction integration with Replicate
1 parent 059f323 commit dcf5ad3

7 files changed

Lines changed: 295 additions & 3 deletions

File tree

replicate-client/src/main/java/io/graversen/replicate/client/configuration/ReplicateClients.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public static Replicate v1(@NonNull ReplicateClientProperties properties) {
3131
final UnaryOperator<Feign.Builder> feignCustomizer = feign -> feign
3232
.encoder(new JacksonEncoder(objectMapper()))
3333
.decoder(new JacksonDecoder())
34-
.logLevel(Logger.Level.FULL)
34+
.logLevel(Logger.Level.BASIC)
3535
.logger(new Slf4jDebugLogger());
3636

3737
return ReplicateClients.v1(properties, feignCustomizer);
@@ -56,7 +56,6 @@ public static ObjectMapper objectMapper() {
5656
}
5757

5858
private static Feign.Builder feignBuilder() {
59-
return Feign.builder()
60-
.client(new Http2Client());
59+
return Feign.builder().client(new Http2Client());
6160
}
6261
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package io.graversen.replicate.service;
2+
3+
import io.graversen.replicate.common.TextConversation;
4+
import lombok.Getter;
5+
import lombok.NonNull;
6+
import lombok.RequiredArgsConstructor;
7+
8+
@Getter
9+
@RequiredArgsConstructor
10+
public class CreateTextPrediction {
11+
private final @NonNull TextConversation conversation;
12+
private final Double temperature;
13+
private final Integer maxTokens;
14+
private final Integer minTokens;
15+
private final String promptTemplate;
16+
private final Double topP;
17+
private final Integer topK;
18+
19+
public static CreateTextPrediction fromOneMessage(@NonNull String systemPrompt, @NonNull String userMessage) {
20+
return CreateTextPrediction.fromConversation(TextConversation.of(systemPrompt, userMessage));
21+
}
22+
23+
public static CreateTextPrediction fromConversation(@NonNull TextConversation conversation) {
24+
return new CreateTextPrediction(conversation, null, null, null, null, null, null);
25+
}
26+
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package io.graversen.replicate.service;
2+
3+
import lombok.NonNull;
4+
import lombok.RequiredArgsConstructor;
5+
import lombok.ToString;
6+
import lombok.Value;
7+
8+
import java.time.Duration;
9+
import java.time.OffsetDateTime;
10+
import java.util.ArrayList;
11+
import java.util.List;
12+
import java.util.Optional;
13+
import java.util.function.Function;
14+
import java.util.stream.Collectors;
15+
16+
@Value
17+
@RequiredArgsConstructor
18+
public class PredictionResponse {
19+
@NonNull String id;
20+
String version;
21+
OffsetDateTime createdAt;
22+
OffsetDateTime startedAt;
23+
OffsetDateTime completedAt;
24+
String error;
25+
String status;
26+
Object output;
27+
PredictionUrls urls;
28+
29+
@ToString.Include
30+
public Optional<Duration> getQueueLatency() {
31+
if (createdAt == null || startedAt == null) {
32+
return Optional.empty();
33+
}
34+
35+
return Optional.ofNullable(Duration.between(createdAt, startedAt));
36+
}
37+
38+
@ToString.Include
39+
public Optional<Duration> getProcessingLatency() {
40+
if (startedAt == null || completedAt == null) {
41+
return Optional.empty();
42+
}
43+
44+
return Optional.ofNullable(Duration.between(startedAt, completedAt));
45+
}
46+
47+
public <T> Optional<T> getOutput(@NonNull Function<Object, T> outputMapper) {
48+
try {
49+
if (output != null) {
50+
return Optional.of(outputMapper.apply(output));
51+
} else {
52+
return Optional.empty();
53+
}
54+
} catch (Exception e) {
55+
return Optional.empty();
56+
}
57+
}
58+
59+
public <T> Optional<T> getOutput(@NonNull Class<T> outputType) {
60+
try {
61+
return Optional.ofNullable(outputType.cast(output));
62+
} catch (Exception e) {
63+
return Optional.empty();
64+
}
65+
}
66+
67+
public Optional<String> getTextOutput() {
68+
try {
69+
return getOutput(output -> (ArrayList<String>) output).map(composeTextResponse());
70+
} catch (Exception e) {
71+
return Optional.empty();
72+
}
73+
}
74+
75+
private Function<List<String>, String> composeTextResponse() {
76+
return strings -> strings.stream().filter(string -> !string.isBlank()).collect(Collectors.joining());
77+
}
78+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package io.graversen.replicate.service;
2+
3+
import lombok.NonNull;
4+
import lombok.RequiredArgsConstructor;
5+
import lombok.Value;
6+
7+
@Value
8+
@RequiredArgsConstructor
9+
public class PredictionUrls {
10+
@NonNull String cancelUrl;
11+
@NonNull String getUrl;
12+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package io.graversen.replicate.service;
2+
3+
import io.graversen.replicate.client.feign.FeignUtils;
4+
import io.graversen.replicate.client.replicate.Replicate;
5+
import io.graversen.replicate.common.PredictionMapper;
6+
import io.graversen.replicate.common.PredictionTypes;
7+
import io.graversen.replicate.common.ReplicateModel;
8+
import io.graversen.replicate.util.ReplicateUtils;
9+
import lombok.NonNull;
10+
import lombok.RequiredArgsConstructor;
11+
import lombok.SneakyThrows;
12+
import lombok.extern.slf4j.Slf4j;
13+
import org.springframework.stereotype.Component;
14+
15+
import java.util.LinkedHashMap;
16+
import java.util.List;
17+
import java.util.Optional;
18+
import java.util.function.Predicate;
19+
import java.util.function.Supplier;
20+
21+
@Slf4j
22+
@Component
23+
@RequiredArgsConstructor
24+
public class ReplicateService {
25+
private final @NonNull Replicate replicate;
26+
private final @NonNull FeignUtils feignUtils;
27+
private final @NonNull List<PredictionMapper> predictionMappers;
28+
29+
public Optional<PredictionResponse> createPrediction(@NonNull ReplicateModel model, @NonNull Object createPrediction) {
30+
return doCreatePrediction(model, createPrediction);
31+
}
32+
33+
@SneakyThrows
34+
public Optional<PredictionResponse> createPrediction(@NonNull ReplicateModel model, @NonNull CreateTextPrediction createPrediction) {
35+
final var predictionMapper = predictionMappers.stream()
36+
.filter(supportsTextPredictions())
37+
.filter(supportsModel(model))
38+
.findFirst()
39+
.orElseThrow(unsupportedModelError(model));
40+
41+
final var mappedPrediction = predictionMapper.apply(model, createPrediction);
42+
return doCreatePrediction(model, mappedPrediction);
43+
}
44+
45+
Predicate<PredictionMapper> supportsTextPredictions() {
46+
return predictionMapper -> predictionMapper.supportsType(PredictionTypes.TEXT);
47+
}
48+
49+
Predicate<PredictionMapper> supportsModel(@NonNull ReplicateModel model) {
50+
return predictionMapper -> predictionMapper.supportsModel(model);
51+
}
52+
53+
private Optional<PredictionResponse> doCreatePrediction(@NonNull ReplicateModel model, @NonNull Object createPrediction) {
54+
final var replicateResponse = replicate.createPrediction(model.getOwner(), model.getName(), createPrediction);
55+
56+
if (replicateResponse.status() >= 200 && replicateResponse.status() < 300) {
57+
final LinkedHashMap<String, Object> convertedResponse = feignUtils.convert(replicateResponse, LinkedHashMap.class);
58+
return ReplicateUtils.mapPredictionResponse(convertedResponse);
59+
} else {
60+
final LinkedHashMap<String, Object> convertedResponse = feignUtils.convert(replicateResponse, LinkedHashMap.class);
61+
final var errorMessage = String.format(
62+
"Replicate API error: %s (HTTP Status %s). Model: %s",
63+
convertedResponse.get("detail"),
64+
convertedResponse.get("status"),
65+
model
66+
);
67+
log.error(errorMessage);
68+
throw new RuntimeException(errorMessage);
69+
}
70+
}
71+
72+
private Supplier<Throwable> unsupportedModelError(@NonNull ReplicateModel model) {
73+
return () -> new IllegalArgumentException("Unsupported model: " + model);
74+
}
75+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package io.graversen.replicate.util;
2+
3+
import lombok.NonNull;
4+
import lombok.experimental.UtilityClass;
5+
import lombok.extern.slf4j.Slf4j;
6+
7+
import java.util.Map;
8+
9+
@Slf4j
10+
@UtilityClass
11+
public class JsonSchemaUtils {
12+
private static final String SET_ADDITIONAL_PROPERTY_METHOD = "setAdditionalProperty";
13+
14+
public static void setAdditionalProperties(@NonNull Object object, @NonNull Map<String, Object> keysAndValues) {
15+
keysAndValues.forEach((key, value) -> setAdditionalProperties(object, key, value));
16+
}
17+
18+
public static void setAdditionalProperties(@NonNull Object object, @NonNull String key, @NonNull Object value) {
19+
try {
20+
final var method = object.getClass().getMethod(SET_ADDITIONAL_PROPERTY_METHOD, String.class, Object.class);
21+
method.invoke(object, key, value);
22+
} catch (NoSuchMethodException e) {
23+
log.error("Could not invoke {}#{} because there is no such method", object.getClass().getSimpleName(), SET_ADDITIONAL_PROPERTY_METHOD);
24+
} catch (Exception e) {
25+
log.error(e.getMessage(), e);
26+
}
27+
}
28+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package io.graversen.replicate.util;
2+
3+
import io.graversen.replicate.service.PredictionResponse;
4+
import io.graversen.replicate.service.PredictionUrls;
5+
import jakarta.annotation.Nullable;
6+
import lombok.NonNull;
7+
import lombok.SneakyThrows;
8+
import lombok.experimental.UtilityClass;
9+
import lombok.extern.slf4j.Slf4j;
10+
11+
import java.time.OffsetDateTime;
12+
import java.time.format.DateTimeFormatter;
13+
import java.util.LinkedHashMap;
14+
import java.util.Optional;
15+
16+
@Slf4j
17+
@UtilityClass
18+
public class ReplicateUtils {
19+
private static final String ATTRIBUTE_ID = "id";
20+
private static final String ATTRIBUTE_VERSION = "version";
21+
private static final String ATTRIBUTE_CREATED_AT = "created_at";
22+
private static final String ATTRIBUTE_STARTED_AT = "started_at";
23+
private static final String ATTRIBUTE_COMPLETED_AT = "completed_at";
24+
private static final String ATTRIBUTE_ERROR = "error";
25+
private static final String ATTRIBUTE_STATUS = "status";
26+
private static final String ATTRIBUTE_OUTPUT = "output";
27+
private static final String ATTRIBUTE_URLS = "urls";
28+
private static final String ATTRIBUTE_CANCEL_URL = "cancel";
29+
private static final String ATTRIBUTE_GET_URL = "get";
30+
31+
public static Optional<PredictionResponse> mapPredictionResponse(@NonNull LinkedHashMap<String, Object> responseMap) {
32+
try {
33+
final var id = (String) responseMap.get(ATTRIBUTE_ID);
34+
final var version = (String) responseMap.get(ATTRIBUTE_VERSION);
35+
final var createdAt = parseOffsetDateTime((String) responseMap.get(ATTRIBUTE_CREATED_AT));
36+
final var startedAt = parseOffsetDateTime((String) responseMap.get(ATTRIBUTE_STARTED_AT));
37+
final var completedAt = parseOffsetDateTime((String) responseMap.get(ATTRIBUTE_COMPLETED_AT));
38+
final var error = (String) responseMap.get(ATTRIBUTE_ERROR);
39+
final var status = (String) responseMap.get(ATTRIBUTE_STATUS);
40+
final var output = responseMap.get(ATTRIBUTE_OUTPUT);
41+
42+
final var urlsMap = (LinkedHashMap<String, Object>) responseMap.get(ATTRIBUTE_URLS);
43+
final var cancelUrl = (String) urlsMap.get(ATTRIBUTE_CANCEL_URL);
44+
final var getUrl = (String) urlsMap.get(ATTRIBUTE_GET_URL);
45+
46+
PredictionUrls urls = null;
47+
if (cancelUrl != null && getUrl != null) {
48+
urls = new PredictionUrls(cancelUrl, getUrl);
49+
}
50+
51+
final var predictionResponse = new PredictionResponse(id, version, createdAt, startedAt, completedAt, error, status, output, urls);
52+
return Optional.of(predictionResponse);
53+
} catch (Exception e) {
54+
log.error(e.getMessage(), e);
55+
return Optional.empty();
56+
}
57+
}
58+
59+
@SneakyThrows
60+
private static String getStringField(@NonNull Class<?> responseObjectClass, @NonNull Object instance, @NonNull String methodName) {
61+
final var getter = responseObjectClass.getMethod(methodName);
62+
return getter.invoke(instance) != null ? (String) getter.invoke(instance) : null;
63+
}
64+
65+
@SneakyThrows
66+
private static OffsetDateTime getOffsetDateTimeField(@NonNull Class<?> responseObjectClass, @NonNull Object instance, @NonNull String methodName) {
67+
final var getter = responseObjectClass.getMethod(methodName);
68+
return getter.invoke(instance) != null ? parseOffsetDateTime((String) getter.invoke(instance)) : null;
69+
}
70+
71+
private static OffsetDateTime parseOffsetDateTime(@Nullable String isoDateTime) {
72+
return isoDateTime != null ? OffsetDateTime.parse(isoDateTime, DateTimeFormatter.ISO_OFFSET_DATE_TIME) : null;
73+
}
74+
}

0 commit comments

Comments
 (0)