Simpletransformers
Transformers for Information Retrieval, Text Classification, NER, QA, Language Modelling, Language Generation, T5, Multi-Modal, and Conversational AI
Install / Use
/learn @ThilinaRajapakse/SimpletransformersREADME
Simple Transformers
This library is based on the Transformers library by HuggingFace. Simple Transformers lets you quickly train and evaluate Transformer models. Only 3 lines of code are needed to initialize, train, and evaluate a model.
Supported Tasks:
- Information Retrieval (Dense Retrieval)
- (Large) Language Models (Training, Fine-tuning, and Generation)
- Encoder Model Training and Fine-tuning
- Sequence Classification
- Token Classification (NER)
- Question Answering
- Language Generation
- T5 Model
- Seq2Seq Tasks
- Multi-Modal Classification
- Conversational AI
Citation
If you use Simple Transformers in your work, please cite:
@inproceedings{Rajapakse2024SimpleTransformers,
author = {Rajapakse, Thilina C. and Yates, Andrew and de Rijke, Maarten},
title = {Simple Transformers: Open-source for All},
booktitle = {Proceedings of the 2024 Annual International ACM SIGIR
Conference on Research and Development in Information
Retrieval in the Asia Pacific Region},
series = {SIGIR-AP 2024},
pages = {209--215},
year = {2024},
doi = {10.1145/3673791.3698412},
url = {https://doi.org/10.1145/3673791.3698412},
location = {Tokyo, Japan}
}
Table of contents
<!--ts--> <!--te-->Setup
With Conda
- Install
AnacondaorMinicondaPackage Manager from here - Create a new virtual environment and install packages.
$ conda create -n st python pandas tqdm
$ conda activate st
Using Cuda:
$ conda install pytorch>=1.6 cudatoolkit=11.0 -c pytorch
Without using Cuda
$ conda install pytorch cpuonly -c pytorch
- Install
simpletransformers.
$ pip install simpletransformers
Optional
- Install
WeightsandBiases(wandb) for tracking and visualizing training in a web browser.
$ pip install wandb
Usage
All documentation is now live at simpletransformers.ai
Simple Transformer models are built with a particular Natural Language Processing (NLP) task in mind. Each such model comes equipped with features and functionality designed to best fit the task that they are intended to perform. The high-level process of using Simple Transformers models follows the same pattern.
- Initialize a task-specific model
- Train the model with
train_model() - Evaluate the model with
eval_model() - Make predictions on (unlabelled) data with
predict()
However, there are necessary differences between the different models to ensure that they are well suited for their intended task. The key differences will typically be the differences in input/output data formats and any task specific features/configuration options. These can all be found in the documentation section for each task.
The currently implemented task-specific Simple Transformer models, along with their task, are given below.
| Task | Model |
| --------------------------------------------------------- | ------------------------------- |
| Binary and multi-class text classification | ClassificationModel |
| Conversational AI (chatbot training) | ConvAIModel |
| Language generation | LanguageGenerationModel |
| Language model training/fine-tuning | LanguageModelingModel |
| Multi-label text classification | MultiLabelClassificationModel |
| Multi-modal classification (text and image data combined) | MultiModalClassificationModel |
| Named entity recognition | NERModel |
| Question answering | QuestionAnsweringModel |
| Regression | ClassificationModel |
| Sentence-pair classification | ClassificationModel |
| Text Representation Generation | RepresentationModel |
| Document Retrieval | RetrievalModel |
- Please refer to the relevant section in the docs for more information on how to use these models.
- Example scripts can be found in the examples directory.
- See the Changelog for up-to-date changes to the project.
A quick example
from simpletransformers.classification import ClassificationModel, ClassificationArgs
import pandas as pd
import logging
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)
# Preparing train data
train_data = [
["Aragorn was the heir of Isildur", 1],
["Frodo was the heir of Isildur", 0],
]
train_df = pd.DataFrame(train_data)
train_df.columns = ["text", "labels"]
# Preparing eval data
eval_data = [
["Theoden was the king of Rohan", 1],
["Merry was the king of Rohan", 0],
]
eval_df = pd.DataFrame(eval_data)
eval_df.columns = ["text", "labels"]
# Optional model configuration
model_args = ClassificationArgs(num_train_epochs=1)
# Create a ClassificationModel
model = ClassificationModel(
"roberta", "roberta-base", args=model_args
)
# Train the model
model.train_model(train_df)
# Evaluate the model
result, model_outputs, wrong_predictions = model.eval_model(eval_df)
# Make predictions with the model
predictions, raw_outputs = model.predict(["Sam was a Wizard"])
Experiment Tracking with Weights and Biases
- Weights and Biases makes it incredibly easy to keep track of all your experiments. Check it out on Colab here:
Current Pretrained Models
For a list of pretrained models, see Hugging Face docs.
The model_types available for each task can be found under their respective section. Any pretrained model of that type
found in the Hugging Face docs should work. To use any of them set the correct model_type and model_name in the args
dictionary.
Contributors ✨
Thanks goes to these wonderful people (emoji key):
<!-- ALL-CONTRIBUTORS-LIST:START - Do not remove or modify this section --> <!-- prettier-ignore-start --> <!-- markdownlint-disable --> <table> <tbody> <tr> <td align="center"><a href="https://github.com/hawktang"><img src="https://avatars0.githubusercontent.com/u/2004071?v=4?s=100" width="100px;" alt=""/><br /><sub><b>hawktang</b></sub></a><br /><a href="https://github.com/ThilinaRajapakse/simpletransformers/commits?author=hawktang" title="Code">💻</a></td> <td align="center"><a href="http://datawizzards.io"><img src="https://avatars0.githubusercontent.com/u/22409996?v=4?s=100" width="100px;" alt=""/><br /><sub><b>Mabu Manaileng</b></sub></a><br /><a href="https://github.com/ThilinaRajapakse/simpletransformers/commits?author=mabu-dev" title="Code">💻</a></td> <td align="center"><a href="https://www.facebook.com/aliosm97"><img src="https://avatars3.githubusercontent.com/u/7662492?v=4?s=100" width="100px;" alt=""/><br /><sub><b>Ali Hamdi Ali Fadel</b></sub></a><br /><a href="https://github.com/ThilinaRajapakse/simpletransformers/commits?author=AliOsm" title="Code">💻</a></td> <td align="center"><a href="http://tovly.co"><img src="https://avatars0.githubusercontent.com/u/12242351?v=4?s=100" width="100px;" alt=""/><br /><sub><b>Tovly Deutsch</b></sub></a><br /><a href="https://github.com/ThilinaRajapakse/simpletransformers/commits?author=TovlyDeutsch" title="Code">💻</a></td> <td align="center"><a href="https://github.com/hlo-world"><img src="https://avatars0.githubusercontent.com/u/9633055?v=4?s=100" width="100px;" alt=""/><br /><sub><b>hlo-world</b></sub></a><br /><a href="https://github.com/ThilinaRajapakse/simpletransformers/commits?author=hlo-world" title="Code">💻</a></td> <td align="center"><a href="https://github.com/huntertl"><img src="https://avatars1.githubusercontent.com/u/15113885?v=4?s=100" width="100px;" alt=""/><br /><sub><b>huntertl</b></sub></a><br /><a href="https://github.com/ThilinaRajapakse/simpletransformers/commits?author=huntertl" title="Code">💻</a></td> <td align="center"><a href="https://whattheshot.com"><img src="https://avatars2.githubusercontent.com/u/623763?v=4?s=100" width="100px;" alt=""/><br /Related Skills
node-connect
338.7kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
83.6kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
338.7kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
commit-push-pr
83.6kCommit, push, and open a PR
