Source code for synax

# synax/synax.py

import jax.numpy as jnp
import jax
import numpy as np

from functools import partial
import scipy.constants as const
from typing import List, Tuple, Union,Dict
jax.config.update("jax_enable_x64", True)

q_converter = 1/(4*np.pi*const.epsilon_0)**0.5

B_converter = (4*np.pi/const.mu_0)**0.5

freq_irrelavent_const = (const.e*q_converter)**3/(const.electron_mass*const.speed_of_light**2)*(np.sqrt(3)/(8*np.pi))*1e19 # moves kpc = 1e16 km here.

elect_combi = 2/3*const.electron_mass*const.speed_of_light/(const.e*q_converter)

kpc = 3.08567758

temp_covert = (const.hbar*1e9)/(const.Boltzmann*2.725)

rm_freq_irrelavent_const = (const.e*q_converter)**3/(const.electron_mass**2*const.speed_of_light**4)/(2*np.pi)*1e6*1e-4*B_converter*1e19*3.08 # moves 1/cm^3 = 1e6 1/m^3 1 gauss = 1e-4 tesla 1 kpc = 3.08e19 m here.

#return _t*(np.exp(p)-1.)**2/(p**2*np.exp(p))

@jax.jit
def sync_I_const(freq,spectral_index: float=3.):
    """
    calculating the constant irrelavent to b_perp and C in the synchrotron emissivity.

    Args:
       freq (float): frequency to be computed. In GHz.
       spectral_index (float or jax.Array): spectrum of cosmic ray electron spectrum.

    Returns:
       (jax.Array): parallel emissivity constant for the synchrotron emission.
    """
    
    gamma_func_1 = jax.scipy.special.gamma(spectral_index/4.-1/12.)
    
    gamma_func_2_process = (2e-4*B_converter)**(spectral_index/2.+0.5)/(spectral_index+1)*jax.scipy.special.gamma(spectral_index/4+19/12.)# the transition from micro-Gauss to tesla is here.
    
    omega = 2*jnp.pi*freq*1e9
    
    freq_irrelavent = freq_irrelavent_const/(2*const.Boltzmann*freq**2*1e18/(const.speed_of_light**2))
    
    consts = freq_irrelavent*(omega*elect_combi)**(0.5-spectral_index/2)*gamma_func_1*gamma_func_2_process
    
    #p = freq*temp_covert
    
    return consts*kpc#*(jnp.exp(p)-1.)**2/(p**2*jnp.exp(p))

@jax.jit
def sync_P_const(freq,spectral_index: float=3.):
    """
    calculating the constant irrelavent to b_perp and C in the polarized synchrotron emissivity.

    Args:
       freq (float): frequency to be computed. In GHz.
       spectral_index (float or jax.Array): spectrum of cosmic ray electron spectrum.

    Returns:
       (jax.Array): perpenndicular emissivity constant for the polarized synchrotron emission.
    """
    
    gamma_func_1 = jax.scipy.special.gamma(spectral_index/4.-1/12.)
    
    gamma_func_2_process = (2e-4*B_converter)**(spectral_index/2.+0.5)/(4.)*jax.scipy.special.gamma(spectral_index/4+7/12.)# the transition from micro-Gauss to tesla is here.
    
    omega = 2*jnp.pi*freq*1e9
    
    freq_irrelavent = freq_irrelavent_const/(2*const.Boltzmann*freq**2*1e18/(const.speed_of_light**2))
    
    consts = freq_irrelavent*(omega*elect_combi)**(0.5-spectral_index/2)*gamma_func_1*gamma_func_2_process
    
    #p = freq*temp_covert
    
    return consts*kpc#*(jnp.exp(p)-1.)**2/(p**2*jnp.exp(p))

@jax.jit
def sync_emiss_I(freq:float, b_perp: jax.Array,C:jax.Array,spectral_index: float=3.):
    """
    Calculating the  synchrotron emissivity.

    Args:
       freq (float): frequency to be computed. In GHz.
       b_perp (jax.Array): 3D magnetic field ($B_t$) perpendicular to the LOS.
       C (jax.Array): 3D field, defined by $N(\gamma)d\gamma = C\gamma^{-p}d\gamma$. Varied at different locations.
       spectral_index (float): spectrum of cosmic ray electron spectrum.

    Returns:
       (jax.Array): parallel emissivity for the synchrotron emission.
    """
    

    return b_perp**(0.5+spectral_index*0.5)*C*sync_I_const(freq,spectral_index=spectral_index)

@jax.jit
def sync_emiss_P(freq:float, b_perp: jax.Array,C:jax.Array,spectral_index: float=3.):
    """
    Calculating the polarized synchrotron emissivity.

    Args:
       freq (float): frequency to be computed. In GHz.
       b_perp (jax.Array): 3D magnetic field ($B_t$) perpendicular to the LOS.
       C (jax.Array): 3D field, defined by $N(\gamma)d\gamma = C\gamma^{-p}d\gamma$. Varied at different locations.
       spectral_index (float): spectrum of cosmic ray electron spectrum.
    Returns:
       (jax.Array): perpendicular emissivity for the polarized synchrotron emission.
    """
    

    return b_perp**(0.5+spectral_index*0.5)*C*sync_P_const(freq,spectral_index=spectral_index)



