-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Rag and Tex-to-SQL Agent done by Safni Usman #102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# **RAG and Text-to-SQL Agent** | ||
|
||
This project implements a custom agent capable of querying either: | ||
|
||
- A **LlamaCloud index** for RAG-based retrieval. | ||
- A **SQL query engine** as a tool. | ||
|
||
### **Use Case** | ||
|
||
We use PDFs of Wikipedia pages of US cities and a SQL database containing their populations and states as documents. | ||
Find the pages here: | ||
- [New York City](https://en.wikipedia.org/wiki/New_York_City) | ||
- [Los Angeles](https://en.wikipedia.org/wiki/Los_Angeles) | ||
- [Chicago](https://en.wikipedia.org/wiki/Chicago) | ||
- [Houston](https://en.wikipedia.org/wiki/Houston) | ||
- [Miami](https://en.wikipedia.org/wiki/Miami) | ||
- [Seattle](https://en.wikipedia.org/wiki/Seattle) | ||
|
||
## **Technologies Used** | ||
|
||
- **LlamaIndex** – for orchestrating the agent. | ||
- **LlamaTrace(Phoenix-Arize)** – for observability. | ||
- **Streamlit** – to build the UI. | ||
- **GPT 3.5 turbo** – as the LLM. | ||
|
||
## **Demo** | ||
|
||
- [**Video Demo**](demo.mp4) | ||
- [**Hugging Face Space**](https://huggingface.co/spaces/Safni/RAG_SGL_APP) | ||
|
||
|
||
|
||
## **Installation and Setup** | ||
|
||
### **1. Set up LlamaCloud API** | ||
|
||
Get an API key from [**LlamaCloud**](https://cloud.llamaindex.ai/) and add it to the `.env` file: | ||
|
||
```ini | ||
LLAMA_CLOUD_API_KEY=<YOUR_API_KEY> | ||
``` | ||
|
||
### **2. Set up Observability** | ||
|
||
Integrate **LlamaTrace** for observability. Obtain an API key from [**LlamaTrace**](https://llamatrace.com/login) and add it to the `.env` file: | ||
|
||
```ini | ||
PHOENIX_API_KEY=<YOUR_API_KEY> | ||
``` | ||
|
||
### **3. Set up OpenAI API** | ||
|
||
Get an API key from [**OpenAI**](https://platform.openai.com/) and add it to the `.env` file: | ||
|
||
```ini | ||
OPENAI_API_KEY=<YOUR_API_KEY> | ||
``` | ||
|
||
### **4. Install Dependencies** | ||
|
||
Ensure you have **Python 3.11+** installed. Then, install dependencies: | ||
|
||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
### **5. Run the App** | ||
|
||
Launch the application using: | ||
|
||
```bash | ||
streamlit run app.py | ||
``` | ||
|
||
|
||
## 📬 Stay Updated with Our Newsletter! | ||
**Get a FREE Data Science eBook** 📖 with 150+ essential lessons in Data Science when you subscribe to our newsletter! Stay in the loop with the latest tutorials, insights, and exclusive resources. [Subscribe now!](https://join.dailydoseofds.com) | ||
|
||
[](https://join.dailydoseofds.com) | ||
|
||
--- | ||
|
||
## Contribution | ||
|
||
Contributions are welcome! Please fork the repository and submit a pull request with your improvements. |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,82 @@ | ||||||||||||||||||||||||||||||||||
import streamlit as st | ||||||||||||||||||||||||||||||||||
import asyncio | ||||||||||||||||||||||||||||||||||
from query_agent import RouterOutputAgentWorkflow | ||||||||||||||||||||||||||||||||||
from tool_setup import QueryEngineTools | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
class RAGSQLAgentApp: | ||||||||||||||||||||||||||||||||||
"""Streamlit UI for querying SQL databases using an agent workflow with continuous chat history.""" | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
def __init__(self): | ||||||||||||||||||||||||||||||||||
"""Initialize the application and workflow.""" | ||||||||||||||||||||||||||||||||||
self.tools = QueryEngineTools() | ||||||||||||||||||||||||||||||||||
self.workflow = RouterOutputAgentWorkflow( | ||||||||||||||||||||||||||||||||||
tools=[self.tools.get_sql_tool(), self.tools.get_llama_cloud_tool()], | ||||||||||||||||||||||||||||||||||
verbose=True, | ||||||||||||||||||||||||||||||||||
timeout=120 | ||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
self._setup_ui() | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
async def run_query(self, query): | ||||||||||||||||||||||||||||||||||
"""Runs the query asynchronously using the workflow.""" | ||||||||||||||||||||||||||||||||||
return await self.workflow.run(message=query) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
def _setup_ui(self): | ||||||||||||||||||||||||||||||||||
"""Setup Streamlit UI components.""" | ||||||||||||||||||||||||||||||||||
with st.sidebar: | ||||||||||||||||||||||||||||||||||
st.image("./assets/my_image.png", width=300) | ||||||||||||||||||||||||||||||||||
st.markdown("## How to Use") | ||||||||||||||||||||||||||||||||||
st.write( | ||||||||||||||||||||||||||||||||||
"1. Enter your query about the US cities in the chat box given and press Enter.\n" | ||||||||||||||||||||||||||||||||||
"2. The assistant will process your query and respond.\n" | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
st.markdown("## Powered By") | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
col1, col2, col3 = st.columns(3) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Display images in each column | ||||||||||||||||||||||||||||||||||
with col1: | ||||||||||||||||||||||||||||||||||
st.image("./assets/image1.png", width=80) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
with col2: | ||||||||||||||||||||||||||||||||||
st.image("./assets/image2.png", width=80) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
with col3: | ||||||||||||||||||||||||||||||||||
st.image("./assets/image3.png", width=80) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Main UI Title | ||||||||||||||||||||||||||||||||||
st.image("./assets/cover.png", width=800) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Initialize session state for chat history | ||||||||||||||||||||||||||||||||||
if "chat_history" not in st.session_state: | ||||||||||||||||||||||||||||||||||
st.session_state.chat_history = [] | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Display chat history | ||||||||||||||||||||||||||||||||||
for chat in st.session_state.chat_history: | ||||||||||||||||||||||||||||||||||
with st.chat_message("user"): | ||||||||||||||||||||||||||||||||||
st.write(chat["query"]) | ||||||||||||||||||||||||||||||||||
with st.chat_message("assistant"): | ||||||||||||||||||||||||||||||||||
st.write(chat["response"]) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# User input box for new queries | ||||||||||||||||||||||||||||||||||
query = st.chat_input("Enter your query:") | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
if query: | ||||||||||||||||||||||||||||||||||
with st.chat_message("user"): | ||||||||||||||||||||||||||||||||||
st.write(query) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
with st.spinner("Processing..."): | ||||||||||||||||||||||||||||||||||
loop = asyncio.new_event_loop() | ||||||||||||||||||||||||||||||||||
asyncio.set_event_loop(loop) | ||||||||||||||||||||||||||||||||||
result = loop.run_until_complete(self.run_query(query)) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
with st.chat_message("assistant"): | ||||||||||||||||||||||||||||||||||
st.write(result) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
Comment on lines
+70
to
+76
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Add exception handling for failed queries. If - result = loop.run_until_complete(self.run_query(query))
+ try:
+ result = loop.run_until_complete(self.run_query(query))
+ except Exception as e:
+ st.error(f"An error occurred while processing your query: {e}")
+ return 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||
# Store chat in session history | ||||||||||||||||||||||||||||||||||
st.session_state.chat_history.append({"query": query, "response": result}) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Run the app | ||||||||||||||||||||||||||||||||||
if __name__ == "__main__": | ||||||||||||||||||||||||||||||||||
RAGSQLAgentApp() |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,36 @@ | ||||||||||||||||||||||||||
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex | ||||||||||||||||||||||||||
import llama_index.core | ||||||||||||||||||||||||||
import os | ||||||||||||||||||||||||||
from dotenv import load_dotenv | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
class LlamaCloudSetup: | ||||||||||||||||||||||||||
def __init__(self): | ||||||||||||||||||||||||||
load_dotenv() | ||||||||||||||||||||||||||
self.llama_cloud_api_key = os.getenv('LLAMA_CLOUD_API_KEY') | ||||||||||||||||||||||||||
self.phoenix_api_key = os.getenv('PHOENIX_API_KEY') | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
self.index = LlamaCloudIndex( | ||||||||||||||||||||||||||
name="city_pdfs_index", | ||||||||||||||||||||||||||
project_name="Default", | ||||||||||||||||||||||||||
organization_id="8c1b26ae-5c98-4873-bce9-28d0c67a3dce", | ||||||||||||||||||||||||||
api_key=self.llama_cloud_api_key, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
Comment on lines
+12
to
+17
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Move hardcoded values to environment variables The organization_id and other configuration parameters are hardcoded in the class. This makes the code less flexible and potentially poses a security risk. self.index = LlamaCloudIndex(
- name="city_pdfs_index",
- project_name="Default",
- organization_id="8c1b26ae-5c98-4873-bce9-28d0c67a3dce",
+ name=os.getenv('LLAMA_CLOUD_INDEX_NAME', 'city_pdfs_index'),
+ project_name=os.getenv('LLAMA_CLOUD_PROJECT_NAME', 'Default'),
+ organization_id=os.getenv('LLAMA_CLOUD_ORGANIZATION_ID'),
api_key=self.llama_cloud_api_key,
) 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
self._setup_observability() | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def _setup_observability(self): | ||||||||||||||||||||||||||
"""Set up observability using Arize Phoenix.""" | ||||||||||||||||||||||||||
os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = f"api_key={self.phoenix_api_key}" | ||||||||||||||||||||||||||
llama_index.core.set_global_handler( | ||||||||||||||||||||||||||
"arize_phoenix", endpoint="https://llamatrace.com/v1/traces" | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def get_query_engine(self): | ||||||||||||||||||||||||||
"""Returns the query engine instance.""" | ||||||||||||||||||||||||||
return self.index.as_query_engine() | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Usage example: | ||||||||||||||||||||||||||
if __name__ == "__main__": | ||||||||||||||||||||||||||
llama_cloud_setup = LlamaCloudSetup() | ||||||||||||||||||||||||||
query_engine = llama_cloud_setup.get_query_engine() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
from typing import Dict, List, Any, Optional | ||
from IPython.display import display, Markdown | ||
import asyncio | ||
from llama_index.core.workflow import Context # <-- Add this | ||
from llama_index.core.tools import BaseTool | ||
from llama_index.core.llms import ChatMessage | ||
from llama_index.core.llms.llm import ToolSelection, LLM | ||
from llama_index.core.workflow import ( | ||
Workflow, | ||
Event, | ||
StartEvent, | ||
StopEvent, | ||
step, | ||
) | ||
from llama_index.core.base.response.schema import Response | ||
from llama_index.core.tools import FunctionTool | ||
from llama_index.llms.openai import OpenAI | ||
|
||
|
||
class InputEvent(Event): | ||
"""Input event.""" | ||
|
||
class GatherToolsEvent(Event): | ||
"""Gather Tools Event""" | ||
|
||
tool_calls: Any | ||
|
||
class ToolCallEvent(Event): | ||
"""Tool Call event""" | ||
|
||
tool_call: ToolSelection | ||
|
||
class ToolCallEventResult(Event): | ||
"""Tool call event result.""" | ||
|
||
msg: ChatMessage | ||
|
||
class RouterOutputAgentWorkflow(Workflow): | ||
"""Custom router output agent workflow.""" | ||
|
||
def __init__(self, | ||
tools: List[BaseTool], | ||
timeout: Optional[float] = 10.0, | ||
disable_validation: bool = False, | ||
verbose: bool = False, | ||
llm: Optional[LLM] = None, | ||
chat_history: Optional[List[ChatMessage]] = None, | ||
): | ||
"""Constructor.""" | ||
|
||
super().__init__(timeout=timeout, disable_validation=disable_validation, verbose=verbose) | ||
|
||
self.tools: List[BaseTool] = tools | ||
self.tools_dict: Optional[Dict[str, BaseTool]] = {tool.metadata.name: tool for tool in self.tools} | ||
self.llm: LLM = llm or OpenAI(temperature=0, model="gpt-3.5-turbo") | ||
self.chat_history: List[ChatMessage] = chat_history or [] | ||
|
||
|
||
def reset(self) -> None: | ||
"""Resets Chat History""" | ||
|
||
self.chat_history = [] | ||
|
||
@step() | ||
async def prepare_chat(self, ev: StartEvent) -> InputEvent: | ||
message = ev.get("message") | ||
if message is None: | ||
raise ValueError("'message' field is required.") | ||
|
||
# add msg to chat history | ||
chat_history = self.chat_history | ||
chat_history.append(ChatMessage(role="user", content=message)) | ||
return InputEvent() | ||
|
||
@step() | ||
async def chat(self, ev: InputEvent) -> GatherToolsEvent | StopEvent: | ||
"""Appends msg to chat history, then gets tool calls.""" | ||
|
||
# Put msg into LLM with tools included | ||
chat_res = await self.llm.achat_with_tools( | ||
self.tools, | ||
chat_history=self.chat_history, | ||
verbose=self._verbose, | ||
allow_parallel_tool_calls=True | ||
) | ||
tool_calls = self.llm.get_tool_calls_from_response(chat_res, error_on_no_tool_call=False) | ||
|
||
ai_message = chat_res.message | ||
self.chat_history.append(ai_message) | ||
if self._verbose: | ||
print(f"Chat message: {ai_message.content}") | ||
|
||
# no tool calls, return chat message. | ||
if not tool_calls: | ||
return StopEvent(result=ai_message.content) | ||
|
||
return GatherToolsEvent(tool_calls=tool_calls) | ||
|
||
@step(pass_context=True) | ||
async def dispatch_calls(self, ctx: Context, ev: GatherToolsEvent) -> ToolCallEvent: | ||
"""Dispatches calls.""" | ||
|
||
tool_calls = ev.tool_calls | ||
await ctx.set("num_tool_calls", len(tool_calls)) | ||
|
||
# trigger tool call events | ||
for tool_call in tool_calls: | ||
ctx.send_event(ToolCallEvent(tool_call=tool_call)) | ||
|
||
return None | ||
|
||
@step() | ||
async def call_tool(self, ev: ToolCallEvent) -> ToolCallEventResult: | ||
"""Calls tool.""" | ||
|
||
tool_call = ev.tool_call | ||
|
||
# get tool ID and function call | ||
id_ = tool_call.tool_id | ||
|
||
if self._verbose: | ||
print(f"Calling function {tool_call.tool_name} with msg {tool_call.tool_kwargs}") | ||
|
||
# call function and put result into a chat message | ||
tool = self.tools_dict[tool_call.tool_name] | ||
output = await tool.acall(**tool_call.tool_kwargs) | ||
msg = ChatMessage( | ||
name=tool_call.tool_name, | ||
content=str(output), | ||
role="tool", | ||
additional_kwargs={ | ||
"tool_call_id": id_, | ||
"name": tool_call.tool_name | ||
} | ||
) | ||
|
||
return ToolCallEventResult(msg=msg) | ||
|
||
@step(pass_context=True) | ||
async def gather(self, ctx: Context, ev: ToolCallEventResult) -> StopEvent | None: | ||
"""Gathers tool calls.""" | ||
# wait for all tool call events to finish. | ||
tool_events = ctx.collect_events(ev, [ToolCallEventResult] * await ctx.get("num_tool_calls")) | ||
if not tool_events: | ||
return None | ||
|
||
for tool_event in tool_events: | ||
# append tool call chat messages to history | ||
self.chat_history.append(tool_event.msg) | ||
|
||
# # after all tool calls finish, pass input event back, restart agent loop | ||
return InputEvent() | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
streamlit | ||
llama-index | ||
sqlalchemy | ||
pyvis | ||
python-dotenv | ||
arize-phoenix | ||
llama-index-callbacks-arize-phoenix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider error handling for missing or failing tools.
The constructor sets up the workflow with
get_sql_tool()
andget_llama_cloud_tool()
. If either fails or isn't configured properly, the workflow may break. You might wrap tool initialization (and environment checks) in try-catch logic to report clearer error messages to users.