Skip to content

Dev branch for the ToolUseAgent #239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/agentlab/agents/agent_args.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import bgym
from bgym import AbstractAgentArgs

from agentlab.experiments.benchmark import Benchmark


class AgentArgs(AbstractAgentArgs):
"""Base class for agent arguments for instantiating an agent.
Expand All @@ -14,7 +16,7 @@ class MyAgentArgs(AgentArgs):
Note: for working properly with AgentXRay, the arguments need to be serializable and hasable.
"""

def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode: bool):
def set_benchmark(self, benchmark: Benchmark, demo_mode: bool):
"""Optional method to set benchmark specific flags.

This allows the agent to have minor adjustments based on the benchmark.
Expand Down
10 changes: 3 additions & 7 deletions src/agentlab/agents/dynamic_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,9 @@

import bgym
from browsergym.core.action.base import AbstractActionSet
from browsergym.utils.obs import (
flatten_axtree_to_str,
flatten_dom_to_str,
overlay_som,
prune_html,
)
from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, overlay_som, prune_html

from agentlab.experiments.benchmark import HighLevelActionSetArgs
from agentlab.llm.llm_utils import (
BaseMessage,
ParseError,
Expand Down Expand Up @@ -99,7 +95,7 @@ class ObsFlags(Flags):

@dataclass
class ActionFlags(Flags):
action_set: bgym.HighLevelActionSetArgs = None # should be set by the set_benchmark method
action_set: HighLevelActionSetArgs = None # should be set by the set_benchmark method
long_description: bool = True
individual_examples: bool = False

Expand Down
13 changes: 7 additions & 6 deletions src/agentlab/agents/generic_agent/agent_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from agentlab.agents import dynamic_prompting as dp
from agentlab.experiments import args
from agentlab.experiments.benchmark import HighLevelActionSetArgs
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT

from .generic_agent import GenericAgentArgs
Expand All @@ -31,7 +32,7 @@
filter_visible_elements_only=False,
),
action=dp.ActionFlags(
action_set=bgym.HighLevelActionSetArgs(
action_set=HighLevelActionSetArgs(
subsets=["bid"],
multiaction=False,
),
Expand Down Expand Up @@ -79,7 +80,7 @@
filter_visible_elements_only=False,
),
action=dp.ActionFlags(
action_set=bgym.HighLevelActionSetArgs(
action_set=HighLevelActionSetArgs(
subsets=["bid"],
multiaction=False,
),
Expand Down Expand Up @@ -126,7 +127,7 @@
filter_visible_elements_only=False,
),
action=dp.ActionFlags(
action_set=bgym.HighLevelActionSetArgs(
action_set=HighLevelActionSetArgs(
subsets=["bid"],
multiaction=False,
),
Expand Down Expand Up @@ -176,7 +177,7 @@
filter_visible_elements_only=False,
),
action=dp.ActionFlags(
action_set=bgym.HighLevelActionSetArgs(
action_set=HighLevelActionSetArgs(
subsets=["bid"],
multiaction=True,
),
Expand Down Expand Up @@ -231,7 +232,7 @@
filter_visible_elements_only=False,
),
action=dp.ActionFlags(
action_set=bgym.HighLevelActionSetArgs(
action_set=HighLevelActionSetArgs(
subsets=["bid"],
multiaction=False,
),
Expand Down Expand Up @@ -319,7 +320,7 @@
filter_visible_elements_only=args.Choice([True, False], p=[0.3, 0.7]),
),
action=dp.ActionFlags(
action_set=bgym.HighLevelActionSetArgs(
action_set=HighLevelActionSetArgs(
subsets=args.Choice([["bid"], ["bid", "coord"]]),
multiaction=args.Choice([True, False], p=[0.7, 0.3]),
),
Expand Down
5 changes: 3 additions & 2 deletions src/agentlab/agents/generic_agent/generic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,20 @@

from copy import deepcopy
from dataclasses import asdict, dataclass
from functools import partial
from warnings import warn

import bgym
from browsergym.experiments.agent import Agent, AgentInfo

from agentlab.agents import dynamic_prompting as dp
from agentlab.agents.agent_args import AgentArgs
from agentlab.experiments.benchmark import Benchmark
from agentlab.llm.chat_api import BaseModelArgs
from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry
from agentlab.llm.tracking import cost_tracker_decorator

from .generic_agent_prompt import GenericPromptFlags, MainPrompt
from functools import partial


@dataclass
Expand All @@ -37,7 +38,7 @@ def __post_init__(self):
except AttributeError:
pass

def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode):
def set_benchmark(self, benchmark: Benchmark, demo_mode):
"""Override Some flags based on the benchmark."""
if benchmark.name.startswith("miniwob"):
self.flags.obs.use_html = True
Expand Down
3 changes: 2 additions & 1 deletion src/agentlab/agents/generic_agent/reproducibility_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from bs4 import BeautifulSoup

from agentlab.agents.agent_args import AgentArgs
from agentlab.experiments.benchmark import HighLevelActionSetArgs
from agentlab.experiments.loop import ExpArgs, ExpResult, yield_all_exp_results
from agentlab.experiments.study import Study
from agentlab.llm.chat_api import make_assistant_message
Expand Down Expand Up @@ -144,7 +145,7 @@ def _make_backward_compatible(agent_args: GenericAgentArgs):
if isinstance(action_set, str):
action_set = action_set.split("+")

agent_args.flags.action.action_set = bgym.HighLevelActionSetArgs(
agent_args.flags.action.action_set = HighLevelActionSetArgs(
subsets=action_set,
multiaction=agent_args.flags.action.multi_actions,
)
Expand Down
Empty file.
184 changes: 184 additions & 0 deletions src/agentlab/agents/tool_use_agent/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import json
import logging
from copy import deepcopy as copy
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any

import bgym
from browsergym.core.observation import extract_screenshot

from agentlab.agents.agent_args import AgentArgs
from agentlab.llm.llm_utils import image_to_png_base64_url
from agentlab.llm.response_api import OpenAIResponseModelArgs
from agentlab.llm.tracking import cost_tracker_decorator

if TYPE_CHECKING:
from openai.types.responses import Response


@dataclass
class ToolUseAgentArgs(AgentArgs):
temperature: float = 0.1
model_args: OpenAIResponseModelArgs = None

def __post_init__(self):
try:
self.agent_name = f"ToolUse-{self.model_args.model_name}".replace("/", "_")
except AttributeError:
pass

def make_agent(self) -> bgym.Agent:
return ToolUseAgent(
temperature=self.temperature,
model_args=self.model_args,
)

def set_reproducibility_mode(self):
self.temperature = 0

def prepare(self):
return self.model_args.prepare_server()

def close(self):
return self.model_args.close_server()


class ToolUseAgent(bgym.Agent):
def __init__(
self,
temperature: float,
model_args: OpenAIResponseModelArgs,
):
self.temperature = temperature
self.chat = model_args.make_model()
self.model_args = model_args

self.action_set = bgym.HighLevelActionSet(["coord"], multiaction=False)

self.tools = self.action_set.to_tool_description()

# self.tools.append(
# {
# "type": "function",
# "name": "chain_of_thought",
# "description": "A tool that allows the agent to think step by step. Every other action must ALWAYS be preceeded by a call to this tool.",
# "parameters": {
# "type": "object",
# "properties": {
# "thoughts": {
# "type": "string",
# "description": "The agent's reasoning process.",
# },
# },
# "required": ["thoughts"],
# },
# }
# )

self.llm = model_args.make_model(extra_kwargs={"tools": self.tools})

self.messages = []

def obs_preprocessor(self, obs):
page = obs.pop("page", None)
if page is not None:
obs["screenshot"] = extract_screenshot(page)
else:
raise ValueError("No page found in the observation.")

return obs

@cost_tracker_decorator
def get_action(self, obs: Any) -> tuple[str, dict]:

if len(self.messages) == 0:
system_message = {
"role": "system",
"content": "You are an agent. Based on the observation, you will decide which action to take to accomplish your goal.",
}
goal_object = [el for el in obs["goal_object"]]
for content in goal_object:
if content["type"] == "text":
content["type"] = "input_text"
elif content["type"] == "image_url":
content["type"] = "input_image"
goal_message = {"role": "user", "content": goal_object}
goal_message["content"].append(
{
"type": "input_image",
"image_url": image_to_png_base64_url(obs["screenshot"]),
}
)
self.messages.append(system_message)
self.messages.append(goal_message)
else:
if obs["last_action_error"] == "":
self.messages.append(
{
"type": "function_call_output",
"call_id": self.previous_call_id,
"output": "Function call executed, see next observation.",
}
)
self.messages.append(
{
"role": "user",
"content": [
{
"type": "input_image",
"image_url": image_to_png_base64_url(obs["screenshot"]),
}
],
}
)
else:
self.messages.append(
{
"type": "function_call_output",
"call_id": self.previous_call_id,
"output": f"Function call failed: {obs['last_action_error']}",
}
)

response: "Response" = self.llm(
messages=self.messages,
temperature=self.temperature,
)

action = "noop()"
think = ""
for output in response.output:
if output.type == "function_call":
arguments = json.loads(output.arguments)
action = f"{output.name}({", ".join([f"{k}={v}" for k, v in arguments.items()])})"
self.previous_call_id = output.call_id
self.messages.append(output)
break
elif output.type == "reasoning":
if len(output.summary) > 0:
think += output.summary[0].text + "\n"
self.messages.append(output)

return (
action,
bgym.AgentInfo(
think=think,
chat_messages=[],
stats={},
),
)


MODEL_CONFIG = OpenAIResponseModelArgs(
model_name="o4-mini-2025-04-16",
max_total_tokens=200_000,
max_input_tokens=200_000,
max_new_tokens=100_000,
temperature=0.1,
vision_support=True,
)

AGENT_CONFIG = ToolUseAgentArgs(
temperature=0.1,
model_args=MODEL_CONFIG,
)
8 changes: 5 additions & 3 deletions src/agentlab/agents/visual_agent/agent_configs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import bgym

import agentlab.agents.dynamic_prompting as dp
from agentlab.experiments.benchmark import HighLevelActionSetArgs
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT

from .visual_agent import VisualAgentArgs
from .visual_agent_prompts import PromptFlags
import agentlab.agents.dynamic_prompting as dp
import bgym

# the other flags are ignored for this agent.
DEFAULT_OBS_FLAGS = dp.ObsFlags(
Expand All @@ -16,7 +18,7 @@
)

DEFAULT_ACTION_FLAGS = dp.ActionFlags(
action_set=bgym.HighLevelActionSetArgs(subsets=["coord"]),
action_set=HighLevelActionSetArgs(subsets=["coord"]),
long_description=True,
individual_examples=False,
)
Expand Down
5 changes: 3 additions & 2 deletions src/agentlab/agents/visual_agent/visual_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

from agentlab.agents import dynamic_prompting as dp
from agentlab.agents.agent_args import AgentArgs
from agentlab.experiments.benchmark import Benchmark
from agentlab.llm.chat_api import BaseModelArgs
from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry
from agentlab.llm.tracking import cost_tracker_decorator

from .visual_agent_prompts import PromptFlags, MainPrompt
from .visual_agent_prompts import MainPrompt, PromptFlags


@dataclass
Expand All @@ -34,7 +35,7 @@ def __post_init__(self):
except AttributeError:
pass

def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode):
def set_benchmark(self, benchmark: Benchmark, demo_mode):
"""Override Some flags based on the benchmark."""
self.flags.obs.use_tabs = benchmark.is_multi_tab

Expand Down
2 changes: 2 additions & 0 deletions src/agentlab/experiments/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .base import Benchmark, HighLevelActionSetArgs
from .configs import DEFAULT_BENCHMARKS
Loading
Loading