Source code for tooluniverse.nvidia_nim_tool
"""
NVIDIA NIM Healthcare API Tool.
This module provides a unified interface to NVIDIA's cloud-hosted healthcare AI APIs,
including protein structure prediction, molecular docking, protein design, genomics,
and medical imaging tools.
All APIs require a NVIDIA API key set as the NVIDIA_API_KEY environment variable.
Get your key at: https://build.nvidia.com
Rate limit: 40 requests per minute (enforced internally).
"""
import os
import time
import requests
from typing import Dict, Any, Optional, List
from .base_tool import BaseTool
from .tool_registry import register_tool
# Rate limiting configuration
_last_request_time = 0.0
_MIN_REQUEST_INTERVAL = 1.5 # 40 RPM = 1.5 seconds between requests
def _enforce_rate_limit():
"""Enforce rate limiting (40 RPM = 1.5s between requests)."""
global _last_request_time
current_time = time.time()
elapsed = current_time - _last_request_time
if elapsed < _MIN_REQUEST_INTERVAL:
time.sleep(_MIN_REQUEST_INTERVAL - elapsed)
_last_request_time = time.time()
[docs]
@register_tool("NvidiaNIMTool")
class NvidiaNIMTool(BaseTool):
"""
NVIDIA NIM Healthcare API tool.
Provides unified access to 16 NVIDIA cloud-hosted healthcare AI APIs:
Structure Prediction:
- AlphaFold2, AlphaFold2-Multimer, ESMFold, OpenFold2, OpenFold3, Boltz2
Protein Design:
- ProteinMPNN, RFdiffusion
Molecular Tools:
- DiffDock, GenMol, MolMIM
Genomics:
- Evo2-40B, MSA-Search, ESM2-650M
Medical Imaging:
- MAISI, Vista3D
Configuration fields:
- endpoint: API endpoint path (relative to base URL)
- base_url: Override base URL (default: https://health.api.nvidia.com/v1/biology)
- async_expected: Whether 202 async response is expected
- poll_seconds: NVCF-POLL-SECONDS header value (default 300)
- response_type: Expected response type (json, pdb, mfasta, zip)
- timeout: Request timeout in seconds (default 600)
"""
DEFAULT_BASE_URL = "https://health.api.nvidia.com/v1/biology"
STATUS_URL = "https://integrate.api.nvidia.com/v1/status"
ASSETS_URL = "https://api.nvcf.nvidia.com/v2/nvcf/assets"
DEFAULT_TIMEOUT = 600
DEFAULT_POLL_SECONDS = 300
MAX_POLL_ATTEMPTS = 120 # 10 minutes with 5s intervals
POLL_INTERVAL = 5
[docs]
def __init__(self, tool_config: Dict[str, Any]):
super().__init__(tool_config)
fields = tool_config.get("fields", {})
self.endpoint = fields.get("endpoint", "")
self.base_url = fields.get("base_url", self.DEFAULT_BASE_URL)
self.async_expected = fields.get("async_expected", False)
self.poll_seconds = fields.get("poll_seconds", self.DEFAULT_POLL_SECONDS)
self.response_type = fields.get("response_type", "json")
self.timeout = fields.get("timeout", self.DEFAULT_TIMEOUT)
# Get API key from environment
self.api_key = os.environ.get("NVIDIA_API_KEY")
[docs]
def _get_headers(self) -> Dict[str, str]:
"""Build request headers with authentication."""
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
# Add NVCF-POLL-SECONDS for async operations
if self.async_expected:
headers["NVCF-POLL-SECONDS"] = str(self.poll_seconds)
return headers
[docs]
def _build_url(self) -> str:
"""Build the full API URL."""
# Handle endpoints that include full path
if self.endpoint.startswith("http"):
return self.endpoint
# Ensure proper URL construction
base = self.base_url.rstrip("/")
endpoint = self.endpoint.lstrip("/")
return f"{base}/{endpoint}"
[docs]
def _poll_for_result(self, req_id: str, headers: Dict[str, str]) -> Dict[str, Any]:
"""
Poll the status endpoint for async operation results.
Args:
req_id: The nvcf-reqid from the 202 response
headers: Request headers including auth
Returns:
Final response from the API
"""
poll_url = f"{self.STATUS_URL}/{req_id}"
for attempt in range(self.MAX_POLL_ATTEMPTS):
try:
_enforce_rate_limit()
response = requests.get(poll_url, headers=headers, timeout=self.timeout)
if response.status_code != 202:
# Operation complete
return self._parse_response(response)
# Still processing, wait and retry
time.sleep(self.POLL_INTERVAL)
except requests.exceptions.RequestException as e:
return {
"status": "error",
"error": "Poll request failed",
"detail": str(e),
"request_id": req_id,
}
return {
"status": "error",
"error": "Polling timeout",
"detail": f"Operation did not complete within {self.MAX_POLL_ATTEMPTS * self.POLL_INTERVAL} seconds",
"request_id": req_id,
}
[docs]
def _parse_response(self, response: requests.Response) -> Dict[str, Any]:
"""Parse API response based on response type and status."""
if response.status_code == 401:
return {
"status": "error",
"error": "Authentication failed",
"detail": "Invalid or missing NVIDIA_API_KEY. Get your key at https://build.nvidia.com",
"status_code": 401,
}
if response.status_code == 429:
return {
"status": "error",
"error": "Rate limit exceeded",
"detail": "NVIDIA NIM API rate limit (40 RPM) exceeded. Please wait and retry.",
"status_code": 429,
}
if response.status_code == 404:
return {
"status": "error",
"error": "Endpoint not found",
"detail": response.text,
"status_code": 404,
}
if response.status_code >= 500:
return {
"status": "error",
"error": "Server error",
"detail": response.text,
"status_code": response.status_code,
}
if response.status_code not in [200, 201]:
return {
"status": "error",
"error": f"API returned status {response.status_code}",
"detail": response.text[:1000] if response.text else "No details",
"status_code": response.status_code,
}
# Handle different response types
content_type = response.headers.get("Content-Type", "")
if "application/zip" in content_type:
# Return binary content info for ZIP files (medical imaging)
return {
"status": "success",
"content_type": "application/zip",
"content_length": len(response.content),
"data": f"<ZIP file, {len(response.content)} bytes>",
"note": "Use response.content to access the raw ZIP data",
"_raw_content": response.content,
}
if "application/octet-stream" in content_type or self.response_type == "binary":
# Return binary content info (e.g., npz files for embeddings)
return {
"status": "success",
"content_type": content_type or "application/octet-stream",
"content_length": len(response.content),
"data": f"<Binary data, {len(response.content)} bytes>",
"note": "Use _raw_content to access the raw binary data",
"_raw_content": response.content,
}
if self.response_type == "pdb" or "text/plain" in content_type:
# PDB structure text
return {
"status": "success",
"structure": response.text,
"format": "pdb",
}
if self.response_type == "mfasta":
# Multi-FASTA format
try:
data = response.json()
return {
"status": "success",
"data": data,
"format": "mfasta",
}
except ValueError:
return {
"status": "success",
"sequences": response.text,
"format": "mfasta",
}
# Default: JSON response
try:
data = response.json()
return {
"status": "success",
"data": data,
}
except ValueError as e:
return {
"status": "error",
"error": "Failed to parse JSON response",
"detail": str(e),
"raw_response": response.text[:500] if response.text else None,
}
[docs]
def _validate_api_key(self) -> Optional[Dict[str, Any]]:
"""Validate API key is present."""
if not self.api_key:
return {
"status": "error",
"error": "Missing API key",
"detail": (
"NVIDIA_API_KEY environment variable not set. "
"Get your API key at https://build.nvidia.com and set it:\n"
"export NVIDIA_API_KEY=nvapi-..."
),
}
return None
[docs]
def _upload_asset(
self, content: str, description: str = "diffdock-file"
) -> Dict[str, Any]:
"""
Upload a file to NVIDIA's asset storage for tools that require staged inputs (e.g., DiffDock).
This implements the NVCF asset upload pattern:
1. POST to assets API to get an upload URL and asset ID
2. PUT the file content to the upload URL
3. Return the asset ID to be used in the main API request
Args:
content: File content to upload (string)
description: Description for the asset
Returns:
Dictionary with 'asset_id' on success, or error details on failure
"""
if not self.api_key:
return {
"status": "error",
"error": "Missing API key",
"detail": "NVIDIA_API_KEY required for asset upload",
}
# Step 1: Request upload URL
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}
payload = {
"contentType": "text/plain",
"description": description,
}
try:
_enforce_rate_limit()
response = requests.post(
self.ASSETS_URL, headers=headers, json=payload, timeout=30
)
response.raise_for_status()
result = response.json()
upload_url = result.get("uploadUrl")
asset_id = result.get("assetId")
if not upload_url or not asset_id:
return {
"status": "error",
"error": "Failed to get upload URL",
"detail": f"Response: {result}",
}
# Step 2: Upload content to S3
s3_headers = {
"x-amz-meta-nvcf-asset-description": description,
"Content-Type": "text/plain",
}
_enforce_rate_limit()
upload_response = requests.put(
upload_url, data=content, headers=s3_headers, timeout=300
)
upload_response.raise_for_status()
return {
"status": "success",
"asset_id": asset_id,
}
except requests.exceptions.RequestException as e:
return {
"status": "error",
"error": "Asset upload failed",
"detail": str(e),
}
[docs]
def _handle_diffdock_staged(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
Handle DiffDock's staged asset upload workflow.
When is_staged=True, protein and ligand should be raw content that needs
to be uploaded as assets. The asset IDs are then used in the actual request.
Args:
arguments: Original arguments with protein and ligand content
Returns:
Modified arguments with asset IDs, or error dict on failure
"""
protein_content = arguments.get("protein", "")
ligand_content = arguments.get("ligand", "")
# Upload protein
protein_result = self._upload_asset(protein_content, "protein-pdb")
if protein_result.get("status") == "error":
return protein_result
protein_asset_id = protein_result["asset_id"]
# Upload ligand
ligand_result = self._upload_asset(ligand_content, "ligand-sdf")
if ligand_result.get("status") == "error":
return ligand_result
ligand_asset_id = ligand_result["asset_id"]
# Return asset IDs and update arguments
return {
"status": "success",
"protein_asset_id": protein_asset_id,
"ligand_asset_id": ligand_asset_id,
"asset_references": f"{protein_asset_id},{ligand_asset_id}",
}
[docs]
def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute the NVIDIA NIM API call.
Args:
arguments: Dictionary of API-specific parameters
Returns:
Dictionary containing:
- status: "success" or error information
- data: API response data
- Additional fields based on response type
"""
# Validate API key
key_error = self._validate_api_key()
if key_error:
return key_error
# Validate required parameters
missing = [
k
for k in self.get_required_parameters()
if k not in arguments or arguments[k] is None
]
if missing:
return {
"status": "error",
"error": "Missing required parameters",
"detail": f"Required: {', '.join(missing)}",
"provided": list(arguments.keys()),
}
# Build URL and headers
url = self._build_url()
headers = self._get_headers()
# Handle DiffDock staged asset upload workflow
asset_references = None
request_arguments = arguments.copy()
is_diffdock = "diffdock" in self.endpoint.lower()
is_staged = arguments.get("is_staged", False)
if is_diffdock and is_staged:
# Upload assets and get asset IDs
staged_result = self._handle_diffdock_staged(arguments)
if staged_result.get("status") == "error":
return staged_result
# Update arguments with asset IDs
request_arguments["protein"] = staged_result["protein_asset_id"]
request_arguments["ligand"] = staged_result["ligand_asset_id"]
request_arguments["is_staged"] = True
asset_references = staged_result["asset_references"]
# Add asset references header
headers["NVCF-INPUT-ASSET-REFERENCES"] = asset_references
# Enforce rate limiting
_enforce_rate_limit()
try:
# Make the API request
response = requests.post(
url, headers=headers, json=request_arguments, timeout=self.timeout
)
# Handle async response (202 Accepted)
if response.status_code == 202:
req_id = response.headers.get("nvcf-reqid")
if not req_id:
return {
"status": "error",
"error": "Async operation started but no request ID returned",
"detail": "Missing nvcf-reqid header",
}
# For DiffDock with staged assets, ensure headers are preserved for polling
if asset_references:
headers["NVCF-INPUT-ASSET-REFERENCES"] = asset_references
# Poll for result
return self._poll_for_result(req_id, headers)
# Handle synchronous response
return self._parse_response(response)
except requests.exceptions.Timeout:
return {
"status": "error",
"error": "Request timeout",
"detail": f"Request timed out after {self.timeout} seconds",
}
except requests.exceptions.ConnectionError as e:
return {
"status": "error",
"error": "Connection error",
"detail": str(e),
}
except requests.exceptions.RequestException as e:
return {
"status": "error",
"error": "Request failed",
"detail": str(e),
}