Retrieval Augmentation Generation(RAG)

Limitations of LLMs:

  • Trained on a static dataset, llms lack real-time information and updates. They can only generate responses based on the knowledge available up to the last training cut-off date.

  • Limited by their training data size and capacity, which can lead to missing niche or less frequent information.

  • Struggle with long context management due to fixed input lengths. They can’t retain or recall long sequences of information effectively.

  • Prone to generating plausible-sounding but incorrect or misleading information, a phenomenon known as hallucination.

  • May lack specialized knowledge for niche domains, as they are trained on a broad range of topics.

  • Can be computationally expensive to fine-tune or retrain for specific tasks or updated information.

How RAG solves these issues?

  • RAG incorporates real-time data retrieval, allowing the model to access up-to-date information from external sources. This dynamic retrieval enhances the relevance and accuracy of the responses.

  • It uses external databases and documents to augment the model’s memory, thus providing information beyond the model’s internal knowledge.

  • It retrieves relevant documents or data chunks on-demand, effectively bypassing the context length limitation by dynamically providing pertinent information as needed.

  • RAG reduces hallucinations by grounding responses in retrieved documents, ensuring that the information is more reliable and accurate.

  • It allows for the incorporation of domain-specific databases and documents, enhancing the model’s capability to handle specialized queries accurately.

  • It avoids the need for frequent retraining by using retrieval mechanisms to update knowledge, which is more cost-effective.

General RAG Architecture

General RAG Architecture

General RAG Architecture

How it works?

1. Documents:These are the input documents (e.g., PDFs, web pages) that contain the information the system will use to answer questions.Example: A collection of scientific papers on climate change.

2. Split the documents into smaller chunks:Large documents are divided into smaller, manageable pieces to facilitate processing and retrieval.A 50-page paper might be split into 25 two-page chunks.

3. Use LLMs to convert documents to Vector Embeddings:Each document chunk is converted into a numerical vector representation using an LLM.A paragraph about rising sea levels might be converted into a 768-dimensional vector.

4. Vector Database:These vector embeddings are stored in a specialized database optimized for similarity searches.Pinecone or Faiss or similar vector database could be used to store and quickly retrieve similar vectors.

5. Query - Questions to ask:This represents the user’s input, typically a question they want answered.”What are the projected effects of climate change on coastal cities by 2050?” could be the query.

6. Use LLMs to convert Query to Vector Embeddings:The user’s question is also converted into a vector embedding using the same process as the documents.The question about coastal cities would be transformed into a vector representation.

7. Search the similar/relevant top K documents:The system searches the vector database for document chunks most similar to the query vector.Example: It might find 5 document chunks discussing sea level rise and urban planning.

8. Retrieved Context based on similarity:The most relevant document chunks are retrieved from the database.Text excerpts from climate reports discussing coastal city impacts.

9. Query + Context from documents (Prompt Template):The original query is combined with the retrieved context into a prompt for the LLM.A prompt combining the question about coastal cities with relevant excerpts from the retrieved documents.

10. LLM:The final large language model processes the combined query and context to generate an answer.GPT-4 or a similar model could be used to synthesize the information and answer the question.

11. Answer from LLM:The system provides a response based on the retrieved context and the LLM’s understanding.

Code Example

I have created a simple RAG application to chat with “Constitution of Nepal” with Llama 3.1 which is open source recently launched model by Meta.

  1. Initializing Language Model and Embeddings with Ollama: The code initializes a language model and its embeddings using the Ollama library with a specified model.
MODEL = "llama3.1:8b"
model = Ollama(model=MODEL)
embeddings = OllamaEmbeddings(model=MODEL)
  1. Loading PDF Document into Text Documents with PyPDFLoader: The code loads a PDF document titled “constitution.pdf” into a collection of text documents using the PyPDFLoader class.
loader = PyPDFLoader("constitution.pdf")
documents = loader.load()
  1. Splitting Documents into Chunks with RecursiveCharacterTextSplitter: The code configures a RecursiveCharacterTextSplitter with a chunk size of 1000 characters and an overlap of 200 characters, then splits the loaded PDF documents into these text chunks.
chunk_size = 1000

chunk_overlap = 200

text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)

chunks = text_splitter.split_documents(documents)
    
  1. Generating and Validating Embeddings for Document Chunks: The code iterates through text chunks, embedding each chunk using the OllamaEmbeddings model, and stores the embeddings in a list if they are valid. It handles exceptions and prints errors for invalid chunks. Finally, it stacks the embeddings into a NumPy array and prints the shape of the resulting array, or prints a message if no embeddings were generated.
chunk_embeddings = []

for chunk in chunks:
    try:
        # Ensure each chunk is a string
        if isinstance(chunk.page_content, str):
            embedding = embeddings.embed_query(chunk.page_content)
            if embedding and isinstance(embedding, list) and len(embedding) > 0:
                embedding = np.array(embedding, dtype='float32')
                chunk_embeddings.append(embedding)
            else:
                print(f"Empty or invalid embedding for chunk: {chunk.page_content[:50]}...")
        else:
            print(f"Chunk content is not a string: {type(chunk.page_content)}")
    except Exception as e:
        print(f"Error embedding chunk: {e}")

if chunk_embeddings:
    embeddings_array = np.vstack(chunk_embeddings)
    print(f"Embeddings array shape: {embeddings_array.shape}")
