Retrieval Augmentation Generation(RAG)
Limitations of LLMs:
Trained on a static dataset, llms lack real-time information and updates. They can only generate responses based on the knowledge available up to the last training cut-off date.
Limited by their training data size and capacity, which can lead to missing niche or less frequent information.
Struggle with long context management due to fixed input lengths. They can’t retain or recall long sequences of information effectively.
Prone to generating plausible-sounding but incorrect or misleading information, a phenomenon known as hallucination.
May lack specialized knowledge for niche domains, as they are trained on a broad range of topics.
Can be computationally expensive to fine-tune or retrain for specific tasks or updated information.
How RAG solves these issues?
RAG incorporates real-time data retrieval, allowing the model to access up-to-date information from external sources. This dynamic retrieval enhances the relevance and accuracy of the responses.
It uses external databases and documents to augment the model’s memory, thus providing information beyond the model’s internal knowledge.
It retrieves relevant documents or data chunks on-demand, effectively bypassing the context length limitation by dynamically providing pertinent information as needed.
RAG reduces hallucinations by grounding responses in retrieved documents, ensuring that the information is more reliable and accurate.
It allows for the incorporation of domain-specific databases and documents, enhancing the model’s capability to handle specialized queries accurately.
It avoids the need for frequent retraining by using retrieval mechanisms to update knowledge, which is more cost-effective.
General RAG Architecture
How it works?
1. Documents:These are the input documents (e.g., PDFs, web pages) that contain the information the system will use to answer questions.Example: A collection of scientific papers on climate change.
2. Split the documents into smaller chunks:Large documents are divided into smaller, manageable pieces to facilitate processing and retrieval.A 50-page paper might be split into 25 two-page chunks.
3. Use LLMs to convert documents to Vector Embeddings:Each document chunk is converted into a numerical vector representation using an LLM.A paragraph about rising sea levels might be converted into a 768-dimensional vector.
4. Vector Database:These vector embeddings are stored in a specialized database optimized for similarity searches.Pinecone or Faiss or similar vector database could be used to store and quickly retrieve similar vectors.
5. Query - Questions to ask:This represents the user’s input, typically a question they want answered.”What are the projected effects of climate change on coastal cities by 2050?” could be the query.
6. Use LLMs to convert Query to Vector Embeddings:The user’s question is also converted into a vector embedding using the same process as the documents.The question about coastal cities would be transformed into a vector representation.
7. Search the similar/relevant top K documents:The system searches the vector database for document chunks most similar to the query vector.Example: It might find 5 document chunks discussing sea level rise and urban planning.
8. Retrieved Context based on similarity:The most relevant document chunks are retrieved from the database.Text excerpts from climate reports discussing coastal city impacts.
9. Query + Context from documents (Prompt Template):The original query is combined with the retrieved context into a prompt for the LLM.A prompt combining the question about coastal cities with relevant excerpts from the retrieved documents.
10. LLM:The final large language model processes the combined query and context to generate an answer.GPT-4 or a similar model could be used to synthesize the information and answer the question.
11. Answer from LLM:The system provides a response based on the retrieved context and the LLM’s understanding.
Code Example
I have created a simple RAG application to chat with “Constitution of Nepal” with Llama 3.1 which is open source recently launched model by Meta.
- Initializing Language Model and Embeddings with Ollama: The code initializes a language model and its embeddings using the Ollama library with a specified model.
= "llama3.1:8b"
MODEL = Ollama(model=MODEL)
model = OllamaEmbeddings(model=MODEL) embeddings
- Loading PDF Document into Text Documents with PyPDFLoader: The code loads a PDF document titled “constitution.pdf” into a collection of text documents using the PyPDFLoader class.
= PyPDFLoader("constitution.pdf")
loader = loader.load() documents
- Splitting Documents into Chunks with RecursiveCharacterTextSplitter: The code configures a RecursiveCharacterTextSplitter with a chunk size of 1000 characters and an overlap of 200 characters, then splits the loaded PDF documents into these text chunks.
= 1000
chunk_size
= 200
chunk_overlap
= RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
text_splitter
= text_splitter.split_documents(documents)
chunks
- Generating and Validating Embeddings for Document Chunks: The code iterates through text chunks, embedding each chunk using the OllamaEmbeddings model, and stores the embeddings in a list if they are valid. It handles exceptions and prints errors for invalid chunks. Finally, it stacks the embeddings into a NumPy array and prints the shape of the resulting array, or prints a message if no embeddings were generated.
= []
chunk_embeddings
for chunk in chunks:
try:
# Ensure each chunk is a string
if isinstance(chunk.page_content, str):
= embeddings.embed_query(chunk.page_content)
embedding if embedding and isinstance(embedding, list) and len(embedding) > 0:
= np.array(embedding, dtype='float32')
embedding
chunk_embeddings.append(embedding)else:
print(f"Empty or invalid embedding for chunk: {chunk.page_content[:50]}...")
else:
print(f"Chunk content is not a string: {type(chunk.page_content)}")
except Exception as e:
print(f"Error embedding chunk: {e}")
if chunk_embeddings:
= np.vstack(chunk_embeddings)
embeddings_array print(f"Embeddings array shape: {embeddings_array.shape}")
else:
print("No embeddings generated. Check the embedding process.")
= np.empty((0, 0), dtype='float32') embeddings_array
- Initializing FAISS Index with GPU or CPU Fallback: The code initializes variables for a FAISS index and checks if the embeddings array is non-empty. It attempts to create a GPU-based FAISS index using faiss.GpuIndexFlatL2. If successful, it adds the embeddings in batches to the GPU index and sets is_gpu_index to True. If GPU indexing fails, it falls back to creating a CPU-based FAISS index using faiss.IndexFlatL2 and adds the embeddings to this index instead. Any exceptions during these processes are caught and logged.
# Initialize FAISS index
= None
faiss_index = False
is_gpu_index
if embeddings_array.size > 0:
try:
= embeddings_array.shape[1] # Dimension of embeddings
dim
# Try GPU first
try:
= faiss.StandardGpuResources()
res = faiss.GpuIndexFlatConfig()
config = 0 # Use GPU 0
config.device = faiss.GpuIndexFlatL2(res, dim, config)
gpu_index
# Add embeddings in batches
= 100 # Adjust this based on your GPU memory
batch_size for i in range(0, embeddings_array.shape[0], batch_size):
= embeddings_array[i:i+batch_size]
batch
gpu_index.add(batch)
= gpu_index
faiss_index = True
is_gpu_index print("FAISS GPU index initialized and embeddings added.")
except Exception as e:
print(f"GPU indexing failed: {e}")
print("Falling back to CPU indexing.")
# CPU indexing
= faiss.IndexFlatL2(dim)
cpu_index
cpu_index.add(embeddings_array)= cpu_index
faiss_index print("FAISS CPU index initialized and embeddings added.")
except Exception as e:
print(f"Error initializing FAISS index: {e}")
else:
print("Empty embeddings array. Cannot initialize FAISS index.")
- Creating a FAISS-based Document Retriever Class: The FaissRetriever class integrates with FAISS for document retrieval. It initializes with an FAISS index, a list of document chunks, and an embedding model. The get_relevant_documents method calls the private _retrieve method to fetch documents. In _retrieve, the query is embedded into a vector, and FAISS is used to find the top k closest documents. The results are then assembled into Document objects using the indices from FAISS. Errors during retrieval are handled and logged. The asynchronous method aget_relevant_documents simply wraps the synchronous get_relevant_documents method for asynchronous use.
class FaissRetriever(BaseRetriever, BaseModel):
= Field(description="FAISS index")
index: Any = Field(description="List of document chunks")
chunks: List[Any] = Field(description="Embedding model")
embedding_model: Any
class Config:
= True
arbitrary_types_allowed
def __init__(self, **data):
super().__init__(**data)
print(f"FaissRetriever initialized with index: {self.index is not None}, chunks: {len(self.chunks)}, embedding_model: {self.embedding_model is not None}")
def get_relevant_documents(self, query: str):
return self._retrieve(query)
def _retrieve(self, query, k=5):
try:
print(f"Query: {query}")
= self.embedding_model.embed_query(query)
query_embedding = np.array(query_embedding).astype('float32').reshape(1, -1)
query_embedding print(f"Query embedding shape: {query_embedding.shape}")
= self.index.search(query_embedding, k)
distances, indices = [Document(page_content=self.chunks[i].page_content, metadata=self.chunks[i].metadata) for i in indices[0]]
results return results
except Exception as e:
print(f"Error retrieving documents: {e}")
return []
async def aget_relevant_documents(self, query: str):
return self.get_relevant_documents(query)
- Running QA Chain with FAISS Index for Query Processing: The code checks if the FAISS index (faiss_index) is initialized. If so, it creates a FaissRetriever instance with the FAISS index, document chunks, and embedding model. It then sets up a RetrievalQA chain using this retriever and a language model. The QA chain is used to query information about the office of the prime minister, and the resulting answer is printed. If an error occurs during execution, it is caught, logged, and the full traceback is printed. If the FAISS index is not initialized, a message is displayed indicating that the QA chain cannot be run.
if faiss_index is not None:
try:
= FaissRetriever(index=faiss_index, chunks=chunks, embedding_model=embeddings)
retriever print(f"Using {'GPU' if is_gpu_index else 'CPU'} FAISS index")
= RetrievalQA.from_chain_type(
qa_chain =model,
llm="stuff",
chain_type=retriever,
retriever=True
return_source_documents
)
# Query the system
= "Under what circumstances do the office of prime minister be vacant?"
query = qa_chain.invoke({"query": query})
result print("Answer:", result['result'])
# print("\nSource Documents:")
# for doc in result['source_documents']:
# print(doc.page_content[:100] + "...") # Print first 100 characters of each source document
except Exception as e:
print(f"Error running QA chain: {e}")
import traceback
# This will print the full error traceback
traceback.print_exc() else:
print("Cannot run QA chain as FAISS index is not initialized.")
Full code can be found here: LocalRAG
Resources
https://arxiv.org/pdf/2005.11401
https://blogs.nvidia.com/blog/what-is-retrieval-augmented-generation/
https://www.databricks.com/glossary/retrieval-augmented-generation-rag
https://arxiv.org/pdf/2312.10997
https://www.youtube.com/watch?v=rhZgXNdhWDY&t=2551s&ab_channel=UmarJamil
https://www.youtube.com/watch?v=HRvyei7vFSM&ab_channel=Underfitted