1+ {
2+ "cells" : [
3+ {
4+ "cell_type" : " markdown" ,
5+ "metadata" : {
6+ "id" : " view-in-github" ,
7+ "colab_type" : " text"
8+ },
9+ "source" : [
10+ " <a href=\" https://colab.research.google.com/github/ratulb/mojo_programming/blob/main/gpu_puzzles/broadcast_add_layout.ipynb\" target=\" _parent\" ><img src=\" https://colab.research.google.com/assets/colab-badge.svg\" alt=\" Open In Colab\" /></a>"
11+ ]
12+ },
13+ {
14+ "cell_type" : " code" ,
15+ "source" : [
16+ " !curl -ssL https://magic.modular.com/ | bash"
17+ ],
18+ "metadata" : {
19+ "id" : " oghVhc-plDnV"
20+ },
21+ "execution_count" : null ,
22+ "outputs" : []
23+ },
24+ {
25+ "cell_type" : " code" ,
26+ "source" : [
27+ " import os\n " ,
28+ " os.environ['PATH'] += ':/root/.modular/bin'"
29+ ],
30+ "metadata" : {
31+ "id" : " bo0LqVellMRb"
32+ },
33+ "execution_count" : null ,
34+ "outputs" : []
35+ },
36+ {
37+ "cell_type" : " code" ,
38+ "source" : [
39+ " !magic init gpu_puzzles --format mojoproject"
40+ ],
41+ "metadata" : {
42+ "id" : " orDFbNYOlmVj"
43+ },
44+ "execution_count" : null ,
45+ "outputs" : []
46+ },
47+ {
48+ "cell_type" : " code" ,
49+ "source" : [
50+ " %cd gpu_puzzles/"
51+ ],
52+ "metadata" : {
53+ "id" : " I_II8JVmluuj"
54+ },
55+ "execution_count" : null ,
56+ "outputs" : []
57+ },
58+ {
59+ "cell_type" : " code" ,
60+ "source" : [
61+ " %%writefile broadcast_add_layout.mojo\n " ,
62+ " \n " ,
63+ " ### Broadcast Addiotion\n " ,
64+ " ### Add 2 vectors\n " ,
65+ " \n " ,
66+ " from gpu import thread_idx\n " ,
67+ " from gpu.host import DeviceContext\n " ,
68+ " from layout import Layout, LayoutTensor\n " ,
69+ " from testing import assert_equal\n " ,
70+ " \n " ,
71+ " \n " ,
72+ " alias SIZE = 3\n " ,
73+ " alias dtype = DType.float32\n " ,
74+ " alias BLOCKS_PER_GRID = 1\n " ,
75+ " alias THREADS_PER_BLOCK = (3, 3)\n " ,
76+ " \n " ,
77+ " alias layout_out = Layout.row_major(SIZE, SIZE)\n " ,
78+ " alias layout_a = Layout.row_major(1, SIZE)\n " ,
79+ " alias layout_b = Layout.row_major(SIZE, 1)\n " ,
80+ " \n " ,
81+ " \n " ,
82+ " \n " ,
83+ " fn broadcast_add_layout[layout_out: Layout, layout_a: Layout, layout_b: Layout](\n " ,
84+ " out: LayoutTensor[mut=True, dtype, layout_out],\n " ,
85+ " a: LayoutTensor[mut=True, dtype, layout_a],\n " ,
86+ " b: LayoutTensor[mut=True, dtype, layout_b],\n " ,
87+ " ):\n " ,
88+ " row = thread_idx.y\n " ,
89+ " col = thread_idx.x\n " ,
90+ " if row < SIZE and col < SIZE:\n " ,
91+ " out[row, col] = a[0, row] + b[col, 0]\n " ,
92+ " \n " ,
93+ " \n " ,
94+ " fn main() raises:\n " ,
95+ " with DeviceContext() as ctx:\n " ,
96+ " out_buffer = ctx.enqueue_create_buffer[dtype](SIZE * SIZE).enqueue_fill(0)\n " ,
97+ " expected_buffer = ctx.enqueue_create_host_buffer[dtype](\n " ,
98+ " SIZE * SIZE\n " ,
99+ " ).enqueue_fill(0)\n " ,
100+ " a_buffer = ctx.enqueue_create_buffer[dtype](SIZE).enqueue_fill(0)\n " ,
101+ " b_buffer = ctx.enqueue_create_buffer[dtype](SIZE).enqueue_fill(0)\n " ,
102+ " \n " ,
103+ " with a_buffer.map_to_host() as a_buffer_host, b_buffer.map_to_host() as b_buffer_host:\n " ,
104+ " for i in range(SIZE):\n " ,
105+ " a_buffer_host[i] = i\n " ,
106+ " b_buffer_host[i] = i\n " ,
107+ " print(a_buffer)\n " ,
108+ " print(b_buffer)\n " ,
109+ " for i in range(SIZE):\n " ,
110+ " for j in range(SIZE):\n " ,
111+ " expected_buffer[i * SIZE + j] = a_buffer_host[i] + b_buffer_host[j]\n " ,
112+ " print(expected_buffer)\n " ,
113+ " \n " ,
114+ " out = LayoutTensor[mut=True, dtype, layout_out](out_buffer.unsafe_ptr())\n " ,
115+ " a = LayoutTensor[mut=True, dtype, layout_a](a_buffer.unsafe_ptr())\n " ,
116+ " b = LayoutTensor[mut=True, dtype, layout_b](b_buffer.unsafe_ptr())\n " ,
117+ " expected = LayoutTensor[mut=True, dtype, layout_out](expected_buffer.unsafe_ptr())\n " ,
118+ " \n " ,
119+ " ctx.enqueue_function[broadcast_add_layout[layout_out, layout_a, layout_b]](\n " ,
120+ " out,\n " ,
121+ " a,\n " ,
122+ " b,\n " ,
123+ " SIZE,\n " ,
124+ " grid_dim=BLOCKS_PER_GRID,\n " ,
125+ " block_dim=THREADS_PER_BLOCK,\n " ,
126+ " )\n " ,
127+ " ctx.synchronize()\n " ,
128+ " \n " ,
129+ " with out_buffer.map_to_host() as out_buffer_host:\n " ,
130+ " print(out_buffer_host)\n " ,
131+ " for i in range(SIZE):\n " ,
132+ " for j in range(SIZE):\n " ,
133+ " assert_equal(out_buffer_host[i * SIZE + j], expected_buffer[i * SIZE + j])\n "
134+ ],
135+ "metadata" : {
136+ "colab" : {
137+ "base_uri" : " https://localhost:8080/"
138+ },
139+ "id" : " r8TtOuGcmo7L" ,
140+ "outputId" : " 80443e36-bc03-42ef-a22a-4d7472efc586"
141+ },
142+ "execution_count" : 14 ,
143+ "outputs" : [
144+ {
145+ "output_type" : " stream" ,
146+ "name" : " stdout" ,
147+ "text" : [
148+ " Overwriting broadcast_add_layout.mojo\n "
149+ ]
150+ }
151+ ]
152+ },
153+ {
154+ "cell_type" : " code" ,
155+ "source" : [
156+ " !magic run mojo broadcast_add_layout.mojo"
157+ ],
158+ "metadata" : {
159+ "colab" : {
160+ "base_uri" : " https://localhost:8080/"
161+ },
162+ "id" : " 2heIJSH7lxPj" ,
163+ "outputId" : " a963ec38-d02b-4869-c069-f120c4b370ad"
164+ },
165+ "execution_count" : 15 ,
166+ "outputs" : [
167+ {
168+ "output_type" : " stream" ,
169+ "name" : " stdout" ,
170+ "text" : [
171+ " \u001b [32m⠁\u001b [0m \r \u001b [2K\u001b [32m⠁\u001b [0m activating environment \r \u001b [2K\u001b [32m⠁\u001b [0m activating environment \r \u001b [2KDeviceBuffer([0.0, 0.0, 0.0])\n " ,
172+ " DeviceBuffer([0.0, 0.0, 0.0])\n " ,
173+ " HostBuffer([0.0, 1.0, 2.0, 1.0, 2.0, 3.0, 2.0, 3.0, 4.0])\n " ,
174+ " HostBuffer([0.0, 1.0, 2.0, 1.0, 2.0, 3.0, 2.0, 3.0, 4.0])\n "
175+ ]
176+ }
177+ ]
178+ },
179+ {
180+ "cell_type" : " code" ,
181+ "source" : [
182+ " !magic run mojo format broadcast_add_layout.mojo"
183+ ],
184+ "metadata" : {
185+ "id" : " 2KeEPNK2GYKV"
186+ },
187+ "execution_count" : null ,
188+ "outputs" : []
189+ }
190+ ],
191+ "metadata" : {
192+ "colab" : {
193+ "name" : " Welcome To Colab" ,
194+ "provenance" : [],
195+ "gpuType" : " T4" ,
196+ "include_colab_link" : true
197+ },
198+ "kernelspec" : {
199+ "display_name" : " Python 3" ,
200+ "name" : " python3"
201+ },
202+ "accelerator" : " GPU"
203+ },
204+ "nbformat" : 4 ,
205+ "nbformat_minor" : 0
206+ }
0 commit comments