Sampling with NUTS¶
Here we demonstrate how to sample with NUTS since we have access to the gradient.
The problem is suppose we have a mock observation, and the thermal electron field \(n_e\) and cosmic ray electron distribution \(C\) are all known. The \(\mathbf{B}\) field is LSA model, but we do not know the correct parameter.
In this case, we will perform No-U-Turn Sampler (NUTS) to obtain the posterier for these parameter.
In this example we use blackjax, but there’re more packages compatible with JAX ecosystem such as numpyro and pyMC3.
[1]:
#%env XLA_PYTHON_CLIENT_PREALLOCATE=false
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform
env: XLA_PYTHON_CLIENT_ALLOCATOR=platform
[2]:
import os
import sys
import jax,blackjax
jax.config.update("jax_enable_x64", True)
#sys.path.append('../synax/')
import synax,importlib
import jax.numpy as jnp
import interpax
import healpy as hp
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
import scipy.constants as const
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
[3]:
#for debug
def reload_package(package):
importlib.reload(package)
for attribute_name in dir(package):
attribute = getattr(package, attribute_name)
if type(attribute) == type(package):
importlib.reload(attribute)
reload_package(synax)
Construct mock observation¶
We construct a mock observation by adding a gaussian noise with std. 0.001 K to the simulated signal. This simulated signal is provided in examples/ in the repo.
[4]:
obs_maps = np.load('Sim_lsa.npy')
np.random.seed(42)
obs_Q = obs_maps[0] + np.random.randn(obs_maps[0].shape[0])*0.001
obs_U = obs_maps[1] + np.random.randn(obs_maps[0].shape[0])*0.001
hp.mollview(obs_Q,norm='hist',cmap='coolwarm')
Set up coordinates and fields¶
During inference, there’re some constant values such as the coordinates, \(n_e\) fields and \(C\) fields. We can pre calculate them to save some time during inference.
First let’s generate the coordinates.
[5]:
nside = 64
num_int_points = 512
poss,dls,nhats = synax.coords.get_healpix_positions(nside=nside,num_int_points=num_int_points)
plt.scatter(poss[0,::10,500],poss[1,::10,500],alpha=0.05)
poss.shape
2024-09-05 21:15:20.066599: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
[5]:
(3, 49152, 512)
Then generate C field, using the true model (WMAP model)
[6]:
C_generator = synax.cfield.C_WMAP(poss)
C_field = C_generator.C_field()
Our \(n_e\) field here is intepolated from a regular 3D field. This field is also provided in examples/
[7]:
nx,ny,nz = 256,256,64
xs,step = jnp.linspace(-20,20,nx,endpoint=False,retstep=True)
xs = xs + step*0.5
ys,step = jnp.linspace(-20,20,ny,endpoint=False,retstep=True)
ys = ys + step*0.5
zs,step = jnp.linspace(-5,5,nz,endpoint=False,retstep=True)
zs = zs + step*0.5
coords = jnp.meshgrid(xs,ys,zs,indexing='ij')
coords[0].shape
[7]:
(256, 256, 64)
[8]:
tereg = np.load('te.npy')
[9]:
plt.imshow(tereg[:,:,32],vmax=0.2)
plt.colorbar()
[9]:
<matplotlib.colorbar.Colorbar at 0x7f1a9c1ae290>
[10]:
TE_generator = synax.tefield.TE_grid(poss,(xs,ys,zs))
[11]:
%%time
TE_field = TE_generator.TE_field(tereg)
TE_field.shape
CPU times: user 3.96 s, sys: 9.45 s, total: 13.4 s
Wall time: 13.4 s
[11]:
(49152, 512)
Carry out sampling¶
Now let’s sample. The B field generator is also not a variable during inference, we can define one in advance.
[12]:
lsa_params = {"b0":1.2,
"psi0":27.0*np.pi/180,
"psi1":0.9*np.pi/180,
"chi0":25.0*np.pi/180}
B_generator = synax.bfield.B_lsa(poss)
B_field = B_generator.B_field(lsa_params)
[13]:
from datetime import date
rng_key = jax.random.key(42)
simer = synax.synax.Synax(sim_I = False)
freq = 2.4
spectral_index = 3.
Here we construct a likelihood. We first sim a map, then the likelihood reads,
We ignored this constant as it is not important.
[14]:
@jax.jit
def lsa_model(lsa_params):
B_field = B_generator.B_field(lsa_params)
sync = simer.sim(freq,B_field,C_field,TE_field,nhats,dls,spectral_index)
return sync['Q'],sync['U']
@jax.jit
def logdensity_fn(lsa_params):
Sync_Q,Sync_U = lsa_model(lsa_params)
return -1*jnp.sum(((Sync_Q-obs_Q)/0.001/2)**2) - jnp.sum(((Sync_U-obs_U)/0.001/2)**2)
value_grad = jax.value_and_grad(logdensity_fn)
I randomly choose a initial position.
[33]:
initial_position = {"b0":1.2,
"psi0":45.0*np.pi/180,
"psi1":45*np.pi/180,
"chi0":45.0*np.pi/180}
%time value_grad(initial_position)
%time logdensity_fn(lsa_params)
CPU times: user 54.3 ms, sys: 43.7 ms, total: 97.9 ms
Wall time: 94.5 ms
CPU times: user 13 ms, sys: 4.19 ms, total: 17.2 ms
Wall time: 15.8 ms
[33]:
Array(-24627.87520317, dtype=float64)
Some blackjax stuff. blackjax.window_adaptation will tune the parameters for NUTS. Let’s run the tuning for 2000 iters.
[16]:
warmup = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
rng_key, warmup_key, sample_key = jax.random.split(rng_key, 3)
%time (state, parameters), _ = warmup.run(warmup_key, initial_position, num_steps=200)
CPU times: user 3min 42s, sys: 28.4 s, total: 4min 11s
Wall time: 4min 5s
[17]:
state, parameters
[17]:
(HMCState(position={'b0': Array(1.20070619, dtype=float64, weak_type=True), 'chi0': Array(0.43587975, dtype=float64, weak_type=True), 'psi0': Array(0.47015432, dtype=float64, weak_type=True), 'psi1': Array(0.01567667, dtype=float64, weak_type=True)}, logdensity=Array(-24631.68532029, dtype=float64), logdensity_grad={'b0': Array(-6944.07374439, dtype=float64, weak_type=True), 'chi0': Array(469.31222052, dtype=float64, weak_type=True), 'psi0': Array(6582.81530091, dtype=float64, weak_type=True), 'psi1': Array(-1910.32828168, dtype=float64, weak_type=True)}),
{'step_size': Array(0.02905872, dtype=float64, weak_type=True),
'inverse_mass_matrix': Array([9.09935240e-05, 9.21521877e-05, 9.10850627e-05, 9.25374521e-05], dtype=float64)})
Copied from blackjax document: the sampling function. Let’s sample it for 500 iterations.
[24]:
def inference_loop(rng_key, kernel, initial_state, num_samples):
@jax.jit
def one_step(state, rng_key):
state, _ = kernel(rng_key, state)
return state, state
keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, keys)
return states
kernel = blackjax.nuts(logdensity_fn, **parameters,max_num_doublings = 10).step
%time states = inference_loop(sample_key, kernel, state, 500)
mcmc_samples = states.position
CPU times: user 5min 48s, sys: 32.2 s, total: 6min 20s
Wall time: 6min 14s
Trace plot, looks nice.
[25]:
plt.plot(mcmc_samples["chi0"])
[25]:
[<matplotlib.lines.Line2D at 0x7f19e048d750>]
Post analysis¶
First let’s draw the posterior. Apparently we get the true answer.
[21]:
from getdist import plots, MCSamples
names = ["b0","psi0","psi1","chi0"]
labels = [r"b_0",r"\psi_0",r"\psi_1",r"\chi_{0,B}"]
samps = np.array([mcmc_samples[names[i]] for i in range(4)]).T
samples = MCSamples(samples=samps,names = names, labels = labels)
Removed no burn in
[29]:
%matplotlib inline
true_params = {"b0":1.2,
"psi0":27.0*np.pi/180,
"psi1":0.9*np.pi/180,
"chi0":25.0*np.pi/180}
g = plots.get_subplot_plotter(rc_sizes=20)
g.settings.lab_fontsize = 16
g.settings.axes_fontsize = 16
g.triangle_plot([samples],names, filled=True,markers=true_params)
plt.suptitle('Noise RMS = 1mK',x=0.88,y=0.88,ha='right',va='top',fontsize = 24)
#plt.suptitle('Noise RMS = 1mK',va = 'bottom',fontsize = 20);
#plt.savefig('../figures/posterior_lsa_1mk.pdf',dpi=500,bbox_inches='tight')
[29]:
Text(0.88, 0.88, 'Noise RMS = 1mK')
This is the \(\hat{r}\) to verify the convergence. We’re very converged!
[22]:
[blackjax.diagnostics.potential_scale_reduction(samps.T[i].reshape([4,-1])) for i in range(4)]
[22]:
[Array(1.00035458, dtype=float64),
Array(1.00266913, dtype=float64),
Array(1.00216394, dtype=float64),
Array(1.00098933, dtype=float64)]
ESS also have the same order of magnitude as sampling iterations.
[23]:
[blackjax.diagnostics.effective_sample_size(samps.T[i].reshape([4,-1])) for i in range(4)]
[23]:
[Array(683.00025544, dtype=float64),
Array(447.13332335, dtype=float64),
Array(276.4008167, dtype=float64),
Array(193.83849754, dtype=float64)]
Here is the std and accuracy for each parameter.
[25]:
for name in names:
acc = np.mean(mcmc_samples[name])/true_params[name]
print(name+" acc:" + str(1-acc))
acc = np.mean(mcmc_samples[name])
print(name+" mean:" + str(acc))
acc = np.percentile(mcmc_samples[name],[16,84]) - np.mean(mcmc_samples[name])
print(name+" std:" + str(acc[0])+", "+str(acc[1]))
b0 acc:-0.0001738643826538766
b0 mean:1.2002086372591845
b0 std:-0.0003372701586705773, 0.0003392296976092446
psi0 acc:0.0011506644537284672
psi0 mean:0.4706966601892819
psi0 std:-0.0004436898594020744, 0.0003889098435341798
psi1 acc:0.06484060577279827
psi1 mean:0.014689449414198292
psi1 std:-0.0011748403015462167, 0.001154116877975028
chi0 acc:-0.0008735126447161345
chi0 mean:0.43671345479128487
chi0 std:-0.001444809572399064, 0.0014276393331947301
[ ]: