Source code for tooluniverse.nvidia_nim_tool

"""
NVIDIA NIM Healthcare API Tool.

This module provides a unified interface to NVIDIA's cloud-hosted healthcare AI APIs,
including protein structure prediction, molecular docking, protein design, genomics,
and medical imaging tools.

All APIs require a NVIDIA API key set as the NVIDIA_API_KEY environment variable.
Get your key at: https://build.nvidia.com

Rate limit: 40 requests per minute (enforced internally).
"""

import os
import re
import time
import requests
from typing import Dict, Any, Optional, List
from urllib.parse import urlparse
from .base_tool import BaseTool
from .tool_registry import register_tool


# Rate limiting configuration
_last_request_time = 0.0
_MIN_REQUEST_INTERVAL = 1.5  # 40 RPM = 1.5 seconds between requests


def _enforce_rate_limit():
    """Enforce rate limiting (40 RPM = 1.5s between requests)."""
    global _last_request_time
    current_time = time.time()
    elapsed = current_time - _last_request_time
    if elapsed < _MIN_REQUEST_INTERVAL:
        time.sleep(_MIN_REQUEST_INTERVAL - elapsed)
    _last_request_time = time.time()


[docs] @register_tool("NvidiaNIMTool") class NvidiaNIMTool(BaseTool): """ NVIDIA NIM Healthcare API tool. Provides unified access to 16 NVIDIA cloud-hosted healthcare AI APIs: Structure Prediction: - AlphaFold2, AlphaFold2-Multimer, ESMFold, OpenFold2, OpenFold3, Boltz2 Protein Design: - ProteinMPNN, RFdiffusion Molecular Tools: - DiffDock, GenMol, MolMIM Genomics: - Evo2-40B, MSA-Search, ESM2-650M Medical Imaging: - MAISI, Vista3D Configuration fields: - endpoint: API endpoint path (relative to base URL); may contain {placeholder} segments filled from path_params (e.g. "arc/{model}/generate") - path_params: {name: default} map for templated endpoint segments; the value is taken from the request argument of the same name (or the default) and is not forwarded in the request body (e.g. {"model": "evo2-40b"}) - base_url: Override base URL (default: https://health.api.nvidia.com/v1/biology) - async_expected: Whether 202 async response is expected - poll_seconds: NVCF-POLL-SECONDS header value (default 300) - response_type: Expected response type (json, pdb, mfasta, zip) - timeout: Request timeout in seconds (default 600) """ DEFAULT_BASE_URL = "https://health.api.nvidia.com/v1/biology" # Hosted NVCF functions are polled on the SAME gateway host they are invoked # on. The biology/medical-imaging NIMs live on health.api.nvidia.com, so async # results must be polled there — NOT on integrate.api.nvidia.com, which only # serves the OpenAI-compatible LLM endpoints and has no /v1/status route # (it returns a plain-text "404 page not found"). The poll host is derived # from the invocation base_url at request time; this is the default. STATUS_URL = "https://health.api.nvidia.com/v1/status" ASSETS_URL = "https://api.nvcf.nvidia.com/v2/nvcf/assets" DEFAULT_TIMEOUT = 600 DEFAULT_POLL_SECONDS = 300 MAX_POLL_ATTEMPTS = 120 # 10 minutes with 5s intervals POLL_INTERVAL = 5
[docs] def __init__(self, tool_config: Dict[str, Any]): super().__init__(tool_config) fields = tool_config.get("fields", {}) self.endpoint = fields.get("endpoint", "") # Optional {placeholder} -> default map for templated endpoints (e.g. # selecting a hosted model variant like evo2-40b vs evo2-7b). self.path_params = fields.get("path_params", {}) or {} self.base_url = fields.get("base_url", self.DEFAULT_BASE_URL) self.async_expected = fields.get("async_expected", False) self.poll_seconds = fields.get("poll_seconds", self.DEFAULT_POLL_SECONDS) self.response_type = fields.get("response_type", "json") self.timeout = fields.get("timeout", self.DEFAULT_TIMEOUT) # Get API key from environment self.api_key = os.environ.get("NVIDIA_API_KEY")
[docs] def _get_headers(self) -> Dict[str, str]: """Build request headers with authentication.""" headers = { "Content-Type": "application/json", "Accept": "application/json", } if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" # Add NVCF-POLL-SECONDS for async operations if self.async_expected: headers["NVCF-POLL-SECONDS"] = str(self.poll_seconds) return headers
[docs] def _build_url(self, arguments: Optional[Dict[str, Any]] = None) -> str: """Build the full API URL, filling any {placeholder} path params. A templated endpoint (e.g. "arc/{model}/generate") lets one tool target several hosted model variants. Each placeholder is filled from the request argument of the same name, falling back to the per-field default in ``path_params``; values are restricted to a safe slug charset. """ endpoint = self.endpoint for key, default in self.path_params.items(): value = (arguments or {}).get(key) or default if not re.fullmatch(r"[A-Za-z0-9._-]+", str(value)): value = default endpoint = endpoint.replace("{" + key + "}", str(value)) # Handle endpoints that include full path if endpoint.startswith("http"): return endpoint # Ensure proper URL construction base = self.base_url.rstrip("/") endpoint = endpoint.lstrip("/") return f"{base}/{endpoint}"
[docs] def _poll_for_result(self, req_id: str, headers: Dict[str, str]) -> Dict[str, Any]: """ Poll the status endpoint for async operation results. Args: req_id: The nvcf-reqid from the 202 response headers: Request headers including auth Returns: Final response from the API """ # Poll on the same gateway host the request was sent to (e.g. # health.api.nvidia.com), falling back to the documented default. host = urlparse(self.base_url).netloc or urlparse(self.STATUS_URL).netloc poll_url = f"https://{host}/v1/status/{req_id}" for attempt in range(self.MAX_POLL_ATTEMPTS): try: _enforce_rate_limit() response = requests.get(poll_url, headers=headers, timeout=self.timeout) if response.status_code != 202: # Operation complete return self._parse_response(response) # Still processing, wait and retry time.sleep(self.POLL_INTERVAL) except requests.exceptions.RequestException as e: return { "status": "error", "error": "Poll request failed", "detail": str(e), "request_id": req_id, } return { "status": "error", "error": "Polling timeout", "detail": f"Operation did not complete within {self.MAX_POLL_ATTEMPTS * self.POLL_INTERVAL} seconds", "request_id": req_id, }
[docs] def _parse_response(self, response: requests.Response) -> Dict[str, Any]: """Parse API response based on response type and status.""" if response.status_code == 401: return { "status": "error", "error": "Authentication failed", "detail": "Invalid or missing NVIDIA_API_KEY. Get your key at https://build.nvidia.com", "status_code": 401, } if response.status_code == 429: return { "status": "error", "error": "Rate limit exceeded", "detail": "NVIDIA NIM API rate limit (40 RPM) exceeded. Please wait and retry.", "status_code": 429, } if response.status_code == 404: # Two distinct 404s: the route is valid but the model function isn't # provisioned for this account ("Not found for account ..." — a JSON # body), versus a genuinely wrong path (plain "404 page not found"). body = response.text or "" if "not found for account" in body.lower(): return { "status": "error", "error": "Model not available for this NVIDIA account", "detail": ( "This NIM model is not provisioned for your NVIDIA_API_KEY " "and may require special/enterprise access. " + body[:300] ), "status_code": 404, } return { "status": "error", "error": "Endpoint not found", "detail": body, "status_code": 404, } if response.status_code >= 500: # 5xx from the gateway often has an empty body; the NVCF headers carry # the real signal (e.g. nvcf-status "errored" on a cold/failed function). nvcf_status = response.headers.get("nvcf-status") detail = response.text.strip() if not detail and nvcf_status: detail = ( f"NVCF function status: {nvcf_status}; request id " f"{response.headers.get('nvcf-reqid')}. The hosted function may be " f"cold-starting or temporarily unavailable — retry shortly." ) elif not detail: detail = "No response body." return { "status": "error", "error": "Server error", "detail": detail, "status_code": response.status_code, } if response.status_code not in [200, 201]: body = response.text or "" # A hosted function NVIDIA has flagged unhealthy returns # "DEGRADED function cannot be invoked" (seen as a 400). It's a # transient server-side state, not a bad request from the caller. if "degraded" in body.lower(): return { "status": "error", "error": "NIM function temporarily degraded", "detail": ( "NVIDIA has flagged this hosted model as degraded and is " "rejecting invocations — a transient server-side state, not " "a problem with your request. Retry later. " + body[:300] ), "status_code": response.status_code, } return { "status": "error", "error": f"API returned status {response.status_code}", "detail": body[:1000] if body else "No details", "status_code": response.status_code, } # Handle different response types content_type = response.headers.get("Content-Type", "") if "application/zip" in content_type: # Return binary content info for ZIP files (medical imaging) return { "status": "success", "content_type": "application/zip", "content_length": len(response.content), "data": f"<ZIP file, {len(response.content)} bytes>", "note": "Use response.content to access the raw ZIP data", "_raw_content": response.content, } if "application/octet-stream" in content_type or self.response_type == "binary": # Return binary content info (e.g., npz files for embeddings) return { "status": "success", "content_type": content_type or "application/octet-stream", "content_length": len(response.content), "data": f"<Binary data, {len(response.content)} bytes>", "note": "Use _raw_content to access the raw binary data", "_raw_content": response.content, } if self.response_type == "pdb" or "text/plain" in content_type: # PDB structure text. Some NIMs (e.g. ESMFold) wrap it in JSON # {"pdbs": ["...ATOM records..."]} even on the pdb path, so unwrap to # the actual PDB string rather than handing back a JSON blob. structure = response.text stripped = structure.lstrip() if stripped.startswith(("{", "[")): try: payload = response.json() except ValueError: payload = None if isinstance(payload, dict): pdbs = payload.get("pdbs") or payload.get("pdb") if isinstance(pdbs, list) and pdbs: structure = pdbs[0] elif isinstance(pdbs, str): structure = pdbs return { "status": "success", "structure": structure, "format": "pdb", } if self.response_type == "mfasta": # Multi-FASTA format try: data = response.json() return { "status": "success", "data": data, "format": "mfasta", } except ValueError: return { "status": "success", "sequences": response.text, "format": "mfasta", } # Default: JSON response try: data = response.json() # Some NIMs (e.g. DiffDock) answer HTTP 200 but report an inner # failure (e.g. {"status": "failed", "detail": ...}). Surface that as # an error rather than a misleading top-level success. if isinstance(data, dict) and str(data.get("status", "")).lower() in ( "failed", "error", "errored", ): detail = str(data.get("detail") or data.get("message") or data)[:300] return { "status": "error", "error": "NIM reported an inner failure", "detail": detail, "data": data, } return { "status": "success", "data": data, } except ValueError as e: return { "status": "error", "error": "Failed to parse JSON response", "detail": str(e), "raw_response": response.text[:500] if response.text else None, }
[docs] def _validate_api_key(self) -> Optional[Dict[str, Any]]: """Validate API key is present.""" if not self.api_key: return { "status": "error", "error": "Missing API key", "detail": ( "NVIDIA_API_KEY environment variable not set. " "Get your API key at https://build.nvidia.com and set it:\n" "export NVIDIA_API_KEY=nvapi-..." ), } return None
[docs] def _upload_asset( self, content: str, description: str = "diffdock-file" ) -> Dict[str, Any]: """ Upload a file to NVIDIA's asset storage for tools that require staged inputs (e.g., DiffDock). This implements the NVCF asset upload pattern: 1. POST to assets API to get an upload URL and asset ID 2. PUT the file content to the upload URL 3. Return the asset ID to be used in the main API request Args: content: File content to upload (string) description: Description for the asset Returns: Dictionary with 'asset_id' on success, or error details on failure """ if not self.api_key: return { "status": "error", "error": "Missing API key", "detail": "NVIDIA_API_KEY required for asset upload", } # Step 1: Request upload URL headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", "Accept": "application/json", } payload = { "contentType": "text/plain", "description": description, } try: _enforce_rate_limit() response = requests.post( self.ASSETS_URL, headers=headers, json=payload, timeout=30 ) response.raise_for_status() result = response.json() upload_url = result.get("uploadUrl") asset_id = result.get("assetId") if not upload_url or not asset_id: return { "status": "error", "error": "Failed to get upload URL", "detail": f"Response: {result}", } # Step 2: Upload content to S3 s3_headers = { "x-amz-meta-nvcf-asset-description": description, "Content-Type": "text/plain", } _enforce_rate_limit() upload_response = requests.put( upload_url, data=content, headers=s3_headers, timeout=300 ) upload_response.raise_for_status() return { "status": "success", "asset_id": asset_id, } except requests.exceptions.RequestException as e: return { "status": "error", "error": "Asset upload failed", "detail": str(e), }
[docs] def _handle_diffdock_staged(self, arguments: Dict[str, Any]) -> Dict[str, Any]: """ Handle DiffDock's staged asset upload workflow. When is_staged=True, protein and ligand should be raw content that needs to be uploaded as assets. The asset IDs are then used in the actual request. Args: arguments: Original arguments with protein and ligand content Returns: Modified arguments with asset IDs, or error dict on failure """ protein_content = arguments.get("protein", "") ligand_content = arguments.get("ligand", "") # Upload protein protein_result = self._upload_asset(protein_content, "protein-pdb") if protein_result.get("status") == "error": return protein_result protein_asset_id = protein_result["asset_id"] # Upload ligand ligand_result = self._upload_asset(ligand_content, "ligand-sdf") if ligand_result.get("status") == "error": return ligand_result ligand_asset_id = ligand_result["asset_id"] # Return asset IDs and update arguments return { "status": "success", "protein_asset_id": protein_asset_id, "ligand_asset_id": ligand_asset_id, "asset_references": f"{protein_asset_id},{ligand_asset_id}", }
[docs] def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: """ Execute the NVIDIA NIM API call. Args: arguments: Dictionary of API-specific parameters Returns: Dictionary containing: - status: "success" or error information - data: API response data - Additional fields based on response type """ # Validate API key key_error = self._validate_api_key() if key_error: return key_error # Validate required parameters missing = [ k for k in self.get_required_parameters() if k not in arguments or arguments[k] is None ] if missing: return { "status": "error", "error": "Missing required parameters", "detail": f"Required: {', '.join(missing)}", "provided": list(arguments.keys()), } # Build URL (filling any templated path params) and headers url = self._build_url(arguments) headers = self._get_headers() # Handle DiffDock staged asset upload workflow asset_references = None request_arguments = arguments.copy() # Path-param values select the model/variant in the URL, not body fields. for key in self.path_params: request_arguments.pop(key, None) is_diffdock = "diffdock" in self.endpoint.lower() is_staged = arguments.get("is_staged", False) if is_diffdock and is_staged: # Upload assets and get asset IDs staged_result = self._handle_diffdock_staged(arguments) if staged_result.get("status") == "error": return staged_result # Update arguments with asset IDs request_arguments["protein"] = staged_result["protein_asset_id"] request_arguments["ligand"] = staged_result["ligand_asset_id"] request_arguments["is_staged"] = True asset_references = staged_result["asset_references"] # Add asset references header headers["NVCF-INPUT-ASSET-REFERENCES"] = asset_references # Enforce rate limiting _enforce_rate_limit() try: # Make the API request response = requests.post( url, headers=headers, json=request_arguments, timeout=self.timeout ) # Handle async response (202 Accepted) if response.status_code == 202: req_id = response.headers.get("nvcf-reqid") if not req_id: return { "status": "error", "error": "Async operation started but no request ID returned", "detail": "Missing nvcf-reqid header", } # For DiffDock with staged assets, ensure headers are preserved for polling if asset_references: headers["NVCF-INPUT-ASSET-REFERENCES"] = asset_references # Poll for result return self._poll_for_result(req_id, headers) # Handle synchronous response return self._parse_response(response) except requests.exceptions.Timeout: return { "status": "error", "error": "Request timeout", "detail": f"Request timed out after {self.timeout} seconds", } except requests.exceptions.ConnectionError as e: return { "status": "error", "error": "Connection error", "detail": str(e), } except requests.exceptions.RequestException as e: return { "status": "error", "error": "Request failed", "detail": str(e), }