Lox
Logging library for JAX that is compatible with transformations and primitives such as vmap and scan.
Install / Use
/learn @huterguier/LoxREADME
Logging in JAX
lox is a lightweight and flexible logging library for JAX.
All you need to do is decorate your code with lox.log statements and lox does the rest.
Using JAX's intermediate function representation Lox can dynamically insert callbacks to log your data, or collect the logs that would have been generated during the execution and return them as part of the output of your function.
While it's obviously possible to implement this functionality yourself, lox provides a simple and efficient way to do so without having to carry around boilerplate code in your functions.
Features
-
🔌 Plug-and-Play: Simply add
lox.logstatements where you need them.loxhandles all the complex boilerplate of plumbing data through JAX's transformations, keeping your function signatures clean and focused on the logic. -
📦 Automatic Extraction: Instead of explicitly returning data from your functions,
lox.spoolautomatically "spools up" all logs generated during a function's execution. It collects them and returns them as a singlelogdictalongside the function's original output. -
📡 Dynamic Callbacks: Using
lox.tap, you can "tap into" a JAX-transformed function using custom callbacks. This is ideal for live monitoring and debugging without halting execution. -
✅
vmapover Seeds: Built on its own JAX primitive,loxworks effortlessly with core transformations likejit,scan, andvmap. -
📊 Experiment Loggers: Includes built-in loggers that seamlessly pipe your metrics to popular experiment tracking platforms including
wandbandneptune, which are also fully compatible withvmap.
Quick Start
Basic API
At its core lox is built around 2 central function transformations called tap and spool.
They work by traversing the function's jaxpr, JAX's internal intermediate function representation, and dynamically alters it to match the desired behavior.
In order to use them with your function, all you need to do is specify what you want to log using lox.log.
>>> import jax
>>> import jax.numpy as jnp
>>> import lox
>>> def f(xs):
... lox.log({"xs": xs})
... def step(carry, x):
... carry += x
... lox.log({"carry": carry})
... return carry, x
... y, _ = jax.lax.scan(step, 0, xs)
... return y
>>> xs = jnp.arange(3)
The first transformation, lox.tap, lets you "tap into" function execution by attaching a callback that receives logs as they're generated.
It streams logs in real time, making it great for debugging or live monitoring.
In the following example we use a simple callback that writes all logs to the console.
>>> def callback(logs):
... print("Logging:", logs)
>>> y = lox.tap(f, callback=callback)(xs)
Logging: {'xs': [0, 1, 2]}
Logging: {'carry': 0}
Logging: {'carry': 1}
Logging: {'carry': 3}
The second transformation, lox.spool, "spools up" all logs during execution and returns them alongside the function's output.
This is especially useful when frequent callbacks would be too expensive.
For instance, instead of logging on every iteration, you can collect all logs for a training step and emit them in a single call.
spool is also particularly useful for collecting logs over multiple steps and then applying a reduction like jnp.mean to them.
>>> y, logs = lox.spool(f)(xs)
>>> print("Collected Logs:", logs)
Collected Logs: {'xs': [0, 1, 2], 'carry': [0, 1, 3]}
Logdicts
Lox provides its own internal data structure for logs called logdict, which is a subclass of Python's built-in dict.
To the naked eye, it behaves like a regular dictionary, but it comes with some additional features that make it easier to work with logs.
In addition to the raw data, a logdict also contains the steps at which the logs were recorded.
The following example demonstrates how to log data along with additional step information.
>>> def f(xs):
... def body(i, carry):
... carry += xs[i]
... lox.log({"carry": carry}, step=i, episode=i//2)
... return carry
... y = jax.lax.fori_loop(0, len(xs), body, 0)
... return y
>>> y, logs = lox.spool(f)(xs)
In the example above, we log the carry value at each iteration of a loop, along with the current step and episode.
The step information can be accessed using attributes of the logdict.
We can then access them using logs.step and logs.episode.
An arbitrary amount of keywords can be added to lox.log which will all be treated as additional step information.
>>> print("Collected Logs:", logs["carry"])
Collected Logs: [0, 1, 3]
>>> print("Corresponding Steps:", logs.step['carry'])
Corresponding Steps: [0, 1, 2]
>>> print("Corresponding Episodes:", logs.episode['carry'])
Corresponding Episodes: [0, 0, 1]
Loggers
Lox comes with built-in loggers for common use cases.
Loggers support both lox.tap and lox.spool transformations and let you easily log to different backends.
An example is lox.loggers.SaveLogger, which saves logs to a specified directory in a structured format for later use. Loggers are instantiated with any necessary configuration, and then initialized with a random key using init to produce a logger state. This state is then passed to the tap or spool transformation along with the function to be logged.
>>> import lox.loggers
>>> key = jax.random.key(0)
>>> logger = lox.loggers.SaveLogger("./.lox/")
>>> logger_state = logger.init(key)
>>> y = logger.spool(f, logger_state)(xs)
Loggers can also be combined to log to multiple backends simultaneously using lox.loggers.MultiLogger. The difference between tap and spool is preserved, so you can use MultiLogger with either transformation. Hence spool only logs once at the end of the function execution, while tap logs every time a log is encountered.
>>> from lox.loggers import ConsoleLogger, MultiLogger, SaveLogger
>>> console_logger = ConsoleLogger()
>>> save_logger = SaveLogger(".lox/")
>>> multi_logger = MultiLogger(console_logger, save_logger)
>>> multi_logger_state = multi_logger.init(key)
>>> y = multi_logger.tap(f, multi_logger_state)(xs)
Installation
lox can be installed directly from this GitHub repository.
pip install git+https://github.com/huterguier/lox
By default lox comes without any of the external experiment loggers. Make sure to include the optional dependencies or to install them manually.
pip install "lox[wandb,neptune] @ git+https://github.com/huterguier/lox"
Related Skills
node-connect
349.7kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
109.7kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
349.7kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
349.7kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
