Skip to content

Nishan Jain - Technical Writer Task (RAG and Text-to-SQL) #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions Task1-sql-router-assessment/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SQL-RAG LlamaIndex Agent

## Overview
This LlamaIndex-based application is an AI-powered query agent that retrieves data from both a database and a file-based system using LlamaIndex and OpenAI's GPT models. It utilizes an agent router to determine the appropriate source for answering user queries.

## Features
- **Hybrid Querying**: Supports querying structured SQL databases and unstructured text-based data.
- **Agent Router**: Determines whether to query the database or use a semantic search engine.
- **LlamaIndex Integration**: Uses LlamaCloudIndex for text-based queries.
- **OpenAI-powered LLM**: Uses GPT models for query interpretation and decision-making.
- **Arize Phoenix Logging**: Enables trace logging for query execution.
- **Streamlit UI**: Provides an interactive web interface for users to enter queries and view responses.

## Installation

### Prerequisites
- Python 3.8+
- OpenAI API Key
- Llama Cloud API Key
- Phoenix API Key

### Steps
```sh
# Clone the repository
git clone <repo-url>
cd sql-rag-streamlit

# Install dependencies
pip install -r requirements.txt

# Set up environment variables
export OPENAI_API_KEY="your_openai_api_key"
export LLAMA_CLOUD_API_KEY="your_llama_cloud_api_key"
export PHOENIX_API_KEY="your_phoenix_api_key"
export ORGANIZATION_ID="your_organization_id"

# Run the application
streamlit run sql_rag.py
```

## Components

### 1. **Database Setup**
- Uses an in-memory SQLite database with a `city_stats` table containing city population and state information.
- Populates the table with sample data.
- Uses SQLAlchemy to manage the database connection.

### 2. **Query Engines**
- **SQL Query Engine**: Executes SQL queries against the `city_stats` table.
- **LlamaCloud Query Engine**: Handles semantic queries on text data.
- These tools are managed by `QueryEngineTool` instances.

### 3. **Agent Router Workflow**
- The `RouterOutputAgentWorkflow` (from `tool_workflow.py`) decides which query engine to use.
- It interacts with OpenAI's GPT model to analyze queries and select the appropriate tool.
- Implements multiple workflow steps such as preparing chat input, selecting tools, and dispatching calls.

### 4. **Streamlit UI**
- A simple interface where users enter queries.
- Displays the response after executing the query via the selected engine.

## Usage
1. Open the Streamlit UI.
2. Enter a query, e.g., "What is the population of New York City?"
3. The agent determines whether to query the SQL database or use the LlamaCloudIndex.
4. The response is displayed in the UI.

## Dependencies
The project relies on the following Python packages:
- `llama_index`
- `arize-phoenix`
- `nest-asyncio`
- `sqlalchemy`
- `streamlit`
- `llama-index-callbacks-arize-phoenix`

## Typefully

https://typefully.com/t/Vkxs861

## Contributing
Contributions are welcome! Feel free to fork the repo and submit a pull request.

## License
This project is licensed under the MIT License.

6 changes: 6 additions & 0 deletions Task1-sql-router-assessment/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
llama_index==0.12.24
arize-phoenix==8.13.1
nest-asyncio==1.6.0
sqlalchemy== 2.0.39
streamlit==1.43.2
llama-index-callbacks-arize-phoenix==0.4.0
142 changes: 142 additions & 0 deletions Task1-sql-router-assessment/sql_rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import os
import nest_asyncio
import asyncio
import streamlit as st
from llama_index.core import SQLDatabase, Settings, set_global_handler
from llama_index.llms.openai import OpenAI
from llama_index.core.query_engine import NLSQLTableQueryEngine
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex
from llama_index.core.tools import QueryEngineTool
from tool_workflow import RouterOutputAgentWorkflow
from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, insert


