Skip to content

RAG

A RAG connects list of Resources with a generation chain.

It generates a response based on list of documents.

---
title: RAG class principles
---
graph LR
input(User input) -->|inject in 'question'| g
input --> r1
input --> r2
input --> r3
subgraph resources
r1[(Resource 1)] --> ag((Aggregate <br>all documents))
r2[(Resource 2)] --> ag
r3[(Resource 3)] --> ag
end
subgraph generation_chain
ag -->|inject in 'context'| g[Answer generation chain]
end
subgraph extras
g --> p[extras processing]
end
p -->|output| eog((" "))

RAG can't be a chat

RAG is stateless with no memory of the past interactions or retrievals. It can't be used without significant improvements for Chat application

Langchain architecture

A RAG is composed of 2 main chains and a collection of extras.

graph LR
input(Input) --> pass[RunnablePassthrough]
input --> r1
input --> r2
pass --> |question| gen
subgraph resources_retrieval_chain
r1[Retriever 1] --> ag["RunnableLambda <br>(Aggregate)"]
r2[Retriever 2] --> ag
end
subgraph generation_chain
ag -->|context| gen["Runnable <br> (Generation chain)"]
end
subgraph extras
gen -->|answer| xtra["RunnableLambda<br> (Extras) "]
end
xtra --> output("Outputs (dict) <br> {'question': ...<br>'answer':...<br>'context':...<br> 'extra1':...<br> 'extra2':...}")

  • resources_retrieval_chain: property automatically generated from the list of resources.
    Outputs a list of documents similar (by default cosine similarity) to user query.
  • generation_chain: class attribute. the chain generating the response from the instructions defined in the prompt
    Outputs a string
  • extras: class attribute. A dictionnary of Callable (aka: python function). Langchain converts automatically each Callableinto a RunnableLambda at build time.
    The output of each function is added to the output dictionnary

Example

from rag.core.generator import RAG
from rag.generation.rag_with_source import rag_chain_from_docs

rag = RAG(
    resources = [my_first_resource, my_other_resource]
    generation_chain = rag_chain_from_docs()
    extras = {"count_docs": lambda x: len(x['context'])}
)

Chain

RAGimplements a property named chain exposing the consolidated Runnable. It can be called as a regular Langchain chain:

>>> type(rag.chain())
langchain_core.runnables.base.RunnableSequence
It can be used a any regular langchain Runnable.
>>> rag.chain().invoke("Who has been the strongest man ever?")
{
    "question": "Who has been the strongest man ever?",
    "answer": "Chuck Norris by far...",
    "context":[Document(page_content= ..., metadata = ...), ...],
    "count_docs": 4
    }

Configure the chain

You can customize some parameters of your chains by providing arguments to the .chain() method. The method accept two configuration arguments: - retrieval_chain: Accept a Dictionnary where keys are resource names and values a dictionnary containing the kwargs of the Resource retriever - generation_chain: Accept a Dictionary of kwargs used to configure the generation chain.

rag.get_chain(retrieval_chain= {
    "my_resource_name": {
        "search_kwargs":{"k": 10} # Retrieve the top 10 documents
    }
})

Metadata filtering

One can filter retrieved document based on metadata using the search_kwargs argument.

rag.get_chain(retrieval_chain= {
    "my_resource_name": {
        "search_kwargs":{"filter": {
            "type": {"$eq": "edito"},
            "year": {"$gt": 2024}
        }}
    }
})

Note

Metadata filters are defined at instanciation. All inferences executed with a given RAG instance will be performed with the same filters.

As the API instanciate a new RAG at every call, it is not expected to be a limitation in production.

But one should consider this behavior if RAG python object ends up being persisted (ex: with pickle)

Custom retrival strategy

resources_retrieval_chain can be overriden to customize the way the context is injected in your generation prompt.

You can be interested to perform a reranking or filtering of your documents before passing them to the LLM. Let's see how to do so quite easily

Examples inspired by contextual compression documentationdoc

from langchain_core.runnables import (
    Runnable,
    RunnableLambda,
    RunnableParallel,
)
from langchain.retrievers import ContextualCompressionRetriever

from rag.core.generator import RAG

class RerankedRAG(RAG):

    llm: LLM

    @property
    def resources_retrieval_chain(self) -> ContextualCompressionRetriever:
        """Retrieve documents from all resources in parallel,then aggregate them and finally filter them using LLMChainFilter."""
        _filter = LLMChainFilter.from_llm(self.llm)

        retriever = RunnableParallel(
            {resource.name: resource.retriever for resource in self.resources}
        ) | RunnableLambda(lambda x: [doc for docs in x.values() for doc in docs]) 

        return ContextualCompressionRetriever(
            base_compressor=_filter, 
            base_retriever=retriever
            )
That's it. The rest of your code remains the same. Such strategy can minimize hallucination or approximation issues by reducing irrelevant context provided to the LLM.