Source code for tooluniverse.llm_clients

from __future__ import annotations
from typing import Any, Dict, List, Optional
import os
import time
import json as _json


[docs] class BaseLLMClient:
[docs] def test_api(self) -> None: raise NotImplementedError
[docs] def infer( self, messages: List[Dict[str, str]], temperature: Optional[float], max_tokens: Optional[int], return_json: bool, custom_format: Any = None, max_retries: int = 5, retry_delay: int = 5, ) -> Optional[str]: raise NotImplementedError
[docs] def infer_stream( self, messages: List[Dict[str, str]], temperature: Optional[float], max_tokens: Optional[int], return_json: bool, custom_format: Any = None, max_retries: int = 5, retry_delay: int = 5, ): """Default streaming implementation falls back to regular inference.""" result = self.infer( messages=messages, temperature=temperature, max_tokens=max_tokens, return_json=return_json, custom_format=custom_format, max_retries=max_retries, retry_delay=retry_delay, ) if result is not None: yield result
[docs] class AzureOpenAIClient(BaseLLMClient): # Built-in defaults for model families (can be overridden by env) DEFAULT_MODEL_LIMITS: Dict[str, Dict[str, int]] = { # GPT-4.1 series "gpt-4.1": {"max_output": 32768, "context_window": 1_047_576}, "gpt-4.1-mini": {"max_output": 32768, "context_window": 1_047_576}, "gpt-4.1-nano": {"max_output": 32768, "context_window": 1_047_576}, # GPT-4o series "gpt-4o-1120": {"max_output": 16384, "context_window": 128_000}, "gpt-4o-0806": {"max_output": 16384, "context_window": 128_000}, "gpt-4o-mini-0718": {"max_output": 16384, "context_window": 128_000}, "gpt-4o": {"max_output": 16384, "context_window": 128_000}, # general prefix # O-series "o4-mini-0416": {"max_output": 100_000, "context_window": 200_000}, "o3-mini-0131": {"max_output": 100_000, "context_window": 200_000}, "o4-mini": {"max_output": 100_000, "context_window": 200_000}, "o3-mini": {"max_output": 100_000, "context_window": 200_000}, # Embeddings (for completeness) "embedding-ada": {"max_output": 8192, "context_window": 8192}, "text-embedding-3-small": {"max_output": 8192, "context_window": 8192}, "text-embedding-3-large": {"max_output": 8192, "context_window": 8192}, }
[docs] def __init__(self, model_id: str, api_version: Optional[str], logger): try: from openai import AzureOpenAI as _AzureOpenAI # type: ignore import openai as _openai # type: ignore except Exception as e: # pragma: no cover raise RuntimeError("openai AzureOpenAI client is not available") from e self._AzureOpenAI = _AzureOpenAI self._openai = _openai self.model_name = model_id self.logger = logger resolved_version = api_version or self._resolve_api_version(model_id) self.logger.debug( f"Resolved Azure API version for {model_id}: {resolved_version}" ) api_key = os.getenv("AZURE_OPENAI_API_KEY") if not api_key: raise ValueError("AZURE_OPENAI_API_KEY not set") endpoint = os.getenv("AZURE_OPENAI_ENDPOINT", "https://azure-ai.hms.edu") self.client = self._AzureOpenAI( azure_endpoint=endpoint, api_key=api_key, api_version=resolved_version ) self.api_version = resolved_version # Load env overrides for model limits (JSON dict of {prefix: {max_output, context_window}}) env_limits_raw = os.getenv("AZURE_DEFAULT_MODEL_LIMITS") self._default_limits: Dict[str, Dict[str, int]] = ( self.DEFAULT_MODEL_LIMITS.copy() ) if env_limits_raw: try: env_limits = _json.loads(env_limits_raw) # shallow merge by keys for k, v in env_limits.items(): if isinstance(v, dict): base = self._default_limits.get(k, {}).copy() base.update( { kk: int(vv) for kk, vv in v.items() if isinstance(vv, (int, float, str)) } ) self._default_limits[k] = base except Exception: # ignore bad env format pass
# --------- helpers (Azure specific) --------- def _resolve_api_version(self, model_id: str) -> str: mapping_raw = os.getenv("AZURE_OPENAI_API_VERSION_BY_MODEL") mapping: Dict[str, str] = {} if mapping_raw: try: mapping = _json.loads(mapping_raw) except Exception: mapping = {} if model_id in mapping: return mapping[model_id] for k, v in mapping.items(): try: if model_id.startswith(k): return v except Exception: continue try: if model_id.startswith("o3-mini") or model_id.startswith("o4-mini"): return "2024-12-01-preview" except Exception: pass return os.getenv("AZURE_OPENAI_API_VERSION", "2024-12-01-preview") def _resolve_default_max_tokens(self, model_id: str) -> Optional[int]: # Highest priority: explicit env per-model tokens mapping mapping_raw = os.getenv("AZURE_MAX_TOKENS_BY_MODEL") mapping: Dict[str, Any] = {} if mapping_raw: try: mapping = _json.loads(mapping_raw) except Exception: mapping = {} if model_id in mapping: try: return int(mapping[model_id]) except Exception: pass for k, v in mapping.items(): try: if model_id.startswith(k): return int(v) except Exception: continue # Next: built-in/default-limits map (with env merged) if model_id in self._default_limits: return int(self._default_limits[model_id].get("max_output", 0)) or None for k, v in self._default_limits.items(): try: if model_id.startswith(k): return int(v.get("max_output", 0)) or None except Exception: continue return None def _normalize_temperature( self, model_id: str, temperature: Optional[float] ) -> Optional[float]: if isinstance(model_id, str) and ( model_id.startswith("o3-mini") or model_id.startswith("o4-mini") ): if temperature is not None: self.logger.warning( f"Model {model_id} does not support 'temperature'; ignoring provided value." ) return None return temperature # --------- public API ---------
[docs] def test_api(self) -> None: test_messages = [{"role": "user", "content": "ping"}] token_attempts = [1, 4, 16, 32] last_error: Optional[Exception] = None for tok in token_attempts: try: try: self.client.chat.completions.create( model=self.model_name, messages=test_messages, max_tokens=tok, temperature=0, ) return except self._openai.BadRequestError: # type: ignore[attr-defined] self.client.chat.completions.create( model=self.model_name, messages=test_messages, max_completion_tokens=tok, ) return except Exception as e: # noqa: BLE001 last_error = e msg = str(e).lower() if ( "max_tokens" in msg or "model output limit" in msg or "finish the message" in msg ) and tok != token_attempts[-1]: continue break if last_error: raise ValueError(f"ChatGPT API test failed: {last_error}") raise ValueError("ChatGPT API test failed: unknown error")
[docs] def infer( self, messages: List[Dict[str, str]], temperature: Optional[float], max_tokens: Optional[int], return_json: bool, custom_format: Any = None, max_retries: int = 5, retry_delay: int = 5, ) -> Optional[str]: retries = 0 call_fn = ( self.client.chat.completions.parse if custom_format is not None else self.client.chat.completions.create ) response_format = ( custom_format if custom_format is not None else ({"type": "json_object"} if return_json else None) ) eff_temp = self._normalize_temperature(self.model_name, temperature) eff_max = ( max_tokens if max_tokens is not None else self._resolve_default_max_tokens(self.model_name) ) while retries < max_retries: try: kwargs: Dict[str, Any] = { "model": self.model_name, "messages": messages, } if response_format is not None: kwargs["response_format"] = response_format if eff_temp is not None: kwargs["temperature"] = eff_temp try: if eff_max is not None: resp = call_fn(max_tokens=eff_max, **kwargs) else: resp = call_fn(**kwargs) except self._openai.BadRequestError as be: # type: ignore[attr-defined] if eff_max is not None: resp = call_fn(max_completion_tokens=eff_max, **kwargs) else: be_msg = str(be).lower() fallback_limits = [ 8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, ] if any( k in be_msg for k in [ "max_tokens", "output limit", "finish the message", "max_completion_tokens", ] ): last_exc: Optional[Exception] = be for lim in fallback_limits: try: try: resp = call_fn( max_completion_tokens=lim, **kwargs ) last_exc = None break except Exception as inner_e: # noqa: BLE001 last_exc = inner_e resp = call_fn(max_tokens=lim, **kwargs) last_exc = None break except Exception as inner2: # noqa: BLE001 last_exc = inner2 continue if last_exc is not None: raise last_exc else: raise be if custom_format is not None: return resp.choices[0].message.parsed.model_dump() return resp.choices[0].message.content except self._openai.RateLimitError: # type: ignore[attr-defined] self.logger.warning( f"Rate limit exceeded. Retrying in {retry_delay} seconds..." ) retries += 1 time.sleep(retry_delay * retries) except Exception as e: # noqa: BLE001 self.logger.error(f"An error occurred: {e}") import traceback traceback.print_exc() break self.logger.error("Max retries exceeded. Unable to complete the request.") return None
[docs] def infer_stream( self, messages: List[Dict[str, str]], temperature: Optional[float], max_tokens: Optional[int], return_json: bool, custom_format: Any = None, max_retries: int = 5, retry_delay: int = 5, ): if return_json or custom_format is not None: yield from super().infer_stream( messages, temperature, max_tokens, return_json, custom_format, max_retries, retry_delay, ) return retries = 0 eff_temp = self._normalize_temperature(self.model_name, temperature) eff_max = ( max_tokens if max_tokens is not None else self._resolve_default_max_tokens(self.model_name) ) while retries < max_retries: try: kwargs: Dict[str, Any] = { "model": self.model_name, "messages": messages, "stream": True, } if eff_temp is not None: kwargs["temperature"] = eff_temp if eff_max is not None: kwargs["max_tokens"] = eff_max stream = self.client.chat.completions.create(**kwargs) for chunk in stream: text = self._extract_text_from_chunk(chunk) if text: yield text return except self._openai.RateLimitError: # type: ignore[attr-defined] self.logger.warning( f"Rate limit exceeded. Retrying in {retry_delay} seconds (streaming)..." ) retries += 1 time.sleep(retry_delay * retries) except Exception as e: # noqa: BLE001 self.logger.error(f"Streaming error: {e}") break # Fallback to non-streaming if streaming fails yield from super().infer_stream( messages, temperature, max_tokens, return_json, custom_format, max_retries, retry_delay, )
[docs] class GeminiClient(BaseLLMClient):
[docs] def __init__(self, model_name: str, logger): try: import google.generativeai as genai # type: ignore except Exception as e: # pragma: no cover raise RuntimeError("google.generativeai not available") from e api_key = os.getenv("GEMINI_API_KEY") if not api_key: raise ValueError("GEMINI_API_KEY not found") self._genai = genai self._genai.configure(api_key=api_key) self.model_name = model_name self.logger = logger
def _build_model(self): return self._genai.GenerativeModel(self.model_name)
[docs] def test_api(self) -> None: model = self._build_model() model.generate_content( "ping", generation_config={ "max_output_tokens": 8, "temperature": 0, }, )
[docs] def infer( self, messages: List[Dict[str, str]], temperature: Optional[float], max_tokens: Optional[int], return_json: bool, custom_format: Any = None, max_retries: int = 5, retry_delay: int = 5, ) -> Optional[str]: if return_json: raise ValueError("Gemini JSON mode not supported here") contents = "" for m in messages: if m["role"] in ("user", "system"): contents += f"{m['content']}\n" retries = 0 while retries < max_retries: try: gen_cfg: Dict[str, Any] = { "temperature": (temperature if temperature is not None else 0) } if max_tokens is not None: gen_cfg["max_output_tokens"] = max_tokens model = self._build_model() resp = model.generate_content(contents, generation_config=gen_cfg) return getattr(resp, "text", None) or getattr(resp, "candidates", [{}])[ 0 ].get("content") except Exception as e: # noqa: BLE001 self.logger.error(f"Gemini error: {e}") retries += 1 time.sleep(retry_delay * retries) return None
@staticmethod def _extract_text_from_stream_chunk(chunk) -> Optional[str]: if chunk is None: return None text = getattr(chunk, "text", None) if text: return text candidates = getattr(chunk, "candidates", None) if not candidates and isinstance(chunk, dict): candidates = chunk.get("candidates") if not candidates: return None candidate = candidates[0] content = getattr(candidate, "content", None) if content is None and isinstance(candidate, dict): content = candidate.get("content") if not content: return None parts = getattr(content, "parts", None) if parts is None and isinstance(content, dict): parts = content.get("parts") if parts and isinstance(parts, list): fragments: List[str] = [] for part in parts: piece = getattr(part, "text", None) if piece is None and isinstance(part, dict): piece = part.get("text") if piece: fragments.append(piece) return "".join(fragments) if fragments else None final_text = getattr(content, "text", None) if final_text is None and isinstance(content, dict): final_text = content.get("text") return final_text
[docs] def infer_stream( self, messages: List[Dict[str, str]], temperature: Optional[float], max_tokens: Optional[int], return_json: bool, custom_format: Any = None, max_retries: int = 5, retry_delay: int = 5, ): if return_json: raise ValueError("Gemini JSON mode not supported here") contents = "" for m in messages: if m["role"] in ("user", "system"): contents += f"{m['content']}\n" retries = 0 while retries < max_retries: try: gen_cfg: Dict[str, Any] = { "temperature": (temperature if temperature is not None else 0) } if max_tokens is not None: gen_cfg["max_output_tokens"] = max_tokens model = self._build_model() stream = model.generate_content( contents, generation_config=gen_cfg, stream=True ) for chunk in stream: text = self._extract_text_from_stream_chunk(chunk) if text: yield text return except Exception as e: # noqa: BLE001 self.logger.error(f"Gemini streaming error: {e}") retries += 1 time.sleep(retry_delay * retries) yield from super().infer_stream( messages, temperature, max_tokens, return_json, custom_format, max_retries, retry_delay, )
[docs] class OpenRouterClient(BaseLLMClient): """ OpenRouter client using OpenAI SDK with custom base URL. Supports models from OpenAI, Anthropic, Google, Qwen, and many other providers. """ # Default model limits based on latest OpenRouter offerings DEFAULT_MODEL_LIMITS: Dict[str, Dict[str, int]] = { "openai/gpt-5": {"max_output": 128_000, "context_window": 400_000}, "openai/gpt-5-codex": {"max_output": 128_000, "context_window": 400_000}, "google/gemini-2.5-flash": {"max_output": 65_536, "context_window": 1_000_000}, "google/gemini-2.5-pro": {"max_output": 65_536, "context_window": 1_000_000}, "anthropic/claude-sonnet-4.5": { "max_output": 16_384, "context_window": 1_000_000, }, }
[docs] def __init__(self, model_id: str, logger): try: from openai import OpenAI as _OpenAI # type: ignore import openai as _openai # type: ignore except Exception as e: # pragma: no cover raise RuntimeError("openai client is not available") from e self._OpenAI = _OpenAI self._openai = _openai self.model_name = model_id self.logger = logger api_key = os.getenv("OPENROUTER_API_KEY") if not api_key: raise ValueError("OPENROUTER_API_KEY not set") # Optional headers for OpenRouter default_headers = {} if site_url := os.getenv("OPENROUTER_SITE_URL"): default_headers["HTTP-Referer"] = site_url if site_name := os.getenv("OPENROUTER_SITE_NAME"): default_headers["X-Title"] = site_name self.client = self._OpenAI( base_url="https://openrouter.ai/api/v1", api_key=api_key, default_headers=default_headers if default_headers else None, ) # Load env overrides for model limits env_limits_raw = os.getenv("OPENROUTER_DEFAULT_MODEL_LIMITS") self._default_limits: Dict[str, Dict[str, int]] = ( self.DEFAULT_MODEL_LIMITS.copy() ) if env_limits_raw: try: env_limits = _json.loads(env_limits_raw) for k, v in env_limits.items(): if isinstance(v, dict): base = self._default_limits.get(k, {}).copy() base.update( { kk: int(vv) for kk, vv in v.items() if isinstance(vv, (int, float, str)) } ) self._default_limits[k] = base except Exception: pass
def _resolve_default_max_tokens(self, model_id: str) -> Optional[int]: """Resolve default max tokens for a model.""" # Highest priority: explicit env per-model tokens mapping mapping_raw = os.getenv("OPENROUTER_MAX_TOKENS_BY_MODEL") mapping: Dict[str, Any] = {} if mapping_raw: try: mapping = _json.loads(mapping_raw) except Exception: mapping = {} if model_id in mapping: try: return int(mapping[model_id]) except Exception: pass # Check for prefix match for k, v in mapping.items(): try: if model_id.startswith(k): return int(v) except Exception: continue # Next: built-in/default-limits map if model_id in self._default_limits: return int(self._default_limits[model_id].get("max_output", 0)) or None # Check for prefix match in default limits for k, v in self._default_limits.items(): try: if model_id.startswith(k): return int(v.get("max_output", 0)) or None except Exception: continue return None
[docs] def test_api(self) -> None: """Test API connectivity with minimal token usage.""" test_messages = [{"role": "user", "content": "ping"}] token_attempts = [1, 4, 16, 32] last_error: Optional[Exception] = None for tok in token_attempts: try: self.client.chat.completions.create( model=self.model_name, messages=test_messages, max_tokens=tok, temperature=0, ) return except Exception as e: # noqa: BLE001 last_error = e msg = str(e).lower() if ( "max_tokens" in msg or "model output limit" in msg or "finish the message" in msg ) and tok != token_attempts[-1]: continue break if last_error: raise ValueError(f"OpenRouter API test failed: {last_error}") raise ValueError("OpenRouter API test failed: unknown error")
[docs] def infer( self, messages: List[Dict[str, str]], temperature: Optional[float], max_tokens: Optional[int], return_json: bool, custom_format: Any = None, max_retries: int = 5, retry_delay: int = 5, ) -> Optional[str]: """Execute inference using OpenRouter.""" retries = 0 call_fn = ( self.client.chat.completions.parse if custom_format is not None else self.client.chat.completions.create ) response_format = ( custom_format if custom_format is not None else ({"type": "json_object"} if return_json else None) ) eff_max = ( max_tokens if max_tokens is not None else self._resolve_default_max_tokens(self.model_name) ) while retries < max_retries: try: kwargs: Dict[str, Any] = { "model": self.model_name, "messages": messages, } if response_format is not None: kwargs["response_format"] = response_format if temperature is not None: kwargs["temperature"] = temperature if eff_max is not None: kwargs["max_tokens"] = eff_max resp = call_fn(**kwargs) if custom_format is not None: return resp.choices[0].message.parsed.model_dump() return resp.choices[0].message.content except self._openai.RateLimitError: # type: ignore[attr-defined] self.logger.warning( f"Rate limit exceeded. Retrying in {retry_delay} seconds..." ) retries += 1 time.sleep(retry_delay * retries) except Exception as e: # noqa: BLE001 self.logger.error(f"OpenRouter error: {e}") import traceback traceback.print_exc() break self.logger.error("Max retries exceeded. Unable to complete the request.") return None
[docs] class VLLMClient(BaseLLMClient):
[docs] def __init__(self, model_name: str, server_url: str, logger): try: from openai import OpenAI except Exception as e: raise RuntimeError("openai package not available for vLLM client") from e if not server_url: raise ValueError("VLLM_SERVER_URL must be provided") self.model_name = model_name # Ensure server_url ends with /v1 for OpenAI-compatible API if not server_url.endswith("/v1"): server_url = server_url.rstrip("/") + "/v1" self.server_url = server_url self.logger = logger self.client = OpenAI( api_key="EMPTY", base_url=self.server_url, )
[docs] def test_api(self) -> None: test_messages = [{"role": "user", "content": "ping"}] try: self.client.chat.completions.create( model=self.model_name, messages=test_messages, max_tokens=8, temperature=0, ) except Exception as e: raise ValueError(f"vLLM API test failed: {e}")
[docs] def infer( self, messages: List[Dict[str, str]], temperature: Optional[float], max_tokens: Optional[int], return_json: bool, custom_format: Any = None, max_retries: int = 5, retry_delay: int = 5, ) -> Optional[str]: if custom_format is not None: self.logger.warning("vLLM does not support custom format, ignoring") retries = 0 while retries < max_retries: try: kwargs: Dict[str, Any] = { "model": self.model_name, "messages": messages, } if temperature is not None: kwargs["temperature"] = temperature if max_tokens is not None: kwargs["max_tokens"] = max_tokens if return_json: kwargs["response_format"] = {"type": "json_object"} resp = self.client.chat.completions.create(**kwargs) return resp.choices[0].message.content except Exception as e: self.logger.error(f"vLLM error: {e}") retries += 1 if retries < max_retries: time.sleep(retry_delay * retries) self.logger.error("Max retries exceeded for vLLM request") return None