ImodelsX
Interpret text data with LLMs (sklearn compatible).
Install / Use
/learn @csinva/ImodelsXREADME
Explainable modeling/steering
| Model | Reference | Output | Description | | :-------------------------- | ------------------------------------------------------------ | ------- | ------------------------------------------------------------ | | Tree-Prompt | 🗂️, 🔗, 📄, 📖, | Explanation<br/>+ Steering | Generates a tree of prompts to<br/>steer an LLM (Official) | | iPrompt | 🗂️, 🔗, 📄, 📖 | Explanation<br/>+ Steering | Generates a prompt that<br/>explains patterns in data (Official) | | AutoPrompt | ㅤㅤ🗂️, 🔗, 📄 | Explanation<br/>+ Steering | Find a natural-language prompt<br/>using input-gradients| | D3 | 🗂️, 🔗, 📄, 📖 | Explanation | Explain the difference between two distributions | | SASC | ㅤㅤ🗂️, 🔗, 📄 | Explanation | Explain a black-box text module<br/>using an LLM (Official) | | Aug-Linear | 🗂️, 🔗, 📄, 📖 | Linear model | Fit better linear model using an LLM<br/>to extract embeddings (Official) | | Aug-Tree | 🗂️, 🔗, 📄, 📖 | Decision tree | Fit better decision tree using an LLM<br/>to expand features (Official) | | QAEmb | 🗂️, 🔗, 📄, 📖 | Explainable<br/>embedding | Generate interpretable embeddings<br/>by asking LLMs questions (Official) | | KAN | 🗂️, 🔗, 📄, 📖 | Small<br/>network | Fit 2-layer Kolmogorov-Arnold network |
<p align="center"> <a href="https://github.com/csinva/imodelsX/tree/master/demo_notebooks">📖</a>Demo notebooks   <a href="https://csinva.io/imodelsX/">🗂️</a> Doc   🔗 Reference code   📄 Research paper </br> ⌛ We plan to support other interpretable algorithms like <a href="https://arxiv.org/abs/2205.12548">RLPrompt</a>, <a href="https://arxiv.org/abs/2007.04612">CBMs</a>, and <a href="https://arxiv.org/abs/2004.00221">NBDT</a>. If you want to contribute an algorithm, feel free to open a PR 😄 </p>General utilities
| Model | Reference | | :-------------------------- | ------------------------------------------------------------ | | 🗂️ LLM wrapper| Easily call different LLMs | | 🗂️ Dataset wrapper| Download minimially processed huggingface datasets | | 🗂️ Bag of Ngrams | Learn a linear model of ngrams | | 🗂️ Linear Finetune | Finetune a single linear layer on top of LLM embeddings |
Quickstart
Installation: pip install imodelsx (or, for more control, clone and install from source)
Demos: see the demo notebooks
Natural-language explanations
Tree-prompt
from imodelsx import TreePromptClassifier
import datasets
import numpy as np
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
# set up data
rng = np.random.default_rng(seed=42)
dset_train = datasets.load_dataset('rotten_tomatoes')['train']
dset_train = dset_train.select(rng.choice(
len(dset_train), size=100, replace=False))
dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(rng.choice(
len(dset_val), size=100, replace=False))
# set up arguments
prompts = [
"This movie is",
" Positive or Negative? The movie was",
" The sentiment of the movie was",
" The plot of the movie was really",
" The acting in the movie was",
]
verbalizer = {0: " Negative.", 1: " Positive."}
checkpoint = "gpt2"
# fit model
m = TreePromptClassifier(
checkpoint=checkpoint,
prompts=prompts,
verbalizer=verbalizer,
cache_prompt_features_dir=None, # 'cache_prompt_features_dir/gp2',
)
m.fit(dset_train["text"], dset_train["label"])
# compute accuracy
preds = m.predict(dset_val['text'])
print('\nTree-Prompt acc (val) ->',
np.mean(preds == dset_val['label'])) # -> 0.7
# compare to accuracy for individual prompts
for i, prompt in enumerate(prompts):
print(i, prompt, '->', m.prompt_accs_[i]) # -> 0.65, 0.5, 0.5, 0.56, 0.51
# visualize decision tree
plot_tree(
m.clf_,
fontsize=10,
feature_names=m.feature_names_,
class_names=list(verbalizer.values()),
filled=True,
)
plt.show()
iPrompt
from imodelsx import explain_dataset_iprompt, get_add_two_numbers_dataset
# get a simple dataset of adding two numbers
input_strings, output_strings = get_add_two_numbers_dataset(num_examples=100)
for i in range(5):
print(repr(input_strings[i]), repr(output_strings[i]))
# explain the relationship between the inputs and outputs
# with a natural-language prompt string
prompts, metadata = explain_dataset_iprompt(
input_strings=input_strings,
output_strings=output_strings,
checkpoint='EleutherAI/gpt-j-6B', # which language model to use
num_learned_tokens=3, # how long of a prompt to learn
n_shots=3, # shots per example
n_epochs=15, # how many epochs to search
verbose=0, # how much to print
llm_float16=True, # whether to load the model in float_16
)
--------
prompts is a list of found natural-language prompt strings
D3 (DescribeDistributionalDifferences)
from imodelsx import explain_dataset_d3
hypotheses, hypothesis_scores = explain_dataset_d3(
pos=positive_samples, # List[str] of positive examples
neg=negative_samples, # another List[str]
num_steps=100,
num_folds=2,
batch_size=64,
)
SASC
Here, we explain a module rather than a dataset
from imodelsx import explain_module_sasc
# a toy module that responds to the length of a string
mod = lambda str_list: np.array([len(s) for s in str_list])
# a toy dataset where the longest strings are animals
text_str_list = ["red", "blue", "x", "1", "2", "hippopotamus", "elephant", "rhinoceros"]
explanation_dict = explain_module_sasc(
text_str_list,
mod,
ngrams=1,
)
Aug-imodels
Use these just a like a scikit-learn model. During training, they fit better features via LLMs, but at test-time they are extremely fast and completely transparent.
from imodelsx import AugLinearClassifier, AugTreeClassifier, AugLinearRegressor, AugTreeRegressor
import datasets
import numpy as np
# set up data
dset = datasets.load_dataset('rotten_tomatoes')['train']
dset = dset.select(np.random.choice(len(dset), size=300, replace=False))
dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(np.random.choice(len(dset_val), size=300, replace=False))
# fit model
m = AugLinearClassifier(
checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',
ngrams=2, # use bigrams
)
m.fit(dset['text'], dset['label'])
# predict
preds = m.predict(dset_val['text'])
print('acc_val', np.mean(preds == dset_val['label']))
# interpret
print('Total ngram coefficients: ', len(m.coefs_dict_))
print('Most positive ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1], reverse=True)[:8]:
print('\t', k, round(v, 2))
print('Most negative ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1])[:8
