@@ -270,32 +270,44 @@ def main(
270270# .. note::
271271# ``R.call_py_func`` adds runtime overhead due to the Python-TVM boundary crossing.
272272# Use it for prototyping or for ops that are not performance-critical.
273- #
274- # Here is an example using ``call_py_func`` inside a Relax function:
275- #
276- # .. code-block:: python
277- #
278- # @I.ir_module
279- # class CallPyFuncModule(BasePyModule):
280- # @I.pyfunc
281- # def my_custom_op(self, x):
282- # """Python fallback for a custom op."""
283- # return torch.sigmoid(x) * x # SiLU / Swish activation
284- #
285- # @R.function
286- # def main(x: R.Tensor((4,), "float32")) -> R.Tensor((4,), "float32"):
287- # # Call the Python function from within Relax IR
288- # result = R.call_py_func(
289- # "my_custom_op", (x,), out_sinfo=R.Tensor((4,), "float32")
290- # )
291- # return result
292- #
293- # mod = CallPyFuncModule(device=tvm.cpu(0))
294- # x = torch.tensor([1.0, -1.0, 2.0, -2.0])
295- # result = mod.main(x)
296- #
297- # The VM executes the compiled Relax bytecode, and when it hits ``call_py_func``, it looks up
298- # the registered Python function by name and calls it with DLPack-converted tensors.
273+
274+ if RUN_EXAMPLE :
275+
276+ @I .ir_module
277+ class CallPyFuncModule (BasePyModule ):
278+ @I .pyfunc
279+ def torch_relu (self , x ):
280+ """Python fallback: PyTorch ReLU."""
281+ return torch .relu (x )
282+
283+ @I .pyfunc
284+ def torch_softmax (self , x , dim = 0 ):
285+ """Python fallback: PyTorch softmax."""
286+ return torch .softmax (x , dim = dim )
287+
288+ @R .function
289+ def main (x : R .Tensor ((10 ,), "float32" )) -> R .Tensor ((10 ,), "float32" ):
290+ # The VM calls back into Python for these ops at runtime
291+ relu_result = R .call_py_func (
292+ "torch_relu" , (x ,), out_sinfo = R .Tensor ((10 ,), "float32" )
293+ )
294+ result = R .call_py_func (
295+ "torch_softmax" , (relu_result ,), out_sinfo = R .Tensor ((10 ,), "float32" )
296+ )
297+ return result
298+
299+ mod = CallPyFuncModule (device = tvm .cpu (0 ))
300+
301+ x = torch .randn (10 , dtype = torch .float32 )
302+
303+ # call_py_func can be called directly from Python as well
304+ relu_result = mod .call_py_func ("torch_relu" , [x ])
305+ result = mod .call_py_func ("torch_softmax" , [relu_result ])
306+
307+ expected = torch .softmax (torch .relu (x ), dim = 0 )
308+ print ("R.call_py_func result:" , result )
309+ print ("PyTorch expected: " , expected )
310+ assert torch .allclose (torch .tensor (result .numpy ()), expected , atol = 1e-5 )
299311
300312
301313######################################################################
0 commit comments