Source code for tooluniverse.agentic_tool

from __future__ import annotations

import os
import json
from datetime import datetime
from typing import Any, Dict, List, Optional

from .base_tool import BaseTool
from .tool_registry import register_tool
from .logging_config import get_logger
from .llm_clients import AzureOpenAIClient, GeminiClient


# Global default fallback configuration
DEFAULT_FALLBACK_CHAIN = [
    {"api_type": "CHATGPT", "model_id": "gpt-4o-1120"},
    {"api_type": "GEMINI", "model_id": "gemini-2.0-flash"},
]

# API key environment variable mapping
API_KEY_ENV_VARS = {
    "CHATGPT": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT"],
    "GEMINI": ["GEMINI_API_KEY"],
}


[docs] @register_tool("AgenticTool") class AgenticTool(BaseTool): """Generic wrapper around LLM prompting supporting JSON-defined configs with prompts and input arguments."""
[docs] @staticmethod def has_any_api_keys() -> bool: """ Check if any API keys are available across all supported API types. Returns: bool: True if at least one API type has all required keys, False otherwise """ for _api_type, required_vars in API_KEY_ENV_VARS.items(): all_keys_present = True for var in required_vars: if not os.getenv(var): all_keys_present = False break if all_keys_present: return True return False
[docs] def __init__(self, tool_config: Dict[str, Any]): super().__init__(tool_config) self.logger = get_logger("AgenticTool") self.name: str = tool_config.get("name", "") self._prompt_template: str = tool_config.get("prompt", "") self._input_arguments: List[str] = tool_config.get("input_arguments", []) # Extract required arguments from parameter schema parameter_info = tool_config.get("parameter", {}) self._required_arguments: List[str] = parameter_info.get("required", []) self._argument_defaults: Dict[str, str] = {} # Set up default values for optional arguments properties = parameter_info.get("properties", {}) for arg in self._input_arguments: if arg not in self._required_arguments: prop_info = properties.get(arg, {}) if "default" in prop_info: self._argument_defaults[arg] = prop_info["default"] # Get configuration from nested 'configs' dict or fallback to top-level configs = tool_config.get("configs", {}) # Helper function to get config values with fallback def get_config(key: str, default: Any) -> Any: return configs.get(key, tool_config.get(key, default)) # LLM configuration self._api_type: str = get_config("api_type", "CHATGPT") self._model_id: str = get_config("model_id", "o1-mini") self._temperature: Optional[float] = get_config("temperature", 0.1) # Ignore configured max_new_tokens; client will resolve per model/env self._max_new_tokens: Optional[int] = None self._return_json: bool = get_config("return_json", False) self._max_retries: int = get_config("max_retries", 5) self._retry_delay: int = get_config("retry_delay", 5) self.return_metadata: bool = get_config("return_metadata", True) self._validate_api_key: bool = get_config("validate_api_key", True) # API fallback configuration self._fallback_api_type: Optional[str] = get_config("fallback_api_type", None) self._fallback_model_id: Optional[str] = get_config("fallback_model_id", None) # Global fallback configuration self._use_global_fallback: bool = get_config("use_global_fallback", True) self._global_fallback_chain: List[Dict[str, str]] = ( self._get_global_fallback_chain() ) # Gemini model configuration (optional; env override) self._gemini_model_id: str = get_config( "gemini_model_id", __import__("os").getenv("GEMINI_MODEL_ID", "gemini-2.0-flash"), ) # Validation if not self._prompt_template: raise ValueError("AgenticTool requires a 'prompt' in the configuration.") if not self._input_arguments: raise ValueError( "AgenticTool requires 'input_arguments' in the configuration." ) # Validate temperature range (skip if None) if ( isinstance(self._temperature, (int, float)) and not 0 <= self._temperature <= 2 ): self.logger.warning( f"Temperature {self._temperature} is outside recommended range [0, 2]" ) # Validate model compatibility self._validate_model_config() # Initialize the provider client self._llm_client = None self._initialization_error = None self._is_available = False self._current_api_type = None self._current_model_id = None # Try primary API first, then fallback if configured self._try_initialize_api()
[docs] def _get_global_fallback_chain(self) -> List[Dict[str, str]]: """Get the global fallback chain from environment or use default.""" # Check environment variable for custom fallback chain env_chain = os.getenv("AGENTIC_TOOL_FALLBACK_CHAIN") if env_chain: try: chain = json.loads(env_chain) if isinstance(chain, list) and all( isinstance(item, dict) and "api_type" in item and "model_id" in item for item in chain ): return chain else: self.logger.warning( "Invalid fallback chain format in environment variable" ) except json.JSONDecodeError: self.logger.warning( "Invalid JSON in AGENTIC_TOOL_FALLBACK_CHAIN environment variable" ) return DEFAULT_FALLBACK_CHAIN.copy()
[docs] def _try_initialize_api(self): """Try to initialize the primary API, fallback to secondary if configured.""" # Try primary API first if self._try_api(self._api_type, self._model_id): return # Try explicit fallback API if configured if self._fallback_api_type and self._fallback_model_id: self.logger.info( f"Primary API {self._api_type} failed, trying explicit fallback {self._fallback_api_type}" ) if self._try_api(self._fallback_api_type, self._fallback_model_id): return # Try global fallback chain if enabled if self._use_global_fallback: self.logger.info( f"Primary API {self._api_type} failed, trying global fallback chain" ) for fallback_config in self._global_fallback_chain: fallback_api = fallback_config["api_type"] fallback_model = fallback_config["model_id"] # Skip if it's the same as primary or explicit fallback if ( fallback_api == self._api_type and fallback_model == self._model_id ) or ( fallback_api == self._fallback_api_type and fallback_model == self._fallback_model_id ): continue self.logger.info( f"Trying global fallback: {fallback_api} ({fallback_model})" ) if self._try_api(fallback_api, fallback_model): return # If we get here, all APIs failed self.logger.warning( f"Tool '{self.name}' failed to initialize with all available APIs" )
[docs] def _try_api(self, api_type: str, model_id: str) -> bool: """Try to initialize a specific API and model.""" try: if api_type == "CHATGPT": self._llm_client = AzureOpenAIClient(model_id, None, self.logger) elif api_type == "GEMINI": self._llm_client = GeminiClient(model_id, self.logger) else: raise ValueError(f"Unsupported API type: {api_type}") # Test API key validity after initialization (if enabled) if self._validate_api_key: self._llm_client.test_api() self.logger.debug( f"Successfully initialized {api_type} model: {model_id}" ) else: self.logger.info("API key validation skipped (validate_api_key=False)") self._is_available = True self._current_api_type = api_type self._current_model_id = model_id self._initialization_error = None if api_type != self._api_type or model_id != self._model_id: self.logger.info( f"Using fallback API: {api_type} with model {model_id} " f"(originally configured: {self._api_type} with {self._model_id})" ) return True except Exception as e: error_msg = f"Failed to initialize {api_type} model {model_id}: {str(e)}" self.logger.warning(error_msg) self._initialization_error = error_msg return False
# ------------------------------------------------------------------ LLM utilities -----------
[docs] def _validate_model_config(self): supported_api_types = ["CHATGPT", "GEMINI"] if self._api_type not in supported_api_types: raise ValueError( f"Unsupported API type: {self._api_type}. Supported types: {supported_api_types}" ) if self._max_new_tokens is not None and self._max_new_tokens <= 0: raise ValueError("max_new_tokens must be positive or None")
# ------------------------------------------------------------------ public API --------------
[docs] def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: start_time = datetime.now() # Check if tool is available before attempting to run if not self._is_available: error_msg = f"Tool '{self.name}' is not available due to initialization error: {self._initialization_error}" self.logger.error(error_msg) if self.return_metadata: return { "success": False, "error": error_msg, "error_type": "ToolUnavailable", "metadata": { "prompt_used": "Tool unavailable", "input_arguments": { arg: arguments.get(arg) for arg in self._input_arguments }, "model_info": { "api_type": self._api_type, "model_id": self._model_id, }, "execution_time_seconds": 0, "timestamp": start_time.isoformat(), }, } else: return f"error: {error_msg} error_type: ToolUnavailable" try: # Validate required args missing_required_args = [ arg for arg in self._required_arguments if arg not in arguments ] if missing_required_args: raise ValueError( f"Missing required input arguments: {missing_required_args}" ) # Fill defaults for optional args for arg in self._input_arguments: if arg not in arguments: arguments[arg] = self._argument_defaults.get(arg, "") self._validate_arguments(arguments) formatted_prompt = self._format_prompt(arguments) messages = [{"role": "user", "content": formatted_prompt}] custom_format = arguments.get("response_format", None) # Delegate to client; client handles provider-specific logic response = self._llm_client.infer( messages=messages, temperature=self._temperature, max_tokens=None, # client resolves per-model defaults/env return_json=self._return_json, custom_format=custom_format, max_retries=self._max_retries, retry_delay=self._retry_delay, ) end_time = datetime.now() execution_time = (end_time - start_time).total_seconds() if self.return_metadata: return { "success": True, "result": response, "metadata": { "prompt_used": ( formatted_prompt if len(formatted_prompt) < 1000 else f"{formatted_prompt[:1000]}..." ), "input_arguments": { arg: arguments.get(arg) for arg in self._input_arguments }, "model_info": { "api_type": self._api_type, "model_id": self._model_id, "temperature": self._temperature, "max_new_tokens": self._max_new_tokens, }, "execution_time_seconds": execution_time, "timestamp": start_time.isoformat(), }, } else: return response except Exception as e: end_time = datetime.now() execution_time = (end_time - start_time).total_seconds() self.logger.error(f"Error executing {self.name}: {str(e)}") if self.return_metadata: return { "success": False, "error": str(e), "error_type": type(e).__name__, "metadata": { "prompt_used": ( formatted_prompt if "formatted_prompt" in locals() else "Failed to format prompt" ), "input_arguments": { arg: arguments.get(arg) for arg in self._input_arguments }, "model_info": { "api_type": self._api_type, "model_id": self._model_id, }, "execution_time_seconds": execution_time, "timestamp": start_time.isoformat(), }, } else: return "error: " + str(e) + " error_type: " + type(e).__name__
# ------------------------------------------------------------------ helpers -----------------
[docs] def _validate_arguments(self, arguments: Dict[str, Any]): for arg_name, value in arguments.items(): if arg_name in self._input_arguments: if isinstance(value, str) and not value.strip(): if arg_name in self._required_arguments: raise ValueError( f"Required argument '{arg_name}' cannot be empty" ) if isinstance(value, str) and len(value) > 100000: pass
[docs] def _format_prompt(self, arguments: Dict[str, Any]) -> str: prompt = self._prompt_template for arg_name in self._input_arguments: placeholder = f"{{{arg_name}}}" value = arguments.get(arg_name, "") if placeholder in prompt: prompt = prompt.replace(placeholder, str(value)) return prompt
[docs] def get_prompt_preview(self, arguments: Dict[str, Any]) -> str: try: args_copy = arguments.copy() missing_required_args = [ arg for arg in self._required_arguments if arg not in args_copy ] if missing_required_args: raise ValueError( f"Missing required input arguments: {missing_required_args}" ) for arg in self._input_arguments: if arg not in args_copy: args_copy[arg] = self._argument_defaults.get(arg, "") return self._format_prompt(args_copy) except Exception as e: return f"Error formatting prompt: {str(e)}"
[docs] def get_model_info(self) -> Dict[str, Any]: return { "api_type": self._api_type, "model_id": self._model_id, "temperature": self._temperature, "max_new_tokens": self._max_new_tokens, "return_json": self._return_json, "max_retries": self._max_retries, "retry_delay": self._retry_delay, "validate_api_key": self._validate_api_key, "gemini_model_id": getattr(self, "_gemini_model_id", None), "is_available": self._is_available, "initialization_error": self._initialization_error, "current_api_type": self._current_api_type, "current_model_id": self._current_model_id, "fallback_api_type": self._fallback_api_type, "fallback_model_id": self._fallback_model_id, "use_global_fallback": self._use_global_fallback, "global_fallback_chain": self._global_fallback_chain, }
[docs] def is_available(self) -> bool: """Check if the tool is available for use.""" return self._is_available
[docs] def get_availability_status(self) -> Dict[str, Any]: """Get detailed availability status of the tool.""" return { "is_available": self._is_available, "initialization_error": self._initialization_error, "api_type": self._api_type, "model_id": self._model_id, }
[docs] def retry_initialization(self) -> bool: """Attempt to reinitialize the tool (useful if API keys were updated).""" try: if self._api_type == "CHATGPT": self._llm_client = AzureOpenAIClient(self._model_id, None, self.logger) elif self._api_type == "GEMINI": self._llm_client = GeminiClient(self._gemini_model_id, self.logger) else: raise ValueError(f"Unsupported API type: {self._api_type}") if self._validate_api_key: self._llm_client.test_api() self.logger.info( f"Successfully reinitialized {self._api_type} model: {self._model_id}" ) self._is_available = True self._initialization_error = None return True except Exception as e: self._initialization_error = str(e) self.logger.warning( f"Retry initialization failed for {self._api_type} model {self._model_id}: {str(e)}" ) return False
[docs] def get_prompt_template(self) -> str: return self._prompt_template
[docs] def get_input_arguments(self) -> List[str]: return self._input_arguments.copy()
[docs] def validate_configuration(self) -> Dict[str, Any]: validation_results = {"valid": True, "warnings": [], "errors": []} try: self._validate_model_config() except ValueError as e: validation_results["valid"] = False validation_results["errors"].append(str(e)) if not self._prompt_template: validation_results["valid"] = False validation_results["errors"].append("Missing prompt template") return validation_results
[docs] def estimate_token_usage(self, arguments: Dict[str, Any]) -> Dict[str, int]: prompt = self._format_prompt(arguments) estimated_input_tokens = len(prompt) // 4 estimated_max_output_tokens = ( self._max_new_tokens if self._max_new_tokens is not None else 2048 ) estimated_total_tokens = estimated_input_tokens + estimated_max_output_tokens return { "estimated_input_tokens": estimated_input_tokens, "max_output_tokens": estimated_max_output_tokens, "estimated_total_tokens": estimated_total_tokens, "prompt_length_chars": len(prompt), }