Source code for tooluniverse.smolagent_tool
from __future__ import annotations
import threading
from typing import Any, Callable, Dict, List, Optional
from .base_tool import BaseTool
from .tool_registry import register_tool
# Global lock for stdout/stderr redirection (for thread safety)
_STREAM_LOCK = threading.Lock()
def _safe_import(module_path: str, symbol: str):
"""Safely import a symbol from a module, raising a helpful error if missing."""
try:
module = __import__(module_path, fromlist=[symbol])
return getattr(module, symbol)
except Exception as e: # noqa: BLE001
raise ImportError(
f"Failed to import '{symbol}' from '{module_path}'. Please install and configure 'smolagents'. Original error: {e}"
)
[docs]
class ToolUniverseTool: # Lazy base; will subclass smolagents.Tool at runtime
"""
Adapter that wraps a ToolUniverse tool and exposes it as a smolagents Tool.
We create the real subclass dynamically to avoid hard dependency when the
module is imported without smolagents installed.
"""
[docs]
def __new__(cls, *args, **kwargs): # pragma: no cover - construct dynamic subclass
# Import here to avoid import-time dependency when not used
Tool = _safe_import("smolagents", "Tool")
# Arguments: tool_name, tooluniverse_instance, tool_config
tool_name: str = args[0]
tooluniverse_instance = args[1]
tool_config = args[2] if len(args) > 2 else None
tu_config = getattr(tooluniverse_instance, "all_tool_dict", {}).get(
tool_name, {}
)
# Helpers to build class attributes
def _convert_parameter_schema(parameter_schema: Dict) -> Dict:
properties = parameter_schema.get("properties", {})
required = set(parameter_schema.get("required", []) or [])
inputs: Dict[str, Dict[str, Any]] = {}
for param_name, info in properties.items():
entry: Dict[str, Any] = {
"type": info.get("type", "string"),
"description": info.get("description", ""),
}
# All optional parameters (not in required list) should be nullable
if param_name not in required:
entry["nullable"] = True
inputs[param_name] = entry
return inputs
def _infer_output_type(return_schema: Dict) -> str:
schema_type = return_schema.get("type", "string")
mapping = {
"object": "string",
"array": "string",
"string": "string",
"integer": "integer",
"number": "number",
"boolean": "boolean",
}
return mapping.get(schema_type, "string")
inputs_schema = _convert_parameter_schema(tu_config.get("parameter", {}))
output_type = _infer_output_type(tu_config.get("return_schema", {}))
# Build a forward function with explicit parameters to satisfy
# smolagents' validation (parameters must match keys in `inputs`).
def __call_tool(
self, __kwargs, _tool_name=tool_name, _tu=tooluniverse_instance
):
try:
result = _tu.run_one_function(
{"name": _tool_name, "arguments": __kwargs}
)
if isinstance(result, dict):
import json
return json.dumps(result, ensure_ascii=False)
return result
except Exception as e: # noqa: BLE001
return f"Error executing tool {_tool_name}: {e}"
param_names = list(inputs_schema.keys())
if param_names:
# Dynamically create a function with signature: (self, p1, p2=None, ...)
# Required params have no default, nullable params get =None
# In Python, parameters with defaults must come after those without
parameter_schema = tu_config.get("parameter", {})
required = set(parameter_schema.get("required", []) or [])
required_params = [p for p in param_names if p in required]
optional_params = [p for p in param_names if p not in required]
# Build signature: required params first, then optional with =None
params_list = required_params + [f"{p}=None" for p in optional_params]
params_sig = ", ".join(params_list)
body_lines = [" _kwargs = {"]
for p in param_names:
body_lines.append(f" '{p}': {p},")
body_lines.append(" }")
body_lines.append(" return __call_tool(self, _kwargs)")
func_src = [f"def _forward(self, {params_sig}):"] + body_lines
func_src = "\n".join(func_src)
ns: Dict[str, Any] = {"__call_tool": __call_tool}
exec(func_src, ns)
_forward = ns["_forward"] # type: ignore[assignment]
else:
# No inputs -> 0-arg forward
def _forward(self): # type: ignore[override]
return __call_tool(self, {})
attrs = {
"name": tool_name,
"description": tu_config.get("description", ""),
"inputs": inputs_schema,
"output_type": output_type,
"forward": _forward,
"tool_config": tool_config or {},
}
DynamicToolCls = type(f"ToolUniverseTool_{tool_name}", (Tool,), attrs) # type: ignore[misc]
return DynamicToolCls()
[docs]
@classmethod
def from_tooluniverse(
cls,
tool_name: str,
tooluniverse_instance,
tool_config: Optional[Dict[str, Any]] = None,
):
"""Factory to create a smolagents-compatible Tool from a ToolUniverse tool.
This mirrors common factory patterns (e.g., from_langchain) and returns
an instance of the dynamically constructed Tool subclass.
"""
return cls(tool_name, tooluniverse_instance, tool_config or {})
[docs]
@register_tool("SmolAgentTool")
class SmolAgentTool(BaseTool):
"""Wrap smolagents agents so they can be used as ToolUniverse tools.
Supports:
- CodeAgent, ToolCallingAgent, Agent, ManagedAgent
- Mixed tools: ToolUniverse tools and smolagents-native tools
- Streaming integration with ToolUniverse stream callbacks
"""
[docs]
def __init__(self, tool_config: Dict[str, Any], tooluniverse=None):
super().__init__(tool_config)
settings = tool_config.get("settings", {})
self.agent_type: str = settings.get("agent_type", "CodeAgent")
self.available_tools: List[Any] = settings.get("available_tools", [])
self.model_config: Dict[str, Any] = settings.get("model", {})
self.agent_init_params: Dict[str, Any] = settings.get("agent_init_params", {})
self.sub_agents_config: List[Dict[str, Any]] = settings.get("sub_agents", [])
# Set by ToolUniverse runtime or passed as parameter
self.tooluniverse = tooluniverse
self.agent = None
# -------------------------
# Initialization helpers
# -------------------------
[docs]
def _get_api_key(self) -> Optional[str]:
api_key = self.model_config.get("api_key")
if isinstance(api_key, str) and api_key.startswith("env:"):
import os
return os.environ.get(api_key[4:])
return api_key
[docs]
def _init_model(self):
provider = self.model_config.get("provider", "HfApiModel")
model_id = self.model_config.get("model_id")
api_key = self._get_api_key()
if provider == "HfApiModel":
HfApiModel = _safe_import("smolagents", "HfApiModel")
return HfApiModel(model_id, token=api_key)
if provider == "OpenAIModel":
OpenAIModel = _safe_import("smolagents", "OpenAIModel")
return OpenAIModel(
model_id=model_id,
api_key=api_key,
api_base=self.model_config.get("api_base"),
)
if provider == "LiteLLMModel":
LiteLLMModel = _safe_import("smolagents", "LiteLLMModel")
return LiteLLMModel(model_id=model_id, api_key=api_key)
if provider == "InferenceClientModel":
InferenceClientModel = _safe_import("smolagents", "InferenceClientModel")
return InferenceClientModel(
model_id=model_id,
provider=self.model_config.get("provider_name"),
token=api_key,
)
if provider == "TransformersModel":
TransformersModel = _safe_import("smolagents", "TransformersModel")
return TransformersModel(
model_id=model_id,
)
if provider == "AzureOpenAIModel":
AzureOpenAIModel = _safe_import("smolagents", "AzureOpenAIModel")
return AzureOpenAIModel(
model_id=model_id,
azure_endpoint=self.model_config.get("azure_endpoint"),
api_key=api_key,
api_version=self.model_config.get("api_version"),
)
if provider == "AmazonBedrockModel":
AmazonBedrockModel = _safe_import("smolagents", "AmazonBedrockModel")
return AmazonBedrockModel(model_id=model_id)
raise ValueError(f"Unsupported model provider: {provider}")
[docs]
def _import_smolagents_tool(self, class_name: str, import_path: str):
"""Dynamically import smolagents tool class with helpful error messages."""
import importlib
try:
module = importlib.import_module(import_path)
except ImportError as e:
raise ImportError(
f"Failed to import module '{import_path}' for smolagents tool '{class_name}'. "
f"Please ensure the module path is correct. "
f"Common paths include 'smolagents.tools' or 'smolagents.default_tools'. "
f"Original error: {e}"
) from e
try:
tool_class = getattr(module, class_name)
except AttributeError as e:
available_attrs = [attr for attr in dir(module) if not attr.startswith("_")]
raise AttributeError(
f"Class '{class_name}' not found in module '{import_path}'. "
f"Available classes in the module: {', '.join(available_attrs[:10])}"
f"{'...' if len(available_attrs) > 10 else ''}. "
f"Please check the class name spelling and ensure it exists in the module. "
f"Original error: {e}"
) from e
return tool_class
[docs]
def _convert_tools(self) -> List[Any]:
"""Convert mixed tool definitions to smolagents Tool instances."""
converted: List[Any] = []
for spec in self.available_tools:
if isinstance(spec, str):
converted.append(
ToolUniverseTool.from_tooluniverse(spec, self.tooluniverse)
)
continue
if not isinstance(spec, dict):
continue
spec_type = spec.get("type", "tooluniverse")
if spec_type == "smolagents":
cls_name = spec.get("class")
if not cls_name:
continue
import_path = spec.get("import_path", "smolagents.tools")
kwargs = spec.get("kwargs", {})
tool_cls = self._import_smolagents_tool(cls_name, import_path)
converted.append(tool_cls(**kwargs))
else:
name = spec.get("name")
if name:
converted.append(
ToolUniverseTool.from_tooluniverse(name, self.tooluniverse)
)
return converted
[docs]
def _create_sub_agents(self, sub_configs: List[Dict[str, Any]]) -> List[Any]:
"""Recursively create sub-agent instances (for ManagedAgent)."""
sub_agents: List[Any] = []
for cfg in sub_configs:
sub_tool_config: Dict[str, Any] = {
"name": cfg.get("name", "sub_agent"),
"type": "SmolAgentTool",
"description": cfg.get("description", ""),
"settings": cfg,
}
sub_tool = SmolAgentTool(sub_tool_config, tooluniverse=self.tooluniverse)
sub_tool._init_agent()
if sub_tool.agent is not None:
sub_agents.append(sub_tool.agent)
return sub_agents
[docs]
def _init_agent(self) -> None:
if self.agent is not None:
return
model = self._init_model()
tools = self._convert_tools()
init_kwargs: Dict[str, Any] = {"tools": tools, "model": model}
# Give the agent an explicit name if supported
if isinstance(self.tool_config.get("name", None), str):
init_kwargs["name"] = self.tool_config["name"]
init_kwargs.update(self.agent_init_params or {})
# Sanitize unsupported kwargs based on agent type and common params
def _sanitize(agent_type: str, params: Dict[str, Any]) -> Dict[str, Any]:
common_allowed = {
"tools",
"model",
"name",
"prompt_templates",
"planning_interval",
"stream_outputs",
"max_steps",
}
codeagent_allowed = common_allowed.union(
{
"add_base_tools",
"additional_authorized_imports",
"verbosity_level",
"executor_type",
"executor_kwargs",
}
)
toolcalling_allowed = common_allowed
agent_allowed = common_allowed
if agent_type == "CodeAgent":
allowed = codeagent_allowed
elif agent_type == "ToolCallingAgent":
allowed = toolcalling_allowed
elif agent_type == "Agent" or agent_type == "ManagedAgent":
allowed = agent_allowed
else:
allowed = common_allowed
# Drop unsupported keys (e.g., max_tool_threads)
return {k: v for k, v in params.items() if k in allowed}
init_kwargs = _sanitize(self.agent_type, init_kwargs)
# Construct agent by type
if self.agent_type == "ManagedAgent":
# Emulate a managed multi-agent system by wrapping sub-agents
# as smolagents Tools and composing a top-level CodeAgent.
CodeAgent = _safe_import("smolagents", "CodeAgent")
# Convert top-level available tools
top_tools = tools[:]
# Build sub-agents and wrap as tools
sub_agents = self._create_sub_agents(self.sub_agents_config)
# Dynamically create a Tool wrapper around a smolagents agent
Tool = _safe_import("smolagents", "Tool")
def _wrap_agent_as_tool(agent_obj, tool_name: str):
# smolagents expects class attributes on Tool subclasses
def _forward(self, task: str): # type: ignore[override]
return agent_obj.run(task)
attrs = {
"name": tool_name,
"description": f"Agent tool wrapper for {tool_name}",
"inputs": {
"task": {
"type": "string",
"description": "Task for sub-agent",
}
},
"output_type": "string",
"forward": _forward,
}
AgentToolCls = type(f"AgentTool_{tool_name}", (Tool,), attrs) # type: ignore[misc]
return AgentToolCls()
for idx, sub in enumerate(sub_agents):
name = getattr(sub, "name", f"sub_agent_{idx + 1}")
top_tools.append(_wrap_agent_as_tool(sub, name))
# Construct the orchestrator agent (CodeAgent) with both native tools and agent-tools
orchestrator_kwargs = {"tools": top_tools, "model": model}
if isinstance(self.tool_config.get("name", None), str):
orchestrator_kwargs["name"] = self.tool_config["name"]
orchestrator_kwargs.update(self.agent_init_params or {})
orchestrator_kwargs = _sanitize("CodeAgent", orchestrator_kwargs)
self.agent = CodeAgent(**orchestrator_kwargs)
return
if self.agent_type == "CodeAgent":
CodeAgent = _safe_import("smolagents", "CodeAgent")
self.agent = CodeAgent(**init_kwargs)
return
if self.agent_type == "ToolCallingAgent":
ToolCallingAgent = _safe_import("smolagents", "ToolCallingAgent")
self.agent = ToolCallingAgent(**init_kwargs)
return
if self.agent_type == "Agent":
Agent = _safe_import("smolagents", "Agent")
self.agent = Agent(**init_kwargs)
return
raise ValueError(f"Unsupported agent type: {self.agent_type}")
# -------------------------
# Execution
# -------------------------
[docs]
def run(
self,
arguments: Dict[str, Any],
stream_callback: Optional[Callable[[str], None]] = None,
**_: Any,
) -> Dict[str, Any]:
"""Execute the agent with optional streaming back into ToolUniverse.
Supports:
- Streaming output (when stream_callback is provided and agent.stream_outputs=True)
- Execution timeout (via agent_init_params.max_execution_time)
- Thread-safe stdout/stderr redirection
"""
import sys
import time
self._init_agent()
task = arguments.get("task", "")
if not task:
# Fallback to 'query' for agents whose parameter is named 'query'
task = arguments.get("query", "")
# Get max_execution_time from config (default: None = unlimited)
max_execution_time = self.agent_init_params.get("max_execution_time")
timeout_error: Optional[Exception] = None
execution_completed = threading.Event()
def _execute_with_timeout():
"""Inner function to execute agent.run with timeout protection."""
try:
# If streaming desired and agent supports streaming via stdout, capture and forward
wants_stream = bool(stream_callback) and bool(
getattr(self.agent, "stream_outputs", False)
)
if wants_stream:
class _StreamProxy:
def __init__(self, cb):
self._cb = cb
self._buf = ""
self._last_line = None
def write(self, s: str):
if not s:
return
self._buf += s
while "\n" in self._buf:
line, self._buf = self._buf.split("\n", 1)
if not line.strip():
continue
# Deduplicate consecutive identical lines
if line == self._last_line:
continue
self._last_line = line
self._cb(line + "\n")
def flush(self):
if self._buf.strip():
if self._buf != self._last_line:
self._cb(self._buf)
self._last_line = self._buf
self._buf = ""
# Use lock to protect stdout redirection (thread-safe)
with _STREAM_LOCK:
old_stdout, old_stderr = sys.stdout, sys.stderr
proxy = _StreamProxy(stream_callback)
sys.stdout = proxy # forward stdout only to avoid dupes
try:
result = self.agent.run(task)
finally:
sys.stdout = old_stdout
sys.stderr = old_stderr
execution_completed.set()
return result
# Non-streaming path (also protect with lock for consistency)
with _STREAM_LOCK:
result = self.agent.run(task)
execution_completed.set()
return result
except Exception as e: # noqa: BLE001
execution_completed.set()
raise e
try:
start_time = time.time()
# Execute with timeout if specified
if max_execution_time is not None and max_execution_time > 0:
import threading as th
result_container: List[Any] = []
exception_container: List[Exception] = []
def _worker():
try:
result_container.append(_execute_with_timeout())
except Exception as e: # noqa: BLE001
exception_container.append(e)
worker_thread = th.Thread(target=_worker, daemon=True)
worker_thread.start()
worker_thread.join(timeout=max_execution_time)
if worker_thread.is_alive():
# Timeout occurred
timeout_error = TimeoutError(
f"Agent execution exceeded maximum time limit of {max_execution_time} seconds. "
f"Task: {task[:100]}..."
)
if stream_callback:
stream_callback(f"\n[TIMEOUT] {timeout_error}\n")
return {
"output": None,
"success": False,
"error": str(timeout_error),
"error_type": "timeout",
}
if exception_container:
raise exception_container[0]
result = result_container[0] if result_container else None
else:
# No timeout - direct execution
result = _execute_with_timeout()
elapsed_time = time.time() - start_time
return {
"output": result,
"success": True,
"execution_time": elapsed_time,
}
except Exception as e: # noqa: BLE001
if stream_callback:
stream_callback(f"\n[ERROR] {e}\n")
return {
"output": None,
"success": False,
"error": str(e),
"error_type": type(e).__name__,
}