RAG
A RAG connects list of Resources with a generation chain.
It generates a response based on list of documents.
---
title: RAG class principles
---
graph LR
input(User input) -->|inject in 'question'| g
input --> r1
input --> r2
input --> r3
subgraph resources
r1[(Resource 1)] --> ag((Aggregate <br>all documents))
r2[(Resource 2)] --> ag
r3[(Resource 3)] --> ag
end
subgraph generation_chain
ag -->|inject in 'context'| g[Answer generation chain]
end
subgraph extras
g --> p[extras processing]
end
p -->|output| eog((" "))
RAG can't be a chat
RAG is stateless with no memory of the past interactions or retrievals. It can't be used without significant improvements for Chat application
Langchain architecture
A RAG is composed of 2 main chains and a collection of extras.
graph LR
input(Input) --> pass[RunnablePassthrough]
input --> r1
input --> r2
pass --> |question| gen
subgraph resources_retrieval_chain
r1[Retriever 1] --> ag["RunnableLambda <br>(Aggregate)"]
r2[Retriever 2] --> ag
end
subgraph generation_chain
ag -->|context| gen["Runnable <br> (Generation chain)"]
end
subgraph extras
gen -->|answer| xtra["RunnableLambda<br> (Extras) "]
end
xtra --> output("Outputs (dict) <br> {'question': ...<br>'answer':...<br>'context':...<br> 'extra1':...<br> 'extra2':...}")
resources_retrieval_chain: property automatically generated from the list of resources.
Outputs a list of documents similar (by default cosine similarity) to user query.generation_chain: class attribute. the chain generating the response from the instructions defined in the prompt
Outputs a stringextras: class attribute. A dictionnary ofCallable(aka: python function). Langchain converts automatically eachCallableinto aRunnableLambdaat build time.
The output of each function is added to the output dictionnary
Example
Chain
RAGimplements a property named chain exposing the consolidated Runnable.
It can be called as a regular Langchain chain:
>>> rag.chain().invoke("Who has been the strongest man ever?")
{
"question": "Who has been the strongest man ever?",
"answer": "Chuck Norris by far...",
"context":[Document(page_content= ..., metadata = ...), ...],
"count_docs": 4
}
Configure the chain
You can customize some parameters of your chains by providing arguments to the .chain() method.
The method accept two configuration arguments:
- retrieval_chain: Accept a Dictionnary where keys are resource names and values a dictionnary containing the kwargs of the Resource retriever
- generation_chain: Accept a Dictionary of kwargs used to configure the generation chain.
rag.get_chain(retrieval_chain= {
"my_resource_name": {
"search_kwargs":{"k": 10} # Retrieve the top 10 documents
}
})
Metadata filtering
One can filter retrieved document based on metadata using the search_kwargs argument.
rag.get_chain(retrieval_chain= {
"my_resource_name": {
"search_kwargs":{"filter": {
"type": {"$eq": "edito"},
"year": {"$gt": 2024}
}}
}
})
Note
Metadata filters are defined at instanciation. All inferences executed with a given RAG instance will be performed with the same filters.
As the API instanciate a new RAG at every call, it is not expected to be a limitation in production.
But one should consider this behavior if RAG python object ends up being persisted (ex: with pickle)
Custom retrival strategy
resources_retrieval_chain can be overriden to customize the way the context is injected in your generation prompt.
You can be interested to perform a reranking or filtering of your documents before passing them to the LLM. Let's see how to do so quite easily
Examples inspired by contextual compression documentationdoc
from langchain_core.runnables import (
Runnable,
RunnableLambda,
RunnableParallel,
)
from langchain.retrievers import ContextualCompressionRetriever
from rag.core.generator import RAG
class RerankedRAG(RAG):
llm: LLM
@property
def resources_retrieval_chain(self) -> ContextualCompressionRetriever:
"""Retrieve documents from all resources in parallel,then aggregate them and finally filter them using LLMChainFilter."""
_filter = LLMChainFilter.from_llm(self.llm)
retriever = RunnableParallel(
{resource.name: resource.retriever for resource in self.resources}
) | RunnableLambda(lambda x: [doc for docs in x.values() for doc in docs])
return ContextualCompressionRetriever(
base_compressor=_filter,
base_retriever=retriever
)