Xspex
JAX interface for XSPEC spectral models.
Install / Use
/learn @wcxve/XspexREADME
xspex
JAX Interface for XSPEC Spectral Models.
Installation
NOTE: Before installation, HEASoft & XSPEC v12.12.1+ are
required to be installed on your system. You can download
from here,
or install
from conda.
Once the HEADAS environment has been initialized, xspex can be
installed directly from PyPI using:
pip install xspex
xspex currently supports Python 3.11 through 3.14.
Examples
Basic Usage
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2'
import jax
import jax.numpy as jnp
import numpy as np
import xspex as xx
# Double precision is required for XSPEC models
jax.config.update('jax_enable_x64', True)
# Get APEC model function
fn, info = xx.get_model('apec')
# Define parameters and energy grid
params = jnp.array([1.0, 1.0, 0.0])
egrid = jnp.linspace(0.1, 0.2, 6)
# Evaluate the model function
value = fn(params, egrid)
print(value)
# output:
# [1.29398149 0.43733974 0.11998129 0.08725635 0.07757024]
JAX Transformations
xspex provides JAX automatic differentiation support for XSPEC models
through finite difference approximation. This allows seamless integration with
JAX's transformations like grad, jacfwd, jacrev, etc.
Computing Gradients
# Get gradient function with respect to parameters
grad_fn = jax.grad(lambda p, e: jnp.sum(fn(p, e)))
# Compute gradient, note that the abundance and redshift are fixed by default
grad = grad_fn(params, egrid)
print(grad)
# output:
# [-3.1665168 0. 0. ]
Computing Jacobian
# Get Jacobian function
jac_fn = jax.jacfwd(lambda p, e: fn(p, e)) # or jax.jacrev, jax.jacobian
# Compute Jacobian matrix
jacobian = jac_fn(params, egrid)
print(jacobian)
# output:
# [[-2.01717805 -0. -0. ]
# [-1.05626962 -0. -0. ]
# [-0.03252301 -0. -0. ]
# [-0.02018553 -0. -0. ]
# [-0.0403606 -0. -0. ]]
Vectorization with vmap
# Create multiple parameter sets
param_sets = jnp.array([
[0.5, 1.0, 0.0],
[1.0, 1.0, 0.0],
[2.0, 1.0, 0.0],
])
# Vectorize the function
vmapped_fn = jax.vmap(fn, in_axes=(0, None))
results = vmapped_fn(param_sets, egrid)
print(results)
# output:
# [[0.52477309 0.56379027 0.13421626 0.11663016 0.17570166]
# [1.29398149 0.43733974 0.11998129 0.08725635 0.07757024]
# [0.18578006 0.10865312 0.08878279 0.07398064 0.06353859]]
Parallel evaluation with pmap
# Replicate parameters across devices
param_sets = jnp.array([
[1.0, 1.0, 0.0],
[2.0, 1.0, 0.0],
])
pmapped_fn = jax.pmap(fn, in_axes=(0, None))
results = pmapped_fn(param_sets, egrid)
print(results)
# output:
# [[1.29398149 0.43733974 0.11998129 0.08725635 0.07757024]
# [0.18578006 0.10865312 0.08878279 0.07398064 0.06353859]]
Custom Finite Difference Automatic Differentiation
# Create model with custom finite difference settings
fn, info = xx.get_model('powerlaw')
fn2 = xx.define_fdjvp( # see the docstring for more details
fn,
info,
delta=1e-6, # Custom step size (relative to parameter value)
method='central', # 'central' or 'forward' finite differences
fixed=None # Optional: specify which parameters to keep fixed
)
