Skip to content

Commit 46b3175

Browse files
committed
✨ Support Flux model family mapping
1 parent d77348b commit 46b3175

6 files changed

Lines changed: 170 additions & 0 deletions

File tree

spring-boot-starter-replicate/src/main/java/io/graversen/replicate/common/TextToImageAspectRatio.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,18 @@ public class TextToImageAspectRatio {
88
@NonNull Integer height;
99
@NonNull AspectRatios aspectRatio;
1010

11+
public static TextToImageAspectRatio square() {
12+
return new TextToImageAspectRatio(1080, AspectRatios.RATIO_1_BY_1);
13+
}
14+
15+
public static TextToImageAspectRatio portrait() {
16+
return new TextToImageAspectRatio(1350, AspectRatios.RATIO_4_BY_5);
17+
}
18+
19+
public static TextToImageAspectRatio reel() {
20+
return new TextToImageAspectRatio(1920, AspectRatios.RATIO_9_BY_16);
21+
}
22+
1123
public Integer getWidth() {
1224
return (int) Math.round(height * ((double) aspectRatio.getWidthRatio() / aspectRatio.getHeightRatio()));
1325
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package io.graversen.replicate.common;
2+
3+
import jakarta.annotation.Nullable;
4+
import lombok.NonNull;
5+
import lombok.Value;
6+
7+
@Value
8+
public class TextToImagePrompt {
9+
@NonNull String prompt;
10+
@Nullable String negativePrompt;
11+
@NonNull TextToImageAspectRatio aspectRatio;
12+
@Nullable Double promptStrength;
13+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package io.graversen.replicate.configuration;
2+
3+
import org.springframework.context.annotation.ComponentScan;
4+
import org.springframework.context.annotation.Configuration;
5+
6+
@Configuration
7+
@ComponentScan("io.graversen.replicate.flux")
8+
public class FluxConfiguration {
9+
10+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package io.graversen.replicate.flux;
2+
3+
import io.graversen.replicate.common.ReplicateModel;
4+
import lombok.experimental.UtilityClass;
5+
6+
@UtilityClass
7+
public class FluxModels {
8+
public static final ReplicateModel FLUX_SCHNELL = new ReplicateModel("black-forest-labs", "flux-schnell");
9+
public static final ReplicateModel FLUX_DEV = new ReplicateModel("black-forest-labs", "flux-dev");
10+
public static final ReplicateModel FLUX_PRO = new ReplicateModel("black-forest-labs", "flux-pro");
11+
}
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package io.graversen.replicate.flux;
2+
3+
import io.graversen.replicate.common.BasePredictionMapper;
4+
import io.graversen.replicate.common.PredictionTypes;
5+
import io.graversen.replicate.common.ReplicateModel;
6+
import io.graversen.replicate.models.*;
7+
import io.graversen.replicate.service.CreateImagePrediction;
8+
import lombok.NonNull;
9+
import org.springframework.stereotype.Component;
10+
11+
import java.util.Set;
12+
import java.util.function.Function;
13+
14+
@Component
15+
public class FluxPredictionMapper extends BasePredictionMapper<CreateImagePrediction, Object> {
16+
@Override
17+
protected Set<ReplicateModel> supportedModels() {
18+
return Set.of(
19+
FluxModels.FLUX_DEV,
20+
FluxModels.FLUX_SCHNELL,
21+
FluxModels.FLUX_PRO
22+
);
23+
}
24+
25+
@Override
26+
public Object apply(@NonNull ReplicateModel model, @NonNull CreateImagePrediction createPrediction) {
27+
if (model.equals(FluxModels.FLUX_DEV)) {
28+
return new BlackForestLabsFluxDevPredictionrequest(
29+
null,
30+
mapFluxDevInput().apply(createPrediction),
31+
null,
32+
null,
33+
null,
34+
null
35+
);
36+
} else if (model.equals(FluxModels.FLUX_SCHNELL)) {
37+
return new BlackForestLabsFluxSchnellPredictionrequest(
38+
null,
39+
mapFluxSchnellInput().apply(createPrediction),
40+
null,
41+
null,
42+
null,
43+
null
44+
);
45+
} else if (model.equals(FluxModels.FLUX_PRO)) {
46+
return new BlackForestLabsFluxProPredictionrequest(
47+
null,
48+
mapFluxProInput().apply(createPrediction),
49+
null,
50+
null,
51+
null,
52+
null
53+
);
54+
} else {
55+
throw new IllegalArgumentException("Unsupported Replicate Model: " + model);
56+
}
57+
}
58+
59+
@Override
60+
public boolean supportsType(@NonNull PredictionTypes type) {
61+
return PredictionTypes.IMAGE.equals(type);
62+
}
63+
64+
Function<CreateImagePrediction, BlackForestLabsFluxDevInput> mapFluxDevInput() {
65+
return createImagePrediction -> new BlackForestLabsFluxDevInput(
66+
createImagePrediction.getSeed(),
67+
null,
68+
createImagePrediction.getPrompt().getPrompt(),
69+
false,
70+
createImagePrediction.getPrompt().getPromptStrength(),
71+
null,
72+
createImagePrediction.getOutputs(),
73+
createImagePrediction.getPrompt().getAspectRatio().getAspectRatioAsString(),
74+
"png",
75+
100,
76+
null,
77+
createImagePrediction.getInferenceSteps(),
78+
true
79+
);
80+
}
81+
82+
Function<CreateImagePrediction, BlackForestLabsFluxSchnellInput> mapFluxSchnellInput() {
83+
return createImagePrediction -> new BlackForestLabsFluxSchnellInput(
84+
createImagePrediction.getSeed(),
85+
createImagePrediction.getPrompt().getPrompt(),
86+
false,
87+
null,
88+
createImagePrediction.getOutputs(),
89+
createImagePrediction.getPrompt().getAspectRatio().getAspectRatioAsString(),
90+
"png",
91+
100,
92+
true
93+
);
94+
}
95+
96+
Function<CreateImagePrediction, BlackForestLabsFluxProInput> mapFluxProInput() {
97+
return createImagePrediction -> new BlackForestLabsFluxProInput(
98+
createImagePrediction.getSeed(),
99+
createImagePrediction.getInferenceSteps(),
100+
createImagePrediction.getPrompt().getPrompt(),
101+
createImagePrediction.getPrompt().getPromptStrength(),
102+
null,
103+
createImagePrediction.getPrompt().getAspectRatio().getAspectRatioAsString(),
104+
5
105+
);
106+
}
107+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package io.graversen.replicate.service;
2+
3+
import io.graversen.replicate.common.TextToImagePrompt;
4+
import jakarta.annotation.Nullable;
5+
import lombok.Getter;
6+
import lombok.NonNull;
7+
import lombok.RequiredArgsConstructor;
8+
9+
@Getter
10+
@RequiredArgsConstructor
11+
public class CreateImagePrediction {
12+
private final @NonNull TextToImagePrompt prompt;
13+
private final @Nullable Integer outputs;
14+
private final @Nullable Integer inferenceSteps;
15+
private final @Nullable Integer seed;
16+
17+
}

0 commit comments

Comments
 (0)