Skip to content

Commit 370d082

Browse files
Text2SQL Limit Clause & Disambiguation Fix (#167)
1 parent a15b88c commit 370d082

File tree

6 files changed

+39
-77
lines changed

6 files changed

+39
-77
lines changed

text_2_sql/autogen/pyproject.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ authors = [
99
requires-python = ">=3.11"
1010
dependencies = [
1111
"aiostream>=0.6.4",
12-
"autogen-agentchat==0.4.5",
13-
"autogen-core==0.4.5",
14-
"autogen-ext[azure,openai]==0.4.5",
12+
"autogen-agentchat==0.4.7",
13+
"autogen-core==0.4.7",
14+
"autogen-ext[azure,openai]==0.4.7",
1515
"grpcio>=1.68.1",
1616
"pyyaml>=6.0.2",
1717
"text_2_sql_core",

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,7 @@ def termination_condition(self):
8383
termination = (
8484
SourceMatchTermination("answer_agent")
8585
| SourceMatchTermination("answer_with_follow_up_suggestions_agent")
86-
# | TextMentionTermination(
87-
# "[]",
88-
# sources=["user_message_rewrite_agent"],
89-
# )
90-
| TextMentionTermination(
91-
"contains_disambiguation_requests",
92-
sources=["parallel_query_solving_agent"],
93-
)
86+
| TextMentionTermination("contains_disambiguation_requests")
9487
| MaxMessageTermination(5)
9588
)
9689
return termination

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from aiostream import stream
1818
from json import JSONDecodeError
1919
import re
20-
import os
2120
from pydantic import BaseModel, Field
2221

2322

@@ -226,23 +225,12 @@ async def consume_inner_messages_from_agentic_flow(
226225
# Create an instance of the InnerAutoGenText2Sql class
227226
inner_autogen_text_2_sql = InnerAutoGenText2Sql(**self.kwargs)
228227

229-
# Add database connection info to injected parameters
230-
query_params = injected_parameters.copy() if injected_parameters else {}
231-
if "Text2Sql__Tsql__ConnectionString" in os.environ:
232-
query_params["database_connection_string"] = os.environ[
233-
"Text2Sql__Tsql__ConnectionString"
234-
]
235-
if "Text2Sql__Tsql__Database" in os.environ:
236-
query_params["database_name"] = os.environ[
237-
"Text2Sql__Tsql__Database"
238-
]
239-
240228
# Launch tasks for each sub-query
241229
inner_solving_generators.append(
242230
consume_inner_messages_from_agentic_flow(
243231
inner_autogen_text_2_sql.process_user_message(
244232
user_message=parallel_message,
245-
injected_parameters=query_params,
233+
injected_parameters=injected_parameters,
246234
database_results=filtered_parallel_messages.database_results,
247235
),
248236
parallel_message,
@@ -294,7 +282,7 @@ async def consume_inner_messages_from_agentic_flow(
294282
),
295283
)
296284

297-
break
285+
return
298286

299287
# Final response
300288
yield Response(

text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py

+7-37
Original file line numberDiff line numberDiff line change
@@ -44,29 +44,6 @@ def __init__(self, **kwargs: dict):
4444
self.kwargs = kwargs
4545
self.set_mode()
4646

47-
# Store original environment variables
48-
self.original_db_conn = os.environ.get("Text2Sql__Tsql__ConnectionString")
49-
self.original_db_name = os.environ.get("Text2Sql__Tsql__Database")
50-
51-
def _update_environment(self, injected_parameters: dict = None):
52-
"""Update environment variables with injected parameters."""
53-
if injected_parameters:
54-
if "database_connection_string" in injected_parameters:
55-
os.environ["Text2Sql__Tsql__ConnectionString"] = injected_parameters[
56-
"database_connection_string"
57-
]
58-
if "database_name" in injected_parameters:
59-
os.environ["Text2Sql__Tsql__Database"] = injected_parameters[
60-
"database_name"
61-
]
62-
63-
def _restore_environment(self):
64-
"""Restore original environment variables."""
65-
if self.original_db_conn:
66-
os.environ["Text2Sql__Tsql__ConnectionString"] = self.original_db_conn
67-
if self.original_db_name:
68-
os.environ["Text2Sql__Tsql__Database"] = self.original_db_name
69-
7047
def set_mode(self):
7148
"""Set the mode of the plugin based on the environment variables."""
7249
self.pre_run_query_cache = (
@@ -195,19 +172,12 @@ def process_user_message(
195172
"""
196173
logging.info("Processing question: %s", user_message)
197174

198-
# Update environment with injected parameters
199-
self._update_environment(injected_parameters)
200-
201-
try:
202-
agent_input = {
203-
"user_message": user_message,
204-
"injected_parameters": injected_parameters,
205-
}
175+
agent_input = {
176+
"user_message": user_message,
177+
"injected_parameters": injected_parameters,
178+
}
206179

207-
if database_results:
208-
agent_input["database_results"] = database_results
180+
if database_results:
181+
agent_input["database_results"] = database_results
209182

210-
return self.agentic_flow.run_stream(task=json.dumps(agent_input))
211-
finally:
212-
# Restore original environment
213-
self._restore_environment()
183+
return self.agentic_flow.run_stream(task=json.dumps(agent_input))

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -311,11 +311,22 @@ def handle_node(node):
311311
current_limit = parsed_query.args.get("limit")
312312
logging.debug("Current Limit: %s", current_limit)
313313

314-
if current_limit is None or current_limit.value > self.row_limit:
314+
# More defensive check to handle different structures
315+
should_add_limit = True
316+
if current_limit is not None:
317+
try:
318+
if hasattr(current_limit, "expression") and hasattr(
319+
current_limit.expression, "value"
320+
):
321+
if current_limit.expression.value <= self.row_limit:
322+
should_add_limit = False
323+
except AttributeError:
324+
logging.warning("Unexpected limit structure: %s", current_limit)
325+
326+
if should_add_limit:
315327
# Create a new LIMIT expression
316328
limit_expr = Limit(expression=Literal.number(self.row_limit))
317-
318-
# Attach it to the query by setting it on the SELECT expression
329+
# Attach it to the query
319330
parsed_query.set("limit", limit_expr)
320331
updated_parsed_queries.append(
321332
parsed_query.sql(dialect=self.database_engine.value.lower())

uv.lock

+12-12
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)