@@ -130,11 +130,43 @@ class AgentControllerData(BaseModel):
130
130
] = None
131
131
132
132
133
+ def save_messages_to_session_data (session_id , id , messages : List [AgentMessage ]):
134
+ from llmstack .apps .app_session_utils import save_app_session_data
135
+
136
+ logger .info (f"Saving messages to session data: { messages } " )
137
+
138
+ save_app_session_data (session_id , id , [m .model_dump_json () for m in messages ])
139
+
140
+
141
+ def load_messages_from_session_data (session_id , id ):
142
+ from llmstack .apps .app_session_utils import get_app_session_data
143
+
144
+ messages = []
145
+
146
+ session_data = get_app_session_data (session_id , id )
147
+ if session_data and isinstance (session_data , list ):
148
+ for data in session_data :
149
+ data_json = json .loads (data )
150
+ if data_json ["role" ] == "system" :
151
+ messages .append (AgentSystemMessage (** data_json ))
152
+ elif data_json ["role" ] == "assistant" :
153
+ messages .append (AgentAssistantMessage (** data_json ))
154
+ elif data_json ["role" ] == "user" :
155
+ messages .append (AgentUserMessage (** data_json ))
156
+
157
+ return messages
158
+
159
+
133
160
class AgentController :
134
161
def __init__ (self , output_queue : asyncio .Queue , config : AgentControllerConfig ):
162
+ self ._session_id = config .metadata .get ("session_id" )
163
+ self ._controller_id = f"{ config .metadata .get ('app_uuid' )} _agent"
164
+ self ._system_message = render_template (config .agent_config .system_message , {})
135
165
self ._output_queue = output_queue
136
166
self ._config = config
137
- self ._messages : List [AgentMessage ] = []
167
+ self ._messages : List [AgentMessage ] = (
168
+ load_messages_from_session_data (self ._session_id , self ._controller_id ) or []
169
+ )
138
170
self ._llm_client = None
139
171
self ._websocket = None
140
172
self ._provider_config = None
@@ -254,18 +286,6 @@ def _init_llm_client(self):
254
286
),
255
287
)
256
288
257
- self ._messages .append (
258
- AgentSystemMessage (
259
- role = AgentMessageRole .SYSTEM ,
260
- content = [
261
- AgentMessageContent (
262
- type = AgentMessageContentType .TEXT ,
263
- data = render_template (self ._config .agent_config .system_message , {}),
264
- )
265
- ],
266
- )
267
- )
268
-
269
289
async def _process_input_audio_stream (self ):
270
290
if self ._input_audio_stream :
271
291
async for chunk in self ._input_audio_stream .read_async ():
@@ -387,6 +407,10 @@ def process(self, data: AgentControllerData):
387
407
# Actor calls this to add a message to the conversation and trigger processing
388
408
self ._messages .append (data .data )
389
409
410
+ # This is a message from the assistant to the user, simply add it to the message to maintain state
411
+ if data .type == AgentControllerDataType .AGENT_OUTPUT_END or data .type == AgentControllerDataType .TOOL_CALLS_END :
412
+ return
413
+
390
414
try :
391
415
if len (self ._messages ) > self ._config .agent_config .max_steps :
392
416
raise Exception (f"Max steps ({ self ._config .agent_config .max_steps } ) exceeded: { len (self ._messages )} " )
@@ -465,7 +489,7 @@ async def process_messages(self, data: AgentControllerData):
465
489
stream = True if self ._config .agent_config .stream is None else self ._config .agent_config .stream
466
490
response = self ._llm_client .chat .completions .create (
467
491
model = self ._config .agent_config .model ,
468
- messages = client_messages ,
492
+ messages = [{ "role" : "system" , "content" : self . _system_message }] + client_messages ,
469
493
stream = stream ,
470
494
tools = self ._config .tools ,
471
495
)
@@ -703,6 +727,9 @@ async def add_ws_event_to_output_queue(self, event: Any):
703
727
logger .error (f"WebSocket error: { event } " )
704
728
705
729
def terminate (self ):
730
+ # Save to session data
731
+ save_messages_to_session_data (self ._session_id , self ._controller_id , self ._messages )
732
+
706
733
# Create task for graceful websocket closure
707
734
if hasattr (self , "_websocket" ) and self ._websocket :
708
735
asyncio .run_coroutine_threadsafe (self ._websocket .close (), self ._loop )
0 commit comments