Skip to content

Commit 3e7e25c

Browse files
committed
⚡ Emit PredictionCreatedEvent, response contains input object
1 parent efa769a commit 3e7e25c

6 files changed

Lines changed: 62 additions & 26 deletions

File tree

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package io.graversen.replicate.facade;
2+
3+
import jakarta.annotation.Nullable;
4+
import lombok.NonNull;
5+
import lombok.RequiredArgsConstructor;
6+
import lombok.SneakyThrows;
7+
import org.springframework.context.ApplicationEventPublisher;
8+
import org.springframework.stereotype.Component;
9+
10+
import java.util.Optional;
11+
import java.util.concurrent.CompletionException;
12+
import java.util.function.Function;
13+
import java.util.function.Supplier;
14+
15+
@Component
16+
@RequiredArgsConstructor
17+
public class CheckAndEmitPredictionCreationTask implements Function<Optional<PredictionResponseAndModel>, PredictionResponseAndModel> {
18+
private final @NonNull ApplicationEventPublisher eventPublisher;
19+
20+
@Override
21+
@SneakyThrows
22+
public PredictionResponseAndModel apply(@Nullable Optional<PredictionResponseAndModel> predictionResponse) {
23+
final var response = predictionResponse.orElseThrow(predictionFailedError());
24+
final var event = mapPredictionCreatedEvent().apply(response);
25+
eventPublisher.publishEvent(event);
26+
return response;
27+
}
28+
29+
Function<PredictionResponseAndModel, PredictionCreatedEvent> mapPredictionCreatedEvent() {
30+
return predictionResponse -> new PredictionCreatedEvent(
31+
predictionResponse.getPredictionResponse().getId(),
32+
predictionResponse.getModel(),
33+
predictionResponse.getPredictionResponse()
34+
);
35+
}
36+
37+
private Supplier<Exception> predictionFailedError() {
38+
return () -> new CompletionException(new RuntimeException("Could not initiate prediction"));
39+
}
40+
}

spring-boot-starter-replicate/src/main/java/io/graversen/replicate/facade/CheckPredictionStateTask.java

Lines changed: 0 additions & 23 deletions
This file was deleted.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package io.graversen.replicate.facade;
2+
3+
import io.graversen.replicate.common.ReplicateModel;
4+
import io.graversen.replicate.service.PredictionResponse;
5+
import lombok.AccessLevel;
6+
import lombok.NonNull;
7+
import lombok.RequiredArgsConstructor;
8+
import lombok.Value;
9+
10+
@Value
11+
@RequiredArgsConstructor(access = AccessLevel.PACKAGE)
12+
public class PredictionCreatedEvent {
13+
@NonNull String id;
14+
@NonNull ReplicateModel model;
15+
@NonNull PredictionResponse prediction;
16+
}

spring-boot-starter-replicate/src/main/java/io/graversen/replicate/facade/ReplicateFacade.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
public class ReplicateFacade {
2020
private final @NonNull ReplicateService replicateService;
2121
private final @NonNull ExecutorService executorService;
22-
private final @NonNull CheckPredictionStateTask checkPredictionStateTask;
22+
private final @NonNull CheckAndEmitPredictionCreationTask checkAndEmitPredictionCreationTask;
2323
private final @NonNull PollPredictionStatusTask pollPredictionStatusTask;
2424
private final @NonNull EmitPredictionResponseTask emitPredictionResponseTask;
2525

@@ -29,7 +29,7 @@ public CompletableFuture<PredictionResponseAndModel> createPrediction(
2929
) {
3030
return CompletableFuture
3131
.supplyAsync(doCreatePrediction(model, createPrediction), executorService)
32-
.thenApplyAsync(checkPredictionStateTask, executorService)
32+
.thenApplyAsync(checkAndEmitPredictionCreationTask, executorService)
3333
.thenApplyAsync(pollPredictionStatusTask, executorService)
3434
.whenCompleteAsync(emitPredictionResponseTask, executorService);
3535
}

spring-boot-starter-replicate/src/main/java/io/graversen/replicate/service/PredictionResponse.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ public class PredictionResponse {
2323
OffsetDateTime completedAt;
2424
String error;
2525
String status;
26+
Object input;
2627
Object output;
2728
PredictionUrls urls;
2829

spring-boot-starter-replicate/src/main/java/io/graversen/replicate/util/ReplicateUtils.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ public class ReplicateUtils {
2424
private static final String ATTRIBUTE_ERROR = "error";
2525
private static final String ATTRIBUTE_STATUS = "status";
2626
private static final String ATTRIBUTE_OUTPUT = "output";
27+
private static final String ATTRIBUTE_INPUT = "input";
2728
private static final String ATTRIBUTE_URLS = "urls";
2829
private static final String ATTRIBUTE_CANCEL_URL = "cancel";
2930
private static final String ATTRIBUTE_GET_URL = "get";
@@ -37,6 +38,7 @@ public static Optional<PredictionResponse> mapPredictionResponse(@NonNull Linked
3738
final var completedAt = parseOffsetDateTime((String) responseMap.get(ATTRIBUTE_COMPLETED_AT));
3839
final var error = (String) responseMap.get(ATTRIBUTE_ERROR);
3940
final var status = (String) responseMap.get(ATTRIBUTE_STATUS);
41+
final var input = responseMap.get(ATTRIBUTE_INPUT);
4042
final var output = responseMap.get(ATTRIBUTE_OUTPUT);
4143

4244
final var urlsMap = (LinkedHashMap<String, Object>) responseMap.get(ATTRIBUTE_URLS);
@@ -48,7 +50,7 @@ public static Optional<PredictionResponse> mapPredictionResponse(@NonNull Linked
4850
urls = new PredictionUrls(cancelUrl, getUrl);
4951
}
5052

51-
final var predictionResponse = new PredictionResponse(id, version, createdAt, startedAt, completedAt, error, status, output, urls);
53+
final var predictionResponse = new PredictionResponse(id, version, createdAt, startedAt, completedAt, error, status, input, output, urls);
5254
return Optional.of(predictionResponse);
5355
} catch (Exception e) {
5456
log.error(e.getMessage(), e);

0 commit comments

Comments
 (0)