def get_api_keys():
"""Retrieve API keys from environment variables."""
keys = {}
required_keys = ["PHOENIX_API_KEY", "OPENAI_API_KEY", "LLAMA_CLOUD_API_KEY", "ORGANIZATION_ID"]
for key in required_keys:
keys[key] = os.getenv(key)
if keys[key] is None:
raise ValueError(f"Missing required environment variable: {key}")
return keys


def setup_logging(api_key):
"""Setup Arize Phoenix logging for tracing."""
os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = f"api_key={api_key}"
set_global_handler("arize_phoenix", endpoint="https://llamatrace.com/v1/traces")


def initialize_llama_index(organization_id):
"""Initialize LlamaCloudIndex and return a query engine."""
try:
index = LlamaCloudIndex(
name="thirsty-finch-2025-03-09",
project_name="Default",
organization_id=organization_id
)
return index.as_query_engine()
except Exception as e:
raise RuntimeError(f"Failed to initialize LlamaCloudIndex: {e}")


def setup_database():
"""Setup SQLite in-memory database and populate it with city data."""
try:
engine = create_engine("sqlite:///:memory:", future=True)
metadata_obj = MetaData()
city_stats_table = Table(
"city_stats",
metadata_obj,
Column("city_name", String(16), primary_key=True),
Column("population", Integer),
Column("state", String(16), nullable=False),
)
metadata_obj.create_all(engine)

rows = [
{"city_name": "New York City", "population": 8336000, "state": "New York"},
{"city_name": "Los Angeles", "population": 3822000, "state": "California"},
{"city_name": "Chicago", "population": 2665000, "state": "Illinois"},
{"city_name": "Houston", "population": 2303000, "state": "Texas"},
{"city_name": "Miami", "population": 449514, "state": "Florida"},
{"city_name": "Seattle", "population": 749256, "state": "Washington"},
]
with engine.begin() as connection:
for row in rows:
connection.execute(insert(city_stats_table).values(**row))

return engine, city_stats_table
except Exception as e:
raise RuntimeError(f"Failed to set up database: {e}")


def setup_query_engines(engine):
"""Setup SQL query engine for querying city data."""
try:
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
sql_query_engine = NLSQLTableQueryEngine(sql_database=sql_database, tables=["city_stats"])
return QueryEngineTool.from_defaults(
query_engine=sql_query_engine,
description="Useful for querying city population and state information in the USA.",
name="sql_tool"
)
except Exception as e:
raise RuntimeError(f"Failed to setup SQL query engine: {e}")


def initialize_workflow(sql_tool, llama_cloud_tool):
"""Initialize the RouterOutputAgentWorkflow with query tools."""
try:
return RouterOutputAgentWorkflow(tools=[sql_tool, llama_cloud_tool], verbose=True, timeout=120)
except Exception as e:
raise RuntimeError(f"Failed to initialize workflow: {e}")


async def run_query(query, workflow):
"""Run a query asynchronously using the workflow."""
try:
return await workflow.run(message=query)
except Exception as e:
raise RuntimeError(f"Query execution failed: {e}")


def main():
"""Main function to run the Streamlit app."""
try:
# Retrieve API keys
api_keys = get_api_keys()
setup_logging(api_keys["PHOENIX_API_KEY"])
nest_asyncio.apply()

# Set default LLM model
Settings.llm = OpenAI("gpt-4o-mini")

# Initialize LlamaCloud Query Engine
llama_cloud_query_engine = initialize_llama_index(api_keys["ORGANIZATION_ID"])
engine, _ = setup_database()
sql_tool = setup_query_engines(engine)
llama_cloud_tool = QueryEngineTool.from_defaults(
query_engine=llama_cloud_query_engine,
description="Useful for answering semantic questions about US cities.",
name="llama_cloud_tool"
)
wf = initialize_workflow(sql_tool, llama_cloud_tool)

