Skip to content

Commit de854e1

Browse files
claude[bot]github-actions[bot]claude
authored
Add challenge 85: LoRA Linear (Medium) (#222)
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent eb14609 commit de854e1

8 files changed

Lines changed: 379 additions & 0 deletions

File tree

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
<p>
2+
Implement a LoRA (Low-Rank Adaptation) linear layer forward pass. Given an input matrix
3+
<code>x</code> of shape <code>batch &times; d_in</code>, a base weight matrix <code>W</code> of
4+
shape <code>d_out &times; d_in</code>, a LoRA down-projection matrix <code>A</code> of shape
5+
<code>rank &times; d_in</code>, and a LoRA up-projection matrix <code>B</code> of shape
6+
<code>d_out &times; rank</code>, compute
7+
<code>output = x &times; W<sup>T</sup> + lora_scale &times; (x &times; A<sup>T</sup>) &times; B<sup>T</sup></code>.
8+
All tensors are <code>float32</code>.
9+
</p>
10+
11+
<svg width="680" height="200" viewBox="0 0 680 200" xmlns="http://www.w3.org/2000/svg" style="display:block; margin:20px auto;">
12+
<rect width="680" height="200" fill="#222"/>
13+
14+
<!-- x block -->
15+
<rect x="20" y="70" width="60" height="60" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.5"/>
16+
<text x="50" y="95" text-anchor="middle" fill="#ccc" font-size="13" font-family="monospace">x</text>
17+
<text x="50" y="112" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">B&times;D_in</text>
18+
19+
<!-- Arrow to W branch -->
20+
<line x1="80" y1="100" x2="110" y2="70" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>
21+
<!-- Arrow to A branch -->
22+
<line x1="80" y1="100" x2="110" y2="145" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>
23+
24+
<!-- W block -->
25+
<rect x="112" y="40" width="70" height="55" fill="#1a3a5c" stroke="#4a9eff" stroke-width="1.5"/>
26+
<text x="147" y="63" text-anchor="middle" fill="#ccc" font-size="13" font-family="monospace">W</text>
27+
<text x="147" y="80" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">D_out&times;D_in</text>
28+
29+
<!-- base output: x@W^T -->
30+
<line x1="182" y1="67" x2="225" y2="90" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>
31+
<rect x="227" y="70" width="80" height="55" fill="#1a4a2a" stroke="#4aff88" stroke-width="1.5"/>
32+
<text x="267" y="92" text-anchor="middle" fill="#ccc" font-size="11" font-family="monospace">x@W&#x1D57;</text>
33+
<text x="267" y="108" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">B&times;D_out</text>
34+
35+
<!-- A block -->
36+
<rect x="112" y="128" width="70" height="50" fill="#3a1a3a" stroke="#cc88ff" stroke-width="1.5"/>
37+
<text x="147" y="150" text-anchor="middle" fill="#ccc" font-size="13" font-family="monospace">A</text>
38+
<text x="147" y="167" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">rank&times;D_in</text>
39+
40+
<!-- hidden = x@A^T -->
41+
<line x1="182" y1="153" x2="225" y2="153" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>
42+
<rect x="227" y="130" width="60" height="45" fill="#3a1a3a" stroke="#cc88ff" stroke-width="1.5"/>
43+
<text x="257" y="152" text-anchor="middle" fill="#ccc" font-size="10" font-family="monospace">x@A&#x1D57;</text>
44+
<text x="257" y="167" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">B&times;rank</text>
45+
46+
<!-- B block -->
47+
<rect x="304" y="128" width="70" height="50" fill="#3a1a3a" stroke="#cc88ff" stroke-width="1.5"/>
48+
<text x="339" y="150" text-anchor="middle" fill="#ccc" font-size="13" font-family="monospace">B</text>
49+
<text x="339" y="167" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">D_out&times;rank</text>
50+
51+
<!-- arrow from hidden to B -->
52+
<line x1="287" y1="153" x2="302" y2="153" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>
53+
54+
<!-- delta = (x@A^T)@B^T -->
55+
<line x1="374" y1="153" x2="415" y2="120" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>
56+
<rect x="417" y="95" width="80" height="55" fill="#3a2a1a" stroke="#ffaa44" stroke-width="1.5"/>
57+
<text x="457" y="117" text-anchor="middle" fill="#ccc" font-size="10" font-family="monospace">&#x3B1;&times;(x@A&#x1D57;)@B&#x1D57;</text>
58+
<text x="457" y="133" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">B&times;D_out</text>
59+
60+
<!-- plus sign -->
61+
<line x1="307" y1="97" x2="415" y2="97" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>
62+
<text x="385" y="88" text-anchor="middle" fill="#ffaa44" font-size="20" font-family="monospace">+</text>
63+
64+
<!-- output -->
65+
<line x1="497" y1="122" x2="535" y2="122" stroke="#888" stroke-width="1.5" marker-end="url(#arr)"/>
66+
<rect x="537" y="95" width="80" height="55" fill="#1a4a2a" stroke="#4aff88" stroke-width="1.5"/>
67+
<text x="577" y="117" text-anchor="middle" fill="#ccc" font-size="12" font-family="monospace">output</text>
68+
<text x="577" y="135" text-anchor="middle" fill="#888" font-size="10" font-family="monospace">B&times;D_out</text>
69+
70+
<defs>
71+
<marker id="arr" markerWidth="6" markerHeight="6" refX="5" refY="3" orient="auto">
72+
<path d="M0,0 L6,3 L0,6 Z" fill="#888"/>
73+
</marker>
74+
</defs>
75+
</svg>
76+
77+
<h2>Implementation Requirements</h2>
78+
<ul>
79+
<li>Implement the <code>solve</code> function; do not change its signature.</li>
80+
<li>Do not use external libraries beyond those provided.</li>
81+
<li>Write the result into <code>output</code>.</li>
82+
</ul>
83+
84+
<h2>Examples</h2>
85+
<p>Example 1:</p>
86+
<p>
87+
\[
88+
x = \begin{bmatrix} 1 & 0 & -1 & 2 \\ 0 & 1 & 1 & -1 \end{bmatrix},\quad
89+
W = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \end{bmatrix},\quad
90+
A = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \end{bmatrix},\quad
91+
B = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 0 & 0 \end{bmatrix}
92+
\]
93+
</p>
94+
<p>With <code>lora_scale</code> = 0.5:</p>
95+
<p>
96+
\[
97+
\text{output} = x W^T + 0.5 \cdot (x A^T) B^T
98+
= \begin{bmatrix} 1 & 0 & -1 \\ 0 & 1 & 1 \end{bmatrix}
99+
+ 0.5 \cdot \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix} \begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \end{bmatrix}
100+
= \begin{bmatrix} 1.5 & 0 & -1 \\ 0 & 1.5 & 1 \end{bmatrix}
101+
\]
102+
</p>
103+
104+
<h2>Constraints</h2>
105+
<ul>
106+
<li>1 &le; <code>batch</code> &le; 1,024</li>
107+
<li>1 &le; <code>d_in</code>, <code>d_out</code> &le; 8,192</li>
108+
<li>1 &le; <code>rank</code> &le; 256; <code>rank</code> &lt; min(<code>d_in</code>, <code>d_out</code>)</li>
109+
<li>All tensors are <code>float32</code> on GPU.</li>
110+
<li>Performance is measured with <code>batch</code> = 256, <code>d_in</code> = 4,096, <code>d_out</code> = 4,096, <code>rank</code> = 64</li>
111+
</ul>
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import ctypes
2+
from typing import Any, Dict, List
3+
4+
import torch
5+
from core.challenge_base import ChallengeBase
6+
7+
8+
class Challenge(ChallengeBase):
9+
def __init__(self):
10+
super().__init__(
11+
name="LoRA Linear",
12+
atol=1e-04,
13+
rtol=1e-04,
14+
num_gpus=1,
15+
access_tier="free",
16+
)
17+
18+
def reference_impl(
19+
self,
20+
x: torch.Tensor,
21+
W: torch.Tensor,
22+
A: torch.Tensor,
23+
B: torch.Tensor,
24+
output: torch.Tensor,
25+
batch: int,
26+
d_in: int,
27+
d_out: int,
28+
rank: int,
29+
lora_scale: float,
30+
):
31+
assert x.shape == (batch, d_in)
32+
assert W.shape == (d_out, d_in)
33+
assert A.shape == (rank, d_in)
34+
assert B.shape == (d_out, rank)
35+
assert output.shape == (batch, d_out)
36+
assert x.dtype == W.dtype == A.dtype == B.dtype == output.dtype == torch.float32
37+
assert x.device.type == "cuda"
38+
assert W.device.type == "cuda"
39+
assert A.device.type == "cuda"
40+
assert B.device.type == "cuda"
41+
assert output.device.type == "cuda"
42+
43+
# Base linear: output = x @ W^T
44+
base = torch.mm(x, W.t())
45+
46+
# LoRA path: delta = lora_scale * (x @ A^T) @ B^T
47+
lora_hidden = torch.mm(x, A.t()) # (batch, rank)
48+
delta = torch.mm(lora_hidden, B.t()) # (batch, d_out)
49+
50+
output.copy_(base + lora_scale * delta)
51+
52+
def get_solve_signature(self) -> Dict[str, tuple]:
53+
return {
54+
"x": (ctypes.POINTER(ctypes.c_float), "in"),
55+
"W": (ctypes.POINTER(ctypes.c_float), "in"),
56+
"A": (ctypes.POINTER(ctypes.c_float), "in"),
57+
"B": (ctypes.POINTER(ctypes.c_float), "in"),
58+
"output": (ctypes.POINTER(ctypes.c_float), "out"),
59+
"batch": (ctypes.c_int, "in"),
60+
"d_in": (ctypes.c_int, "in"),
61+
"d_out": (ctypes.c_int, "in"),
62+
"rank": (ctypes.c_int, "in"),
63+
"lora_scale": (ctypes.c_float, "in"),
64+
}
65+
66+
def _make_test_case(self, batch, d_in, d_out, rank, lora_scale=0.5, zero_x=False):
67+
dtype = torch.float32
68+
device = "cuda"
69+
if zero_x:
70+
x = torch.zeros(batch, d_in, device=device, dtype=dtype)
71+
else:
72+
x = torch.randn(batch, d_in, device=device, dtype=dtype)
73+
W = torch.randn(d_out, d_in, device=device, dtype=dtype) * 0.02
74+
A = torch.randn(rank, d_in, device=device, dtype=dtype) * 0.02
75+
B = torch.zeros(d_out, rank, device=device, dtype=dtype)
76+
output = torch.zeros(batch, d_out, device=device, dtype=dtype)
77+
return {
78+
"x": x,
79+
"W": W,
80+
"A": A,
81+
"B": B,
82+
"output": output,
83+
"batch": batch,
84+
"d_in": d_in,
85+
"d_out": d_out,
86+
"rank": rank,
87+
"lora_scale": lora_scale,
88+
}
89+
90+
def generate_example_test(self) -> Dict[str, Any]:
91+
dtype = torch.float32
92+
device = "cuda"
93+
x = torch.tensor([[1.0, 0.0, -1.0, 2.0], [0.0, 1.0, 1.0, -1.0]], device=device, dtype=dtype)
94+
W = torch.tensor(
95+
[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]],
96+
device=device,
97+
dtype=dtype,
98+
)
99+
A = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], device=device, dtype=dtype)
100+
B = torch.tensor(
101+
[[1.0, 0.0], [0.0, 1.0], [0.0, 0.0]],
102+
device=device,
103+
dtype=dtype,
104+
)
105+
output = torch.zeros(2, 3, device=device, dtype=dtype)
106+
return {
107+
"x": x,
108+
"W": W,
109+
"A": A,
110+
"B": B,
111+
"output": output,
112+
"batch": 2,
113+
"d_in": 4,
114+
"d_out": 3,
115+
"rank": 2,
116+
"lora_scale": 0.5,
117+
}
118+
119+
def generate_functional_test(self) -> List[Dict[str, Any]]:
120+
torch.manual_seed(42)
121+
tests = []
122+
123+
# Edge case: batch=1, tiny dimensions
124+
tests.append(self._make_test_case(1, 4, 4, 1))
125+
126+
# Edge case: zero input
127+
tests.append(self._make_test_case(2, 8, 8, 2, zero_x=True))
128+
129+
# Edge case: rank=1 (minimum LoRA rank)
130+
tests.append(self._make_test_case(4, 16, 16, 1))
131+
132+
# Power-of-2 dimensions
133+
tests.append(self._make_test_case(16, 64, 64, 8))
134+
135+
# Power-of-2, non-square
136+
tests.append(self._make_test_case(32, 128, 64, 16))
137+
138+
# Non-power-of-2 dimensions
139+
tests.append(self._make_test_case(30, 100, 100, 4))
140+
141+
# Non-power-of-2, mixed
142+
tests.append(self._make_test_case(7, 255, 128, 8))
143+
144+
# Realistic small: LLM feed-forward style
145+
tests.append(self._make_test_case(64, 512, 512, 16, lora_scale=0.125))
146+
147+
# Negative inputs
148+
tests.append(
149+
{
150+
"x": torch.full((4, 32), -1.0, device="cuda", dtype=torch.float32),
151+
"W": torch.randn(32, 32, device="cuda", dtype=torch.float32) * 0.02,
152+
"A": torch.randn(8, 32, device="cuda", dtype=torch.float32) * 0.02,
153+
"B": torch.randn(32, 8, device="cuda", dtype=torch.float32) * 0.02,
154+
"output": torch.zeros(4, 32, device="cuda", dtype=torch.float32),
155+
"batch": 4,
156+
"d_in": 32,
157+
"d_out": 32,
158+
"rank": 8,
159+
"lora_scale": 1.0,
160+
}
161+
)
162+
163+
# Larger realistic: transformer hidden size
164+
tests.append(self._make_test_case(128, 1024, 1024, 32, lora_scale=0.0625))
165+
166+
return tests
167+
168+
def generate_performance_test(self) -> Dict[str, Any]:
169+
torch.manual_seed(0)
170+
# LLaMA-style: d_in=d_out=4096, rank=64, batch=256
171+
return self._make_test_case(256, 4096, 4096, 64, lora_scale=0.015625)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include <cuda_runtime.h>
2+
3+
// x, W, A, B, output are device pointers
4+
extern "C" void solve(const float* x, const float* W, const float* A, const float* B, float* output,
5+
int batch, int d_in, int d_out, int rank, float lora_scale) {}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import cutlass
2+
import cutlass.cute as cute
3+
4+
5+
# x, W, A, B, output are tensors on the GPU
6+
@cute.jit
7+
def solve(
8+
x: cute.Tensor,
9+
W: cute.Tensor,
10+
A: cute.Tensor,
11+
B: cute.Tensor,
12+
output: cute.Tensor,
13+
batch: cute.Int32,
14+
d_in: cute.Int32,
15+
d_out: cute.Int32,
16+
rank: cute.Int32,
17+
lora_scale: cute.Float32,
18+
):
19+
pass
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
5+
# x, W, A, B are tensors on GPU
6+
@jax.jit
7+
def solve(
8+
x: jax.Array,
9+
W: jax.Array,
10+
A: jax.Array,
11+
B: jax.Array,
12+
batch: int,
13+
d_in: int,
14+
d_out: int,
15+
rank: int,
16+
lora_scale: float,
17+
) -> jax.Array:
18+
# return output tensor directly
19+
pass
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from gpu.host import DeviceContext
2+
from memory import UnsafePointer
3+
4+
# x, W, A, B, output are device pointers
5+
@export
6+
def solve(
7+
x: UnsafePointer[Float32],
8+
W: UnsafePointer[Float32],
9+
A: UnsafePointer[Float32],
10+
B: UnsafePointer[Float32],
11+
output: UnsafePointer[Float32],
12+
batch: Int32,
13+
d_in: Int32,
14+
d_out: Int32,
15+
rank: Int32,
16+
lora_scale: Float32,
17+
):
18+
pass
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
3+
4+
# x, W, A, B, output are tensors on the GPU
5+
def solve(
6+
x: torch.Tensor,
7+
W: torch.Tensor,
8+
A: torch.Tensor,
9+
B: torch.Tensor,
10+
output: torch.Tensor,
11+
batch: int,
12+
d_in: int,
13+
d_out: int,
14+
rank: int,
15+
lora_scale: float,
16+
):
17+
pass
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
# x, W, A, B, output are tensors on the GPU
7+
def solve(
8+
x: torch.Tensor,
9+
W: torch.Tensor,
10+
A: torch.Tensor,
11+
B: torch.Tensor,
12+
output: torch.Tensor,
13+
batch: int,
14+
d_in: int,
15+
d_out: int,
16+
rank: int,
17+
lora_scale: float,
18+
):
19+
pass

0 commit comments

Comments
 (0)