Skip to content

Rag and Tex-to-SQL Agent done by Safni Usman #102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions rag-sql-agent/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# **RAG and Text-to-SQL Agent**

This project implements a custom agent capable of querying either:

- A **LlamaCloud index** for RAG-based retrieval.
- A **SQL query engine** as a tool.

### **Use Case**

We use PDFs of Wikipedia pages of US cities and a SQL database containing their populations and states as documents.
Find the pages here:
- [New York City](https://en.wikipedia.org/wiki/New_York_City)
- [Los Angeles](https://en.wikipedia.org/wiki/Los_Angeles)
- [Chicago](https://en.wikipedia.org/wiki/Chicago)
- [Houston](https://en.wikipedia.org/wiki/Houston)
- [Miami](https://en.wikipedia.org/wiki/Miami)
- [Seattle](https://en.wikipedia.org/wiki/Seattle)

## **Technologies Used**

- **LlamaIndex** – for orchestrating the agent.
- **LlamaTrace(Phoenix-Arize)** – for observability.
- **Streamlit** – to build the UI.
- **GPT 3.5 turbo** – as the LLM.

## **Demo**

- [**Video Demo**](demo.mp4)
- [**Hugging Face Space**](https://huggingface.co/spaces/Safni/RAG_SGL_APP)



## **Installation and Setup**

### **1. Set up LlamaCloud API**

Get an API key from [**LlamaCloud**](https://cloud.llamaindex.ai/) and add it to the `.env` file:

```ini
LLAMA_CLOUD_API_KEY=<YOUR_API_KEY>
```

### **2. Set up Observability**

Integrate **LlamaTrace** for observability. Obtain an API key from [**LlamaTrace**](https://llamatrace.com/login) and add it to the `.env` file:

```ini
PHOENIX_API_KEY=<YOUR_API_KEY>
```

### **3. Set up OpenAI API**

Get an API key from [**OpenAI**](https://platform.openai.com/) and add it to the `.env` file:

```ini
OPENAI_API_KEY=<YOUR_API_KEY>
```

### **4. Install Dependencies**

Ensure you have **Python 3.11+** installed. Then, install dependencies:

```bash
pip install -r requirements.txt
```

### **5. Run the App**

Launch the application using:

```bash
streamlit run app.py
```


## 📬 Stay Updated with Our Newsletter!
**Get a FREE Data Science eBook** 📖 with 150+ essential lessons in Data Science when you subscribe to our newsletter! Stay in the loop with the latest tutorials, insights, and exclusive resources. [Subscribe now!](https://join.dailydoseofds.com)

[![Daily Dose of Data Science Newsletter](https://github.com/patchy631/ai-engineering/blob/main/resources/join_ddods.png)](https://join.dailydoseofds.com)

---

## Contribution

Contributions are welcome! Please fork the repository and submit a pull request with your improvements.
82 changes: 82 additions & 0 deletions rag-sql-agent/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import streamlit as st
import asyncio
from query_agent import RouterOutputAgentWorkflow
from tool_setup import QueryEngineTools

class RAGSQLAgentApp:
"""Streamlit UI for querying SQL databases using an agent workflow with continuous chat history."""

def __init__(self):
"""Initialize the application and workflow."""
self.tools = QueryEngineTools()
self.workflow = RouterOutputAgentWorkflow(
tools=[self.tools.get_sql_tool(), self.tools.get_llama_cloud_tool()],
verbose=True,
timeout=120
)
self._setup_ui()
Comment on lines +9 to +17
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider error handling for missing or failing tools.

The constructor sets up the workflow with get_sql_tool() and get_llama_cloud_tool(). If either fails or isn't configured properly, the workflow may break. You might wrap tool initialization (and environment checks) in try-catch logic to report clearer error messages to users.


async def run_query(self, query):
"""Runs the query asynchronously using the workflow."""
return await self.workflow.run(message=query)

def _setup_ui(self):
"""Setup Streamlit UI components."""
with st.sidebar:
st.image("./assets/my_image.png", width=300)
st.markdown("## How to Use")
st.write(
"1. Enter your query about the US cities in the chat box given and press Enter.\n"
"2. The assistant will process your query and respond.\n"

)
st.markdown("## Powered By")

col1, col2, col3 = st.columns(3)

# Display images in each column
with col1:
st.image("./assets/image1.png", width=80)

with col2:
st.image("./assets/image2.png", width=80)

with col3:
st.image("./assets/image3.png", width=80)

# Main UI Title
st.image("./assets/cover.png", width=800)


# Initialize session state for chat history
if "chat_history" not in st.session_state:
st.session_state.chat_history = []

# Display chat history
for chat in st.session_state.chat_history:
with st.chat_message("user"):
st.write(chat["query"])
with st.chat_message("assistant"):
st.write(chat["response"])

# User input box for new queries
query = st.chat_input("Enter your query:")

if query:
with st.chat_message("user"):
st.write(query)

with st.spinner("Processing..."):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(self.run_query(query))

with st.chat_message("assistant"):
st.write(result)

Comment on lines +70 to +76
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add exception handling for failed queries.

If workflow.run(message=query) raises an exception (e.g., a tool error or a network failure), the app might crash. Wrap this call in a try-except block to show a user-friendly error message instead of an internal traceback.

- result = loop.run_until_complete(self.run_query(query))
+ try:
+     result = loop.run_until_complete(self.run_query(query))
+ except Exception as e:
+     st.error(f"An error occurred while processing your query: {e}")
+     return
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(self.run_query(query))
with st.chat_message("assistant"):
st.write(result)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(self.run_query(query))
except Exception as e:
st.error(f"An error occurred while processing your query: {e}")
return
with st.chat_message("assistant"):
st.write(result)

# Store chat in session history
st.session_state.chat_history.append({"query": query, "response": result})

# Run the app
if __name__ == "__main__":
RAGSQLAgentApp()
Binary file added rag-sql-agent/assets/cover.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added rag-sql-agent/assets/image1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added rag-sql-agent/assets/image2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added rag-sql-agent/assets/image3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added rag-sql-agent/assets/my_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added rag-sql-agent/demo.mp4
Binary file not shown.
36 changes: 36 additions & 0 deletions rag-sql-agent/llama_cloud_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex
import llama_index.core
import os
from dotenv import load_dotenv

class LlamaCloudSetup:
def __init__(self):
load_dotenv()
self.llama_cloud_api_key = os.getenv('LLAMA_CLOUD_API_KEY')
self.phoenix_api_key = os.getenv('PHOENIX_API_KEY')

self.index = LlamaCloudIndex(
name="city_pdfs_index",
project_name="Default",
organization_id="8c1b26ae-5c98-4873-bce9-28d0c67a3dce",
api_key=self.llama_cloud_api_key,
)
Comment on lines +12 to +17
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Move hardcoded values to environment variables

The organization_id and other configuration parameters are hardcoded in the class. This makes the code less flexible and potentially poses a security risk.

 self.index = LlamaCloudIndex(
-    name="city_pdfs_index",
-    project_name="Default",
-    organization_id="8c1b26ae-5c98-4873-bce9-28d0c67a3dce",
+    name=os.getenv('LLAMA_CLOUD_INDEX_NAME', 'city_pdfs_index'),
+    project_name=os.getenv('LLAMA_CLOUD_PROJECT_NAME', 'Default'),
+    organization_id=os.getenv('LLAMA_CLOUD_ORGANIZATION_ID'),
     api_key=self.llama_cloud_api_key,
 )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.index = LlamaCloudIndex(
name="city_pdfs_index",
project_name="Default",
organization_id="8c1b26ae-5c98-4873-bce9-28d0c67a3dce",
api_key=self.llama_cloud_api_key,
)
self.index = LlamaCloudIndex(
name=os.getenv('LLAMA_CLOUD_INDEX_NAME', 'city_pdfs_index'),
project_name=os.getenv('LLAMA_CLOUD_PROJECT_NAME', 'Default'),
organization_id=os.getenv('LLAMA_CLOUD_ORGANIZATION_ID'),
api_key=self.llama_cloud_api_key,
)


self._setup_observability()

def _setup_observability(self):
"""Set up observability using Arize Phoenix."""
os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = f"api_key={self.phoenix_api_key}"
llama_index.core.set_global_handler(
"arize_phoenix", endpoint="https://llamatrace.com/v1/traces"
)

def get_query_engine(self):
"""Returns the query engine instance."""
return self.index.as_query_engine()


# Usage example:
if __name__ == "__main__":
llama_cloud_setup = LlamaCloudSetup()
query_engine = llama_cloud_setup.get_query_engine()
155 changes: 155 additions & 0 deletions rag-sql-agent/query_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from typing import Dict, List, Any, Optional
from IPython.display import display, Markdown
import asyncio
from llama_index.core.workflow import Context # <-- Add this
from llama_index.core.tools import BaseTool
from llama_index.core.llms import ChatMessage
from llama_index.core.llms.llm import ToolSelection, LLM
from llama_index.core.workflow import (
Workflow,
Event,
StartEvent,
StopEvent,
step,
)
from llama_index.core.base.response.schema import Response
from llama_index.core.tools import FunctionTool
from llama_index.llms.openai import OpenAI


class InputEvent(Event):
"""Input event."""

class GatherToolsEvent(Event):
"""Gather Tools Event"""

tool_calls: Any

class ToolCallEvent(Event):
"""Tool Call event"""

tool_call: ToolSelection

class ToolCallEventResult(Event):
"""Tool call event result."""

msg: ChatMessage

class RouterOutputAgentWorkflow(Workflow):
"""Custom router output agent workflow."""

def __init__(self,
tools: List[BaseTool],
timeout: Optional[float] = 10.0,
disable_validation: bool = False,
verbose: bool = False,
llm: Optional[LLM] = None,
chat_history: Optional[List[ChatMessage]] = None,
):
"""Constructor."""

super().__init__(timeout=timeout, disable_validation=disable_validation, verbose=verbose)

self.tools: List[BaseTool] = tools
self.tools_dict: Optional[Dict[str, BaseTool]] = {tool.metadata.name: tool for tool in self.tools}
self.llm: LLM = llm or OpenAI(temperature=0, model="gpt-3.5-turbo")
self.chat_history: List[ChatMessage] = chat_history or []


def reset(self) -> None:
"""Resets Chat History"""

self.chat_history = []

@step()
async def prepare_chat(self, ev: StartEvent) -> InputEvent:
message = ev.get("message")
if message is None:
raise ValueError("'message' field is required.")

# add msg to chat history
chat_history = self.chat_history
chat_history.append(ChatMessage(role="user", content=message))
return InputEvent()

@step()
async def chat(self, ev: InputEvent) -> GatherToolsEvent | StopEvent:
"""Appends msg to chat history, then gets tool calls."""

# Put msg into LLM with tools included
chat_res = await self.llm.achat_with_tools(
self.tools,
chat_history=self.chat_history,
verbose=self._verbose,
allow_parallel_tool_calls=True
)
tool_calls = self.llm.get_tool_calls_from_response(chat_res, error_on_no_tool_call=False)

ai_message = chat_res.message
self.chat_history.append(ai_message)
if self._verbose:
print(f"Chat message: {ai_message.content}")

# no tool calls, return chat message.
if not tool_calls:
return StopEvent(result=ai_message.content)

return GatherToolsEvent(tool_calls=tool_calls)

@step(pass_context=True)
async def dispatch_calls(self, ctx: Context, ev: GatherToolsEvent) -> ToolCallEvent:
"""Dispatches calls."""

tool_calls = ev.tool_calls
await ctx.set("num_tool_calls", len(tool_calls))

# trigger tool call events
for tool_call in tool_calls:
ctx.send_event(ToolCallEvent(tool_call=tool_call))

return None

@step()
async def call_tool(self, ev: ToolCallEvent) -> ToolCallEventResult:
"""Calls tool."""

tool_call = ev.tool_call

# get tool ID and function call
id_ = tool_call.tool_id

if self._verbose:
print(f"Calling function {tool_call.tool_name} with msg {tool_call.tool_kwargs}")

# call function and put result into a chat message
tool = self.tools_dict[tool_call.tool_name]
output = await tool.acall(**tool_call.tool_kwargs)
msg = ChatMessage(
name=tool_call.tool_name,
content=str(output),
role="tool",
additional_kwargs={
"tool_call_id": id_,
"name": tool_call.tool_name
}
)

return ToolCallEventResult(msg=msg)

@step(pass_context=True)
async def gather(self, ctx: Context, ev: ToolCallEventResult) -> StopEvent | None:
"""Gathers tool calls."""
# wait for all tool call events to finish.
tool_events = ctx.collect_events(ev, [ToolCallEventResult] * await ctx.get("num_tool_calls"))
if not tool_events:
return None

for tool_event in tool_events:
# append tool call chat messages to history
self.chat_history.append(tool_event.msg)

# # after all tool calls finish, pass input event back, restart agent loop
return InputEvent()



7 changes: 7 additions & 0 deletions rag-sql-agent/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
streamlit
llama-index
sqlalchemy
pyvis
python-dotenv
arize-phoenix
llama-index-callbacks-arize-phoenix
Loading