# Streamlit UI
st.title("Agentic RAG Integrating SQL Databases & Semantic Search")
query = st.text_input("Enter your query:")
if st.button("Submit") and query:
try:
response = asyncio.run(run_query(query, wf))
st.write("### Response:")
st.write(str(response))
except Exception as e:
st.error(f"An error occurred: {e}")
except Exception as e:
st.error(f"Application failed to start: {e}")


if __name__ == "__main__":
main()
151 changes: 151 additions & 0 deletions Task1-sql-router-assessment/tool_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from typing import Dict, List, Any, Optional

from llama_index.core.tools import BaseTool
from llama_index.core.llms import ChatMessage
from llama_index.core.llms.llm import ToolSelection, LLM
from llama_index.core.workflow import (
Workflow,
Event,
StartEvent,
StopEvent,
step,
)
from llama_index.core.workflow import Context
from llama_index.llms.openai import OpenAI

class InputEvent(Event):
"""Input event."""


class GatherToolsEvent(Event):
"""Gather Tools Event"""

tool_calls: Any


class ToolCallEvent(Event):
"""Tool Call event"""

tool_call: ToolSelection


class ToolCallEventResult(Event):
"""Tool call event result."""

msg: ChatMessage


class RouterOutputAgentWorkflow(Workflow):
"""Custom router output agent workflow."""

def __init__(self,
tools: List[BaseTool],
timeout: Optional[float] = 10.0,
disable_validation: bool = False,
verbose: bool = False,
llm: Optional[LLM] = None,
chat_history: Optional[List[ChatMessage]] = None,
):
"""Constructor."""

super().__init__(timeout=timeout, disable_validation=disable_validation, verbose=verbose)

self.tools: List[BaseTool] = tools
self.tools_dict: Optional[Dict[str, BaseTool]] = {tool.metadata.name: tool for tool in self.tools}
self.llm: LLM = llm or OpenAI(temperature=0, model="gpt-3.5-turbo")
self.chat_history: List[ChatMessage] = chat_history or []

def reset(self) -> None:
"""Resets Chat History"""

self.chat_history = []

@step()
async def prepare_chat(self, ev: StartEvent) -> InputEvent:
message = ev.get("message")
if message is None:
raise ValueError("'message' field is required.")

# add msg to chat history
chat_history = self.chat_history
chat_history.append(ChatMessage(role="user", content=message))
return InputEvent()

@step()
async def chat(self, ev: InputEvent) -> GatherToolsEvent | StopEvent:
"""Appends msg to chat history, then gets tool calls."""

# Put msg into LLM with tools included
chat_res = await self.llm.achat_with_tools(
self.tools,
chat_history=self.chat_history,
verbose=self._verbose,
allow_parallel_tool_calls=True
)
tool_calls = self.llm.get_tool_calls_from_response(chat_res, error_on_no_tool_call=False)

ai_message = chat_res.message
self.chat_history.append(ai_message)
if self._verbose:
print(f"Chat message: {ai_message.content}")

# no tool calls, return chat message.
if not tool_calls:
return StopEvent(result=ai_message.content)

return GatherToolsEvent(tool_calls=tool_calls)

@step(pass_context=True)
async def dispatch_calls(self, ctx: Context, ev: GatherToolsEvent) -> ToolCallEvent:
"""Dispatches calls."""

tool_calls = ev.tool_calls
await ctx.set("num_tool_calls", len(tool_calls))

# trigger tool call events
for tool_call in tool_calls:
ctx.send_event(ToolCallEvent(tool_call=tool_call))

return None

Comment on lines +98 to +110
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Return type annotation mismatch.

The method signature indicates -> ToolCallEvent, but the function returns None. This mismatch can cause confusion or break type-checking. Either adjust the return type or return a ToolCallEvent (or a list of them).

