-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsorting.py
More file actions
230 lines (180 loc) · 7.42 KB
/
sorting.py
File metadata and controls
230 lines (180 loc) · 7.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
"""
Branchless Sorting — Compare-and-swap without if-statements.
PATTERN: Sorting Networks
A sorting network is a fixed sequence of compare-and-swap operations.
Unlike quicksort or mergesort, the comparisons don't depend on the data —
the SAME swaps happen every time, regardless of input order.
This makes sorting networks:
- Branchless (no data-dependent if-statements)
- Parallelizable (independent swaps can run simultaneously)
- Predictable (constant execution time)
Trade-off: Only practical for small, fixed-size arrays (2-16 elements).
For larger arrays, traditional sorting algorithms are faster.
"""
def branchless_swap(arr, i, j):
"""
If arr[i] > arr[j], swap them. Otherwise, do nothing.
No if-statement — uses arithmetic to conditionally swap.
TRICK:
should_swap = int(a > b) → 1 if out of order, 0 if already sorted
diff = a - b
new arr[i] = a - diff * should_swap
new arr[j] = b + diff * should_swap
WORKED EXAMPLE (a=5, b=2 → needs swap):
diff = 5 - 2 = 3
should_swap = int(5 > 2) = 1
arr[i] = 5 - 3 * 1 = 2 ✓ (smaller value)
arr[j] = 2 + 3 * 1 = 5 ✓ (larger value)
WORKED EXAMPLE (a=2, b=5 → already sorted):
diff = 2 - 5 = -3
should_swap = int(2 > 5) = 0
arr[i] = 2 - (-3) * 0 = 2 ✓ (unchanged)
arr[j] = 5 + (-3) * 0 = 5 ✓ (unchanged)
"""
a, b = arr[i], arr[j]
diff = a - b
should_swap = int(diff > 0) # 1 if a > b (out of order), else 0
arr[i] = a - diff * should_swap # a - (a-b) = b when swapping
arr[j] = b + diff * should_swap # b + (a-b) = a when swapping
def branchless_min_max(a, b):
"""
Return (min, max) of two values without branching.
TRICK: Use int(a < b) as a selector:
is_a_smaller = int(a < b) → 1 if a is smaller, 0 if b is smaller/equal
min = a * is_a_smaller + b * (1 - is_a_smaller)
When a < b: a * 1 + b * 0 = a ✓
When a >= b: a * 0 + b * 1 = b ✓
max = the opposite selector
WORKED EXAMPLE (a=5, b=3):
is_a_smaller = int(5 < 3) = 0
min = 5 * 0 + 3 * 1 = 3 ✓
max = 5 * 1 + 3 * 0 = 5 ✓
"""
is_a_smaller = int(a < b)
min_val = a * is_a_smaller + b * (1 - is_a_smaller)
max_val = a * (1 - is_a_smaller) + b * is_a_smaller
return min_val, max_val
def sorting_network_4(arr):
"""
Sort exactly 4 elements using a sorting network.
A sorting network for 4 elements uses 5 compare-and-swap operations
in a specific order. This order is mathematically proven to sort
any permutation of 4 elements.
The network (indices compared at each step):
Step 1: compare [0] vs [1]
Step 2: compare [2] vs [3] ← can run in parallel with step 1!
Step 3: compare [0] vs [2]
Step 4: compare [1] vs [3] ← can run in parallel with step 3!
Step 5: compare [1] vs [2]
Visual diagram (wires = values, X = swap point):
──[0]──X─────X─────── → smallest
──[1]──X───────X──X── → 2nd smallest
──[2]────X───X────X── → 3rd smallest
──[3]────X─────X───── → largest
"""
if len(arr) != 4:
return arr[:]
result = arr[:]
# Each branchless_swap puts the smaller value at the lower index.
branchless_swap(result, 0, 1) # Sort pairs
branchless_swap(result, 2, 3)
branchless_swap(result, 0, 2) # Sort across pairs
branchless_swap(result, 1, 3)
branchless_swap(result, 1, 2) # Final middle sort
return result
def branchless_clamp(value, min_val, max_val):
"""
Clamp value to [min_val, max_val] without branching.
Traditional (branching):
if value < min_val: return min_val
if value > max_val: return max_val
return value
Branchless:
Use int(condition) * correction to nudge value into range.
WORKED EXAMPLE (value=-5, min=0, max=10):
Step 1 — clamp to min:
diff_min = -5 - 0 = -5
below_min = int(-5 < 0) = 1
value = -5 - (-5) * 1 = -5 + 5 = 0 ✓ (snapped to min)
Step 2 — clamp to max:
diff_max = 0 - 10 = -10
above_max = int(-10 > 0) = 0
value = 0 - (-10) * 0 = 0 ✓ (no change, already in range)
WORKED EXAMPLE (value=15, min=0, max=10):
Step 1: diff_min = 15, below_min = 0 → value stays 15
Step 2: diff_max = 15 - 10 = 5, above_max = 1
value = 15 - 5 * 1 = 10 ✓ (snapped to max)
"""
# Clamp to minimum
diff_min = value - min_val
below_min = int(diff_min < 0)
value = value - diff_min * below_min # Snap to min_val if below
# Clamp to maximum
diff_max = value - max_val
above_max = int(diff_max > 0)
value = value - diff_max * above_max # Snap to max_val if above
return value
def branchless_abs(x):
"""
Absolute value without branching (for 32-bit integers).
TRICK: Use the sign bit as an XOR mask.
For 32-bit signed integers:
mask = x >> 31
If x >= 0: mask = 0x00000000 (all 0s)
If x < 0: mask = 0xFFFFFFFF (all 1s, = -1)
(x ^ mask) - mask
If x >= 0: (x ^ 0) - 0 = x ✓
If x < 0: (x ^ (-1)) - (-1) = ~x + 1 = -x ✓
WHY (~x + 1) equals -x:
In two's complement, flipping all bits (~x) gives -(x+1).
Adding 1 gives -x. This is how negative numbers work in binary!
WORKED EXAMPLE (x = -7):
mask = -7 >> 31 = -1 = 0xFFFFFFFF
x ^ mask = -7 ^ -1 = 6 (flips all bits: ~(-7) = 6)
6 - (-1) = 6 + 1 = 7 ✓
WORKED EXAMPLE (x = 5):
mask = 5 >> 31 = 0
x ^ 0 = 5
5 - 0 = 5 ✓
"""
mask = x >> 31 # All 1s if negative, all 0s if positive
return (x ^ mask) - mask
# ──────────────────────────────────────────────
# Demo
# ──────────────────────────────────────────────
if __name__ == "__main__":
print("Branchless Sorting Demo")
print("=" * 45)
# Test sorting network
test_arrays = [
[4, 2, 3, 1],
[1, 2, 3, 4],
[4, 3, 2, 1],
[2, 4, 1, 3],
]
print("\nSorting Network (4 elements):")
for arr in test_arrays:
sorted_arr = sorting_network_4(arr)
print(f" {arr} → {sorted_arr}")
# Test branchless min/max
print("\nBranchless Min/Max:")
pairs = [(5, 3), (2, 8), (4, 4), (-3, 2)]
for a, b in pairs:
min_v, max_v = branchless_min_max(a, b)
print(f" min({a:+d}, {b:+d}) = {min_v:+d} max({a:+d}, {b:+d}) = {max_v:+d}")
# Test branchless clamp
print("\nBranchless Clamp (range 0–10):")
values = [-5, 0, 5, 10, 15]
for v in values:
clamped = branchless_clamp(v, 0, 10)
print(f" clamp({v:+3d}) = {clamped}")
# Test branchless abs
print("\nBranchless Absolute Value:")
nums = [-7, 0, 5, -100]
for n in nums:
print(f" abs({n:+4d}) = {branchless_abs(n)}")
print()
print("KEY TECHNIQUE: Sorting networks")
print(" Fixed compare-and-swap sequences")
print(" → same operations run regardless of input data")
print(" → no data-dependent branches")