Skip to content

Commit d90354b

Browse files
committed
finish2
1 parent 7a08cd2 commit d90354b

1 file changed

Lines changed: 38 additions & 26 deletions

File tree

docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -270,32 +270,44 @@ def main(
270270
# .. note::
271271
# ``R.call_py_func`` adds runtime overhead due to the Python-TVM boundary crossing.
272272
# Use it for prototyping or for ops that are not performance-critical.
273-
#
274-
# Here is an example using ``call_py_func`` inside a Relax function:
275-
#
276-
# .. code-block:: python
277-
#
278-
# @I.ir_module
279-
# class CallPyFuncModule(BasePyModule):
280-
# @I.pyfunc
281-
# def my_custom_op(self, x):
282-
# """Python fallback for a custom op."""
283-
# return torch.sigmoid(x) * x # SiLU / Swish activation
284-
#
285-
# @R.function
286-
# def main(x: R.Tensor((4,), "float32")) -> R.Tensor((4,), "float32"):
287-
# # Call the Python function from within Relax IR
288-
# result = R.call_py_func(
289-
# "my_custom_op", (x,), out_sinfo=R.Tensor((4,), "float32")
290-
# )
291-
# return result
292-
#
293-
# mod = CallPyFuncModule(device=tvm.cpu(0))
294-
# x = torch.tensor([1.0, -1.0, 2.0, -2.0])
295-
# result = mod.main(x)
296-
#
297-
# The VM executes the compiled Relax bytecode, and when it hits ``call_py_func``, it looks up
298-
# the registered Python function by name and calls it with DLPack-converted tensors.
273+
274+
if RUN_EXAMPLE:
275+
276+
@I.ir_module
277+
class CallPyFuncModule(BasePyModule):
278+
@I.pyfunc
279+
def torch_relu(self, x):
280+
"""Python fallback: PyTorch ReLU."""
281+
return torch.relu(x)
282+
283+
@I.pyfunc
284+
def torch_softmax(self, x, dim=0):
285+
"""Python fallback: PyTorch softmax."""
286+
return torch.softmax(x, dim=dim)
287+
288+
@R.function
289+
def main(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), "float32"):
290+
# The VM calls back into Python for these ops at runtime
291+
relu_result = R.call_py_func(
292+
"torch_relu", (x,), out_sinfo=R.Tensor((10,), "float32")
293+
)
294+
result = R.call_py_func(
295+
"torch_softmax", (relu_result,), out_sinfo=R.Tensor((10,), "float32")
296+
)
297+
return result
298+
299+
mod = CallPyFuncModule(device=tvm.cpu(0))
300+
301+
x = torch.randn(10, dtype=torch.float32)
302+
303+
# call_py_func can be called directly from Python as well
304+
relu_result = mod.call_py_func("torch_relu", [x])
305+
result = mod.call_py_func("torch_softmax", [relu_result])
306+
307+
expected = torch.softmax(torch.relu(x), dim=0)
308+
print("R.call_py_func result:", result)
309+
print("PyTorch expected: ", expected)
310+
assert torch.allclose(torch.tensor(result.numpy()), expected, atol=1e-5)
299311

300312

301313
######################################################################

0 commit comments

Comments
 (0)