Skip to content

Commit 39f0949

Browse files
committed
feat: add number providers to module/number
1 parent 09862f4 commit 39f0949

5 files changed

Lines changed: 215 additions & 0 deletions

File tree

modules/number/build.gradle.kts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
plugins {
2+
alias(libs.plugins.kotlin.serialization)
3+
}
4+
5+
dependencies {
6+
compileOnly(libs.kotlinx.serialization.core)
7+
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package net.azisaba.data
2+
3+
import kotlinx.serialization.SerialName
4+
import kotlinx.serialization.Serializable
5+
import kotlin.math.PI
6+
import kotlin.math.cos
7+
import kotlin.math.ln
8+
import kotlin.math.sqrt
9+
import kotlin.random.Random
10+
11+
@Serializable
12+
sealed interface FloatProvider : NumberProvider<Float> {
13+
@Serializable
14+
@SerialName("Constant")
15+
data class Constant(val value: Float) : FloatProvider {
16+
override fun sample(random: Random): Float = value
17+
}
18+
19+
@Serializable
20+
@SerialName("Uniform")
21+
data class Uniform(val minInclusive: Float, val maxExclusive: Float) : FloatProvider {
22+
init {
23+
require(maxExclusive > minInclusive) {
24+
"maxExclusive must be greater than minInclusive: [$minInclusive, $maxExclusive]"
25+
}
26+
}
27+
28+
override fun sample(random: Random): Float = random.nextFloat() * (maxExclusive - minInclusive) + minInclusive
29+
}
30+
31+
@Serializable
32+
@SerialName("ClampedNormal")
33+
data class ClampedNormal(
34+
val mean: Float,
35+
val deviation: Float,
36+
val min: Float,
37+
val max: Float,
38+
) : FloatProvider {
39+
init {
40+
require(max >= min) {
41+
"max must be greater than or equal to min: [$min, $max]"
42+
}
43+
}
44+
45+
override fun sample(random: Random): Float =
46+
(random.nextGaussian().toFloat() * deviation + mean).coerceIn(min, max)
47+
48+
private fun Random.nextGaussian(): Double {
49+
val u1 = nextDouble()
50+
val u2 = nextDouble()
51+
return sqrt(-2.0 * ln(u1)) * cos(2.0 * PI * u2)
52+
}
53+
}
54+
55+
@Serializable
56+
@SerialName("Trapezoid")
57+
data class Trapezoid(val min: Float, val max: Float, val plateau: Float) : FloatProvider {
58+
init {
59+
require(max >= min) {
60+
"max must be greater than or equal to min: [$min, $max]"
61+
}
62+
require(plateau <= max - min) {
63+
"plateau must be less than or equal to the full span: [$min, $max]"
64+
}
65+
}
66+
67+
override fun sample(random: Random): Float {
68+
val span = max - min
69+
val slope = (span - plateau) / 2.0f
70+
val base = span - slope
71+
return min + random.nextFloat() * base + random.nextFloat() * slope
72+
}
73+
}
74+
75+
@Serializable
76+
@SerialName("Multiplied")
77+
data class Multiplied(val values: List<FloatProvider>) : FloatProvider {
78+
init {
79+
require(values.isNotEmpty()) {
80+
"values must not be empty"
81+
}
82+
}
83+
84+
override fun sample(random: Random): Float {
85+
var product = 1.0f
86+
for (value in values) {
87+
product *= value.sample(random)
88+
}
89+
return product
90+
}
91+
}
92+
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package net.azisaba.data
2+
3+
import kotlinx.serialization.SerialName
4+
import kotlinx.serialization.Serializable
5+
import kotlin.math.PI
6+
import kotlin.math.cos
7+
import kotlin.math.ln
8+
import kotlin.math.sqrt
9+
import kotlin.random.Random
10+
11+
@Serializable
12+
sealed interface IntProvider : NumberProvider<Int> {
13+
@Serializable
14+
@SerialName("Constant")
15+
data class Constant(val value: Int) : IntProvider {
16+
override fun sample(random: Random): Int = value
17+
}
18+
19+
@Serializable
20+
@SerialName("Uniform")
21+
data class Uniform(val minInclusive: Int, val maxInclusive: Int) : IntProvider {
22+
init {
23+
require(maxInclusive >= minInclusive) {
24+
"maxInclusive must be greater than or equal to minInclusive: [$minInclusive, $maxInclusive]"
25+
}
26+
}
27+
28+
override fun sample(random: Random): Int = random.nextInt(minInclusive, maxInclusive + 1)
29+
}
30+
31+
@Serializable
32+
@SerialName("BiasedToBottom")
33+
data class BiasedToBottom(val minInclusive: Int, val maxInclusive: Int) : IntProvider {
34+
init {
35+
require(maxInclusive >= minInclusive) {
36+
"maxInclusive must be greater than or equal to minInclusive: [$minInclusive, $maxInclusive]"
37+
}
38+
}
39+
40+
override fun sample(random: Random): Int =
41+
minInclusive + random.nextInt(random.nextInt(maxInclusive - minInclusive + 1) + 1)
42+
}
43+
44+
@Serializable
45+
@SerialName("Clamped")
46+
data class Clamped(val source: IntProvider, val minInclusive: Int, val maxInclusive: Int) : IntProvider {
47+
init {
48+
require(maxInclusive >= minInclusive) {
49+
"maxInclusive must be greater than or equal to minInclusive: [$minInclusive, $maxInclusive]"
50+
}
51+
}
52+
53+
override fun sample(random: Random): Int = source.sample(random).coerceIn(minInclusive, maxInclusive)
54+
}
55+
56+
@Serializable
57+
@SerialName("ClampedNormal")
58+
data class ClampedNormal(
59+
val mean: Float, val deviation: Float, val minInclusive: Int, val maxInclusive: Int,
60+
) : IntProvider {
61+
init {
62+
require(maxInclusive >= minInclusive) {
63+
"maxInclusive must be greater than or equal to minInclusive: [$minInclusive, $maxInclusive]"
64+
}
65+
}
66+
67+
override fun sample(random: Random): Int =
68+
(random.nextGaussian().toFloat() * deviation + mean)
69+
.coerceIn(minInclusive.toFloat(), maxInclusive.toFloat())
70+
.toInt()
71+
72+
private fun Random.nextGaussian(): Double {
73+
val u1 = nextDouble()
74+
val u2 = nextDouble()
75+
return sqrt(-2.0 * ln(u1)) * cos(2.0 * PI * u2)
76+
}
77+
}
78+
79+
@Serializable
80+
@SerialName("WeightedList")
81+
data class WeightedList(val distribution: List<Entry>) : IntProvider {
82+
init {
83+
require(distribution.isNotEmpty()) {
84+
"distribution must not be empty"
85+
}
86+
require(distribution.all { it.weight > 0 }) {
87+
"distribution weights must be greater than 0"
88+
}
89+
}
90+
91+
override fun sample(random: Random): Int {
92+
val totalWeight = distribution.sumOf(Entry::weight)
93+
val roll = random.nextInt(totalWeight)
94+
var accumulatedWeight = 0
95+
for (entry in distribution) {
96+
accumulatedWeight += entry.weight
97+
if (roll < accumulatedWeight) {
98+
return entry.provider.sample(random)
99+
}
100+
}
101+
102+
return distribution.last().provider.sample(random)
103+
}
104+
105+
@Serializable
106+
data class Entry(val provider: IntProvider, val weight: Int)
107+
}
108+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package net.azisaba.data
2+
3+
import kotlin.random.Random
4+
5+
interface NumberProvider<T : Number> {
6+
fun sample(random: Random): T
7+
}

settings.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ rootProject.name = "data-driven"
66

77
include(":modules:core")
88
include(":modules:json")
9+
include(":modules:number")
910
include(":modules:yaml")

0 commit comments

Comments
 (0)