Logix
AI Logging for Interpretability and Explainability🔬
Install / Use
/learn @logix-project/LogixREADME
<a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg" alt="code style: black"></a>
<a href="https://arxiv.org/abs/2405.13954">
</a>
[!WARNING] This repository is under active development. If you have suggestions or find bugs in LogIX, please open a GitHub issue or reach out.
Introduction
With a few additional lines of code, (traditional) logging supports tracking loss, hyperparameters, etc., providing basic insights for users' AI/ML experiments. But...can we also enable in-depth understanding of large-scale training data, the most important ingredient in AI/ML, with a similar logging interface? Try out LogIX that is built upon our cutting-edge data valuation/attribution research (Support Huggingface Transformers and PyTorch Lightning integrations)!
- PyPI
pip install logix-ai
- From source (Latest, recommended)
git clone https://github.com/logix-project/logix.git
cd logix
pip install -e .
Easy to Integrate
Our software design allows for the seamless integration with popular high-level frameworks including HuggingFace Transformer and PyTorch Lightning, that conveniently handles distributed training, data loading, etc. Advanced users, who don't use high-level frameworks, can still integrate LogIX into their existing training code similarly to any traditional logging software (See our Tutorial).
🤗 HuggingFace Integration
A full example can be found here.
from transformers import Trainer, Seq2SeqTrainer
from logix.huggingface import patch_trainer, LogIXArguments
# Define LogIX arguments
logix_args = LogIXArguments(project="myproject",
config="config.yaml",
lora=True,
hessian="raw",
save="grad")
# Patch HF Trainer
LogIXTrainer = patch_trainer(Trainer)
# Pass LogIXArguments as TrainingArguments
trainer = LogIXTrainer(logix_args=logix_args,
model=model,
train_dataset=train_dataset,
*args,
**kwargs)
# Instead of trainer.train(), use
trainer.extract_log()
trainer.influence()
trainer.self_influence()
⚡ PyTorch Lightning Integration
A full example can be found here.
from lightning import LightningModule, Trainer
from logix.lightning import patch, LogIXArguments
class MyLitModule(LightningModule):
...
def data_id_extractor(batch):
return tokenizer.batch_decode(batch["input_ids"])
# Define LogIX arguments
logix_args = LogIXArguments(project="myproject",
config="config.yaml",
lora=True,
hessian="raw",
save="grad")
# Patch Lightning Module and Trainer
LogIXModule, LogIXTrainer = patch(MyLitModule,
Trainer,
logix_args=logix_args,
data_id_extractor=data_id_extractor)
# Use patched Module and Trainer as before
module = LogIXModule(user_args)
trainer = LogIXTrainer(user_args)
# Instead of trainer.fit(module, train_loader), use
trainer.extract_log(module, train_loader)
trainer.influence(module, train_loader)
Getting Started
Logging
Training log extraction with LogIX is as simple as adding one with statement to the existing
training code. LogIX automatically extracts user-specified logs using PyTorch hooks, and stores
it as a tuple of ([data_ids], log[module_name][log_type]). If needed, LogIX writes these logs
to disk efficiently with memory-mapped files.
import logix
# Initialze LogIX
run = logix.init(project="my_project")
# Specify modules to be tracked for logging
run.watch(model, name_filter=["mlp"], type_filter=[nn.Linear])
# Specify plugins to be used in logging
run.setup({"grad": ["log", "covariance"]})
run.save(True)
for batch in data_loader:
# Set `data_id` (and optionally `mask`) for the current batch
with run(data_id=batch["input_ids"], mask=batch["attention_mask"]):
model.zero_grad()
loss = model(batch)
loss.backward()
# Synchronize statistics (e.g. covariance) and write logs to disk
run.finalize()
Training Data Attribution
As a part of our initial research, we implemented influence functions using LogIX. We plan to provide more pre-implemented interpretability algorithms if there is a demand.
# Build PyTorch DataLoader from saved log data
log_loader = run.build_log_dataloader()
with run(data_id=test_batch["input_ids"]):
test_loss = model(test_batch)
test_loss.backward()
test_log = run.get_log()
run.influence.compute_influence_all(test_log, log_loader) # Data attribution
run.influence.compute_self_influence(test_log) # Uncertainty estimation
Please check out Examples for more detailed examples!
Features
Logs from neural networks are difficult to handle due to the large size. For example, the size of the gradient of each training datapoint is about as large as the whole model. Therefore, we provide various systems support to efficiently scale neural network analysis to billion-scale models. Below are a few features that LogIX currently supports:
- Gradient compression (compression ratio: 1,000-100,000x)
- Memory-map-based data IO
- CPU offloading of statistics
Compatability
| DistributedDataParallel| Mixed Precision| Gradient Checkpointing | torch.compile | FSDP | |:----------------------:|:--------------:|:----------------------:|:-------------:|:--------------:| | ✅ | ✅ | ✅ | ✅ | ✅ |
Contributing
We welcome contributions from the community. Please see our contributing guidelines for details on how to contribute to LogIX.
Citation
To cite this repository:
@article{choe2024your,
title={What is Your Data Worth to GPT? LLM-Scale Data Valuation with Influence Functions},
author={Choe, Sang Keun and Ahn, Hwijeen and Bae, Juhan and Zhao, Kewen and Kang, Minsoo and Chung, Youngseog and Pratapa, Adithya and Neiswanger, Willie and Strubell, Emma and Mitamura, Teruko and others},
journal={arXiv preprint arXiv:2405.13954},
year={2024}
}
License
LogIX is licensed under the Apache 2.0 License.
Related Skills
node-connect
344.1kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
96.8kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
344.1kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
344.1kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
