Optimizing to reproduce the Haslam map

Here we’ll briefly introduce how to optimize a grid \(\mathbf{B}\) model to produce a Haslam 408 MHz map.

[1]:
#%env XLA_PYTHON_CLIENT_PREALLOCATE=false
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform

env: XLA_PYTHON_CLIENT_ALLOCATOR=platform
[2]:
import os
import sys
import jax,blackjax
jax.config.update("jax_enable_x64", True)
#sys.path.append('../synax/')

import synax,importlib
import jax.numpy as jnp
import interpax
import healpy as hp
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
import scipy.constants as const

read observations

Remember replace the filename with your actual directory!

And we will spatial average it to get a NSIDE = 64 map.

[15]:
from astropy.io import fits

fits_image_filename = '../../SyncEmiss/obs/haslam408_dsds_Remazeilles2014.fits'

data = hp.read_map(fits_image_filename)

data = hp.reorder(data,r2n=True)

data = data.reshape((-1,64)).mean(axis=-1)

data = hp.reorder(data,n2r=True)

hp.mollview(data,norm='hist',cmap='coolwarm')
print(data.shape)
(49152,)
../_images/nb_Optimising_haslam_4_1.png
[3]:
#for debug
def reload_package(package):
    importlib.reload(package)
    for attribute_name in dir(package):
        attribute = getattr(package, attribute_name)
        if type(attribute) == type(package):
            importlib.reload(attribute)

reload_package(synax)

Set up coordinates and fields

During inference, there’re some constant values such as the coordinates, \(n_e\) fields and \(C\) fields. We can pre calculate them to save some time during inference.

First let’s generate the coordinates.

[4]:
nside = 64
num_int_points = 512

poss,dls,nhats = synax.coords.get_healpix_positions(nside=nside,num_int_points=num_int_points)
plt.scatter(poss[0,::10,500],poss[1,::10,500],alpha=0.05)
poss.shape
[4]:
(3, 49152, 512)
../_images/nb_Optimising_haslam_7_1.png

Then generate C field

[5]:
C_generator = synax.cfield.C_WMAP(poss)

C_field = C_generator.C_field()

Then grid \(n_e\) field, read the grids in and construct generator, then interpolate.

[6]:
nx,ny,nz = 256,256,64

xs,step = jnp.linspace(-20,20,nx,endpoint=False,retstep=True)
xs = xs + step*0.5

ys,step = jnp.linspace(-20,20,ny,endpoint=False,retstep=True)
ys = ys + step*0.5

zs,step = jnp.linspace(-5,5,nz,endpoint=False,retstep=True)
zs = zs + step*0.5

coords = jnp.meshgrid(xs,ys,zs,indexing='ij')
coords[0].shape
[6]:
(256, 256, 64)
[7]:

tereg = np.load('te.npy')# read it in.
[8]:
TE_generator = synax.tefield.TE_grid(poss,(xs,ys,zs))# set up generator
TE_field = TE_generator.TE_field(tereg)# do interpolation

Carry out optimization

Now let’s optimize. The B field generator is also not a variable during inference, we can define one in advance.

First we need to set up coordinate system for this grid \(\mathbf{B}\) field.

Then call synax to get the B_generator

[9]:
nx,ny,nz = 128,128,32

xs,step = jnp.linspace(-20,20,nx,endpoint=False,retstep=True)
xs = xs + step*0.5

ys,step = jnp.linspace(-20,20,ny,endpoint=False,retstep=True)
ys = ys + step*0.5

zs,step = jnp.linspace(-5,5,nz,endpoint=False,retstep=True)
zs = zs + step*0.5

coords = jnp.meshgrid(xs,ys,zs,indexing='ij')
coords[0].shape

B_generator = synax.bfield.B_grid(poss,(xs,ys,zs))
B_field = jnp.ones((nx,ny,nz,3))*1e-6
B_generator.B_field(B_field).shape
[9]:
(49152, 512, 3)

set up some constants, including frequency, simer and spectral index.

[10]:
from datetime import date
rng_key = jax.random.key(42)
simer = synax.synax.Synax(sim_I = True,sim_P=False)
freq = 0.408
spectral_index = 3.

Here our loss function is simply MSE.

[11]:
@jax.jit
def grid_model(B_field_grid,freq):
    # this function generates the sync I map with a grid B field
    B_field = B_generator.B_field(B_field_grid)

    sync = simer.sim(freq,B_field,C_field,TE_field,nhats,dls,spectral_index)

    return sync['I']