-async def dispatch_calls(self, ctx: Context, ev: GatherToolsEvent) -> ToolCallEvent:
+async def dispatch_calls(self, ctx: Context, ev: GatherToolsEvent) -> None:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@step(pass_context=True)
async def dispatch_calls(self, ctx: Context, ev: GatherToolsEvent) -> ToolCallEvent:
"""Dispatches calls."""
tool_calls = ev.tool_calls
await ctx.set("num_tool_calls", len(tool_calls))
# trigger tool call events
for tool_call in tool_calls:
ctx.send_event(ToolCallEvent(tool_call=tool_call))
return None
@step(pass_context=True)
async def dispatch_calls(self, ctx: Context, ev: GatherToolsEvent) -> None:
"""Dispatches calls."""
tool_calls = ev.tool_calls
await ctx.set("num_tool_calls", len(tool_calls))
# trigger tool call events
for tool_call in tool_calls:
ctx.send_event(ToolCallEvent(tool_call=tool_call))
return None

@step()
async def call_tool(self, ev: ToolCallEvent) -> ToolCallEventResult:
"""Calls tool."""

tool_call = ev.tool_call

# get tool ID and function call
id_ = tool_call.tool_id

if self._verbose:
print(f"Calling function {tool_call.tool_name} with msg {tool_call.tool_kwargs}")

# call function and put result into a chat message
tool = self.tools_dict[tool_call.tool_name]
output = await tool.acall(**tool_call.tool_kwargs)
msg = ChatMessage(
name=tool_call.tool_name,
content=str(output),
role="tool",
additional_kwargs={
"tool_call_id": id_,
"name": tool_call.tool_name
}
)

return ToolCallEventResult(msg=msg)

Comment on lines +111 to +137
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Potential KeyError for missing tools.

tool = self.tools_dict[tool_call.tool_name] will raise KeyError if no matching tool exists. Consider handling this with a graceful error message or fallback logic.

+if tool_call.tool_name not in self.tools_dict:
+    raise RuntimeError(f"Tool '{tool_call.tool_name}' not found.")
 tool = self.tools_dict[tool_call.tool_name]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@step()
async def call_tool(self, ev: ToolCallEvent) -> ToolCallEventResult:
"""Calls tool."""
tool_call = ev.tool_call
# get tool ID and function call
id_ = tool_call.tool_id
if self._verbose:
print(f"Calling function {tool_call.tool_name} with msg {tool_call.tool_kwargs}")
# call function and put result into a chat message
tool = self.tools_dict[tool_call.tool_name]
output = await tool.acall(**tool_call.tool_kwargs)
msg = ChatMessage(
name=tool_call.tool_name,
content=str(output),
role="tool",
additional_kwargs={
"tool_call_id": id_,
"name": tool_call.tool_name
}
)
return ToolCallEventResult(msg=msg)
@step()
async def call_tool(self, ev: ToolCallEvent) -> ToolCallEventResult:
"""Calls tool."""
tool_call = ev.tool_call
# get tool ID and function call
id_ = tool_call.tool_id
if self._verbose:
print(f"Calling function {tool_call.tool_name} with msg {tool_call.tool_kwargs}")
# call function and put result into a chat message
if tool_call.tool_name not in self.tools_dict:
raise RuntimeError(f"Tool '{tool_call.tool_name}' not found.")
tool = self.tools_dict[tool_call.tool_name]
output = await tool.acall(**tool_call.tool_kwargs)
msg = ChatMessage(
name=tool_call.tool_name,
content=str(output),
role="tool",
additional_kwargs={
"tool_call_id": id_,
"name": tool_call.tool_name
}
)
return ToolCallEventResult(msg=msg)

@step(pass_context=True)
async def gather(self, ctx: Context, ev: ToolCallEventResult) -> StopEvent | None:
"""Gathers tool calls."""
# wait for all tool call events to finish.
tool_events = ctx.collect_events(ev, [ToolCallEventResult] * await ctx.get("num_tool_calls"))
if not tool_events:
return None

for tool_event in tool_events:
# append tool call chat messages to history
self.chat_history.append(tool_event.msg)

# # after all tool calls finish, pass input event back, restart agent loop
return InputEvent()