Skip to content

Commit 0189332

Browse files
committed
Mext try
1 parent b3cb776 commit 0189332

3 files changed

Lines changed: 63 additions & 20 deletions

File tree

pyrecest/distributions/abstract_manifold_specific_distribution.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ def sample_metropolis_hastings_jax(
156156
n: number of samples to return (after burn-in and thinning)
157157
"""
158158
import jax.numpy as _jnp
159-
from jax import _random, _lax
159+
from jax import lax as _lax
160+
from jax import random as _random
160161

161162

162163
start_point = _jnp.asarray(start_point)

pyrecest/distributions/hypersphere_subset/abstract_hyperhemispherical_distribution.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -69,28 +69,39 @@ def sample_metropolis_hastings(
6969
def proposal(_):
7070
return HyperhemisphericalUniformDistribution(self.dim).sample(1)
7171
else:
72+
# JAX backend: proposal(key, x) -> x_prop
7273
import jax as _jax
7374
import jax.numpy as _jnp
74-
def proposal(key: _jax.Array, _) -> tuple[_jax.Array, _jnp.ndarray]:
75-
# Sample on full sphere
76-
if self.dim == 2:
77-
key, key_phi = _jnp.random.split(key)
78-
key, key_sz = _jnp.random.split(key)
79-
80-
phi = 2.0 * pi * _jnp.random.uniform(key_phi, shape=(1,))
81-
sz = 2.0 * _jnp.random.uniform(key_sz, shape=(1,)) - 1.0
82-
r = sqrt(1.0 - sz**2)
8375

84-
s = stack([r * cos(phi), r * sin(phi), sz], axis=1)
76+
def proposal(key, _):
77+
"""JAX independence proposal: uniform on upper hemisphere."""
78+
if self.dim == 2:
79+
# Explicit S² sampling
80+
key, key_phi = _jax.random.split(key)
81+
key, key_sz = _jax.random.split(key)
82+
83+
phi = 2.0 * _jnp.pi * _jax.random.uniform(key_phi, shape=(1,))
84+
sz = 2.0 * _jax.random.uniform(key_sz, shape=(1,)) - 1.0
85+
r = _jnp.sqrt(1.0 - sz**2)
86+
87+
# Shape (1, 3)
88+
s = _jnp.stack(
89+
[r * _jnp.cos(phi), r * _jnp.sin(phi), sz],
90+
axis=1,
91+
)
8592
else:
86-
key, subkey = random.split(key)
87-
samples_unnorm = random.normal(subkey, shape=(1, self.dim + 1))
88-
norms = linalg.norm(samples_unnorm, axis=1, keepdims=True)
93+
# General S^d: sample N(0, I) in R^{d+1} and normalize
94+
key, subkey = _jax.random.split(key)
95+
samples_unnorm = _jax.random.normal(subkey, shape=(1, self.dim + 1))
96+
norms = _jnp.linalg.norm(samples_unnorm, axis=1, keepdims=True)
8997
s = samples_unnorm / norms
9098

91-
# To upper hemisphere
92-
s = (1 - 2 * (s[-1, :] < 0)) * s
93-
return key, s
99+
# Project to upper hemisphere: last coordinate >= 0
100+
# s shape: (1, dim+1); last coord is s[..., -1:]
101+
sign = _jnp.where(s[..., -1:] < 0.0, -1.0, 1.0)
102+
s = sign * s
103+
104+
return s
94105

95106

96107
if start_point is None:

pyrecest/distributions/hypersphere_subset/abstract_hyperspherical_distribution.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
sin,
2222
vstack,
2323
zeros,
24+
sqrt,
25+
stack,
26+
linalg,
2427
)
2528
from scipy.optimize import minimize
2629

@@ -74,9 +77,37 @@ def sample_metropolis_hastings(
7477
HypersphericalUniformDistribution,
7578
)
7679

77-
def proposal(_):
78-
return HypersphericalUniformDistribution(self.dim).sample(1)
79-
80+
if pyrecest.backend.__backend_name__ in ("numpy", "pytorch"):
81+
def proposal(_):
82+
return HypersphericalUniformDistribution(self.dim).sample(1)
83+
else:
84+
import jax as _jax
85+
import jax.numpy as _jnp
86+
def proposal(key, _):
87+
"""JAX independence proposal: uniform on hypersphere."""
88+
if self.dim == 2:
89+
# Explicit S² sampling
90+
key, key_phi = _jax.random.split(key)
91+
key, key_sz = _jax.random.split(key)
92+
93+
phi = 2.0 * _jnp.pi * _jax.random.uniform(key_phi, shape=(1,))
94+
sz = 2.0 * _jax.random.uniform(key_sz, shape=(1,)) - 1.0
95+
r = _jnp.sqrt(1.0 - sz**2)
96+
97+
# Shape (1, 3)
98+
s = _jnp.stack(
99+
[r * _jnp.cos(phi), r * _jnp.sin(phi), sz],
100+
axis=1,
101+
)
102+
else:
103+
# General S^d: sample N(0, I) in R^{d+1} and normalize
104+
key, subkey = _jax.random.split(key)
105+
samples_unnorm = _jax.random.normal(subkey, shape=(1, self.dim + 1))
106+
norms = _jnp.linalg.norm(samples_unnorm, axis=1, keepdims=True)
107+
s = samples_unnorm / norms
108+
109+
return s
110+
80111
if start_point is None:
81112
start_point = HypersphericalUniformDistribution(self.dim).sample(1)
82113
# Call the sample_metropolis_hastings method of AbstractDistribution

0 commit comments

Comments
 (0)