# synax/bfield.py
import jax,interpax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from functools import partial
import scipy.constants as const
from typing import List, Tuple, Union,Dict
[docs]
class B_jf12:
"""
jf12 B field model(https://ui.adsabs.harvard.edu/abs/2012ApJ...757...14J/abstract).
Args:
coords (Union[jax.Array,List[jax.Array],Tuple[jax.Array]]): coordinates of all integration points. Should be of size (3,...), for example ``coords[0]`` is the x-coordinates.
Returns:
A instance of jf12 B field generator.
"""
def __init__(self, coords:Union[jax.Array,List[jax.Array],Tuple[jax.Array]]):
get_index_vmap = jax.vmap(self.get_index)
# Add it as an instance method
setattr(self, 'get_index_vmap', get_index_vmap)
self.rho = (coords[0]**2+coords[1]**2+coords[2]**2)**(1/2)
self.r = (coords[0]**2+coords[1]**2)**(1/2)
self.z = coords[2]
self.shape = coords[2].shape
self.phi = jnp.arctan2(coords[1],coords[0])
self.indexs = self.get_index_vmap(self.r.reshape(-1),self.phi.reshape(-1),self.z.reshape(-1))
B_calc_vmap = jax.vmap(self.B_calc,in_axes=(None, 0,0,0,0,0,0,0))
setattr(self, 'B_calc_vmap', B_calc_vmap)
self.Rmax = 20 #outer boundary of GMF
self.rho_gc = 1. #interior boundary of GMF
#self.inc = 11.5*jnp.pi/180. #inclination, in degrees
self.rmin = 5. # outer boundary of the molecular ring region
self.rcent = 3.# inner boundary of the molecular ring region (field is zero within this region)
self.f = jnp.array([ 0.130, 0.165, 0.094, 0.122,
0.13, 0.118, 0.084, 0.156]) # fractions of circumference spanned by each spiral arm
self.rc_b = jnp.array([0,5.1, 6.3, 7.1, 8.3, 9.8,
11.4, 12.7, 15.5]) # the radii where the spiral arm boundaries cross the negative x-axis
ones_field = jnp.ones_like(coords[0])
mask = (self.rho<self.rho_gc)|(self.r>self.Rmax)
self.total_mask = ones_field.at[mask].set(1e-16)
ones_field = jnp.ones_like(coords[0])
mask = (self.r<self.rcent)
self.rcent_mask = ones_field.at[mask].set(1e-16)
ones_field = jnp.ones_like(coords[0])
mask = (self.r<self.rmin)
self.rmin_mask = ones_field.at[mask].set(1e-16)
[docs]
@staticmethod
def B_calc(jf12_params: Dict[str,float],r:float,phi:float,z:float,is_r_less_rmin = 1.,is_r_less_rcent = 1.,b_disk = 1.,mask = 1.):
"""
Calculate jf12 B-field at a given position ``(r,phi,z)``.
Args:
jf12_params (Dict[str,float]): A dict contains all parameters of the jf12 model.
r (float): r in cylindrical coordinates.
phi (float): phi in cylindrical coordinates.
z (float): z in cylindrical coordinates.
is_r_less_rmin: is this position less than the outer boundary of the molecular ring region. 1 for not, 0 for yes.
is_r_less_rcent: is this position less than the inner boundary of the molecular ring region. 1 for not, 0 for yes.
b_disk: B field in this spiral arm. 0. for all non-spiral-arm positions.
mask: is this position less than the outer boundary of the total B field. 1 for not, 0 for yes.
Returns:
(Bx,By,Bz) in this position.
"""
#r = (x**2+y**2)**1/2
#phi = jnp.arctan2(y,x)
#disk components
inc = 11.5*jnp.pi/180. #inclination, in degrees
b0 = 5./r
z_profile = 1/(1+jnp.exp(-2/jf12_params['w_disk']*(jnp.abs(z)-jf12_params['h_disk'])))
B_cyl_disk = jnp.array([0,b0*jf12_params['b_ring']*(1-z_profile),0])*(1-is_r_less_rmin)
B_cyl_disk += b_disk*is_r_less_rmin*jnp.array([jnp.sin(inc),jnp.cos(inc),0])*b0* (1 - z_profile)
#toroidal components
z_sign = jnp.sign(z)
b1 = (z_sign+1.)/2*jf12_params['bn']+(1.-z_sign)/2*jf12_params['bs']
rh = (z_sign+1.)/2*jf12_params['rn']+(1.-z_sign)/2*jf12_params['rs']
bh = b1 * (1. - 1. / (1. + jnp.exp(-2. / jf12_params['wh'] * (r - rh)))) * jnp.exp(-jnp.abs(z)/ (jf12_params['z0']))
B_cyl_h = jnp.array([0.,bh*z_profile,0.])
#X-field
rc_X = jf12_params['rpc_x'] + jnp.abs(z) / jnp.tan(jf12_params['x_theta'])
rc_sign = jnp.sign(r-rc_X)
rp_X = (r - jnp.abs(z)/jnp.tan(jf12_params['x_theta']))*(rc_sign+1.)/2 + (1- rc_sign)/2*r*jf12_params['rpc_x']/rc_X
x_theta = jf12_params['x_theta']*(rc_sign+1.)/2+ (1- rc_sign)/2*jnp.arctan(jnp.abs(z)/(r-rp_X))
B_X = jf12_params['b0_x']*rp_X/r*jnp.exp(-rp_X/jf12_params['r0_x'])*(rc_sign+1.)/2 + (1- rc_sign)/2*jf12_params['b0_x']*(jf12_params['rpc_x']/rc_X)**2*jnp.exp(-rp_X/jf12_params['r0_x'])
B_cyl_X = jnp.array([B_X*jnp.cos(x_theta)*z_sign,0,B_X*jnp.sin(x_theta)])
B_cyl = B_cyl_disk*is_r_less_rcent+B_cyl_h+B_cyl_X
return jnp.array([B_cyl[0]*jnp.cos(phi) - B_cyl[1]*jnp.sin(phi),B_cyl[0]*jnp.sin(phi) + B_cyl[1]*jnp.cos(phi),B_cyl[2]])*mask
@partial(jax.jit, static_argnums=(0,))
def B_field(self,jf12_params):
"""
Calculate jf12 B-field at all positions specified by ``coords``.
Args:
jf12_params (Dict[str,float]): A dict contains all parameters of the jf12 model.
Returns:
jnp.Array of shape (``coords[0].shape``,3). ``coords`` is the parameter of your B_jf12 instance.
"""
bv_b = jnp.array([jf12_params['b_arm_1'],jf12_params['b_arm_2'],jf12_params['b_arm_3'],jf12_params['b_arm_4'],jf12_params['b_arm_5'],jf12_params['b_arm_6'],jf12_params['b_arm_7'],0,0])
b8 = -1*(self.f[:8]*bv_b[:8]).sum()/self.f[7]
bv_b = bv_b.at[7].set(b8)
disk_values = jnp.take(bv_b,self.indexs)
B_field = self.B_calc_vmap(jf12_params,self.r.reshape(-1),self.phi.reshape(-1),self.z.reshape(-1),self.rmin_mask.reshape(-1),self.rcent_mask.reshape(-1),disk_values.reshape(-1),self.total_mask.reshape(-1)).reshape(self.shape + (3,))
return B_field*1e-6
[docs]
@staticmethod
def get_index(r:float,phi:float,z:float) -> int:
"""
Calculate the which spiral arms is position specified by ``coordinates`` in.
Args:
r (float): r in cylindrical coordinates.
phi (float): phi in cylindrical coordinates.
z (float): z in cylindrical coordinates.
Returns:
Spiral arm index (0-7). 8 for not in any spiral arms.
"""
rc_b = jnp.array([0,5.1, 6.3, 7.1, 8.3, 9.8,
11.4, 12.7, 15.5])
inc = 11.5*jnp.pi/180. #inclination, in degrees
r_negx1 = r * jnp.exp((jnp.pi - phi) / jnp.tan(jnp.pi / 2 - inc))
r_negx2 = r * jnp.exp((-1 * jnp.pi - phi) / jnp.tan(jnp.pi / 2 - inc))
r_negx3 = r * jnp.exp((-3 * jnp.pi - phi) / jnp.tan(jnp.pi / 2 - inc))
#r_negx4 = r * jnp.exp((-5 * jnp.pi - phi) / jnp.tan(jnp.pi / 2 - inc))
r_negx = jnp.where(r_negx1 <= rc_b[8], r_negx1,
jnp.where(r_negx2 <= rc_b[8], r_negx2,
r_negx3))
index = jnp.searchsorted(rc_b, r_negx)
return index -1
def __str__(self):
"""
String representation of the instance
"""
return f'B_jf12'
def __repr__(self):
"""
Official string representation of the instance
"""
return f'B_jf12'
[docs]
class B_lsa():
"""
lsa B field model. See Synax paper for more details.
Args:
coords (Union[jax.Array,List[jax.Array],Tuple[jax.Array]]): coordinates of all integration points. Should be of size (3,...), for example ``coords[0]`` is the x-coordinates.
Returns:
A instance of lsa B field generator.
"""
def __init__(self, coords:Union[jax.Array,List[jax.Array],Tuple[jax.Array]]):
self.rho = (coords[0]**2+coords[1]**2+coords[2]**2)**(1/2)
self.r = (coords[0]**2+coords[1]**2)**(1/2)
self.cos_p = coords[0]/self.r
self.sin_p = coords[1]/self.r
self.z = coords[2]
self.shape = coords[2].shape
ones_field = jnp.ones_like(coords[0])
mask = (self.r<3)|(self.r>20.)
self.total_mask = ones_field.at[mask].set(1e-16)
B_calc_vmap = jax.vmap(self.B_calc,in_axes=(None, 0,0,0,0,0))
setattr(self, 'B_calc_vmap', B_calc_vmap)
[docs]
@staticmethod
def B_calc(lsa_params,r,z,cos_p,sin_p,mask):
"""
Calculate lsa B-field at a given position ``(r,phi,z)``.
Args:
lsa_params (Dict[str,float]): A dict contains all parameters of the lsa model.
r (float): r in cylindrical coordinates.
z (float): z in cylindrical coordinates.
cos_p (float): cos(phi), phi is the polar angle in cylindrical coordinates.
sin_p (float): sin(phi), phi is the polar angle in cylindrical coordinates.
mask: is this position less than the outer boundary of the total B field. 1 for not, 0 for yes.
Returns:
(Bx,By,Bz) in this position.
"""
psi = lsa_params["psi0"]+lsa_params["psi1"]*jnp.log(r/8)
chi = lsa_params["chi0"]*jnp.tanh(z)
return jnp.array([jnp.sin(psi)*jnp.cos(chi)*cos_p - jnp.cos(psi)*jnp.cos(chi)*sin_p ,
jnp.sin(psi)*jnp.cos(chi)*sin_p + jnp.cos(psi)*jnp.cos(chi)*cos_p ,
jnp.sin(chi),])*mask*lsa_params["b0"]
@partial(jax.jit, static_argnums=(0,))
def B_field(self,lsa_params):
"""
Calculate lsa B-field at all positions specified by ``coords``.
Args:
lsa_params (Dict[str,float]): A dict contains all parameters of the lsa model.
Returns:
jnp.Array of shape (``coords[0].shape``,3). ``coords`` is the parameter of your B_lsa instance.
"""
return (self.B_calc_vmap(lsa_params,self.r.reshape(-1),self.z.reshape(-1),self.cos_p.reshape(-1),self.sin_p.reshape(-1),self.total_mask.reshape(-1))*1e-6).reshape(self.shape+ (3,))
def __str__(self):
"""
String representation of the instance
"""
return f'B_lsa'
def __repr__(self):
"""
Official string representation of the instance
"""
return f'B_lsa'
[docs]
class B_grid():
"""
grid B field model. See Synax paper for more details.
Args:
coords (Union[jax.Array,List[jax.Array],Tuple[jax.Array]]): coordinates of all integration points. Should be of size (3,...), for example ``coords[0]`` is the x-coordinates.
coords_field (Union[jax.Array,List[jax.Array],Tuple[jax.Array]]): coords[i] is the 1D vector of coordinates along i-th axis. Since the grid is a regular 3D grid, 1D vectors are sufficient to represents the coordinates.
Returns:
A instance of grid B field generator.
"""
def __init__(self, coords:Union[jax.Array,List[jax.Array],Tuple[jax.Array]],coords_field:Union[jax.Array,List[jax.Array],Tuple[jax.Array]]):
self.x = coords[0]
self.y = coords[1]
self.z = coords[2]
self.pos = coords
self.xf = coords_field[0]
self.yf = coords_field[1]
self.zf = coords_field[2]
self.shape = coords[2].shape
field_calc = lambda pos,field: interpax.interp3d(pos[0].reshape(-1),pos[1].reshape(-1),pos[2].reshape(-1),self.xf,self.yf,self.zf,field,method='linear',extrap=True)
setattr(self, 'field_calc', field_calc)
@partial(jax.jit, static_argnums=(0,))
def B_field(self,B_field_grid):
"""
Calculate grid B-field at all positions specified by ``coords``.
Args:
B_field_grid (Dict[str,float]): your field in a regular 3D grid.
Returns:
jnp.Array of shape (``coords[0].shape``,3). ``coords`` is the parameter of your B_grid instance.
"""
return self.field_calc(self.pos,B_field_grid).reshape(self.shape(3,))
def __str__(self):
"""
String representation of the instance
"""
return f'B_grid'
def __repr__(self):
"""
Official string representation of the instance
"""
return f'B_grid'