else:
    print("No embeddings generated. Check the embedding process.")
    embeddings_array = np.empty((0, 0), dtype='float32')
  1. Initializing FAISS Index with GPU or CPU Fallback: The code initializes variables for a FAISS index and checks if the embeddings array is non-empty. It attempts to create a GPU-based FAISS index using faiss.GpuIndexFlatL2. If successful, it adds the embeddings in batches to the GPU index and sets is_gpu_index to True. If GPU indexing fails, it falls back to creating a CPU-based FAISS index using faiss.IndexFlatL2 and adds the embeddings to this index instead. Any exceptions during these processes are caught and logged.
# Initialize FAISS index
faiss_index = None
is_gpu_index = False

if embeddings_array.size > 0:
    try:
        dim = embeddings_array.shape[1]  # Dimension of embeddings
        
        # Try GPU first
        try:
            res = faiss.StandardGpuResources()
            config = faiss.GpuIndexFlatConfig()
            config.device = 0  # Use GPU 0
            gpu_index = faiss.GpuIndexFlatL2(res, dim, config)
            
            # Add embeddings in batches
            batch_size = 100  # Adjust this based on your GPU memory
            for i in range(0, embeddings_array.shape[0], batch_size):
                batch = embeddings_array[i:i+batch_size]
                gpu_index.add(batch)
            
            faiss_index = gpu_index
            is_gpu_index = True
            print("FAISS GPU index initialized and embeddings added.")
        except Exception as e:
            print(f"GPU indexing failed: {e}")
            print("Falling back to CPU indexing.")
            
            # CPU indexing
            cpu_index = faiss.IndexFlatL2(dim)
            cpu_index.add(embeddings_array)
            faiss_index = cpu_index
            print("FAISS CPU index initialized and embeddings added.")
        
    except Exception as e:
        print(f"Error initializing FAISS index: {e}")
else:
    print("Empty embeddings array. Cannot initialize FAISS index.")
  1. Creating a FAISS-based Document Retriever Class: The FaissRetriever class integrates with FAISS for document retrieval. It initializes with an FAISS index, a list of document chunks, and an embedding model. The get_relevant_documents method calls the private _retrieve method to fetch documents. In _retrieve, the query is embedded into a vector, and FAISS is used to find the top k closest documents. The results are then assembled into Document objects using the indices from FAISS. Errors during retrieval are handled and logged. The asynchronous method aget_relevant_documents simply wraps the synchronous get_relevant_documents method for asynchronous use.
class FaissRetriever(BaseRetriever, BaseModel):
    index: Any = Field(description="FAISS index")
    chunks: List[Any] = Field(description="List of document chunks")
    embedding_model: Any = Field(description="Embedding model")

    class Config:
        arbitrary_types_allowed = True

    def __init__(self, **data):
        super().__init__(**data)
        print(f"FaissRetriever initialized with index: {self.index is not None}, chunks: {len(self.chunks)}, embedding_model: {self.embedding_model is not None}")

    def get_relevant_documents(self, query: str):
        return self._retrieve(query)

    def _retrieve(self, query, k=5):
        try:
            print(f"Query: {query}")
            query_embedding = self.embedding_model.embed_query(query)
            query_embedding = np.array(query_embedding).astype('float32').reshape(1, -1)
            print(f"Query embedding shape: {query_embedding.shape}")
            distances, indices = self.index.search(query_embedding, k)
            results = [Document(page_content=self.chunks[i].page_content, metadata=self.chunks[i].metadata) for i in indices[0]]
            return results
        except Exception as e:
            print(f"Error retrieving documents: {e}")
            return []

    async def aget_relevant_documents(self, query: str):
        return self.get_relevant_documents(query)
  1. Running QA Chain with FAISS Index for Query Processing: The code checks if the FAISS index (faiss_index) is initialized. If so, it creates a FaissRetriever instance with the FAISS index, document chunks, and embedding model. It then sets up a RetrievalQA chain using this retriever and a language model. The QA chain is used to query information about the office of the prime minister, and the resulting answer is printed. If an error occurs during execution, it is caught, logged, and the full traceback is printed. If the FAISS index is not initialized, a message is displayed indicating that the QA chain cannot be run.
if faiss_index is not None:
    try:
        retriever = FaissRetriever(index=faiss_index, chunks=chunks, embedding_model=embeddings)
        print(f"Using {'GPU' if is_gpu_index else 'CPU'} FAISS index")
        
        qa_chain = RetrievalQA.from_chain_type(
            llm=model,
            chain_type="stuff",
            retriever=retriever,
            return_source_documents=True
        )

        # Query the system
        query = "Under what circumstances do the office of prime minister be vacant?"
        result = qa_chain.invoke({"query": query})
        print("Answer:", result['result'])
        # print("\nSource Documents:")
        # for doc in result['source_documents']:
        #     print(doc.page_content[:100] + "...")  # Print first 100 characters of each source document
    except Exception as e:
        print(f"Error running QA chain: {e}")
        import traceback
        traceback.print_exc()  # This will print the full error traceback
else:
    print("Cannot run QA chain as FAISS index is not initialized.")

Full code can be found here: LocalRAG

Resources

https://arxiv.org/pdf/2005.11401

https://blogs.nvidia.com/blog/what-is-retrieval-augmented-generation/

https://www.databricks.com/glossary/retrieval-augmented-generation-rag

https://arxiv.org/pdf/2312.10997

https://www.youtube.com/watch?v=rhZgXNdhWDY&t=2551s&ab_channel=UmarJamil

https://www.youtube.com/watch?v=HRvyei7vFSM&ab_channel=Underfitted