MetaTree
Official implementation of MetaTree: Learning a Decision Tree Algorithm with Transformers
Install / Use
/learn @EvanZhuang/MetaTreeREADME
Quickstart -- use MetaTree to generate decision tree models
Model is available at https://huggingface.co/yzhuang/MetaTree
- Install
metatreelib:
pip install metatreelib
# Alternatively,
# clone then pip install -e .
# pip install git+https://github.com/EvanZhuang/MetaTree
- Use MetaTree on your datasets to generate a decision tree model
from metatree.model_metatree import LlamaForMetaTree as MetaTree
from metatree.decision_tree_class import DecisionTree, DecisionTreeForest
from metatree.run_train import preprocess_dimension_patch
from transformers import AutoConfig
import imodels # pip install imodels
# Initialize Model
model_name_or_path = "yzhuang/MetaTree"
config = AutoConfig.from_pretrained(model_name_or_path)
model = MetaTree.from_pretrained(
model_name_or_path,
config=config,
)
decision_tree_forest = DecisionTreeForest()
# Load Datasets
X, y, feature_names = imodels.get_clean_dataset('fico', data_source='imodels')
print("Dataset Shapes X={}, y={}, Num of Classes={}".format(X.shape, y.shape, len(set(y))))
train_idx, test_idx = sklearn.model_selection.train_test_split(range(X.shape[0]), test_size=0.3, random_state=seed)
# Dimension Subsampling
feature_idx = np.random.choice(X.shape[1], 10, replace=False)
X = X[:, feature_idx]
test_X, test_y = X[test_idx], y[test_idx]
# Sample Train and Test Data
subset_idx = random.sample(train_idx, 256)
train_X, train_y = X[subset_idx], y[subset_idx]
input_x = torch.tensor(train_X, dtype=torch.float32)
input_y = torch.nn.functional.one_hot(torch.tensor(train_y)).float()
batch = {"input_x": input_x, "input_y": input_y, "input_y_clean": input_y}
batch = preprocess_dimension_patch(batch, n_feature=10, n_class=10)
model.depth = 2
outputs = model.generate_decision_tree(batch['input_x'], batch['input_y'], depth=model.depth)
decision_tree_forest.add_tree(DecisionTree(auto_dims=outputs.metatree_dimensions, auto_thresholds=outputs.tentative_splits, input_x=batch['input_x'], input_y=batch['input_y'], depth=model.depth))
print("Decision Tree Features: ", [x.argmax(dim=-1) for x in outputs.metatree_dimensions])
print("Decision Tree Thresholds: ", outputs.tentative_splits)
- Inference with the decision tree model
tree_pred = decision_tree_forest.predict(torch.tensor(test_X, dtype=torch.float32))
accuracy = accuracy_score(test_y, tree_pred.argmax(dim=-1).squeeze(0))
print("MetaTree Test Accuracy: ", accuracy)
Example Usage
We show a complete example of using MetaTree at notebook
Questions?
If you have any questions related to the code or the paper, feel free to reach out to us at y5zhuang@ucsd.edu.
Citation
If you find our paper and code useful, please cite us:
@misc{zhuang2024learning,
title={Learning a Decision Tree Algorithm with Transformers},
author={Yufan Zhuang and Liyuan Liu and Chandan Singh and Jingbo Shang and Jianfeng Gao},
year={2024},
eprint={2402.03774},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Related Skills
proje
Interactive vocabulary learning platform with smart flashcards and spaced repetition for effective language acquisition.
YC-Killer
2.7kA library of enterprise-grade AI agents designed to democratize artificial intelligence and provide free, open-source alternatives to overvalued Y Combinator startups. If you are excited about democratizing AI access & AI agents, please star ⭐️ this repository and use the link in the readme to join our open source AI research team.
best-practices-researcher
The most comprehensive Claude Code skills registry | Web Search: https://skills-registry-web.vercel.app
research_rules
Research & Verification Rules Quote Verification Protocol Primary Task "Make sure that the quote is relevant to the chapter and so you we want to make sure that we want to have it identifie
