Skip to content

Commit cda8bb5

Browse files
committed
🚧 Flux 2 support and better image generation abstraction
1 parent ab544a7 commit cda8bb5

13 files changed

Lines changed: 285 additions & 116 deletions

spring-boot-starter-replicate/src/main/java/io/graversen/replicate/common/AspectRatios.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,8 @@ public static Optional<AspectRatios> fromValue(String value) {
3434
public String getAspectRatio() {
3535
return widthRatio + ":" + heightRatio;
3636
}
37+
38+
public static AspectRatios defaultValue() {
39+
return RATIO_1_BY_1;
40+
}
3741
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package io.graversen.replicate.common;
2+
3+
public enum ModerationLevels {
4+
HIGH,
5+
LOW;
6+
7+
public static ModerationLevels defaultValue() {
8+
return HIGH;
9+
}
10+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package io.graversen.replicate.common;
2+
3+
public enum OutputFormats {
4+
WEBP,
5+
JPG,
6+
PNG
7+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package io.graversen.replicate.common;
2+
3+
public enum QualityModes {
4+
HIGHER_QUALITY,
5+
HIGHER_PERFORMANCE
6+
}

spring-boot-starter-replicate/src/main/java/io/graversen/replicate/facade/PredictionRetryPolicy.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@ public PredictionRetryPolicy(@NonNull Integer maxAttempts, @NonNull Duration del
2222
}
2323

2424
public static PredictionRetryPolicy defaultPolicy() {
25-
return new PredictionRetryPolicy(50, Duration.ofMillis(200));
25+
return new PredictionRetryPolicy(50, Duration.ofMillis(250));
2626
}
2727
}

spring-boot-starter-replicate/src/main/java/io/graversen/replicate/facade/ReplicateFacade.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package io.graversen.replicate.facade;
22

33
import io.graversen.replicate.common.ReplicateModel;
4-
import io.graversen.replicate.service.CreateImagePrediction;
4+
import io.graversen.replicate.service.CreateImagePrediction2;
55
import io.graversen.replicate.service.CreateTextPrediction;
66
import io.graversen.replicate.service.ReplicateService;
77
import lombok.NonNull;
@@ -37,7 +37,7 @@ public CompletableFuture<PredictionResponseAndModel> createPrediction(
3737

3838
public CompletableFuture<PredictionResponseAndModel> createPrediction(
3939
@NonNull ReplicateModel model,
40-
@NonNull CreateImagePrediction createPrediction
40+
@NonNull CreateImagePrediction2 createPrediction
4141
) {
4242
return CompletableFuture
4343
.supplyAsync(doCreatePrediction(model, createPrediction), executorService)
@@ -56,7 +56,7 @@ Supplier<Optional<PredictionResponseAndModel>> doCreatePrediction(
5656

5757
Supplier<Optional<PredictionResponseAndModel>> doCreatePrediction(
5858
@NonNull ReplicateModel model,
59-
@NonNull CreateImagePrediction createPrediction
59+
@NonNull CreateImagePrediction2 createPrediction
6060
) {
6161
return () -> replicateService.createPrediction(model, createPrediction)
6262
.map(predictionResponse -> new PredictionResponseAndModel(predictionResponse, model));
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
package io.graversen.replicate.flux;
2+
3+
import io.graversen.replicate.common.*;
4+
import io.graversen.replicate.models.*;
5+
import io.graversen.replicate.service.CreateImagePrediction2;
6+
import lombok.NonNull;
7+
import lombok.RequiredArgsConstructor;
8+
import lombok.SneakyThrows;
9+
import org.springframework.stereotype.Component;
10+
11+
import java.net.URI;
12+
import java.util.List;
13+
import java.util.Objects;
14+
import java.util.Set;
15+
import java.util.function.Function;
16+
17+
@Component
18+
@RequiredArgsConstructor
19+
public class Flux2PredictionMapper extends BasePredictionMapper<CreateImagePrediction2, Object> {
20+
21+
@Override
22+
protected Set<ReplicateModel> supportedModels() {
23+
return Set.of(
24+
FluxModels.FLUX_2_DEV,
25+
FluxModels.FLUX_2_FLEX,
26+
FluxModels.FLUX_2_PRO,
27+
FluxModels.FLUX_2_MAX
28+
);
29+
}
30+
31+
@Override
32+
public Object apply(@NonNull ReplicateModel model, @NonNull CreateImagePrediction2 createPrediction) {
33+
if (model.equals(FluxModels.FLUX_2_DEV)) {
34+
return new BlackForestLabsFlux2DevPredictionrequest(
35+
null,
36+
mapFluxDevInput().apply(createPrediction),
37+
null,
38+
null,
39+
null,
40+
null,
41+
null
42+
);
43+
} else if (model.equals(FluxModels.FLUX_2_FLEX)) {
44+
return new BlackForestLabsFlux2FlexPredictionrequest(
45+
null,
46+
mapFluxFlexInput().apply(createPrediction),
47+
null,
48+
null,
49+
null,
50+
null,
51+
null
52+
);
53+
} else if (model.equals(FluxModels.FLUX_2_PRO)) {
54+
return new BlackForestLabsFlux2ProPredictionrequest(
55+
null,
56+
mapFluxProInput().apply(createPrediction),
57+
null,
58+
null,
59+
null,
60+
null,
61+
null
62+
);
63+
} else if (model.equals(FluxModels.FLUX_2_MAX)) {
64+
return new BlackForestLabsFlux2MaxPredictionrequest(
65+
null,
66+
mapFluxMaxInput().apply(createPrediction),
67+
null,
68+
null,
69+
null,
70+
null,
71+
null
72+
);
73+
} else {
74+
throw new IllegalArgumentException("Unsupported Replicate Model: " + model);
75+
}
76+
77+
}
78+
79+
@Override
80+
public boolean supportsType(@NonNull PredictionTypes type) {
81+
return PredictionTypes.IMAGE.equals(type);
82+
}
83+
84+
Function<CreateImagePrediction2, BlackForestLabsFlux2DevInput> mapFluxDevInput() {
85+
return createImagePrediction -> new BlackForestLabsFlux2DevInput(
86+
null,
87+
null,
88+
null,
89+
createImagePrediction.getPrompt(),
90+
mapGoFast(createImagePrediction),
91+
mapAspectRatio(createImagePrediction),
92+
mapInputImages(createImagePrediction),
93+
mapOutputFormat(createImagePrediction),
94+
100,
95+
mapDisableSafetyChecker(createImagePrediction)
96+
);
97+
}
98+
99+
Function<CreateImagePrediction2, BlackForestLabsFlux2FlexInput> mapFluxFlexInput() {
100+
return createImagePrediction -> new BlackForestLabsFlux2FlexInput(
101+
null,
102+
null,
103+
null,
104+
null,
105+
createImagePrediction.getPrompt(),
106+
null,
107+
null,
108+
mapAspectRatio(createImagePrediction),
109+
mapInputImages(createImagePrediction),
110+
mapOutputFormat(createImagePrediction),
111+
100,
112+
mapSafetyTolerance(createImagePrediction),
113+
true
114+
);
115+
}
116+
117+
Function<CreateImagePrediction2, BlackForestLabsFlux2MaxInput> mapFluxMaxInput() {
118+
return createImagePrediction -> new BlackForestLabsFlux2MaxInput(
119+
null,
120+
null,
121+
null,
122+
createImagePrediction.getPrompt(),
123+
null,
124+
mapAspectRatio(createImagePrediction),
125+
mapInputImages(createImagePrediction),
126+
mapOutputFormat(createImagePrediction),
127+
100,
128+
mapSafetyTolerance(createImagePrediction)
129+
);
130+
}
131+
132+
Function<CreateImagePrediction2, BlackForestLabsFlux2ProInput> mapFluxProInput() {
133+
return createImagePrediction -> new BlackForestLabsFlux2ProInput(
134+
null,
135+
null,
136+
null,
137+
createImagePrediction.getPrompt(),
138+
null,
139+
mapAspectRatio(createImagePrediction),
140+
mapInputImages(createImagePrediction),
141+
mapOutputFormat(createImagePrediction),
142+
100,
143+
mapSafetyTolerance(createImagePrediction)
144+
);
145+
}
146+
147+
private Boolean mapDisableSafetyChecker(@NonNull CreateImagePrediction2 createImagePrediction) {
148+
if (createImagePrediction.getModerationLevel() != null) {
149+
return createImagePrediction.getModerationLevel() == ModerationLevels.HIGH
150+
? Boolean.FALSE
151+
: Boolean.TRUE;
152+
} else {
153+
return Boolean.TRUE;
154+
}
155+
}
156+
157+
private Integer mapSafetyTolerance(@NonNull CreateImagePrediction2 createImagePrediction) {
158+
if (createImagePrediction.getModerationLevel() != null) {
159+
return createImagePrediction.getModerationLevel() == ModerationLevels.HIGH
160+
? 2
161+
: 5;
162+
} else {
163+
return 5;
164+
}
165+
}
166+
167+
private String mapOutputFormat(@NonNull CreateImagePrediction2 createImagePrediction) {
168+
return Objects.requireNonNullElse(createImagePrediction.getOutputFormat(), OutputFormats.PNG).toString().toLowerCase();
169+
}
170+
171+
private List<URI> mapInputImages(@NonNull CreateImagePrediction2 createImagePrediction) {
172+
return Objects.requireNonNullElseGet(createImagePrediction.getInputImages(), Set::<String>of).stream()
173+
.map(URI::create)
174+
.toList();
175+
}
176+
177+
private String mapAspectRatio(@NonNull CreateImagePrediction2 createImagePrediction) {
178+
return Objects.requireNonNullElse(createImagePrediction.getAspectRatio(), AspectRatios.defaultValue()).getAspectRatio();
179+
}
180+
181+
private Boolean mapGoFast(@NonNull CreateImagePrediction2 createImagePrediction) {
182+
if (createImagePrediction.getQuality() != null) {
183+
return createImagePrediction.getQuality() == QualityModes.HIGHER_PERFORMANCE
184+
? Boolean.TRUE
185+
: Boolean.FALSE;
186+
} else {
187+
return Boolean.TRUE;
188+
}
189+
}
190+
}

spring-boot-starter-replicate/src/main/java/io/graversen/replicate/flux/FluxModels.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
@UtilityClass
77
public class FluxModels {
8-
public static final ReplicateModel FLUX_SCHNELL = new ReplicateModel("black-forest-labs", "flux-schnell");
9-
public static final ReplicateModel FLUX_DEV = new ReplicateModel("black-forest-labs", "flux-dev");
10-
public static final ReplicateModel FLUX_PRO = new ReplicateModel("black-forest-labs", "flux-pro");
8+
public static final ReplicateModel FLUX_2_DEV = new ReplicateModel("black-forest-labs", "flux-2-dev");
9+
public static final ReplicateModel FLUX_2_FLEX = new ReplicateModel("black-forest-labs", "flux-2-flex");
10+
public static final ReplicateModel FLUX_2_PRO = new ReplicateModel("black-forest-labs", "flux-2-pro");
11+
public static final ReplicateModel FLUX_2_MAX = new ReplicateModel("black-forest-labs", "flux-2-max");
1112
}

spring-boot-starter-replicate/src/main/java/io/graversen/replicate/flux/FluxPredictionMapper.java

Lines changed: 0 additions & 107 deletions
This file was deleted.

spring-boot-starter-replicate/src/main/java/io/graversen/replicate/service/CreateImagePrediction.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,4 @@ public class CreateImagePrediction {
1313
private final @Nullable Integer outputs;
1414
private final @Nullable Integer inferenceSteps;
1515
private final @Nullable Integer seed;
16-
1716
}

0 commit comments

Comments
 (0)