import os
import sys
import jax,interpax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from functools import partial
from typing import List, Tuple, Union,Dict
[docs]
class C_WMAP():
"""
WMAP C field model. (https://iopscience.iop.org/article/10.1086/513699)
Args:
coords (Union[jax.Array,List[jax.Array],Tuple[jax.Array]]): coordinates of all integration points. Should be of size (3,...), for example ``coords[0]`` is the x-coordinates.
Returns:
A instance of WMAP C field generator.
"""
def __init__(self, coords:Union[jax.Array,List[jax.Array],Tuple[jax.Array]]):
self.r = (coords[0]**2+coords[1]**2)**(1/2)
self.z = coords[2]
self.shape = coords[2].shape
C_calc_vmap = jax.vmap(self.C_calc,in_axes=(None, 0,0))
setattr(self, 'C_calc_vmap', C_calc_vmap)
[docs]
@staticmethod
def C_calc(WMAP_params, r:float,z:float):
"""
Calculate wmap C-field at a given position ``(r,phi,z)``.
Args:
r (float): r in cylindrical coordinates.
z (float): z in cylindrical coordinates.
Returns:
cosmic ray electron spectrum constant C in this position.
"""
return WMAP_params['C0']*jnp.exp(-r/WMAP_params['hr'])/jnp.cosh(z/WMAP_params['hd'])**2#*(1-jnp.floor(c))
@partial(jax.jit, static_argnums=(0,))
def C_field(self,WMAP_params = {'C0':211.13068378473076,'hr':5.,'hd':1.}):
"""
Calculate WMAP C-field at all positions specified by ``coords``.
Args:
WMAP_params (Dict[str,float]): A dict contains all parameters of the WMAP model.
Returns:
jnp.Array of shape (``coords[0].shape``). ``coords`` is the parameter of your C_WMAP instance.
"""
return self.C_calc_vmap(WMAP_params,self.r.reshape(-1),self.z.reshape(-1)).reshape(self.shape)
def __str__(self):
"""
String representation of the instance
"""
return f'C_WMAP'
def __repr__(self):
"""
Official string representation of the instance
"""
return f'C_WMAP'
[docs]
class C_uni():
"""
Uniform C field model cenntered at ``center``
Args:
coords (Union[jax.Array,List[jax.Array],Tuple[jax.Array]]): coordinates of all integration points. Should be of size (3,...), for example ``coords[0]`` is the x-coordinates.
center: center of uniform C field. default to be the earth.
Returns:
A instance of uniform C field generator.
"""
def __init__(self, coords:Union[jax.Array,List[jax.Array],Tuple[jax.Array]],center = (-8.3,0.0,0.006)):
self.x = coords[0] - center[0]
self.y = coords[1] - center[1]
self.z = coords[2] - center[2]
self.shape = coords[2].shape
C_calc_vmap = jax.vmap(self.C_calc,in_axes=(None, 0,0,0))
setattr(self, 'C_calc_vmap', C_calc_vmap)
[docs]
@staticmethod
def C_calc(Uni_params,x:float,y:float,z:float):
"""
Calculate uniform C-field at a given position ``(x,y,z)``.
Args:
x (float): x in cartisian coordinates.
y (float): y in cartisian coordinates.
z (float): z in cartisian coordinates.
Returns:
cosmic ray electron spectrum constant C in this position.
"""
c = (x**2+y**2+z**2)/jnp.max(jnp.array([x**2+y**2+z**2,Uni_params['rho0']**2]))#+1e-7
return (1-jnp.floor(c))*Uni_params['C0']
@partial(jax.jit, static_argnums=(0,))
def C_field(self,Uni_params = {'C0':1.0,'rho0':4.,}):
"""
Calculate uniform C-field at all positions specified by ``coords``.
Args:
Uni_params (Dict[str,float]): A dict contains all parameters of the uniform model.
Returns:
jnp.Array of shape (``coords[0].shape``). ``coords`` is the parameter of your C_uni instance.
"""
return self.C_calc_vmap(Uni_params,self.x.reshape(-1),self.y.reshape(-1),self.z.reshape(-1)).reshape(self.shape)
def __str__(self):
"""
String representation of the instance
"""
return f'C_uni'
def __repr__(self):
"""
Official string representation of the instance
"""
return f'C_uni'
[docs]
class C_grid():
"""
grid C field model. See Synax paper for more details.
Args:
coords (Union[jax.Array,List[jax.Array],Tuple[jax.Array]]): coordinates of all integration points. Should be of size (3,...), for example ``coords[0]`` is the x-coordinates.
coords_field (Union[jax.Array,List[jax.Array],Tuple[jax.Array]]): coords[i] is the 1D vector of coordinates along i-th axis. Since the grid is a regular 3D grid, 1D vectors are sufficient to represents the coordinates.
Returns:
A instance of grid C field generator.
"""
def __init__(self, coords:Union[jax.Array,List[jax.Array],Tuple[jax.Array]],coords_field:Union[jax.Array,List[jax.Array],Tuple[jax.Array]]):
self.x = coords[0]
self.y = coords[1]
self.z = coords[2]
self.pos = coords
self.xf = coords_field[0]
self.yf = coords_field[1]
self.zf = coords_field[2]
self.shape = coords[2].shape
field_calc = lambda pos,field: interpax.interp3d(pos[0].reshape(-1),pos[1].reshape(-1),pos[2].reshape(-1),self.xf,self.yf,self.zf,field,method='linear',extrap=True)
setattr(self, 'field_calc', field_calc)
@partial(jax.jit, static_argnums=(0,))
def C_field(self,C_field_grid):
"""
Calculate grid C-field at all positions specified by ``coords``.
Args:
C_field_grid (Dict[str,float]): your field in a regular 3D grid.
Returns:
jnp.Array of shape (``coords[0].shape``). ``coords`` is the parameter of your C_grid instance.
"""
return self.field_calc(self.pos,C_field_grid).reshape(self.shape+(3,))
def __str__(self):
"""
String representation of the instance
"""
return f'C_grid'
def __repr__(self):
"""
Official string representation of the instance
"""
return f'C_grid'