SkillAgentSearch skills...

VardaGPT

Associative memory-enhanced GPT-2 model

Install / Use

/learn @ixaxaar/VardaGPT
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

VardaGPT

<!-- START doctoc generated TOC please keep comment here to allow auto update --> <!-- DON'T EDIT THIS SECTION, INSTEAD RE-RUN doctoc TO UPDATE --> <!-- END doctoc generated TOC please keep comment here to allow auto update -->

VardaGPT is a memory-enhanced GPT-2 model powered by Hugging Face Transformers and FAISS. Inspired by J.R.R. Tolkien's Silmarillion, VardaGPT aims to provide guidance and knowledge through its memory-augmented text generation capabilities.

TLDR - Training

The VardaGPTAssociative model combines GPT-2 with an associative memory to improve context retrieval. This repository includes a script to train this model on the WikiText-2 dataset.

Requirements

  • Python 3.7+
  • PyTorch 1.8.1+
  • torchtext 0.9.1
  • transformers 4.10.0
  • rich 10.3.0
  • faiss-cpu 1.7.1

To install the required packages, you can use the following command:

pip install -r requirements.txt

Usage

To train the VardaGPTAssociative model on the WikiText-2 dataset, use the provided training script (train_varda_gpt_associative.py). You can customize the training settings by passing command-line arguments. Here's a basic example:

python train_varda_gpt_associative.py --epochs 5 --learning_rate 1e-4 --use_gpu

Available command-line arguments:

  • --epochs: Number of epochs to train the model (default: 5).
  • --learning_rate: Learning rate for the optimizer (default: 1e-4).
  • --memory_size: Maximum number of items the associative memory can store (default: 10000).
  • --memory_dim: Dimensionality of the embeddings stored in the associative memory (default: 768).
  • --index_type: Type of index used for the associative memory (default: "flat").
  • --num_clusters: Number of clusters to use for the memory if the index type is "ivf" (default: 1024).
  • --num_search_results: Number of search results to return from the associative memory (default: 5).
  • --use_gpu: Whether to use the GPU for the model if available (default: False).
  • --batch_size: Batch size for training (default: 1).
  • --forgetfulness_factor: Forgetfulness factor for the associative memory (default: 0.001).

During training, the script will periodically print the training loss, validation loss, and elapsed time for each epoch, along with a progress bar for each training step.

After training, you can use the trained model for your specific use case, such as text generation or fine-tuning for a particular task.

Overview

<details> <summary>Click me</summary>
@startuml
!define AWSPUML https://raw.githubusercontent.com/awslabs/aws-icons-for-plantuml/v14.0

actor User

skinparam component {
  BackgroundColor<<Data Preparation>> LightSkyBlue
  BackgroundColor<<FAISS Memory>> Plum
  BackgroundColor<<GPT-2 Adaptation>> LightGreen
  BackgroundColor<<Training>> LightSalmon
  BackgroundColor<<Inference>> LightCoral
  BorderColor Black
  FontName Arial
}

package "VardaGPT" {
  [Data Preparation]<<Data Preparation>> --> [FAISS Memory]<<FAISS Memory>>
  [Data Preparation]<<Data Preparation>> --> [GPT-2 Adaptation]<<GPT-2 Adaptation>>

  [FAISS Memory]<<FAISS Memory>> --> [GPT-2 Adaptation]<<GPT-2 Adaptation>>
  [GPT-2 Adaptation]<<GPT-2 Adaptation>> --> [Training]<<Training>>

  [Training]<<Training>> --> [Inference]<<Inference>>
  [FAISS Memory]<<FAISS Memory>> --> [Inference]<<Inference>>

  User --> [Data Preparation]<<Data Preparation>> : Dataset
  User --> [Inference]<<Inference>> : Prompts
}

@enduml
</details>

overview

This diagram shows the main components of the VardaGPT project and their interactions. The Data Preparation component processes the dataset and feeds it to both the FAISS Memory Model and the GPT-2 Model Adaptation component. The FAISS Memory Model generates embeddings, which are used by the GPT-2 Model Adaptation component to create a modified GPT-2 model. The modified GPT-2 model is then trained and evaluated, and the final trained model is used in the Inference and Application component. The user provides the dataset and prompts for text generation.

Models

The associative memory model:

<details> <summary>Click me</summary>
@startuml

rectangle "Input Vectors" as input #b3e0ff
rectangle "Memory" as memory #f2d7b9
rectangle "Concatenated Input" as concatenated_input #f6e3c6
rectangle "Fully Connected Layer (fc)" as fc #e5ebf0
rectangle "GPT-2 Transformer" as transformer #c6e0b4
rectangle "GPT-2 LM Head" as lm_head #c9daf8
rectangle "Fully Connected Layer\n(fc_storable_vector)" as fc_storable_vector #c9daf8
rectangle "Fully Connected Layer\n(fc_store_decision)" as fc_store_decision #c9daf8