[docs] class Synax(): """ Synax simulator Args: sim_I (bool): whether sim synchrotron intensity. sim_P (bool): whether sim polarized synchrotron intensity. Returns: A instance of Synax simulator. """ def __init__(self, sim_I = True,sim_P = True): self.sim_I = sim_I self.sim_P = sim_P @staticmethod @jax.jit def RM(freq,B_field,TE_field,nhats,dls,B_los): """ Calculate rotation measure. Args: freq (float): frequency to be computed. In GHz. B_field (jax.Array): 3D magnetic field ($B_t$) perpendicular to the LOS at different places. TE_field (jax.Array): 3D electron density field, defined by $N(\gamma)d\gamma = C\gamma^{-p}d\gamma$. Varied at different locations. nhats (jnp.Array): In unit of rad. unit vector of different LoS. dls (jnp.float): In unit of kpc. length of each integration segment for every LoS. B_los (jax.Array): LoS B-field magnitude. Returns: tuple: - fd (jnp.Array): rotation measure for each LoS integration point. - fd_q (jnp.Array): cos(2*polarized angle), for Q map calculation. - fd_u (jnp.Array): sin(2*polarized angle), for U map calculation. """ phis = rm_freq_irrelavent_const*TE_field*B_los sinb = nhats[...,2] cosb = jnp.sqrt(1-sinb**2) cosl = nhats[...,0]/cosb sinl = nhats[...,1]/cosb Bz = B_field[...,2] By = B_field[...,1] Bx = B_field[...,0] tanchi0 = (Bz*cosb[:,jnp.newaxis]-sinb[:,jnp.newaxis]*(cosl[:,jnp.newaxis]*Bx+By*sinl[:,jnp.newaxis]))/(Bx*sinl[:,jnp.newaxis]-By*cosl[:,jnp.newaxis]+1e-16) chi0 = jnp.arctan(tanchi0) phi_int = jnp.cumsum(phis,axis=1)*dls[:,jnp.newaxis] fd = phi_int*const.c**2/(freq**2*1e18) fd_q = jnp.cos(2*fd+2*chi0) fd_u = jnp.sin(2*fd+2*chi0) return fd,fd_q,fd_u @staticmethod @jax.jit def B_los(B_field,nhats): """ Calculate LoS B-field. Args: B_field (jax.Array): 3D magnetic field ($B_t$) perpendicular to the LOS at different places. nhats (jnp.Array): In unit of rad. unit vector of different LoS. Returns: (jnp.Array): LoS B-field magnitude. """ return -1*((nhats[:,jnp.newaxis,:]*B_field)).sum(axis=-1) @partial(jax.jit, static_argnums=(0,)) def sim(self,freq,B_field,C_field,TE_field,nhats,dls,spectral_index): """ Calculate sychrotron map. Args: freq (float): frequency to be computed. In GHz. B_field (jax.Array): 3D magnetic field ($B_t$) perpendicular to the LOS at different places. C_field (jax.Array): 3D field, defined by $N(\gamma)d\gamma = C\gamma^{-p}d\gamma$. Varied at different locations. TE_field (jax.Array): 3D electron density field, defined by $N(\gamma)d\gamma = C\gamma^{-p}d\gamma$. Varied at different locations. nhats (jnp.Array): In unit of rad. unit vector of different LoS. dls (jnp.float): In unit of kpc. length of each integration segment for every LoS. spectral_index (float): spectrum of cosmic ray electron spectrum. Returns: dict: - dict['I'](jnp.Array): Sychrotron I map. return 0 if ``sim_I=False`` - dict['Q'](jnp.Array): Sychrotron Q map. return 0 if ``sim_P=False`` - dict['U'](jnp.Array): Sychrotron U map. return 0 if ``sim_P=False`` """ B_los = self.B_los(B_field,nhats) B_trans = ((B_field**2).sum(axis=-1)-B_los**2)**0.5 Sync_I = 0. Sync_Q = 0. Sync_U = 0. if self.sim_I: emiss = sync_emiss_I(freq,B_trans,C_field,spectral_index=spectral_index) Sync_I = emiss.sum(axis=-1)*dls if self.sim_P: fd,fd_q,fd_u = self.RM(freq,B_field,TE_field,nhats,dls,B_los) emiss = sync_emiss_P(freq,B_trans,C_field,spectral_index=spectral_index) Sync_Q = (emiss*fd_q).sum(axis=-1)*dls Sync_U = (emiss*fd_u).sum(axis=-1)*dls return {'I':Sync_I,'Q':Sync_Q,'U':Sync_U} def __str__(self): """ String representation of the instance """ return f'Synax' def __repr__(self): """ Official string representation of the instance """ return f'Synax'