Finax
High Performance Quantative Finance on JAX
Install / Use
/learn @luweizheng/FinaxREADME
Finax: High Performance Quantative Finance on JAX
Introduction
We developed a quantative finance library based on JAX. This library provides high-performance financial computation leveraging the hardware acceleration, parallel scientific computing and automatic differentiation of JAX.
Finax is still a research project that is currently under development. Some APIs may change in the future. We welcome any suggestions and contributions.
We can:
-
run financial workloads on CPU/GPU/TPU with XLA acceleration
-
calculate mathematical derivative of financial models, i.e. Greeks
-
distribute workloads on multiple devices
examples directory contains several demonstrations of using our Finax library.
Install
JAX installation
You must first follow JAX's installation guide to install JAX based on your device architecture (CPU/GPU/TPU).
Finax
pip install finax --upgrade
Usage
Getting-started
Here is an example for option pricing using Black-Scholes model.
import numpy as np
from jax import jit
from jax.config import config
config.update("jax_enable_x64", True)
from finax.black_sholes.vanilla_prices import option_price
option_price_fn = jit(option_price)
dtype = jnp.float64
forwards = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=dtype)
strikes = jnp.array([3.0, 3.0, 3.0, 3.0, 3.0], dtype=dtype)
volatilities = jnp.array([0.0001, 102.0, 2.0, 0.1, 0.4], dtype=dtype)
expiries = jnp.array(1.0, dtype=dtype)
computed_prices = option_price_fn(
volatilities=volatilities,
strikes=strikes,
expiries=expiries,
forwards=forwards)
64-bit precision
To enable 64-bit precision, set the respective JAX flag before importing finax (see the JAX guide), for example:
from jax.config import config
config.update("jax_enable_x64", True)