input -down-> memory : Perform search in memory
memory -down-> concatenated_input : Concatenate search results with input vectors
concatenated_input -down-> fc : Apply fully connected layer (fc)
fc -down-> transformer : Pass through GPT-2 transformer
transformer -down-> lm_head : Apply GPT-2 lm_head
transformer -right-> fc_storable_vector : Apply fully connected layer (fc_storable_vector)
transformer -right-> fc_store_decision : Apply fully connected layer (fc_store_decision)

note right of fc_storable_vector: Calculate storable vector\n and store decision
note right of fc_store_decision: Store the storable_vector in\n the associative memory if\n the store_decision is affirmative
note bottom of lm_head: Return logits

@enduml

</details>

model1

<details> <summary>Click me</summary>
@startuml
title Forward Function

!define Tensor(t,d) t + " (" + d + ")"
!define DEVICE "device"

actor "input_vectors" as input_vectors
actor "memory_input" as memory_input

note right of input_vectors
  Tensor:
  (batch_size, seq_len, embedding_dim)
end note

note right of memory_input
  Tensor (optional):
  (batch_size, seq_len, embedding_dim)
end note

input_vectors -> DEVICE
memory_input -> DEVICE

DEVICE -> "search(memory_input)" as search
search --> "indices, distances" as search_result
note right of search_result
  Tensors:
  indices: (batch_size, seq_len, num_search_results)
  distances: (batch_size, seq_len, num_search_results)
end note

search_result -> "get_all_embeddings()" as all_embeddings
note right of all_embeddings
  Tensor:
  (memory_size, embedding_dim)
end note

all_embeddings -> "search_results" as search_results
note right of search_results
  Tensor:
  (batch_size, seq_len, search_results_dim)
end note

search_results --> "concatenate(input_vectors, search_results)" as concatenated_input
note right of concatenated_input
  Tensor:
  (batch_size, seq_len, embedding_dim + search_results_dim)
end note

concatenated_input --> "self.fc(concatenated_input)" as fc_output
note right of fc_output
  Tensor:
  (batch_size, seq_len, embedding_dim)
end note

fc_output --> "self.gpt2_model.transformer(inputs_embeds=input_vectors)" as transformer_outputs
transformer_outputs --> "hidden_states" as hidden_states
note right of hidden_states
  Tensor:
  (batch_size, seq_len, embedding_dim)
end note

hidden_states --> "self.gpt2_model.lm_head(hidden_states)" as logits
note right of logits
  Tensor:
  (batch_size, seq_len, vocab_size)
end note

hidden_states --> "self.fc_storable_vector(hidden_states)" as storable_vector
note right of storable_vector
  Tensor:
  (batch_size, seq_len, memory_dim)
end note

hidden_states --> "self.fc_store_decision(hidden_states)" as store_decision
note right of store_decision
  Tensor:
  (batch_size, seq_len, 1)
end note

hidden_states --> "self.fc_delete_decision(hidden_states)" as delete_decision
note right of delete_decision
  Tensor:
  (batch_size, seq_len, num_search_results)
end note

hidden_states --> "self.fc_deletable_vector(hidden_states)" as deletable_vector
note right of deletable_vector
  Tensor:
  (batch_size, seq_len, memory_dim)
end note

storable_vector --> "self.memory.add(storable_vector_to_store)" as add_memory

deletable_vector --> "calculate L2 distances" as l2_distances
note right of l2_distances
  Tensor:
  (batch_size, num_search_results)
end note

l2_distances --> "threshold comparison" as threshold_comparison
note right of threshold_comparison
  Tensor (bool):
  (batch_size, num_search_results)
end note

threshold_comparison --> "self.memory.remove(indices_to_delete_flat)" as remove_memory

logits --> "return logits" as return_logits

@enduml
</details>

model

Training, Evaluation, and Fine-tuning Process

<details> <summary>Click me</summary>
@startuml

skinparam activity {
  BackgroundColor LightSkyBlue
  BorderColor Black
  FontName Arial
}

start

:Data Preparation;

partition "FAISS Memory Model" {
  :Create FAISS Index;
  :Encode and Decode Text Data;
  :Test FAISS Index;
}

partition "GPT-2 Model Adaptation" {
  :Load Pre-trained GPT-2 Model;
  :Modify GPT-2 Architecture;
  :Define Custom Loss Function;
}

partition "Training" {
  :Train Adapted GPT-2 Model;
  :Save Model Checkpoints;
}

partition "Evaluation" {
  :Evaluate
View on GitHub
GitHub Stars336
CategoryDevelopment
Updated4mo ago
Forks8

Languages

Python

Security Score

77/100

Audited on Dec 2, 2025

No findings