def logdensity_fn(B_field):
    # this function calculates the loss function with a grid B field
    Sync_I = grid_model(B_field,freq)

    return -1*jnp.sum(((Sync_I-data))**2)


B_field = (np.random.randn(nx,ny,nz,3)*1e-6)*jnp.array([1.,1.,0.1])
Sync_I = grid_model(B_field,freq)
[12]:
hp.mollview(Sync_I,norm='hist',cmap='coolwarm')
../_images/nb_Optimising_haslam_20_0.png
[13]:
logp_grad = jax.value_and_grad(lambda x:-1*logdensity_fn(x))

Now let’s optimize! here we use optax optimizer.

[20]:
import optax
from tqdm import tqdm

B_field = (np.random.randn(*(nx,ny,nz,3))*1e-6)*jnp.array([1.,1.,0.1])

B_opt = B_field

solver = optax.yogi(learning_rate=1e-6)

opt_state = solver.init(B_opt)# initialize optimizer

loss = []

#create a mask
ones_field = np.ones((nx,ny,nz))

mask = ((coords[0]**2+coords[1]**2)>400)|((coords[0]**2+coords[1]**2)<9)# add a mask to mask inner region and outer region

ones_field[mask] = 1e-10 # can't be 0, 0 would cause polarization angle NaN problem.

mask = jnp.array(ones_field)[:,:,:,jnp.newaxis]

B_opt = B_opt*mask

progress_bar = tqdm(range(200))
for i in progress_bar:
    value,grad = logp_grad(B_opt)
    if jnp.isnan(value):
        break
    loss.append(value)
    updates, opt_state = solver.update(grad, opt_state, B_opt)
    B_opt = optax.apply_updates(B_opt, updates)
    B_opt = B_opt*mask

    info = { 'loss': loss[-1]}

    # Update the postfix with the current info
    progress_bar.set_postfix(info)
100%|██████████| 200/200 [01:11<00:00,  2.79it/s, loss=4319.830813641992]

visualize the results

First let’s take a look at the 3 component of the B field at z = 0 kpc.

[21]:
plt.figure(dpi=200,figsize=(12,3))

plt.subplot(131)
plt.imshow(B_opt[:,:,16,0])
plt.title(r'$B_x$')
plt.colorbar()

plt.subplot(132)
plt.imshow(B_opt[:,:,16,1])
plt.title(r'$B_y$')
plt.colorbar()

plt.subplot(133)
plt.imshow(B_opt[:,:,16,2])
plt.title(r'$B_z$')
plt.colorbar()
[21]:
<matplotlib.colorbar.Colorbar at 0x7f5768696d50>
../_images/nb_Optimising_haslam_25_1.png

Re-simulate with this optimized \(\mathcal{B}\) field.

[22]:
Sync_I = grid_model(B_opt,freq)

Looks pretty! We did reproduce the Haslam map. However, the degree of freedom is too large to be constrained by a single map. Thus, this optimized B field is not really close to the truth.

But with more constraints we can get a better B field in the future.

[23]:
plt.figure(dpi=200,figsize=(10,3),)
np.random.seed(42)
plt.subplot(131)
hp.mollview(data,format='%.2g',norm='hist',cmap='coolwarm',hold=True,title='Haslam Map')

plt.subplot(132)
hp.mollview(Sync_I,format='%.2g',norm='hist',cmap='coolwarm',hold=True,title='Optimized Synax I')

plt.subplot(133)
hp.mollview(Sync_I-data,format='%.2g',cmap='coolwarm',hold=True,title='Residuals',max=1,min=-1)


#plt.savefig("../figures/haslam_opt.pdf",bbox_inches='tight',dpi=200)
../_images/nb_Optimising_haslam_29_0.png

Let’s see the optimized B magnitude. Seems too high, ~10 times of the current estimation.

[26]:
plt.figure(dpi=200,figsize=(7,3),)

plt.subplot(122)
plt.imshow(((B_opt**2).sum(axis=-1)**(1/2))[:,:,16],vmax=3e-5, vmin=0)
plt.colorbar(label='gauss')
plt.title(r'$\Vert B_{opt} \Vert$')

plt.subplot(121)
plt.hist(Sync_I-data,bins=np.linspace(-1,1,100))
plt.title('Residuals')
plt.tight_layout()
#plt.savefig("../figures/haslam_B.pdf",bbox_inches='tight',dpi=200)
../_images/nb_Optimising_haslam_31_0.png
[25]:
np.std(Sync_I-data)
[25]:
Array(0.29502564, dtype=float64)
[ ]: