SkillAgentSearch skills...

Xspex

JAX interface for XSPEC spectral models.

Install / Use

/learn @wcxve/Xspex
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

xspex

JAX Interface for XSPEC Spectral Models.

PyPI - Python Version PyPI - Version License: GPL v3<br> Coverage Status

Installation

NOTE: Before installation, HEASoft & XSPEC v12.12.1+ are required to be installed on your system. You can download from here, or install from conda.

Once the HEADAS environment has been initialized, xspex can be installed directly from PyPI using:

pip install xspex

xspex currently supports Python 3.11 through 3.14.

Examples

Basic Usage

import os

os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2'

import jax
import jax.numpy as jnp
import numpy as np
import xspex as xx

# Double precision is required for XSPEC models
jax.config.update('jax_enable_x64', True)

# Get APEC model function
fn, info = xx.get_model('apec')

# Define parameters and energy grid
params = jnp.array([1.0, 1.0, 0.0])
egrid = jnp.linspace(0.1, 0.2, 6)

# Evaluate the model function
value = fn(params, egrid)
print(value)
# output:
# [1.29398149 0.43733974 0.11998129 0.08725635 0.07757024]

JAX Transformations

xspex provides JAX automatic differentiation support for XSPEC models through finite difference approximation. This allows seamless integration with JAX's transformations like grad, jacfwd, jacrev, etc.

Computing Gradients

# Get gradient function with respect to parameters
grad_fn = jax.grad(lambda p, e: jnp.sum(fn(p, e)))

# Compute gradient, note that the abundance and redshift are fixed by default
grad = grad_fn(params, egrid)
print(grad)
# output:
# [-3.1665168  0.         0.       ]

Computing Jacobian

# Get Jacobian function
jac_fn = jax.jacfwd(lambda p, e: fn(p, e))  # or jax.jacrev, jax.jacobian

# Compute Jacobian matrix
jacobian = jac_fn(params, egrid)
print(jacobian)
# output:
# [[-2.01717805 -0.         -0.        ]
#  [-1.05626962 -0.         -0.        ]
#  [-0.03252301 -0.         -0.        ]
#  [-0.02018553 -0.         -0.        ]
#  [-0.0403606  -0.         -0.        ]]

Vectorization with vmap

# Create multiple parameter sets
param_sets = jnp.array([
    [0.5, 1.0, 0.0],
    [1.0, 1.0, 0.0],
    [2.0, 1.0, 0.0],
])

# Vectorize the function
vmapped_fn = jax.vmap(fn, in_axes=(0, None))
results = vmapped_fn(param_sets, egrid)
print(results)
# output:
# [[0.52477309 0.56379027 0.13421626 0.11663016 0.17570166]
#  [1.29398149 0.43733974 0.11998129 0.08725635 0.07757024]
#  [0.18578006 0.10865312 0.08878279 0.07398064 0.06353859]]

Parallel evaluation with pmap

# Replicate parameters across devices
param_sets = jnp.array([
    [1.0, 1.0, 0.0],
    [2.0, 1.0, 0.0],
])

pmapped_fn = jax.pmap(fn, in_axes=(0, None))
results = pmapped_fn(param_sets, egrid)
print(results)
# output:
# [[1.29398149 0.43733974 0.11998129 0.08725635 0.07757024]
#  [0.18578006 0.10865312 0.08878279 0.07398064 0.06353859]]

Custom Finite Difference Automatic Differentiation

# Create model with custom finite difference settings
fn, info = xx.get_model('powerlaw')
fn2 = xx.define_fdjvp(  # see the docstring for more details
    fn,
    info,
    delta=1e-6,  # Custom step size (relative to parameter value)
    method='central',  # 'central' or 'forward' finite differences
    fixed=None  # Optional: specify which parameters to keep fixed
)
View on GitHub
GitHub Stars5
CategoryProduct
Updated1d ago
Forks1

Languages

C++

Security Score

90/100

Audited on Mar 31, 2026

No findings