Source code for tooluniverse.llm_clients

from __future__ import annotations
from typing import Any, Dict, List, Optional
import os
import time
import json as _json


[docs] class BaseLLMClient:
[docs] def test_api(self) -> None: raise NotImplementedError
[docs] def infer( self, messages: List[Dict[str, str]], temperature: Optional[float], max_tokens: Optional[int], return_json: bool, custom_format: Any = None, max_retries: int = 5, retry_delay: int = 5, ) -> Optional[str]: raise NotImplementedError
[docs] class AzureOpenAIClient(BaseLLMClient): # Built-in defaults for model families (can be overridden by env) DEFAULT_MODEL_LIMITS: Dict[str, Dict[str, int]] = { # GPT-4.1 series "gpt-4.1": {"max_output": 32768, "context_window": 1_047_576}, "gpt-4.1-mini": {"max_output": 32768, "context_window": 1_047_576}, "gpt-4.1-nano": {"max_output": 32768, "context_window": 1_047_576}, # GPT-4o series "gpt-4o-1120": {"max_output": 16384, "context_window": 128_000}, "gpt-4o-0806": {"max_output": 16384, "context_window": 128_000}, "gpt-4o-mini-0718": {"max_output": 16384, "context_window": 128_000}, "gpt-4o": {"max_output": 16384, "context_window": 128_000}, # general prefix # O-series "o4-mini-0416": {"max_output": 100_000, "context_window": 200_000}, "o3-mini-0131": {"max_output": 100_000, "context_window": 200_000}, "o4-mini": {"max_output": 100_000, "context_window": 200_000}, "o3-mini": {"max_output": 100_000, "context_window": 200_000}, # Embeddings (for completeness) "embedding-ada": {"max_output": 8192, "context_window": 8192}, "text-embedding-3-small": {"max_output": 8192, "context_window": 8192}, "text-embedding-3-large": {"max_output": 8192, "context_window": 8192}, }
[docs] def __init__(self, model_id: str, api_version: Optional[str], logger): try: from openai import AzureOpenAI as _AzureOpenAI # type: ignore import openai as _openai # type: ignore except Exception as e: # pragma: no cover raise RuntimeError("openai AzureOpenAI client is not available") from e self._AzureOpenAI = _AzureOpenAI self._openai = _openai self.model_name = model_id self.logger = logger resolved_version = api_version or self._resolve_api_version(model_id) self.logger.debug( f"Resolved Azure API version for {model_id}: {resolved_version}" ) api_key = os.getenv("AZURE_OPENAI_API_KEY") if not api_key: raise ValueError("AZURE_OPENAI_API_KEY not set") endpoint = os.getenv("AZURE_OPENAI_ENDPOINT", "https://azure-ai.hms.edu") self.client = self._AzureOpenAI( azure_endpoint=endpoint, api_key=api_key, api_version=resolved_version ) self.api_version = resolved_version # Load env overrides for model limits (JSON dict of {prefix: {max_output, context_window}}) env_limits_raw = os.getenv("AZURE_DEFAULT_MODEL_LIMITS") self._default_limits: Dict[str, Dict[str, int]] = ( self.DEFAULT_MODEL_LIMITS.copy() ) if env_limits_raw: try: env_limits = _json.loads(env_limits_raw) # shallow merge by keys for k, v in env_limits.items(): if isinstance(v, dict): base = self._default_limits.get(k, {}).copy() base.update( { kk: int(vv) for kk, vv in v.items() if isinstance(vv, (int, float, str)) } ) self._default_limits[k] = base except Exception: # ignore bad env format pass
# --------- helpers (Azure specific) --------- def _resolve_api_version(self, model_id: str) -> str: mapping_raw = os.getenv("AZURE_OPENAI_API_VERSION_BY_MODEL") mapping: Dict[str, str] = {} if mapping_raw: try: mapping = _json.loads(mapping_raw) except Exception: mapping = {} if model_id in mapping: return mapping[model_id] for k, v in mapping.items(): try: if model_id.startswith(k): return v except Exception: continue try: if model_id.startswith("o3-mini") or model_id.startswith("o4-mini"): return "2024-12-01-preview" except Exception: pass return os.getenv("AZURE_OPENAI_API_VERSION", "2024-12-01-preview") def _resolve_default_max_tokens(self, model_id: str) -> Optional[int]: # Highest priority: explicit env per-model tokens mapping mapping_raw = os.getenv("AZURE_MAX_TOKENS_BY_MODEL") mapping: Dict[str, Any] = {} if mapping_raw: try: mapping = _json.loads(mapping_raw) except Exception: mapping = {} if model_id in mapping: try: return int(mapping[model_id]) except Exception: pass for k, v in mapping.items(): try: if model_id.startswith(k): return int(v) except Exception: continue # Next: built-in/default-limits map (with env merged) if model_id in self._default_limits: return int(self._default_limits[model_id].get("max_output", 0)) or None for k, v in self._default_limits.items(): try: if model_id.startswith(k): return int(v.get("max_output", 0)) or None except Exception: continue return None def _normalize_temperature( self, model_id: str, temperature: Optional[float] ) -> Optional[float]: if isinstance(model_id, str) and ( model_id.startswith("o3-mini") or model_id.startswith("o4-mini") ): if temperature is not None: self.logger.warning( f"Model {model_id} does not support 'temperature'; ignoring provided value." ) return None return temperature # --------- public API ---------
[docs] def test_api(self) -> None: test_messages = [{"role": "user", "content": "ping"}] token_attempts = [1, 4, 16, 32] last_error: Optional[Exception] = None for tok in token_attempts: try: try: self.client.chat.completions.create( model=self.model_name, messages=test_messages, max_tokens=tok, temperature=0, ) return except self._openai.BadRequestError: # type: ignore[attr-defined] self.client.chat.completions.create( model=self.model_name, messages=test_messages, max_completion_tokens=tok, ) return except Exception as e: # noqa: BLE001 last_error = e msg = str(e).lower() if ( "max_tokens" in msg or "model output limit" in msg or "finish the message" in msg ) and tok != token_attempts[-1]: continue break if last_error: raise ValueError(f"ChatGPT API test failed: {last_error}") raise ValueError("ChatGPT API test failed: unknown error")
[docs] def infer( self, messages: List[Dict[str, str]], temperature: Optional[float], max_tokens: Optional[int], return_json: bool, custom_format: Any = None, max_retries: int = 5, retry_delay: int = 5, ) -> Optional[str]: retries = 0 call_fn = ( self.client.chat.completions.parse if custom_format is not None else self.client.chat.completions.create ) response_format = ( custom_format if custom_format is not None else ({"type": "json_object"} if return_json else None) ) eff_temp = self._normalize_temperature(self.model_name, temperature) eff_max = ( max_tokens if max_tokens is not None else self._resolve_default_max_tokens(self.model_name) ) while retries < max_retries: try: kwargs: Dict[str, Any] = { "model": self.model_name, "messages": messages, } if response_format is not None: kwargs["response_format"] = response_format if eff_temp is not None: kwargs["temperature"] = eff_temp try: if eff_max is not None: resp = call_fn(max_tokens=eff_max, **kwargs) else: resp = call_fn(**kwargs) except self._openai.BadRequestError as be: # type: ignore[attr-defined] if eff_max is not None: resp = call_fn(max_completion_tokens=eff_max, **kwargs) else: be_msg = str(be).lower() fallback_limits = [ 8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, ] if any( k in be_msg for k in [ "max_tokens", "output limit", "finish the message", "max_completion_tokens", ] ): last_exc: Optional[Exception] = be for lim in fallback_limits: try: try: resp = call_fn( max_completion_tokens=lim, **kwargs ) last_exc = None break except Exception as inner_e: # noqa: BLE001 last_exc = inner_e resp = call_fn(max_tokens=lim, **kwargs) last_exc = None break except Exception as inner2: # noqa: BLE001 last_exc = inner2 continue if last_exc is not None: raise last_exc else: raise be if custom_format is not None: return resp.choices[0].message.parsed.model_dump() return resp.choices[0].message.content except self._openai.RateLimitError: # type: ignore[attr-defined] self.logger.warning( f"Rate limit exceeded. Retrying in {retry_delay} seconds..." ) retries += 1 time.sleep(retry_delay * retries) except Exception as e: # noqa: BLE001 self.logger.error(f"An error occurred: {e}") import traceback traceback.print_exc() break self.logger.error("Max retries exceeded. Unable to complete the request.") return None
[docs] class GeminiClient(BaseLLMClient):
[docs] def __init__(self, model_name: str, logger): try: import google.generativeai as genai # type: ignore except Exception as e: # pragma: no cover raise RuntimeError("google.generativeai not available") from e api_key = os.getenv("GEMINI_API_KEY") if not api_key: raise ValueError("GEMINI_API_KEY not found") self._genai = genai self._genai.configure(api_key=api_key) self.model_name = model_name self.logger = logger
def _build_model(self): return self._genai.GenerativeModel(self.model_name)
[docs] def test_api(self) -> None: model = self._build_model() model.generate_content( "ping", generation_config={ "max_output_tokens": 8, "temperature": 0, }, )
[docs] def infer( self, messages: List[Dict[str, str]], temperature: Optional[float], max_tokens: Optional[int], return_json: bool, custom_format: Any = None, max_retries: int = 5, retry_delay: int = 5, ) -> Optional[str]: if return_json: raise ValueError("Gemini JSON mode not supported here") contents = "" for m in messages: if m["role"] in ("user", "system"): contents += f"{m['content']}\n" retries = 0 while retries < max_retries: try: gen_cfg: Dict[str, Any] = { "temperature": (temperature if temperature is not None else 0) } if max_tokens is not None: gen_cfg["max_output_tokens"] = max_tokens model = self._build_model() resp = model.generate_content(contents, generation_config=gen_cfg) return getattr(resp, "text", None) or getattr(resp, "candidates", [{}])[ 0 ].get("content") except Exception as e: # noqa: BLE001 self.logger.error(f"Gemini error: {e}") retries += 1 time.sleep(retry_delay * retries) return None