Skip to content

Commit efa769a

Browse files
committed
✨ Async Replicate prediction generation
With event-driven notifications
1 parent 42b09ce commit efa769a

16 files changed

Lines changed: 405 additions & 0 deletions
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package io.graversen.replicate.common;
2+
3+
import lombok.NonNull;
4+
5+
import java.util.Objects;
6+
import java.util.Set;
7+
8+
public abstract class BasePredictionMapper<T, R> implements PredictionMapper<T, R> {
9+
protected abstract Set<ReplicateModel> supportedModels();
10+
11+
@Override
12+
public boolean supportsModel(@NonNull ReplicateModel model) {
13+
return Objects.requireNonNullElseGet(supportedModels(), Set::of).stream().anyMatch(model::equals);
14+
}
15+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package io.graversen.replicate.common;
2+
3+
import lombok.NonNull;
4+
5+
import java.util.function.BiFunction;
6+
7+
public interface PredictionMapper<T, R> extends BiFunction<ReplicateModel, T, R> {
8+
@Override
9+
R apply(@NonNull ReplicateModel model, @NonNull T createPrediction);
10+
11+
boolean supportsType(@NonNull PredictionTypes type);
12+
13+
boolean supportsModel(@NonNull ReplicateModel model);
14+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package io.graversen.replicate.common;
2+
3+
public enum PredictionTypes {
4+
TEXT,
5+
IMAGE
6+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package io.graversen.replicate.common;
2+
3+
import lombok.NonNull;
4+
import lombok.Value;
5+
6+
@Value
7+
public class ReplicateModel {
8+
@NonNull String owner;
9+
@NonNull String name;
10+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package io.graversen.replicate.configuration;
2+
3+
import org.springframework.context.annotation.Configuration;
4+
import org.springframework.context.annotation.Import;
5+
6+
@Configuration
7+
@Import(ReplicateConfiguration.class)
8+
public class ReplicateAutoConfiguration {
9+
10+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package io.graversen.replicate.configuration;
2+
3+
import io.graversen.replicate.client.configuration.ReplicateClientProperties;
4+
import io.graversen.replicate.client.configuration.ReplicateClients;
5+
import io.graversen.replicate.client.feign.FeignUtils;
6+
import io.graversen.replicate.client.replicate.Replicate;
7+
import io.graversen.replicate.facade.PredictionRetryPolicy;
8+
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
9+
import org.springframework.boot.context.properties.EnableConfigurationProperties;
10+
import org.springframework.context.annotation.Bean;
11+
import org.springframework.context.annotation.ComponentScan;
12+
import org.springframework.context.annotation.Configuration;
13+
import org.springframework.context.annotation.Import;
14+
15+
import java.util.Objects;
16+
import java.util.concurrent.ExecutorService;
17+
import java.util.concurrent.Executors;
18+
import java.util.function.Function;
19+
20+
@Configuration
21+
@EnableConfigurationProperties(ReplicateProperties.class)
22+
@Import({Llama3Configuration.class})
23+
@ComponentScan({"io.graversen.replicate.service", "io.graversen.replicate.facade"})
24+
public class ReplicateConfiguration {
25+
@Bean
26+
public Replicate replicate(ReplicateProperties properties) {
27+
final var replicateClientProperties = mapReplicateClientProperties().apply(properties);
28+
return ReplicateClients.v1(replicateClientProperties);
29+
}
30+
31+
@Bean
32+
public FeignUtils feignUtils() {
33+
return new FeignUtils(ReplicateClients.objectMapper());
34+
}
35+
36+
@Bean
37+
@ConditionalOnMissingBean
38+
public ExecutorService defaultReplicateExecutor() {
39+
return Executors.newVirtualThreadPerTaskExecutor();
40+
}
41+
42+
@Bean
43+
public PredictionRetryPolicy predictionRetryPolicy(ReplicateProperties properties) {
44+
final var defaultRetryPolicy = PredictionRetryPolicy.defaultPolicy();
45+
final var pollDelay = Objects.requireNonNullElse(properties.getPredictionPollDelay(), defaultRetryPolicy.getDelay());
46+
final var pollAttempts = Objects.requireNonNullElse(properties.getPredictionPollAttempts(), defaultRetryPolicy.getMaxAttempts());
47+
return new PredictionRetryPolicy(pollAttempts, pollDelay);
48+
}
49+
50+
Function<ReplicateProperties, ReplicateClientProperties> mapReplicateClientProperties() {
51+
return replicateProperties -> new ReplicateClientProperties(
52+
replicateProperties.getToken(),
53+
replicateProperties.getApiUrl()
54+
);
55+
}
56+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package io.graversen.replicate.configuration;
2+
3+
import lombok.Getter;
4+
import lombok.NonNull;
5+
import lombok.RequiredArgsConstructor;
6+
import org.springframework.boot.context.properties.ConfigurationProperties;
7+
8+
import java.time.Duration;
9+
10+
@Getter
11+
@RequiredArgsConstructor
12+
@ConfigurationProperties(prefix = "replicate")
13+
public class ReplicateProperties {
14+
private final @NonNull String token;
15+
private final String apiUrl;
16+
private final Duration predictionPollDelay;
17+
private final Integer predictionPollAttempts;
18+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package io.graversen.replicate.facade;
2+
3+
import jakarta.annotation.Nullable;
4+
import lombok.SneakyThrows;
5+
import org.springframework.stereotype.Component;
6+
7+
import java.util.Optional;
8+
import java.util.concurrent.CompletionException;
9+
import java.util.function.Function;
10+
import java.util.function.Supplier;
11+
12+
@Component
13+
public class CheckPredictionStateTask implements Function<Optional<PredictionResponseAndModel>, PredictionResponseAndModel> {
14+
@Override
15+
@SneakyThrows
16+
public PredictionResponseAndModel apply(@Nullable Optional<PredictionResponseAndModel> predictionResponse) {
17+
return predictionResponse.orElseThrow(predictionFailedError());
18+
}
19+
20+
private Supplier<Exception> predictionFailedError() {
21+
return () -> new CompletionException(new RuntimeException("Could not initiate prediction"));
22+
}
23+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package io.graversen.replicate.facade;
2+
3+
import jakarta.annotation.Nullable;
4+
import lombok.NonNull;
5+
import lombok.RequiredArgsConstructor;
6+
import org.springframework.context.ApplicationEventPublisher;
7+
import org.springframework.stereotype.Component;
8+
9+
import java.util.function.BiConsumer;
10+
import java.util.function.Function;
11+
12+
@Component
13+
@RequiredArgsConstructor
14+
public class EmitPredictionResponseTask implements BiConsumer<PredictionResponseAndModel, Throwable> {
15+
private final @NonNull ApplicationEventPublisher eventPublisher;
16+
17+
@Override
18+
public void accept(@Nullable PredictionResponseAndModel predictionResponse, @Nullable Throwable throwable) {
19+
if (throwable == null) {
20+
final var event = mapPredictionUpdatedEvent().apply(predictionResponse);
21+
eventPublisher.publishEvent(event);
22+
} else {
23+
final var event = mapPredictionFailedEvent(throwable).apply(predictionResponse);
24+
eventPublisher.publishEvent(event);
25+
}
26+
}
27+
28+
Function<PredictionResponseAndModel, PredictionUpdatedEvent> mapPredictionUpdatedEvent() {
29+
return predictionResponse -> {
30+
final var status = PredictionStatus.fromString(predictionResponse.getPredictionResponse().getStatus());
31+
return new PredictionUpdatedEvent(
32+
predictionResponse.getPredictionResponse().getId(),
33+
predictionResponse.getModel(),
34+
status,
35+
predictionResponse.getPredictionResponse()
36+
);
37+
};
38+
}
39+
40+
Function<PredictionResponseAndModel, PredictionFailedEvent> mapPredictionFailedEvent(@NonNull Throwable throwable) {
41+
return predictionResponse -> new PredictionFailedEvent(
42+
predictionResponse != null ? predictionResponse.getPredictionResponse().getId() : null,
43+
predictionResponse != null ? predictionResponse.getModel() : null,
44+
throwable
45+
);
46+
}
47+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package io.graversen.replicate.facade;
2+
3+
import io.graversen.replicate.client.feign.FeignUtils;
4+
import io.graversen.replicate.client.replicate.Replicate;
5+
import io.graversen.replicate.service.PredictionResponse;
6+
import io.graversen.replicate.util.ReplicateUtils;
7+
import jakarta.annotation.Nullable;
8+
import lombok.NonNull;
9+
import lombok.RequiredArgsConstructor;
10+
import lombok.SneakyThrows;
11+
import lombok.extern.slf4j.Slf4j;
12+
import org.springframework.stereotype.Component;
13+
14+
import java.util.LinkedHashMap;
15+
import java.util.function.Function;
16+
17+
@Slf4j
18+
@Component
19+
@RequiredArgsConstructor
20+
public class PollPredictionStatusTask implements Function<PredictionResponseAndModel, PredictionResponseAndModel> {
21+
private final @NonNull PredictionRetryPolicy retryPolicy;
22+
private final @NonNull Replicate replicate;
23+
private final @NonNull FeignUtils feignUtils;
24+
25+
@Override
26+
@SneakyThrows
27+
public PredictionResponseAndModel apply(@Nullable PredictionResponseAndModel predictionResponse) {
28+
if (predictionResponse == null) {
29+
return null;
30+
}
31+
32+
final var predictionId = predictionResponse.getPredictionResponse().getId();
33+
log.debug("[Prediction: {}]: Prediction retrieved for processing", predictionId);
34+
35+
for (int attempt = 1; attempt <= retryPolicy.getMaxAttempts(); attempt++) {
36+
log.debug("[Prediction: {}]: Attempt {} of {} to retrieve prediction status", predictionId, attempt, retryPolicy.getMaxAttempts());
37+
final var getPredictionResponse = replicate.getPrediction(predictionId);
38+
final LinkedHashMap<String, Object> convertedResponse = feignUtils.convert(getPredictionResponse, LinkedHashMap.class);
39+
final var mappedResponse = ReplicateUtils.mapPredictionResponse(convertedResponse);
40+
41+
if (mappedResponse.isPresent() && PredictionStatus.SUCCEEDED.asString().equals(mappedResponse.get().getStatus())) {
42+
log.info("[Prediction: {}]: Prediction succeeded on attempt {}/{}", predictionId, attempt, retryPolicy.getMaxAttempts());
43+
return new PredictionResponseAndModel(mappedResponse.get(), predictionResponse.getModel());
44+
}
45+
46+
if (mappedResponse.isPresent() && PredictionStatus.FAILED.asString().equals(mappedResponse.get().getStatus())) {
47+
log.error("[Prediction: {}]: Prediction failed on attempt {}/{}", predictionId, attempt, retryPolicy.getMaxAttempts());
48+
throw new IllegalStateException(String.format("Prediction '%s' failed: %s", predictionId, mappedResponse.get().getStatus()));
49+
}
50+
51+
try {
52+
log.debug("[Prediction: {}]: Prediction not yet succeeded (status: {}), retrying after {} ms", predictionId, mappedResponse.map(PredictionResponse::getStatus).orElse("unknown"), retryPolicy.getDelay().toMillis());
53+
Thread.sleep(retryPolicy.getDelay().toMillis());
54+
} catch (InterruptedException e) {
55+
Thread.currentThread().interrupt();
56+
throw new IllegalStateException("Thread was interrupted during retry delay", e);
57+
}
58+
}
59+
60+
log.error("[Prediction: {}]: Max retry attempts reached without success", predictionId);
61+
throw new IllegalStateException(String.format("Max retry attempts reached for prediction '%s' without success", predictionId));
62+
}
63+
}

0 commit comments

Comments
 (0)