Cde
code for training & evaluating Contextual Document Embedding models
Install / Use
/learn @jxmorris12/CdeREADME
Contextual Document Embeddings (CDE)
This repository contains the training and evaluation code we used to produce cde-small-v1, our state-of-the-art small text embedding model. This includes the code for:
- efficient state-of-the-art contrastive training of retrieval models
- our custom two-stage model architecture that embeds contextual tokens and uses them in downstream embeddings
- a two-stage gradient caching technique that enables training our two-headed model efficiently
- clustering large datasets and caching the clusters
- packing clusters and sampling from them, even in distributed settings
- on-the-fly filter for clusters based on a pretrained model
- more!
cde naturally integrates "context tokens" into the embedding process. As of October 1st, 2024, cde-small-v1 is the best small model (under 400M params) on the MTEB leaderboard for text embedding models, with an average score of 65.00.
👉 <b><a href="https://colab.research.google.com/drive/1r8xwbp7_ySL9lP-ve4XMJAHjidB9UkbL?usp=sharing">Try on Colab</a></b> <br> 👉 <b><a href="https://arxiv.org/abs/2410.02525">Contextual Document Embeddings (ArXiv)</a></b>

Install
install pytorch w/ cuda, install requirements:
uv pip install -r requirements.txt
then install FlashAttention:
uv pip install --no-cache-dir flash-attn --no-build-isolation git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/layer_norm git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/fused_dense_lib git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/xentropy
make sure ninja is installed first (uv pip install ninja) to make flash attention installation ~50x faster.
Example command
Here's an example command for pretraining a biencoder:
python finetune.py --per_device_train_batch_size 1024 --per_device_eval_batch_size 256 --use_wandb 1 --dataset nomic_unsupervised --sampling_strategy domain --num_train_epochs 3 --learning_rate 2e-05 --embedder nomic-ai/nomic-bert-2048 --clustering_model gtr_base --clustering_query_to_doc 1 --ddp_find_unused_parameters 0 --eval_rerank_topk 32 --lr_scheduler_type linear --warmup_steps 5600 --disable_dropout 1 --max_seq_length 32 --logging_steps 2000 --use_prefix 1 --save_steps 99999999999 --logit_scale 50 --max_eval_batches 16 --ddp_share_negatives_between_gpus 0 --torch_compile 0 --use_gc 1 --fp16 0 --bf16 1 --eval_steps 400000 --limit_layers 6 --max_batch_size_fits_in_memory 2048 --use_wandb 1 --ddp_find_unused_parameters 0 --arch biencoder --logging_steps 4 --train_cluster_size 2048 --max_seq_length 512 --max_batch_size_fits_in_memory 64 --dataset nomic_unsupervised --exp_group 2024-10-30-biencoder-test --exp_name 2024-10-30-biencoder-pretrain-example
How to use cde-small-v1
Our embedding model needs to be used in two stages. The first stage is to gather some dataset information by embedding a subset of the corpus using our "first-stage" model. The second stage is to actually embed queries and documents, conditioning on the corpus information from the first stage. Note that we can do the first stage part offline and only use the second-stage weights at inference time.
</details>With Transformers
<details> <summary>Click to learn how to use cde-small-v1 with Transformers</summary>Loading the model
Our model can be loaded using transformers out-of-the-box with "trust remote code" enabled. We use the default BERT uncased tokenizer:
import transformers
model = transformers.AutoModel.from_pretrained("jxm/cde-small-v1", trust_remote_code=True)
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")
Note on prefixes
Nota bene: Like all state-of-the-art embedding models, our model was trained with task-specific prefixes. To do retrieval, you can prepend the following strings to queries & documents:
query_prefix = "search_query: "
document_prefix = "search_document: "
First stage
minicorpus_size = model.config.transductive_corpus_size
minicorpus_docs = [ ... ] # Put some strings here that are representative of your corpus, for example by calling random.sample(corpus, k=minicorpus_size)
assert len(minicorpus_docs) == minicorpus_size # You must use exactly this many documents in the minicorpus. You can oversample if your corpus is smaller.
minicorpus_docs = tokenizer(
[document_prefix + doc for doc in minicorpus_docs],
truncation=True,
padding=True,
max_length=512,
return_tensors="pt"
).to(model.device)
import torch
from tqdm.autonotebook import tqdm
batch_size = 32
dataset_embeddings = []
for i in tqdm(range(0, len(minicorpus_docs["input_ids"]), batch_size)):
minicorpus_docs_batch = {k: v[i:i+batch_size] for k,v in minicorpus_docs.items()}
with torch.no_grad():
dataset_embeddings.append(
model.first_stage_model(**minicorpus_docs_batch)
)
dataset_embeddings = torch.cat(dataset_embeddings)
Running the second stage
Now that we have obtained "dataset embeddings" we can embed documents and queries like normal. Remember to use the document prefix for documents:
docs = tokenizer(
[document_prefix + doc for doc in docs],
truncation=True,
padding=True,
max_length=512,
return_tensors="pt"
).to(model.device)
with torch.no_grad():
doc_embeddings = model.second_stage_model(
input_ids=docs["input_ids"],
attention_mask=docs["attention_mask"],
dataset_embeddings=dataset_embeddings,
)
doc_embeddings /= doc_embeddings.norm(p=2, dim=1, keepdim=True)
and the query prefix for queries:
queries = queries.select(range(16))["text"]
queries = tokenizer(
[query_prefix + query for query in queries],
truncation=True,
padding=True,
max_length=512,
return_tensors="pt"
).to(model.device)
with torch.no_grad():
query_embeddings = model.second_stage_model(
input_ids=queries["input_ids"],
attention_mask=queries["attention_mask"],
dataset_embeddings=dataset_embeddings,
)
query_embeddings /= query_embeddings.norm(p=2, dim=1, keepdim=True)
these embeddings can be compared using dot product, since they're normalized.
</details>What if I don't know what my corpus will be ahead of time?
If you can't obtain corpus information ahead of time, you still have to pass something as the dataset embeddings; our model will work fine in this case, but not quite as well; without corpus information, our model performance drops from 65.0 to 63.8 on MTEB. We provide some random strings that worked well for us that can be used as a substitute for corpus sampling.
With Sentence Transformers
<details open=""> <summary>Click to learn how to use cde-small-v1 with Sentence Transformers</summary>Loading the model
Our model can be loaded using sentence-transformers out-of-the-box with "trust remote code" enabled:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("jxm/cde-small-v1", trust_remote_code=True)
Note on prefixes
Nota bene: Like all state-of-the-art embedding models, our model was trained with task-specific prefixes. To do retrieval, you can use prompt_name="query" and prompt_name="document" in the encode method of the model when embedding queries and documents, respectively.
First stage
minicorpus_size = model[0].config.transductive_corpus_size
minicorpus_docs = [ ... ] # Put some strings here that are representative of your corpus, for example by calling random.sample(corpus, k=minicorpus_size)
assert len(minicorpus_docs) == minicorpus_size # You must use exactly this many documents in the minicorpus. You can oversample if your corpus is smaller.
dataset_embeddings = model.encode(
minicorpus_docs,
prompt_name="document",
convert_to_tensor=True
)
Running the second stage
Now that we have obtained "dataset embeddings" we can embed documents and queries like normal. Remember to use the document prompt for documents:
docs = [...]
queries = [...]
doc_embeddings = model.encode(
docs,
prompt_name="document",
dataset_embeddings=dataset_embeddings,
convert_to_tensor=True,
)
query_embeddings = model.encode(
queries,
prompt_name="query",
dataset_embeddings=dataset_embeddings,
convert_to_tensor=True,
)
these embeddings can be compared using cosine similarity via model.similarity:
similarities = model.similarity(query_embeddings, doc_embeddings)
topk_values, topk_indices = similarities.topk(5)
<details>
<summary>Click here for a full copy-paste ready example</summary>
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
# 1. Load the Sentence Transformer model
model = SentenceTransformer("jxm/cde-small-v1", trust_remote_code=True)
context_docs_size = model[0].config.transductive_corpus_size # 512
# 2. Load the dataset: context dataset, docs, and queries
dataset = load_dataset("sentence-transformers/natural-questions", split="train")
dataset.shuffle(seed=42)
# 10 queries, 512 context docs, 500 docs
queries = dataset["query"][:10]
docs = dataset["answer"][:2000]
context_docs = dataset["answer"][-context_docs_size:] # Last 512 docs
# 3. First stage: embed the context docs
dataset_embeddings = model.encode(
context_docs,
prompt_name="document",
convert_to_tensor=True,
)
# 4. Second stage: embed the docs and queries
doc_embeddings = model.encode(
docs,
prompt_name="document",
dataset_embeddings=dataset_embeddings,
convert_to_tensor=True,
)
quer
