Skip to content

Commit a137861

Browse files
authored
Fix session support for agents (#294)
1 parent 2cb4a88 commit a137861

File tree

4 files changed

+71
-21
lines changed

4 files changed

+71
-21
lines changed

llmstack/apps/apis.py

+1
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,7 @@ def _run_internal(self, request, app_uuid, input_data, source, app_data, session
776776
response = app_runner.run_until_complete(
777777
AppRunnerRequest(client_request_id=str(uuid.uuid4()), session_id=session_id, input=input_data), loop
778778
)
779+
async_to_sync(app_runner.stop)()
779780
return response
780781

781782
def run(self, request, uid, session_id=None):

llmstack/apps/runner/agent_actor.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Dict, List
66

77
from llmstack.apps.runner.agent_controller import (
8+
AgentAssistantMessage,
89
AgentController,
910
AgentControllerConfig,
1011
AgentControllerData,
@@ -126,15 +127,29 @@ async def _process_output(self):
126127
)
127128

128129
elif controller_output.type == AgentControllerDataType.AGENT_OUTPUT_END:
130+
agent_final_output = (
131+
self._stitched_data["agent"][str(message_index)].data.content[0].data
132+
if str(message_index) in self._stitched_data["agent"]
133+
else ""
134+
)
135+
self._agent_controller.process(
136+
AgentControllerData(
137+
type=AgentControllerDataType.AGENT_OUTPUT_END,
138+
data=AgentAssistantMessage(
139+
content=[
140+
AgentMessageContent(
141+
type=AgentMessageContentType.TEXT,
142+
data=agent_final_output,
143+
)
144+
]
145+
),
146+
)
147+
)
129148
self._content_queue.put_nowait(
130149
{
131150
"output": {
132151
**self._agent_outputs,
133-
"output": (
134-
self._stitched_data["agent"][str(message_index)].data.content[0].data
135-
if str(message_index) in self._stitched_data["agent"]
136-
else ""
137-
),
152+
"output": agent_final_output,
138153
},
139154
"chunks": self._stitched_data,
140155
}
@@ -174,6 +189,12 @@ async def _process_output(self):
174189

175190
elif controller_output.type == AgentControllerDataType.TOOL_CALLS_END:
176191
tool_calls = self._stitched_data["agent"][str(message_index)].data.tool_calls
192+
self._agent_controller.process(
193+
AgentControllerData(
194+
type=AgentControllerDataType.TOOL_CALLS_END,
195+
data=AgentToolCallsMessage(tool_calls=tool_calls),
196+
),
197+
)
177198

178199
for tool_call in tool_calls:
179200
tool_call_args = tool_call.arguments

llmstack/apps/runner/agent_controller.py

+41-14
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,43 @@ class AgentControllerData(BaseModel):
130130
] = None
131131

132132

133+
def save_messages_to_session_data(session_id, id, messages: List[AgentMessage]):
134+
from llmstack.apps.app_session_utils import save_app_session_data
135+
136+
logger.info(f"Saving messages to session data: {messages}")
137+
138+
save_app_session_data(session_id, id, [m.model_dump_json() for m in messages])
139+
140+
141+
def load_messages_from_session_data(session_id, id):
142+
from llmstack.apps.app_session_utils import get_app_session_data
143+
144+
messages = []
145+
146+
session_data = get_app_session_data(session_id, id)
147+
if session_data and isinstance(session_data, list):
148+
for data in session_data:
149+
data_json = json.loads(data)
150+
if data_json["role"] == "system":
151+
messages.append(AgentSystemMessage(**data_json))
152+
elif data_json["role"] == "assistant":
153+
messages.append(AgentAssistantMessage(**data_json))
154+
elif data_json["role"] == "user":
155+
messages.append(AgentUserMessage(**data_json))
156+
157+
return messages
158+
159+
133160
class AgentController:
134161
def __init__(self, output_queue: asyncio.Queue, config: AgentControllerConfig):
162+
self._session_id = config.metadata.get("session_id")
163+
self._controller_id = f"{config.metadata.get('app_uuid')}_agent"
164+
self._system_message = render_template(config.agent_config.system_message, {})
135165
self._output_queue = output_queue
136166
self._config = config
137-
self._messages: List[AgentMessage] = []
167+
self._messages: List[AgentMessage] = (
168+
load_messages_from_session_data(self._session_id, self._controller_id) or []
169+
)
138170
self._llm_client = None
139171
self._websocket = None
140172
self._provider_config = None
@@ -254,18 +286,6 @@ def _init_llm_client(self):
254286
),
255287
)
256288

257-
self._messages.append(
258-
AgentSystemMessage(
259-
role=AgentMessageRole.SYSTEM,
260-
content=[
261-
AgentMessageContent(
262-
type=AgentMessageContentType.TEXT,
263-
data=render_template(self._config.agent_config.system_message, {}),
264-
)
265-
],
266-
)
267-
)
268-
269289
async def _process_input_audio_stream(self):
270290
if self._input_audio_stream:
271291
async for chunk in self._input_audio_stream.read_async():
@@ -387,6 +407,10 @@ def process(self, data: AgentControllerData):
387407
# Actor calls this to add a message to the conversation and trigger processing
388408
self._messages.append(data.data)
389409

410+
# This is a message from the assistant to the user, simply add it to the message to maintain state
411+
if data.type == AgentControllerDataType.AGENT_OUTPUT_END or data.type == AgentControllerDataType.TOOL_CALLS_END:
412+
return
413+
390414
try:
391415
if len(self._messages) > self._config.agent_config.max_steps:
392416
raise Exception(f"Max steps ({self._config.agent_config.max_steps}) exceeded: {len(self._messages)}")
@@ -465,7 +489,7 @@ async def process_messages(self, data: AgentControllerData):
465489
stream = True if self._config.agent_config.stream is None else self._config.agent_config.stream
466490
response = self._llm_client.chat.completions.create(
467491
model=self._config.agent_config.model,
468-
messages=client_messages,
492+
messages=[{"role": "system", "content": self._system_message}] + client_messages,
469493
stream=stream,
470494
tools=self._config.tools,
471495
)
@@ -703,6 +727,9 @@ async def add_ws_event_to_output_queue(self, event: Any):
703727
logger.error(f"WebSocket error: {event}")
704728

705729
def terminate(self):
730+
# Save to session data
731+
save_messages_to_session_data(self._session_id, self._controller_id, self._messages)
732+
706733
# Create task for graceful websocket closure
707734
if hasattr(self, "_websocket") and self._websocket:
708735
asyncio.run_coroutine_threadsafe(self._websocket.close(), self._loop)

llmstack/apps/runner/app_runner.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -539,9 +539,10 @@ async def run(self, request: AppRunnerRequest):
539539
)
540540

541541
def run_until_complete(self, request: AppRunnerRequest, event_loop):
542+
final_response = None
542543
for response in iter_over_async(self.run(request), event_loop):
543544
if isinstance(response.data, AppRunnerResponseErrorsData) or isinstance(
544545
response.data, AppRunnerResponseOutputData
545546
):
546-
break
547-
return response
547+
final_response = response
548+
return final_response

0 commit comments

Comments
 (0)