Source code for tooluniverse.agentic_tool
from __future__ import annotations
import os
import json
from datetime import datetime
from typing import Any, Dict, List, Optional
from .base_tool import BaseTool
from .tool_registry import register_tool
from .logging_config import get_logger
from .llm_clients import AzureOpenAIClient, GeminiClient
# Global default fallback configuration
DEFAULT_FALLBACK_CHAIN = [
{"api_type": "CHATGPT", "model_id": "gpt-4o-1120"},
{"api_type": "GEMINI", "model_id": "gemini-2.0-flash"},
]
# API key environment variable mapping
API_KEY_ENV_VARS = {
"CHATGPT": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT"],
"GEMINI": ["GEMINI_API_KEY"],
}
[docs]
@register_tool("AgenticTool")
class AgenticTool(BaseTool):
"""Generic wrapper around LLM prompting supporting JSON-defined configs with prompts and input arguments."""
[docs]
@staticmethod
def has_any_api_keys() -> bool:
"""
Check if any API keys are available across all supported API types.
Returns:
bool: True if at least one API type has all required keys, False otherwise
"""
for _api_type, required_vars in API_KEY_ENV_VARS.items():
all_keys_present = True
for var in required_vars:
if not os.getenv(var):
all_keys_present = False
break
if all_keys_present:
return True
return False
[docs]
def __init__(self, tool_config: Dict[str, Any]):
super().__init__(tool_config)
self.logger = get_logger("AgenticTool")
self.name: str = tool_config.get("name", "")
self._prompt_template: str = tool_config.get("prompt", "")
self._input_arguments: List[str] = tool_config.get("input_arguments", [])
# Extract required arguments from parameter schema
parameter_info = tool_config.get("parameter", {})
self._required_arguments: List[str] = parameter_info.get("required", [])
self._argument_defaults: Dict[str, str] = {}
# Set up default values for optional arguments
properties = parameter_info.get("properties", {})
for arg in self._input_arguments:
if arg not in self._required_arguments:
prop_info = properties.get(arg, {})
if "default" in prop_info:
self._argument_defaults[arg] = prop_info["default"]
# Get configuration from nested 'configs' dict or fallback to top-level
configs = tool_config.get("configs", {})
# Helper function to get config values with fallback
def get_config(key: str, default: Any) -> Any:
return configs.get(key, tool_config.get(key, default))
# LLM configuration
self._api_type: str = get_config("api_type", "CHATGPT")
self._model_id: str = get_config("model_id", "o1-mini")
self._temperature: Optional[float] = get_config("temperature", 0.1)
# Ignore configured max_new_tokens; client will resolve per model/env
self._max_new_tokens: Optional[int] = None
self._return_json: bool = get_config("return_json", False)
self._max_retries: int = get_config("max_retries", 5)
self._retry_delay: int = get_config("retry_delay", 5)
self.return_metadata: bool = get_config("return_metadata", True)
self._validate_api_key: bool = get_config("validate_api_key", True)
# API fallback configuration
self._fallback_api_type: Optional[str] = get_config("fallback_api_type", None)
self._fallback_model_id: Optional[str] = get_config("fallback_model_id", None)
# Global fallback configuration
self._use_global_fallback: bool = get_config("use_global_fallback", True)
self._global_fallback_chain: List[Dict[str, str]] = (
self._get_global_fallback_chain()
)
# Gemini model configuration (optional; env override)
self._gemini_model_id: str = get_config(
"gemini_model_id",
__import__("os").getenv("GEMINI_MODEL_ID", "gemini-2.0-flash"),
)
# Validation
if not self._prompt_template:
raise ValueError("AgenticTool requires a 'prompt' in the configuration.")
if not self._input_arguments:
raise ValueError(
"AgenticTool requires 'input_arguments' in the configuration."
)
# Validate temperature range (skip if None)
if (
isinstance(self._temperature, (int, float))
and not 0 <= self._temperature <= 2
):
self.logger.warning(
f"Temperature {self._temperature} is outside recommended range [0, 2]"
)
# Validate model compatibility
self._validate_model_config()
# Initialize the provider client
self._llm_client = None
self._initialization_error = None
self._is_available = False
self._current_api_type = None
self._current_model_id = None
# Try primary API first, then fallback if configured
self._try_initialize_api()
[docs]
def _get_global_fallback_chain(self) -> List[Dict[str, str]]:
"""Get the global fallback chain from environment or use default."""
# Check environment variable for custom fallback chain
env_chain = os.getenv("AGENTIC_TOOL_FALLBACK_CHAIN")
if env_chain:
try:
chain = json.loads(env_chain)
if isinstance(chain, list) and all(
isinstance(item, dict) and "api_type" in item and "model_id" in item
for item in chain
):
return chain
else:
self.logger.warning(
"Invalid fallback chain format in environment variable"
)
except json.JSONDecodeError:
self.logger.warning(
"Invalid JSON in AGENTIC_TOOL_FALLBACK_CHAIN environment variable"
)
return DEFAULT_FALLBACK_CHAIN.copy()
[docs]
def _try_initialize_api(self):
"""Try to initialize the primary API, fallback to secondary if configured."""
# Try primary API first
if self._try_api(self._api_type, self._model_id):
return
# Try explicit fallback API if configured
if self._fallback_api_type and self._fallback_model_id:
self.logger.info(
f"Primary API {self._api_type} failed, trying explicit fallback {self._fallback_api_type}"
)
if self._try_api(self._fallback_api_type, self._fallback_model_id):
return
# Try global fallback chain if enabled
if self._use_global_fallback:
self.logger.info(
f"Primary API {self._api_type} failed, trying global fallback chain"
)
for fallback_config in self._global_fallback_chain:
fallback_api = fallback_config["api_type"]
fallback_model = fallback_config["model_id"]
# Skip if it's the same as primary or explicit fallback
if (
fallback_api == self._api_type and fallback_model == self._model_id
) or (
fallback_api == self._fallback_api_type
and fallback_model == self._fallback_model_id
):
continue
self.logger.info(
f"Trying global fallback: {fallback_api} ({fallback_model})"
)
if self._try_api(fallback_api, fallback_model):
return
# If we get here, all APIs failed
self.logger.warning(
f"Tool '{self.name}' failed to initialize with all available APIs"
)
[docs]
def _try_api(self, api_type: str, model_id: str) -> bool:
"""Try to initialize a specific API and model."""
try:
if api_type == "CHATGPT":
self._llm_client = AzureOpenAIClient(model_id, None, self.logger)
elif api_type == "GEMINI":
self._llm_client = GeminiClient(model_id, self.logger)
else:
raise ValueError(f"Unsupported API type: {api_type}")
# Test API key validity after initialization (if enabled)
if self._validate_api_key:
self._llm_client.test_api()
self.logger.debug(
f"Successfully initialized {api_type} model: {model_id}"
)
else:
self.logger.info("API key validation skipped (validate_api_key=False)")
self._is_available = True
self._current_api_type = api_type
self._current_model_id = model_id
self._initialization_error = None
if api_type != self._api_type or model_id != self._model_id:
self.logger.info(
f"Using fallback API: {api_type} with model {model_id} "
f"(originally configured: {self._api_type} with {self._model_id})"
)
return True
except Exception as e:
error_msg = f"Failed to initialize {api_type} model {model_id}: {str(e)}"
self.logger.warning(error_msg)
self._initialization_error = error_msg
return False
# ------------------------------------------------------------------ LLM utilities -----------
[docs]
def _validate_model_config(self):
supported_api_types = ["CHATGPT", "GEMINI"]
if self._api_type not in supported_api_types:
raise ValueError(
f"Unsupported API type: {self._api_type}. Supported types: {supported_api_types}"
)
if self._max_new_tokens is not None and self._max_new_tokens <= 0:
raise ValueError("max_new_tokens must be positive or None")
# ------------------------------------------------------------------ public API --------------
[docs]
def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
start_time = datetime.now()
# Check if tool is available before attempting to run
if not self._is_available:
error_msg = f"Tool '{self.name}' is not available due to initialization error: {self._initialization_error}"
self.logger.error(error_msg)
if self.return_metadata:
return {
"success": False,
"error": error_msg,
"error_type": "ToolUnavailable",
"metadata": {
"prompt_used": "Tool unavailable",
"input_arguments": {
arg: arguments.get(arg) for arg in self._input_arguments
},
"model_info": {
"api_type": self._api_type,
"model_id": self._model_id,
},
"execution_time_seconds": 0,
"timestamp": start_time.isoformat(),
},
}
else:
return f"error: {error_msg} error_type: ToolUnavailable"
try:
# Validate required args
missing_required_args = [
arg for arg in self._required_arguments if arg not in arguments
]
if missing_required_args:
raise ValueError(
f"Missing required input arguments: {missing_required_args}"
)
# Fill defaults for optional args
for arg in self._input_arguments:
if arg not in arguments:
arguments[arg] = self._argument_defaults.get(arg, "")
self._validate_arguments(arguments)
formatted_prompt = self._format_prompt(arguments)
messages = [{"role": "user", "content": formatted_prompt}]
custom_format = arguments.get("response_format", None)
# Delegate to client; client handles provider-specific logic
response = self._llm_client.infer(
messages=messages,
temperature=self._temperature,
max_tokens=None, # client resolves per-model defaults/env
return_json=self._return_json,
custom_format=custom_format,
max_retries=self._max_retries,
retry_delay=self._retry_delay,
)
end_time = datetime.now()
execution_time = (end_time - start_time).total_seconds()
if self.return_metadata:
return {
"success": True,
"result": response,
"metadata": {
"prompt_used": (
formatted_prompt
if len(formatted_prompt) < 1000
else f"{formatted_prompt[:1000]}..."
),
"input_arguments": {
arg: arguments.get(arg) for arg in self._input_arguments
},
"model_info": {
"api_type": self._api_type,
"model_id": self._model_id,
"temperature": self._temperature,
"max_new_tokens": self._max_new_tokens,
},
"execution_time_seconds": execution_time,
"timestamp": start_time.isoformat(),
},
}
else:
return response
except Exception as e:
end_time = datetime.now()
execution_time = (end_time - start_time).total_seconds()
self.logger.error(f"Error executing {self.name}: {str(e)}")
if self.return_metadata:
return {
"success": False,
"error": str(e),
"error_type": type(e).__name__,
"metadata": {
"prompt_used": (
formatted_prompt
if "formatted_prompt" in locals()
else "Failed to format prompt"
),
"input_arguments": {
arg: arguments.get(arg) for arg in self._input_arguments
},
"model_info": {
"api_type": self._api_type,
"model_id": self._model_id,
},
"execution_time_seconds": execution_time,
"timestamp": start_time.isoformat(),
},
}
else:
return "error: " + str(e) + " error_type: " + type(e).__name__
# ------------------------------------------------------------------ helpers -----------------
[docs]
def _validate_arguments(self, arguments: Dict[str, Any]):
for arg_name, value in arguments.items():
if arg_name in self._input_arguments:
if isinstance(value, str) and not value.strip():
if arg_name in self._required_arguments:
raise ValueError(
f"Required argument '{arg_name}' cannot be empty"
)
if isinstance(value, str) and len(value) > 100000:
pass
[docs]
def _format_prompt(self, arguments: Dict[str, Any]) -> str:
prompt = self._prompt_template
for arg_name in self._input_arguments:
placeholder = f"{{{arg_name}}}"
value = arguments.get(arg_name, "")
if placeholder in prompt:
prompt = prompt.replace(placeholder, str(value))
return prompt
[docs]
def get_prompt_preview(self, arguments: Dict[str, Any]) -> str:
try:
args_copy = arguments.copy()
missing_required_args = [
arg for arg in self._required_arguments if arg not in args_copy
]
if missing_required_args:
raise ValueError(
f"Missing required input arguments: {missing_required_args}"
)
for arg in self._input_arguments:
if arg not in args_copy:
args_copy[arg] = self._argument_defaults.get(arg, "")
return self._format_prompt(args_copy)
except Exception as e:
return f"Error formatting prompt: {str(e)}"
[docs]
def get_model_info(self) -> Dict[str, Any]:
return {
"api_type": self._api_type,
"model_id": self._model_id,
"temperature": self._temperature,
"max_new_tokens": self._max_new_tokens,
"return_json": self._return_json,
"max_retries": self._max_retries,
"retry_delay": self._retry_delay,
"validate_api_key": self._validate_api_key,
"gemini_model_id": getattr(self, "_gemini_model_id", None),
"is_available": self._is_available,
"initialization_error": self._initialization_error,
"current_api_type": self._current_api_type,
"current_model_id": self._current_model_id,
"fallback_api_type": self._fallback_api_type,
"fallback_model_id": self._fallback_model_id,
"use_global_fallback": self._use_global_fallback,
"global_fallback_chain": self._global_fallback_chain,
}
[docs]
def is_available(self) -> bool:
"""Check if the tool is available for use."""
return self._is_available
[docs]
def get_availability_status(self) -> Dict[str, Any]:
"""Get detailed availability status of the tool."""
return {
"is_available": self._is_available,
"initialization_error": self._initialization_error,
"api_type": self._api_type,
"model_id": self._model_id,
}
[docs]
def retry_initialization(self) -> bool:
"""Attempt to reinitialize the tool (useful if API keys were updated)."""
try:
if self._api_type == "CHATGPT":
self._llm_client = AzureOpenAIClient(self._model_id, None, self.logger)
elif self._api_type == "GEMINI":
self._llm_client = GeminiClient(self._gemini_model_id, self.logger)
else:
raise ValueError(f"Unsupported API type: {self._api_type}")
if self._validate_api_key:
self._llm_client.test_api()
self.logger.info(
f"Successfully reinitialized {self._api_type} model: {self._model_id}"
)
self._is_available = True
self._initialization_error = None
return True
except Exception as e:
self._initialization_error = str(e)
self.logger.warning(
f"Retry initialization failed for {self._api_type} model {self._model_id}: {str(e)}"
)
return False
[docs]
def validate_configuration(self) -> Dict[str, Any]:
validation_results = {"valid": True, "warnings": [], "errors": []}
try:
self._validate_model_config()
except ValueError as e:
validation_results["valid"] = False
validation_results["errors"].append(str(e))
if not self._prompt_template:
validation_results["valid"] = False
validation_results["errors"].append("Missing prompt template")
return validation_results
[docs]
def estimate_token_usage(self, arguments: Dict[str, Any]) -> Dict[str, int]:
prompt = self._format_prompt(arguments)
estimated_input_tokens = len(prompt) // 4
estimated_max_output_tokens = (
self._max_new_tokens if self._max_new_tokens is not None else 2048
)
estimated_total_tokens = estimated_input_tokens + estimated_max_output_tokens
return {
"estimated_input_tokens": estimated_input_tokens,
"max_output_tokens": estimated_max_output_tokens,
"estimated_total_tokens": estimated_total_tokens,
"prompt_length_chars": len(prompt),
}