import jax.numpy as jnp
import jax,math
import numpy as np
import healpy as hp
from functools import partial
import scipy.constants as const
jax.config.update("jax_enable_x64", True)
# obtaining integration locations
@partial(jax.jit, static_argnums=(2,3,4,5,6,7))
def obtain_positions(theta,phi,obs_coord:tuple[float] = (-8.3,0.,0.006),x_length:float=20,y_length:float=20,z_length:float=5,num_int_points:int=512,epsilon:float=1e-7):
"""
Calculate the integration points along one line of sight, from the location of the observer to the box boundary
Args:
theta (float): In unit of rad. The galactic longitude.
phi (float): In unit of rad. The galactic co-lattitude. These values can be automatically generate by ``healpy.pix2ang`` with ``lonlat = False``.
obs_coord (tuple[float]): In unit of kpc. the location of observer.
x_length (float): In unit of kpc. half of the box length along x-axis.
y_length (float): In unit of kpc. half of the box length along y-axis.
z_length (float): In unit of kpc. half of the box length along z-axis.
num_int_points (int): the number of integration points along one LoS.
Returns:
tuple:
- pos (jnp.Array): In unit of kpc. 2D array of shape (``num_int_points``,3), coordinates of integration points along one sightline specified by (theta,phi).
- dl (jnp.float): In unit of kpc. length of each integration segment.
- nhat (jnp.Array): In unit of rad. 1D array of shape (3), unit vector of this LoS.
"""
nx = jnp.sin(theta)*jnp.cos(phi)
ny = jnp.sin(theta)*jnp.sin(phi)
nz = jnp.cos(theta)
max_val = jnp.max(jnp.abs(jnp.array([nx/(x_length-obs_coord[0]*jnp.sign(nx)),ny/(y_length-obs_coord[1]*jnp.sign(ny)),nz/(z_length-obs_coord[2]*jnp.sign(nz))])))
int_points,step = jnp.linspace(0,1,num_int_points,endpoint=False,retstep=True)
int_points = int_points + step*0.5
xs = nx/max_val*int_points
ys = ny/max_val*int_points
zs = nz/max_val*int_points
dl = (xs[0]**2+ys[0]**2+zs[0]**2)**0.5*2
return jnp.array([xs+obs_coord[0],ys+obs_coord[1],zs+obs_coord[2]]),dl,jnp.array([nx,ny,nz])
# obtaining integration locations
@partial(jax.jit, static_argnums=(2,3,4,5,6,7))
def obtain_positions_hammurabi(theta,phi,obs_coord:tuple[float] = (-8.3,0.,0.006),x_length:float=4,y_length:float=4,z_length:float=4,num_int_points:int=256,epsilon:float=1e-7):
"""
Calculate the integration points along one line of sight in hammurabi way. Unlike integrate to the box boundary, now we integrate to a certain distance ``(x_length,y_length,z_length)`` way from observer.
Args:
theta (float): In unit of rad. The galactic longitude.
phi (float): In unit of rad. The galactic co-lattitude. These values can be automatically generate by ``healpy.pix2ang`` with ``lonlat = False``.
obs_coord (tuple[float]): In unit of kpc. the location of observer.
x_length (float): In unit of kpc. integration length along x-axis.
y_length (float): In unit of kpc. integration length along y-axis.
z_length (float): In unit of kpc. integration length along z-axis.
num_int_points (int): the number of integration points along one LoS.
Returns:
tuple:
- pos (jnp.Array): In unit of kpc. 2D array of shape (``num_int_points``,3), coordinates of integration points along one sightline specified by ``(theta,phi)``.
- dl (jnp.float): In unit of kpc. length of each integration segment.
- nhat (jnp.Array): In unit of rad. 1D array of shape (3), unit vector of this LoS.
"""
nx = jnp.sin(theta)*jnp.cos(phi)
ny = jnp.sin(theta)*jnp.sin(phi)
nz = jnp.cos(theta)
#max_val = jnp.max(jnp.abs(jnp.array([nx/(x_length-obs_coord[0]*jnp.sign(nx)),ny/(y_length-obs_coord[1]*jnp.sign(ny)),nz/(z_length-obs_coord[2]*jnp.sign(nz))])))
int_points,step = jnp.linspace(0,1,num_int_points,endpoint=False,retstep=True)
int_points = int_points + step*0.5
xs = x_length*int_points*nx#+obs_coord[0]
ys = y_length*int_points*ny#+obs_coord[1]
zs = z_length*int_points*nz#+obs_coord[2]
dl = (xs[0]**2+ys[0]**2+zs[0]**2)**0.5*2
return jnp.array([xs+obs_coord[0],ys+obs_coord[1],zs+obs_coord[2]]),dl,jnp.array([nx,ny,nz])
[docs]
def get_healpix_positions(nside = 64,obs_coord:tuple[float] = (-8.3,0.,0.006),x_length:float=20,y_length:float=20,z_length:float=5,num_int_points:int=512,epsilon:float=1e-7):
"""
Calculate the integration points along each line of sight for a ``HEALPix`` map with given `nside`, from the location of the observer to the box boundary. A ``HEALPix`` map with a given ``nside`` should contains ``npix = 12*nside**2`` pixels.
Args:
nside (int): ``NSIDE`` of the ``HEALPix`` map.
obs_coord (tuple[float]): In unit of kpc. the location of observer.
x_length (float): In unit of kpc. half of the box length along x-axis.
y_length (float): In unit of kpc. half of the box length along y-axis.
z_length (float): In unit of kpc. half of the box length along z-axis.
num_int_points (int): the number of integration points along one LoS.
Returns:
tuple:
- poss (jnp.Array): In unit of kpc. 3D array of shape (``npix, num_int_points``, 3), coordinates of integration points along all sightlines of a ``HEALPix`` map.
- dls (jnp.Array): In unit of kpc. 1D array of shape (``npix``), length of integration segment for all sightlines.
- nhats (jnp.Array): In unit of rad. 2D array of shape (``npix``,3), unit vector of LoS for all pixels.
"""
obtain_vmap = jax.vmap(lambda theta,phi:obtain_positions(theta,phi,obs_coord = obs_coord,x_length=x_length,y_length=y_length,z_length=z_length,num_int_points=num_int_points,epsilon=epsilon))
n_pixs = np.arange(0,12*nside**2)
theta,phi = hp.pix2ang(nside,n_pixs)
poss,dls,nhats = obtain_vmap(theta,phi)
return poss.transpose((1,0,2)),dls,nhats
[docs]
def get_rotated_box_vertices(x, y, z, theta, phi):
"""
Calculate the coordinates of the vertices of a rectangular box after rotation.
Args:
x, y, z: Half dimensions of the box along the X, Y, and Z axes. e.g. (-x,x) is the boundary of the box in x-axis
theta: Rotation angle around the Z-axis (in radians).
phi: Rotation angle around the Y-axis (in radians).
Returns:
- vertices_rotated: A (8, 3) array of the rotated vertices.
"""
# Half dimensions
hx, hy, hz = x , y , z
# Define the 8 vertices of the box
signs = jnp.array([[1, 1, 1],
[1, 1, -1],
[1, -1, 1],
[1, -1, -1],
[-1, 1, 1],
[-1, 1, -1],
[-1,-1, 1],
[-1,-1, -1]])
vertices = signs * jnp.array([hx, hy, hz])
# Rotation matrices
cos_theta = jnp.cos(theta)
sin_theta = jnp.sin(theta)
cos_phi = jnp.cos(phi)
sin_phi = jnp.sin(phi)
Rz = jnp.array([[cos_theta, -sin_theta, 0],
[sin_theta, cos_theta, 0],
[0, 0, 1]])
Ry = jnp.array([[ cos_phi, 0, sin_phi],
[ 0, 1, 0 ],
[-sin_phi, 0, cos_phi]])
# Combined rotation matrix
R = Ry @ Rz
# Rotate vertices
vertices_rotated = vertices @ R.T
return vertices_rotated
@jax.jit
def find_min_rectangle_xy(vertices_rotated):
"""
Find the minimum axis-aligned rectangle in the X-Y plane that contains all projected vertices.
Args:
vertices_rotated: A (8, 3) array of the rotated vertices.
Returns:
tuple:
- min_x,max_x,min_y,max_y
"""
# Project onto X-Y plane
xy_coords = vertices_rotated[:, :2] # Take only X and Y coordinates
# Find min and max X and Y
min_x = jnp.min(xy_coords[:, 0])
max_x = jnp.max(xy_coords[:, 0])
min_y = jnp.min(xy_coords[:, 1])
max_y = jnp.max(xy_coords[:, 1])
return min_x,max_x,min_y,max_y
[docs]
def vertical_line_intersects_box(vertices_rotated, x0, y0):
"""
Determines whether the vertical line at (x0, y0) intersects the rotated box.
Args:
vertices_rotated: A (8, 3) array of the rotated vertices.
x0, y0: Coordinates of the point(s) in the XY plane. Can be scalars or arrays.
Returns:
tuple:
- intersects: Boolean array indicating whether each line intersects the box.
- min_z: Array of minimum Z-values where the line intersects the box.
- max_z: Array of maximum Z-values where the line intersects the box.
"""
# Ensure x0 and y0 are arrays for broadcasting
x0 = jnp.atleast_1d(x0)
y0 = jnp.atleast_1d(y0)
# Define the faces of the box using vertex indices
faces = jnp.array([
[0, 2, 6, 4], # Top face
[1, 3, 7, 5], # Bottom face
[0, 1, 3, 2], # Side face
[4, 5, 7, 6], # Side face
[0, 1, 5, 4], # Side face
[2, 3, 7, 6], # Side face
], dtype=int)
# For each face, define its two triangles
triangles = jnp.concatenate([
faces[:, [0, 1, 2]],
faces[:, [0, 2, 3]],
], axis=0) # Shape: (12, 3)
# Get the vertices for each triangle
p0 = vertices_rotated[triangles[:, 0]] # Shape: (12, 3)
p1 = vertices_rotated[triangles[:, 1]]
p2 = vertices_rotated[triangles[:, 2]]
# Compute plane normals
v0 = p1 - p0 # Shape: (12, 3)
v1 = p2 - p0
normals = jnp.cross(v0, v1) # Shape: (12, 3)
A = normals[:, 0]
B = normals[:, 1]
C = normals[:, 2]
# Compute D = -dot(normal, p0)
D = -jnp.einsum('ij,ij->i', normals, p0) # Shape: (12,)
# Avoid division by zero for planes where C is zero
C_nonzero = jnp.abs(C) > 1e-8
# Prepare x0, y0 for broadcasting
x0_b = x0[:, None] # Shape: (N_points, 1)
y0_b = y0[:, None] # Shape: (N_points, 1)
# Compute z at (x0, y0)
numerator = -(A * x0_b + B * y0_b + D) # Shape: (N_points, 12)
denominator = C # Shape: (12,)
z = jnp.where(C_nonzero, numerator / C, jnp.nan) # Shape: (N_points, 12)
# Create point p
x0_rep = x0_b * jnp.ones_like(z)
y0_rep = y0_b * jnp.ones_like(z)
p = jnp.stack([x0_rep, y0_rep, z], axis=-1) # Shape: (N_points, 12, 3)
# Compute vectors for barycentric coordinates
v2 = p - p0 # Broadcasting over points
# Compute dot products
dot00 = jnp.einsum('ij,ij->i', v0, v0) # Shape: (12,)
dot01 = jnp.einsum('ij,ij->i', v0, v1)
dot11 = jnp.einsum('ij,ij->i', v1, v1)
# Corrected einsum indices
dot02 = jnp.einsum('nij,ij->ni', v2, v0) # Shape: (N_points, 12)
dot12 = jnp.einsum('nij,ij->ni', v2, v1)
# Compute barycentric coordinates
denom = dot00 * dot11 - dot01 * dot01 # Shape: (12,)
denom_nonzero = jnp.abs(denom) > 1e-8
# Prepare denom for broadcasting
denom = jnp.where(denom_nonzero, denom, jnp.nan) # Shape: (12,)
u = (dot11 * dot02 - dot01 * dot12) / denom # Shape: (N_points, 12)
v = (dot00 * dot12 - dot01 * dot02) / denom
w = 1 - u - v
# Check if point is inside triangle
cond = (u >= -1e-8) & (v >= -1e-8) & (w >= -1e-8) & \
(u <= 1 + 1e-8) & (v <= 1 + 1e-8) & (u + v <= 1 + 1e-8)
valid = cond & C_nonzero[None, :] & denom_nonzero[None, :]
# Extract valid z values
z_values = jnp.where(valid, z, jnp.nan) # Shape: (N_points, 12)
# Determine if any intersections occur
intersects = jnp.any(valid, axis=1) # Shape: (N_points,)
# Compute min and max z-values for each point
min_z = jnp.nanmin(z_values, axis=1) # Shape: (N_points,)
max_z = jnp.nanmax(z_values, axis=1)
return intersects, min_z, max_z
import jax.numpy as jnp
vertical_line_intersects_box_vamp = jax.vmap( vertical_line_intersects_box,in_axes=[None,0,0])