Beyond Guesswork: The Rise of Retrieval Augmented Generation
How to curb your AI's hallucinations
How are you this week? I hope you had a relaxing weekend. It's not like a leading AI company decided to implode, leaving its employees worrying about their future. Jokes aside, I promise this week's deep dive has nothing to do with the soap opera drama unfolding at OpenAI. Instead, we'll focus on how to tackle hallucinating LLMs. If you have feedback, questions, or comments, please let me know through the form below or the comment section. Let's go!
RAGs to Riches
I hated tests as a kid. Capital H hated them. Cram a bunch of material in a super short time window and regurgitate it verbatim on D-day. Fun right? I hoped undergrad would be different, but I was in for a shock. Not only was I expected to parrot my memorized meanderings, but I was also expected to do so neatly. Unlike Steve Jobs, I didn't sit in on a calligraphy class, and my scribbles rarely passed muster.
In hindsight, I realized that there was a common pattern in my responses. When asked a question that I a) didn't know the answer to or b) didn't entirely remember because of exhaustion, I resorted to hallucinating details. But I was a sneaky little bugger. You had to analyze my work through a fine-toothed comb to know that some details were murky. I hated doing this. Therefore, I hated tests.
Laughing at my plight, are you? Try answering this question without searching the internet.
Who's laughing now? Did you make an informed guess, too? As you can see, when we are asked something, we use our priors to respond. We tend to guesstimate or say we don't know when we don't have the answer stored in our heads.
It wasn't until grad school that I started enjoying tests. Why? Open-book exams. Some of these lasted a week! The instructions were simple. Use any approved resource under the sun to answer these questions as best you can. Turn your work in within the deadline. It's an understatement to say this changed the way I approached tests. I could now refer to actual books and research papers or, in some cases, look up syntax on the internet before responding.
My tendency to fudge details or make "informed guesses" reduced significantly. I could now use references to reason, build logic and intuition, and answer with an enhanced context. My answers (and grades) naturally improved.
See where I'm going with this?
Large Language Models (LLMs) face this type of unbearable questioning from us each day.
"Hey, ChatGPT, can you write an essay on why Ms. Strabinsky is my hero? I need to get my grades up in ancient History somehow."
"Hey Claude, can you write my legal briefs? I need to meet a deadline by Monday."
Like us, they look at the information they've learned from their training data and respond as best as they can.
However, despite being trained on the open internet, they still don't have all the answers. Think of subjects grounded in evolving facts, like journalism, or those that require significant experience, like medicine. Doesn't your doctor refer you to a specialist if they feel your condition is beyond their expertise?
What's worse? Change is the only constant. Facts change. Information constantly evolves. New events happen. You get the drift.
Constantly retraining LLMs on updated data is computationally intensive, financially impractical, and environmentally irresponsible.
Faced with questions they don't have answers to and the rotting odor of their training data, they pursue the most logical option: guess. But guessing, a.k.a hallucination, has repercussions.
If we use LLM responses as gospel, we're in trouble. ChatGPT 3.5 once gave me a response backed by fake research papers that don't exist. Like me, it's a sneaky little bugger (maybe not so little). The authors of the papers it cited were real people. The papers weren't. How many people would bother fact-checking?
Instead, if we could empower these models with an open-book system, that would help them and us tremendously. Imagine a situation where an LLM can look up information on the internet, corroborate sources, and then provide responses grounded in fact versus pursuing a career in fiction.
That's exactly what Retrieval Augmented Generation (RAG) does!
Just as open-book exams allowed me to supplement my knowledge with external information, RAG enables LLMs to dynamically pull information from vast external databases, enriching their responses with up-to-date, accurate data.
RAG solves multiple problems at once. First, it eliminates the need to retrain LLMs frequently. When an LLM needs information outside of its training data, RAG allows it to query this externally. As new information is added or if existing information changes, only these external data sources need to be updated. The LLM can now be up-to-date by simply querying this updated source instead.
Second, RAG reduces hallucinations. In fact, RAG compels an LLM to use the retrieved information in its responses (we'll see why and how soon). This reduces the possibility that the model makes a wrong guess. If you want, you can ask a model to cite its sources.
Third, it's easy and fast to use. Developers can implement RAG in as few as five lines of code. Thus, RAG offers a cost-efficient and easy way to ensure that LLMs produce the correct outputs while eliminating the need to retrain them frequently.
Over the rest of the piece, we'll look at the components of a RAG system, how it’s trained, strategies to improve RAG, and more. It's time to ground LLMs in context.
Ready, Set, RAG
How does a RAG system help LLMs provide correct outputs? The answer is in the name. Retrieval Augmented Generation. Before we look at the retrieval augmented bit, let's look at generation first.
Generation
This step is what all LLMs do. Given a sequence of words, they predict the next word in the sequence. If you're curious about how they learn to do this, I've covered that in a past issue. These models rely on vast training data to encode information into their weights. This is called parametric memory, i.e., knowledge encoded in the model's weights.
Techniques like Reinforcement Learning from Human Feedback (RLHF), instruction tuning, and chain-of-thought prompting help these models generate more coherent, contextually appropriate, and sometimes logically correct responses. But, despite the size of the training dataset and these fine-tuning techniques, LLMs can still hallucinate. This is a disservice to users who want to dive deeper into specifics or discuss a current topic. That's where augmentation comes in.
Augmentation
While generation lays the foundation, it's the augmentation phase where RAG really starts to differentiate itself from traditional LLMs. In this step, RAG augments the initial prompt (sequence/query/whatever you want to call it) with externally retrieved information.
We'll look at how retrieval works shortly, but for now, assume that our LLM has a magic wand and can instantly retrieve information from external sources that are related to the prompt at hand.
Let's say the user's query is, "Explain the recent advancements in battery technology for electric vehicles." A traditional LLM will start the generation step right away. Its response would be entirely dependent on its parametric memory. So, if the model were trained two years ago, the user would get information from that time and before. Further, it could still hallucinate specific details.
A RAG-based LLM, on the other hand, will first retrieve information from external documents. This could be the latest research papers published within the last year that explain new battery compositions and their efficiency. It could also be press releases from companies, blog posts by technology experts, etc. Once it has these documents, it augments the user's prompt using them.
Concretely, these retrieved documents are concatenated to the original prompt as context. The LLM then starts the generation step using both its parametric knowledge and these retrieved results. The user, in turn, receives a more current and accurate response. Hence the name Retrieval Augmented Generation. But how does the RAG system know what information to retrieve? Let's look at the retriever next.
Retrieval
So far, we've assumed that a RAG system can magically retrieve the most appropriate information from external sources. But how do we know which documents are most relevant to a given prompt? How can we ensure we provide rich context to the augmentation step? If there were only a way to measure how similar two pieces of text are. Hang on, we can use embeddings! I know. They seem to appear just about everywhere in machine learning.
Embeddings are high-dimensional vectors that can represent words, sentences, or even entire documents. The numbers in an embedding vector aren't random. They capture the meaning of the item they represent. Thus, LLMs can use embeddings to semantically and contextually compare text.
Imagine if we could summarize an entire encyclopedia into a tiny sticky note.
That’s sort of what embeddings do for text. Cool, right?
Naturally, embeddings play a crucial role in the retrieval phase of RAG systems. So, how are they used in this context? First, another language model called a Dense Passage Retrieval system (DPR) converts both the user's query and the external documents into dense embedding vectors. The DPR system takes the query embedding and compares it to the embeddings of documents using methods like cosine similarity, dot product, Manhattan distance, etc.1
To make this easier to visualize, think of a library's indexing system. Each book (document) has a unique code (embedding) representing its contents. When a patron asks the librarian (DPR) for suggestions (query), the librarian can retrieve the most relevant books for the patron.
What's really cool is that the comparison process looks for semantic similarity. Thus, DPR can find documents that match the query even if they don't have matching keywords! Additionally, the comparison and retrieval happen in real-time, ensuring that the augmentation phase has the most up-to-date and relevant information.
In practice, these documents are pre-converted into embeddings and stored in special databases called vector databases. These databases are optimized to retrieve the most relevant documents quickly and are updated whenever new information becomes available or existing information changes. Thus, the DPR system can access up-to-date information just by swapping out these vectors regularly. Instead, imagine retraining LLMs on a regular basis just to update their knowledge. Isn't this way easier and more efficient?
If you have a hard time visualizing what a DPR might look like, recall the IBM Watson computer that won Jeopardy. It's kind of like that, only way more advanced.
The DPR within a RAG system is called non-parametric memory since it retrieves information stored in external documents, not encoded within its parameters. Thus, in RAG, we combine the benefits of both closed-book (LLM Generation) and open-book (DPR) systems. The DPR "cues" the LLM to generate correct responses.
Post-Processing
We skipped a step in between. Before the generator gets to work on the augmented prompt, the retrieved documents have to be post-processed. This is essential to ensure the diversity of sources and their recency. At the end of the day, we want only the best information to be augmented to the prompt. While explaining these in depth is beyond the scope of this piece, here is an excellent article on popular post-processing strategies.
Here's what a RAG workflow looks like:
Generation, Augmentation, and Retrieval collectively contribute to the RAG's ability to provide accurate and current responses. Let's explore how the entire system is trained next.
Training a RAG System
To train a RAG system, we simply take a pretrained language model to generate text (Generator in the figure below) and a pretrained DPR model (Retriever in the figure) and fine-tune them together end to end. Since we've already discussed how LLMs are pretrained in the past, let's briefly cover the DPR model and how it's pretrained.
Pretraining the DPR Model
A DPR model typically consists of two encoders. One of these is called the query encoder and converts the user query into a dense embedding. The other is called the passage encoder and maps any text passage (document) into a dense embedding, creating an index for all the external documents that will be used for retrieval.
At run-time, DPR measures the similarity between the query embedding and the document embeddings it indexed to return the top "K" most relevant documents as results. Here, "K" is user-defined and usually is set to values like 5 or 10.
Pretraining the DPR model means teaching it to be a good ranking function for retrieval. Imagine that you have to write an essay on Ancient History but have no idea where to start and no time to read all the material in the world. A good librarian could help you rank the best books to read in a fraction of the time it would have taken you to do it yourself. How can we do the same for document retrieval?
The goal is to create an embedding space so that relevant queries and documents will be closer together (higher similarity) compared to irrelevant pairs. So, let's take a query, the most relevant document for that question, and a few irrelevant documents and build a training example. Collect a few of these training examples, and we will train the model to do just that!
Once we have both a pretrained generator and a pretrained retriever, we're ready to fine-tune them jointly.
End-to-End Fine Tuning
Conveniently, a RAG system can be fine-tuned either on a per-output basis (a single retrieved document is used to generate all output tokens) or on a per-token basis (different retrieved documents are responsible for different output tokens). In the original RAG paper, the authors jointly trained the retriever and generator components without directly supervising them. That is, they didn't enforce constraints on which documents had to be retrieved. Instead, they had pairs of queries and responses and minimized the error of the joint system over these pairs. To reduce training time further, they only fine-tuned the query encoder in DPR and the generator model while keeping the document encoder fixed.
Now that we have a trained RAG system, let's look at how we can use it, its advantages, and what the open challenges are.
Applications and Challenges
We've gone through all this trouble to understand the nuts and bolts of RAG, but what is good for?2
For starters, you could use a RAG system in your workplace. It would be an upgrade on the static Wiki pages and the archaic chatbot that helps you (read confuses you) find answers to specific questions. Information within a company frequently changes. So, using a RAG system where you only need to update the document vectors is a low-cost solution to improve employee productivity.
RAG is also great for automated customer support. Whether you run a SaaS business, restaurant, or other service, you can have a RAG system help address customer questions and provide up-to-date and correct answers based on documents you maintain. For example, your restaurant RAG-bot could inform customers of the hours of operation on holidays, whether you support special dietary needs, or how they can reserve the venue for special events.
Finally, RAG systems are great research assistants. Whether you need to find sources for an article you're writing or need to get up to speed with current developments, RAG has you covered. Since RAG systems can retrieve information from books, scientific papers, news articles, and other massive databases, they are potent research tools. Plus, whenever you're in doubt, you can ask them to cite their sources and verify that they aren't hallucinating.
Secretly (well, not so secretly anymore), I'd love to build a RAG system that's trained on my notes, thoughts, and saved content. It would be incredibly useful for my writing, research, and thinking in general.
But it's not that simple. There are a few challenges RAG systems pose, and some of these are still open questions.
Single Point of Failure: A RAG system has two core components: a retriever and a generator. However, the generator has to use the context that the retriever provides. So, if the retriever isn't good at identifying and returning relevant information, the generation will suffer too. For example, the retriever could be skewed towards returning only certain types of content. The generator would suffer as a result. Even if the retriever does an excellent job, external sources could have hidden biases in them, which negatively impact generation.
Computation Costs: While RAGs are "efficient" since they leverage information retrieval instead of retraining, they still are language models at their core. This means that all the baggage that comes with LLMs gets checked in for free when RAGs are used, too. In addition to this, practical use cases need vast quantities of external information that can easily swallow terabytes of storage. This introduces another layer of complexity when deploying RAGs. Currently, researchers are trying to reduce these costs.
Ambiguity: The quality of results depends on how well the generator and retriever understand the user's prompt. The retriever might retrieve irrelevant or random documents if the query is ambiguous, leading to poor results.
RAG is a ridiculously simple and elegant solution to help LLMs produce better results by combining the benefits of parametric knowledge and carefully retrieved information. It's easy to update and provides a great platform for building use-case-specific language models that delight its users. Over a decade ago, access to books and research papers thrilled me as I feverishly sought answers for my open-book exams. I wonder how I'd deal with those same tests if I had access to an all-knowing RAG.
I know you're not wondering about that. You're wondering if a Wombat can really poop cubes. Yes, it can. A RAG told me so.
Glossary:
This is a new thing I'm trying for deep dives, so let me know if you find this helpful.
Hallucination: When a language model generates incorrect or fabricated information, often due to gaps in its training.
Retrieval Augmented Generation (RAG): A system that enhances LLM responses by dynamically integrating information from external databases, providing updated and accurate data.
Parametric Memory: Knowledge encoded in the weights of a language model derived from its training data.
Non-parametric Memory: In RAG, this refers to retrieving information stored in external documents and not encoded within the model's parameters.
Dense Passage Retrieval (DPR) System: A key component of RAG responsible for retrieving information by converting user queries and external documents into dense embeddings and returning the top K most relevant ones.
Embeddings: High-dimensional vectors representing words, sentences, or documents, capturing their semantic and contextual meanings.
Query Encoder and Passage Encoder: Within the DPR model, the query encoder converts user queries into embeddings, while the passage encoder does the same for text passages, aiding in retrieval.
Cosine Similarity, Dot Product, Manhattan Distance: Methods used in DPR to compare query embeddings with document embeddings to find the most relevant documents.
Vector Databases: Specialized databases optimized for quick retrieval of relevant documents, where these documents are stored as embeddings.
Resources To Consider:
RAG Literature
Here are some useful papers to understand RAG and related research better:
Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks: https://arxiv.org/abs/2005.11401
Atlas: Few-shot Learning with Retrieval Augmented Language Models: https://arxiv.org/abs/2208.03299
Improving language models by retrieving from trillions of tokens: https://arxiv.org/abs/2112.04426
Dense Passage Retrieval for Open-Domain Question Answering: https://arxiv.org/abs/2004.04906
Patrick Lewis Explains RAG
Understand Retrieval Augmented Generation straight from its creator's mouth. In this podcast episode, Patrick Lewis walks us through what RAG is and how it works. If you enjoy videos more than text, this might be useful for you.
Retrieval Augmented Generation with LangChain
In this code-based webinar, you'll learn about the components of RAG workflows and how to use Langchain to reduce complexity while increasing development velocity for building GenAI applications.
For example, if DPR uses cosine similarity, it would measure the angle between the query vector and each document vector. A smaller angle implies a higher similarity, thus identifying documents that are most relevant to the query.
This is a terrific explanation - thank you. When it comes to "non-parametric memory", is this typically prefetched and the LLM sees this as a local cache of recent information or is this a real time lookup based on the user's query?