-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Nishan Jain - Technical Writer Task (RAG and Text-to-SQL) #89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# SQL-RAG LlamaIndex Agent | ||
|
||
## Overview | ||
This LlamaIndex-based application is an AI-powered query agent that retrieves data from both a database and a file-based system using LlamaIndex and OpenAI's GPT models. It utilizes an agent router to determine the appropriate source for answering user queries. | ||
|
||
## Features | ||
- **Hybrid Querying**: Supports querying structured SQL databases and unstructured text-based data. | ||
- **Agent Router**: Determines whether to query the database or use a semantic search engine. | ||
- **LlamaIndex Integration**: Uses LlamaCloudIndex for text-based queries. | ||
- **OpenAI-powered LLM**: Uses GPT models for query interpretation and decision-making. | ||
- **Arize Phoenix Logging**: Enables trace logging for query execution. | ||
- **Streamlit UI**: Provides an interactive web interface for users to enter queries and view responses. | ||
|
||
## Installation | ||
|
||
### Prerequisites | ||
- Python 3.8+ | ||
- OpenAI API Key | ||
- Llama Cloud API Key | ||
- Phoenix API Key | ||
|
||
### Steps | ||
```sh | ||
# Clone the repository | ||
git clone <repo-url> | ||
cd sql-rag-streamlit | ||
|
||
# Install dependencies | ||
pip install -r requirements.txt | ||
|
||
# Set up environment variables | ||
export OPENAI_API_KEY="your_openai_api_key" | ||
export LLAMA_CLOUD_API_KEY="your_llama_cloud_api_key" | ||
export PHOENIX_API_KEY="your_phoenix_api_key" | ||
export ORGANIZATION_ID="your_organization_id" | ||
|
||
# Run the application | ||
streamlit run sql_rag.py | ||
``` | ||
|
||
## Components | ||
|
||
### 1. **Database Setup** | ||
- Uses an in-memory SQLite database with a `city_stats` table containing city population and state information. | ||
- Populates the table with sample data. | ||
- Uses SQLAlchemy to manage the database connection. | ||
|
||
### 2. **Query Engines** | ||
- **SQL Query Engine**: Executes SQL queries against the `city_stats` table. | ||
- **LlamaCloud Query Engine**: Handles semantic queries on text data. | ||
- These tools are managed by `QueryEngineTool` instances. | ||
|
||
### 3. **Agent Router Workflow** | ||
- The `RouterOutputAgentWorkflow` (from `tool_workflow.py`) decides which query engine to use. | ||
- It interacts with OpenAI's GPT model to analyze queries and select the appropriate tool. | ||
- Implements multiple workflow steps such as preparing chat input, selecting tools, and dispatching calls. | ||
|
||
### 4. **Streamlit UI** | ||
- A simple interface where users enter queries. | ||
- Displays the response after executing the query via the selected engine. | ||
|
||
## Usage | ||
1. Open the Streamlit UI. | ||
2. Enter a query, e.g., "What is the population of New York City?" | ||
3. The agent determines whether to query the SQL database or use the LlamaCloudIndex. | ||
4. The response is displayed in the UI. | ||
|
||
## Dependencies | ||
The project relies on the following Python packages: | ||
- `llama_index` | ||
- `arize-phoenix` | ||
- `nest-asyncio` | ||
- `sqlalchemy` | ||
- `streamlit` | ||
- `llama-index-callbacks-arize-phoenix` | ||
|
||
## Typefully | ||
|
||
https://typefully.com/t/Vkxs861 | ||
|
||
## Contributing | ||
Contributions are welcome! Feel free to fork the repo and submit a pull request. | ||
|
||
## License | ||
This project is licensed under the MIT License. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
llama_index==0.12.24 | ||
arize-phoenix==8.13.1 | ||
nest-asyncio==1.6.0 | ||
sqlalchemy== 2.0.39 | ||
streamlit==1.43.2 | ||
llama-index-callbacks-arize-phoenix==0.4.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import os | ||
import nest_asyncio | ||
import asyncio | ||
import streamlit as st | ||
from llama_index.core import SQLDatabase, Settings, set_global_handler | ||
from llama_index.llms.openai import OpenAI | ||
from llama_index.core.query_engine import NLSQLTableQueryEngine | ||
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex | ||
from llama_index.core.tools import QueryEngineTool | ||
from tool_workflow import RouterOutputAgentWorkflow | ||
from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, insert | ||
|
||
|
||
def get_api_keys(): | ||
"""Retrieve API keys from environment variables.""" | ||
keys = {} | ||
required_keys = ["PHOENIX_API_KEY", "OPENAI_API_KEY", "LLAMA_CLOUD_API_KEY", "ORGANIZATION_ID"] | ||
for key in required_keys: | ||
keys[key] = os.getenv(key) | ||
if keys[key] is None: | ||
raise ValueError(f"Missing required environment variable: {key}") | ||
return keys | ||
|
||
|
||
def setup_logging(api_key): | ||
"""Setup Arize Phoenix logging for tracing.""" | ||
os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = f"api_key={api_key}" | ||
set_global_handler("arize_phoenix", endpoint="https://llamatrace.com/v1/traces") | ||
|
||
|
||
def initialize_llama_index(organization_id): | ||
"""Initialize LlamaCloudIndex and return a query engine.""" | ||
try: | ||
index = LlamaCloudIndex( | ||
name="thirsty-finch-2025-03-09", | ||
project_name="Default", | ||
organization_id=organization_id | ||
) | ||
return index.as_query_engine() | ||
except Exception as e: | ||
raise RuntimeError(f"Failed to initialize LlamaCloudIndex: {e}") | ||
|
||
|
||
def setup_database(): | ||
"""Setup SQLite in-memory database and populate it with city data.""" | ||
try: | ||
engine = create_engine("sqlite:///:memory:", future=True) | ||
metadata_obj = MetaData() | ||
city_stats_table = Table( | ||
"city_stats", | ||
metadata_obj, | ||
Column("city_name", String(16), primary_key=True), | ||
Column("population", Integer), | ||
Column("state", String(16), nullable=False), | ||
) | ||
metadata_obj.create_all(engine) | ||
|
||
rows = [ | ||
{"city_name": "New York City", "population": 8336000, "state": "New York"}, | ||
{"city_name": "Los Angeles", "population": 3822000, "state": "California"}, | ||
{"city_name": "Chicago", "population": 2665000, "state": "Illinois"}, | ||
{"city_name": "Houston", "population": 2303000, "state": "Texas"}, | ||
{"city_name": "Miami", "population": 449514, "state": "Florida"}, | ||
{"city_name": "Seattle", "population": 749256, "state": "Washington"}, | ||
] | ||
with engine.begin() as connection: | ||
for row in rows: | ||
connection.execute(insert(city_stats_table).values(**row)) | ||
|
||
return engine, city_stats_table | ||
except Exception as e: | ||
raise RuntimeError(f"Failed to set up database: {e}") | ||
|
||
|
||
def setup_query_engines(engine): | ||
"""Setup SQL query engine for querying city data.""" | ||
try: | ||
sql_database = SQLDatabase(engine, include_tables=["city_stats"]) | ||
sql_query_engine = NLSQLTableQueryEngine(sql_database=sql_database, tables=["city_stats"]) | ||
return QueryEngineTool.from_defaults( | ||
query_engine=sql_query_engine, | ||
description="Useful for querying city population and state information in the USA.", | ||
name="sql_tool" | ||
) | ||
except Exception as e: | ||
raise RuntimeError(f"Failed to setup SQL query engine: {e}") | ||
|
||
|
||
def initialize_workflow(sql_tool, llama_cloud_tool): | ||
"""Initialize the RouterOutputAgentWorkflow with query tools.""" | ||
try: | ||
return RouterOutputAgentWorkflow(tools=[sql_tool, llama_cloud_tool], verbose=True, timeout=120) | ||
except Exception as e: | ||
raise RuntimeError(f"Failed to initialize workflow: {e}") | ||
|
||
|
||
async def run_query(query, workflow): | ||
"""Run a query asynchronously using the workflow.""" | ||
try: | ||
return await workflow.run(message=query) | ||
except Exception as e: | ||
raise RuntimeError(f"Query execution failed: {e}") | ||
|
||
|
||
def main(): | ||
"""Main function to run the Streamlit app.""" | ||
try: | ||
# Retrieve API keys | ||
api_keys = get_api_keys() | ||
setup_logging(api_keys["PHOENIX_API_KEY"]) | ||
nest_asyncio.apply() | ||
|
||
# Set default LLM model | ||
Settings.llm = OpenAI("gpt-4o-mini") | ||
|
||
# Initialize LlamaCloud Query Engine | ||
llama_cloud_query_engine = initialize_llama_index(api_keys["ORGANIZATION_ID"]) | ||
engine, _ = setup_database() | ||
sql_tool = setup_query_engines(engine) | ||
llama_cloud_tool = QueryEngineTool.from_defaults( | ||
query_engine=llama_cloud_query_engine, | ||
description="Useful for answering semantic questions about US cities.", | ||
name="llama_cloud_tool" | ||
) | ||
wf = initialize_workflow(sql_tool, llama_cloud_tool) | ||
|
||
# Streamlit UI | ||
st.title("Agentic RAG Integrating SQL Databases & Semantic Search") | ||
query = st.text_input("Enter your query:") | ||
if st.button("Submit") and query: | ||
try: | ||
response = asyncio.run(run_query(query, wf)) | ||
st.write("### Response:") | ||
st.write(str(response)) | ||
except Exception as e: | ||
st.error(f"An error occurred: {e}") | ||
except Exception as e: | ||
st.error(f"Application failed to start: {e}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,151 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from typing import Dict, List, Any, Optional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from llama_index.core.tools import BaseTool | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from llama_index.core.llms import ChatMessage | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from llama_index.core.llms.llm import ToolSelection, LLM | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from llama_index.core.workflow import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Workflow, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Event, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
StartEvent, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
StopEvent, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
step, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from llama_index.core.workflow import Context | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from llama_index.llms.openai import OpenAI | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class InputEvent(Event): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Input event.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class GatherToolsEvent(Event): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Gather Tools Event""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
tool_calls: Any | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class ToolCallEvent(Event): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Tool Call event""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
tool_call: ToolSelection | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class ToolCallEventResult(Event): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Tool call event result.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
msg: ChatMessage | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class RouterOutputAgentWorkflow(Workflow): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Custom router output agent workflow.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __init__(self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
tools: List[BaseTool], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
timeout: Optional[float] = 10.0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
disable_validation: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
verbose: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
llm: Optional[LLM] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
chat_history: Optional[List[ChatMessage]] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Constructor.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
super().__init__(timeout=timeout, disable_validation=disable_validation, verbose=verbose) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.tools: List[BaseTool] = tools | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.tools_dict: Optional[Dict[str, BaseTool]] = {tool.metadata.name: tool for tool in self.tools} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.llm: LLM = llm or OpenAI(temperature=0, model="gpt-3.5-turbo") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.chat_history: List[ChatMessage] = chat_history or [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def reset(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Resets Chat History""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.chat_history = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@step() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
async def prepare_chat(self, ev: StartEvent) -> InputEvent: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
message = ev.get("message") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if message is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
raise ValueError("'message' field is required.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# add msg to chat history | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
chat_history = self.chat_history | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
chat_history.append(ChatMessage(role="user", content=message)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return InputEvent() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@step() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
async def chat(self, ev: InputEvent) -> GatherToolsEvent | StopEvent: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Appends msg to chat history, then gets tool calls.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Put msg into LLM with tools included | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
chat_res = await self.llm.achat_with_tools( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.tools, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
chat_history=self.chat_history, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
verbose=self._verbose, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
allow_parallel_tool_calls=True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
tool_calls = self.llm.get_tool_calls_from_response(chat_res, error_on_no_tool_call=False) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ai_message = chat_res.message | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.chat_history.append(ai_message) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self._verbose: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
print(f"Chat message: {ai_message.content}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# no tool calls, return chat message. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if not tool_calls: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return StopEvent(result=ai_message.content) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return GatherToolsEvent(tool_calls=tool_calls) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@step(pass_context=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
async def dispatch_calls(self, ctx: Context, ev: GatherToolsEvent) -> ToolCallEvent: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Dispatches calls.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
tool_calls = ev.tool_calls | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
await ctx.set("num_tool_calls", len(tool_calls)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# trigger tool call events | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for tool_call in tool_calls: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ctx.send_event(ToolCallEvent(tool_call=tool_call)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@step() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
async def call_tool(self, ev: ToolCallEvent) -> ToolCallEventResult: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Calls tool.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
tool_call = ev.tool_call | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# get tool ID and function call | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
id_ = tool_call.tool_id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self._verbose: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
print(f"Calling function {tool_call.tool_name} with msg {tool_call.tool_kwargs}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# call function and put result into a chat message | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
tool = self.tools_dict[tool_call.tool_name] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
output = await tool.acall(**tool_call.tool_kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
msg = ChatMessage( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
name=tool_call.tool_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
content=str(output), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
role="tool", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
additional_kwargs={ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"tool_call_id": id_, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"name": tool_call.tool_name | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return ToolCallEventResult(msg=msg) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+111
to
+137
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Potential KeyError for missing tools.
+if tool_call.tool_name not in self.tools_dict:
+ raise RuntimeError(f"Tool '{tool_call.tool_name}' not found.")
tool = self.tools_dict[tool_call.tool_name] 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@step(pass_context=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
async def gather(self, ctx: Context, ev: ToolCallEventResult) -> StopEvent | None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Gathers tool calls.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# wait for all tool call events to finish. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
tool_events = ctx.collect_events(ev, [ToolCallEventResult] * await ctx.get("num_tool_calls")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if not tool_events: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for tool_event in tool_events: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# append tool call chat messages to history | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.chat_history.append(tool_event.msg) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# # after all tool calls finish, pass input event back, restart agent loop | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return InputEvent() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return type annotation mismatch.
The method signature indicates
-> ToolCallEvent
, but the function returnsNone
. This mismatch can cause confusion or break type-checking. Either adjust the return type or return aToolCallEvent
(or a list of them).📝 Committable suggestion