Sgpt
SGPT: GPT Sentence Embeddings for Semantic Search
Install / Use
/learn @Muennighoff/SgptREADME
SGPT: GPT Sentence Embeddings for Semantic Search
This repository contains code, results & pre-trained models for the paper SGPT: GPT Sentence Embeddings for Semantic Search.
**************************** Updates ****************************
- 2024-02: We released GRIT & GritLM - These models unify SGPT Bi-Encoders, Cross-Encoders, symmetric, asymmetric, and regular GPT (i.e. generation) all in 1 single model at much better performance on all accounts. We recommend switching to these new models :)
- 2022-09: SGPT Bi-Encoders are now easy to use with Sentence Transformers, see new scripts
- 2022-08: Multilingual BLOOM SGPT models were released: Asymmetric, 7.1B parameters & Symmetric, 1.7B parameters. Feel free to open an issue if you need a different model.
- 2022-06: OpenAI released the mechanism of their Search Endpoint that we compared to SGPT Cross-Encoders in the paper. Our methods are very similar. Feel free to test their prompt as seen in
crossencoder/beir/openai_search_endpoint_functionality.py! - 2022-03: 5.8B Bi-Encoder models are now 4% & 1% better on USEB & BEIR, respectively. Paper & models on HF have been updated. This has been done by using larger batch sizes with GradCache, see the paper for more info. If you have previously downloaded them, we recommend replacing it with the new version.
- 2022-02: We released our paper. Check it out! :)
Quick Links
- Overview
- Structure
- Use SGPT with Huggingface
- Use SGPT with Sentence Transformers
- Acknowledgements
- Citation
Overview
We present SGPT-BE and SGPT-CE for applying GPT models as Bi-Encoders or Cross-Encoders to symmetric or asymmetric search. SGPT-BE produces semantically meaningful sentence embeddings by contrastive fine-tuning of only bias tensors and position-weighted mean pooling. SGPT-CE uses log probabilities from GPT models without any fine-tuning. An illustration of the methods follows.

