File tree Expand file tree Collapse file tree
folx/experimental/pallas/attention Expand file tree Collapse file tree Original file line number Diff line number Diff line change 55import jax
66import jax .numpy as jnp
77from jax .experimental import pallas as pl
8- from jax .experimental .pallas import gpu as plgpu
8+
9+ try :
10+ from jax .experimental .pallas import triton as plgpu
11+ except ImportError :
12+ from jax .experimental .pallas import gpu as plgpu
913
1014from .mhsa import mhsa_kernel , reference_mhsa_kernel
1115from .mhsea import mhsea_kernel , reference_mhsea_kernel
Original file line number Diff line number Diff line change 55import jax
66import jax .numpy as jnp
77from jax .experimental import pallas as pl
8- from jax .experimental .pallas import gpu as plgpu
8+
9+ try :
10+ from jax .experimental .pallas import triton as plgpu
11+ except ImportError :
12+ from jax .experimental .pallas import gpu as plgpu
913
1014from folx import forward_laplacian
1115from folx .api import FwdJacobian , FwdLaplArray
Original file line number Diff line number Diff line change 55import jax
66import jax .numpy as jnp
77from jax .experimental import pallas as pl
8- from jax .experimental .pallas import gpu as plgpu
8+
9+ try :
10+ from jax .experimental .pallas import triton as plgpu
11+ except ImportError :
12+ from jax .experimental .pallas import gpu as plgpu
913
1014from .utils import (
1115 big_number ,
Original file line number Diff line number Diff line change 55import jax
66import jax .numpy as jnp
77from jax .experimental import pallas as pl
8- from jax .experimental .pallas import gpu as plgpu
8+
9+ try :
10+ from jax .experimental .pallas import triton as plgpu
11+ except ImportError :
12+ from jax .experimental .pallas import gpu as plgpu
913
1014from .utils import (
1115 big_number ,
You can’t perform that action at this time.
0 commit comments