Feel free to open an issue should you have any questions~
Structure
.
├── biencoder # Training & Inference of Bi-Encoders
│ ├── beir
│ │ ├── custommodels # Directory providing BEIR compatibility for asymmetric mdoels & models with special tokens
│ │ │ └── ...
│ │ ├── io_utils # Exclusively used for beir_openai_embeddings_batched_parallel.py
│ │ │ └── ...
│ │ ├── parallelizer # Exclusively used for beir_openai_embeddings_batched_parallel.py
│ │ │ └── ...
│ │ ├── beir_dense_retriever.py
│ │ ├── beir_openai_embeddings_batched_parallel.py
│ │ ├── requirements.txt
│ │ ├── *.bash # Bash scripts to run multiple experiments
│ │ └── README.md
│ ├── nli_msmarco
│ │ ├── sentence-transformers # An adapted version of sentence-transformers - Install this version for all biencoder experiments
│ │ │ └── ...
│ │ └── README.md
│ └── useb
│ ├── useb
│ │ └── ...
│ ├── *.bash # Bash scripts to run multiple experiments
│ ├── useb_dense_retriever.py
│ └── README.md
├── crossencoder # Inference of Cross-Encoders
│ └── beir
│ ├── *.ipynb # Notebooks explained in the README
│ └── README.md
├── other
│ ├── sgpt_graphic.png
│ └── sgpt_utils.ipynb # Code for creating the graphs in the paper & other
├── requirements.txt
└── README.md
Each data sub-directory provides its own README with an overview of its Structure, Downloads (Datasets, Models) & Commands used to produce the datasets, models & other things. Generally, you can find all models at https://huggingface.co/Muennighoff and json results in various datasets at https://www.kaggle.com/muennighoff/datasets. Model names are explained in their Huggingface READMEs. Dataset names are explained in the sub-folders of this repository.
Use SGPT with Huggingface
Below we provide python examples to use the pre-trained models for your own semantic search use case.
We highly recommend replacing the model names with larger models, e.g. Muennighoff/SGPT-5.8B-weightedmean-nli-bitfit for biencoder/symmetric.
Bi-Encoder
Symmetric Semantic Search BE
import torch
from transformers import AutoModel, AutoTokenizer
from scipy.spatial.distance import cosine
# Get our models - The package will take care of downloading the models automatically
# For best performance: Muennighoff/SGPT-5.8B-weightedmean-nli-bitfit
tokenizer = AutoTokenizer.from_pretrained("Muennighoff/SGPT-125M-weightedmean-nli-bitfit")
model = AutoModel.from_pretrained("Muennighoff/SGPT-125M-weightedmean-nli-bitfit")
# Deactivate Dropout (There is no dropout in the above models so it makes no difference here but other SGPT models may have dropout)
model.eval()
# Tokenize input texts
texts = [
"deep learning",
"artificial intelligence",
"deep diving",
"artificial snow",
]
batch_tokens = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
# Get the embeddings
with torch.no_grad():
# Get hidden state of shape [bs, seq_len, hid_dim]
last_hidden_state = model(**batch_tokens, output_hidden_states=True, return_dict=True).last_hidden_state
# Get weights of shape [bs, seq_len, hid_dim]
weights = (
torch.arange(start=1, end=last_hidden_state.shape[1] + 1)
.unsqueeze(0)
.unsqueeze(-1)
.expand(last_hidden_state.size())
.float().to(last_hidden_state.device)
)
# Get attn mask of shape [bs, seq_len, hid_dim]
input_mask_expanded = (
batch_tokens["attention_mask"]
.unsqueeze(-1)
.expand(last_hidden_state.size())
.float()
)
# Perform weighted mean pooling across seq_len: bs, seq_len, hidden_dim -> bs, hidden_dim
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded * weights, dim=1)
sum_mask = torch.sum(input_mask_expanded * weights, dim=1)
embeddings = sum_embeddings / sum_mask
# Calculate cosine similarities
# Cosine similarities are in [-1, 1]. Higher means more similar
cosine_sim_0_1 = 1 - cosine(embeddings[0], embeddings[1])
cosine_sim_0_2 = 1 - cosine(embeddings[0], embeddings[2])
cosine_sim_0_3 = 1 - cosine(embeddings[0], embeddings[3])
print("Cosine similarity between \"%s\" and \"%s\" is: %.3f" % (texts[0], texts[1], cosine_sim_0_1))
print("Cosine similarity between \"%s\" and \"%s\" is: %.3f" % (texts[0], texts[2], cosine_sim_0_2))
print("Cosine similarity between \"%s\" and \"%s\" is: %.3f" % (texts[0], texts[3], cosine_sim_0_3))
Asymmetric Semantic Search BE
import torch
from transformers import AutoModel, AutoTokenizer
from scipy.spatial.distance import cosine
# Get our models - The package will take care of downloading the models automatically
# For best performance: Muennighoff/SGPT-5.8B-weightedmean-msmarco-specb-bitfit
tokenizer = AutoTokenizer.from_pretrained("Muennighoff/SGPT-125M-weightedmean-msmarco-specb-bitfit")
model = AutoModel.from_pretrained("Muennighoff/SGPT-125M-weightedmean-msmarco-specb-bitfit")
# Deactivate Dropout (There is no dropout in the above models so it makes no difference here but other SGPT models may have dropout)
model.eval()
queries = [
"I'm searching for a planet not too far from Earth.",
]
docs = [
"Neptune is the eighth and farthest-known Solar planet from the Sun. In the Solar System, it is the fourth-largest planet by diameter, the third-most-massive planet, and the densest giant planet. It is 17 times the mass of Earth, slightly more massive than its near-twin Uranus.",
"TRAPPIST-1d, also designated as 2MASS J23062928-0502285 d, is a small exoplanet (about 30% the mass of the earth), which orbits on the inner edge of the habitable zone of the ultracool dwarf star TRAPPIST-1 approximately 40 light-years (12.1 parsecs, or nearly 3.7336×1014 km) away from Earth in the constellation of Aquarius.",
"A harsh desert world orbiting twin suns in the galaxy’s Outer Rim, Tatooine is a lawless place ruled by Hutt gangsters. Many settlers scratch out a living on moisture farms, while spaceport cities such as Mos Eisley and Mos Espa serve as home base for smugglers, criminals, and other rogues.",
]
SPECB_QUE_BOS = tokenizer.encode("[", add_special_tokens=False)[0]
SPECB_QUE_EOS = tokenizer.encode("]", add_special_tokens=False)[0]
SPECB_DOC_BOS = tokenizer.encode("{", add_special_tokens=False)[0]
SPECB_DOC_EOS = tokenizer.encode("}", add_special_tokens=False)[0]
def tokenize_with_specb(texts, is_query):
# Tokenize without padding
batch_tokens = tokenizer(texts, padding=False, truncation=True)
# Add special brackets & pay attention to them
for seq, att in zip(batch_tokens["input_ids"], batch_tokens["attention_mask"]):
if is_query:
seq.insert(0, SPECB_QUE_BOS)
seq.append(SPECB_QUE_EOS)
el
