Source code for tooluniverse.proteinsplus_tool

"""ProteinsPlus API tool using AsyncPollingTool base class.

Converted to use AsyncPollingTool for cleaner code and automatic polling management.
Maintains all original functionality while reducing boilerplate.
"""

import requests
from typing import Any, Dict, Optional, TYPE_CHECKING
from .async_base import AsyncPollingTool
from .tool_registry import register_tool

if TYPE_CHECKING:
    from .task_progress import TaskProgress

PROTEINSPLUS_BASE_URL = "https://proteins.plus/api"

_JSON_HEADERS = {
    "Accept": "application/json",
    "Content-Type": "application/json",
    "User-Agent": "ToolUniverse/ProteinsPlus",
}

_STATUS_HEADERS = {
    "Accept": "application/json",
    "User-Agent": "ToolUniverse/ProteinsPlus",
}


[docs] @register_tool("ProteinsPlusRESTTool") class ProteinsPlusRESTTool(AsyncPollingTool): """ProteinsPlus API tool for protein-ligand docking and binding site analysis. Now uses AsyncPollingTool base class for automatic polling, progress reporting, and timeout management. Original functionality preserved. """ # Configuration parameters for ProteinsPlus endpoints _SIENA_OPTIONAL_KEYS = { "fragment_length": "fragment_length", "flexibility_sensitivity": "flexibility_sensitivity", "site_radius": "siteRadius", "minimal_site_identity": "minimalSiteIdentity", "minimal_site_coverage": "minimalSiteCoverage", "maximum_mutations": "maximum_mutations", }
[docs] def __init__(self, tool_config): """Initialize ProteinsPlus tool with configuration.""" # Extract config before calling super().__init__() fields = tool_config.get("fields", {}) parameter = tool_config.get("parameter", {}) # Set AsyncPollingTool attributes self.name = tool_config.get("name", "ProteinsPlus_Tool") self.description = tool_config.get("description", "ProteinsPlus API tool") self.parameter = parameter self.poll_interval = fields.get("poll_interval", 15) self.max_duration = fields.get("max_wait_time", 1800) # Initialize AsyncPollingTool (generates return_schema) super().__init__() # ProteinsPlus-specific config self.endpoint = fields.get("endpoint", "") self.method = fields.get("method", "POST").upper() self.required = parameter.get("required", []) self.is_async = fields.get("is_async", False) # Store full config for compatibility self._tool_config = tool_config
[docs] def _transform_params(self, arguments: Dict[str, Any]) -> Dict[str, Any]: """Transform user-facing arguments into the nested format ProteinsPlus expects.""" endpoint = self.endpoint if endpoint == "/dogsite_rest": if "pdb_content" in arguments: return {"pdb_content": arguments["pdb_content"]} return { "dogsite": { "pdbCode": arguments.get("pdb_id", ""), "analysisDetail": "1", "bindingSitePredictionGranularity": "1", "ligand": "", "chain": arguments.get("chain", ""), } } if endpoint == "/dogsite3_rest": return { "dogsite3": { "pdbCode": arguments.get("pdb_id", ""), "analysisDetail": arguments.get("analysis_detail", "1"), "bindingSitePredictionGranularity": arguments.get( "druggability", "1" ), "ligand": arguments.get("ligand", ""), "chain": arguments.get("chain", ""), "ligandBias": "1" if arguments.get("ligand_bias", False) else "0", } } if endpoint == "/poseview_rest": return { "poseview": { "pdbCode": arguments.get("pdb_id", ""), "ligand": arguments.get("ligand", ""), } } if endpoint == "/siena_rest": siena_params = { "pdbCode": arguments.get("pdb_id", ""), "mode": arguments.get("mode", "screening"), "ligand": arguments.get("ligand", ""), "pocket": arguments.get("pocket", ""), } for arg_key, api_key in self._SIENA_OPTIONAL_KEYS.items(): if arg_key in arguments: siena_params[api_key] = arguments[arg_key] return {"siena": siena_params} if endpoint == "/structurechecker_rest": return { "structurechecker": { "pdbCode": arguments.get("pdb_id", ""), "setting": arguments.get("setting", "combined"), } } return {k: v for k, v in arguments.items() if v is not None}
# ======================================================================== # Shared helpers # ========================================================================
[docs] def _build_api_url(self, arguments: Dict[str, Any]) -> str: """Build API URL by substituting argument placeholders in the endpoint.""" url = PROTEINSPLUS_BASE_URL + self.endpoint for key, value in arguments.items(): placeholder = f"{{{key}}}" if placeholder in url: url = url.replace(placeholder, str(value)) return url
[docs] def _validate_required(self, arguments: Dict[str, Any]) -> None: """Raise ValueError if any required parameters are missing.""" missing = [k for k in self.required if k not in arguments] if missing: raise ValueError(f"Missing required parameter(s): {', '.join(missing)}")
# ======================================================================== # AsyncPollingTool Required Methods # ========================================================================
[docs] def submit_job(self, arguments: Dict[str, Any]) -> str: """Submit job to ProteinsPlus API and return job location URL. This method handles job submission for async tools. For sync tools, it's not called (handled by run() override). """ self._validate_required(arguments) url = self._build_api_url(arguments) request_data = self._transform_params(arguments) response = requests.post( url, json=request_data, headers=_JSON_HEADERS, timeout=60.0, ) if response.status_code == 404: raise RuntimeError(f"Endpoint not found: {url}") if response.status_code == 400: raise RuntimeError(f"Bad request: {response.text}") if response.status_code not in (200, 201, 202): raise RuntimeError(f"API returned {response.status_code}: {response.text}") # Parse response try: job_data = response.json() except Exception as e: raise RuntimeError(f"Failed to parse response: {e}") # Extract job location URL status_url = job_data.get("location") if not status_url: # Try extracting job_id job_id = job_data.get("job_id") or job_data.get("id") if job_id: status_url = f"{PROTEINSPLUS_BASE_URL}/jobs/{job_id}/status" else: # Check if job completed immediately if job_data.get("status") in ("completed", "success"): # Store for retrieval in check_status self._immediate_result = job_data.get("results", job_data) return "COMPLETED_IMMEDIATELY" raise RuntimeError("No job location or ID in response") return status_url
[docs] def check_status(self, job_id: str) -> Dict[str, Any]: """Check ProteinsPlus job status and return result if complete. Args: job_id: Job location URL from submit_job() Returns: Dict with keys: - done (bool): True if complete - result (any): Results if done - progress (int): Progress percentage (0-100) - error (str): Error message if failed """ # Handle immediate completion case if job_id == "COMPLETED_IMMEDIATELY": result = getattr(self, "_immediate_result", {}) return {"done": True, "result": result, "progress": 100} # Check status via HTTP try: response = requests.get( job_id, headers=_STATUS_HEADERS, timeout=30.0, ) except Exception as e: return {"done": False, "error": f"Status check failed: {e}"} # Handle HTTP 202 (still processing) if response.status_code == 202: return {"done": False, "progress": 30} # Handle HTTP errors if response.status_code not in (200, 201): return { "done": False, "error": f"Status check returned {response.status_code}: {response.text}", } # Parse response try: status_data = response.json() except Exception as e: return {"done": False, "error": f"Failed to parse status: {e}"} # Check internal status_code field (ProteinsPlus-specific) if status_data.get("status_code") == 202: return {"done": False, "progress": 60} # Check status field status = status_data.get("status", "").lower() if status in ("failed", "error"): error_msg = status_data.get("error", "Job failed") return {"done": False, "error": error_msg} # Job complete - extract results results = status_data.get("results", status_data) return {"done": True, "result": results, "progress": 100}
[docs] def format_result(self, result: Any) -> Dict[str, Any]: """Format ProteinsPlus results into standard response format.""" return { "data": result, "metadata": { "source": "ProteinsPlus", "endpoint": self.endpoint, "execution_type": "async", }, }
# ======================================================================== # Override run() for sync/async branching # ========================================================================
[docs] async def run( self, arguments: Dict[str, Any], progress: Optional["TaskProgress"] = None ) -> Dict[str, Any]: """Execute the tool with provided arguments. Overrides AsyncPollingTool.run() to support both sync and async tools. """ if progress: await progress.set_message("Starting ProteinsPlus job") # For async tools, use AsyncPollingTool's run() if self.is_async: return await super().run(arguments, progress) # For sync tools, execute directly return await self._run_sync_request(arguments, progress)
[docs] async def _run_sync_request( self, arguments: Dict[str, Any], progress: Optional["TaskProgress"] ) -> Dict[str, Any]: """Execute a synchronous (non-polling) request.""" if progress: await progress.set_message("Executing synchronous request") missing = [k for k in self.required if k not in arguments] if missing: return { "error": f"Missing required parameter(s): {', '.join(missing)}", "query": arguments, } url = self._build_api_url(arguments) request_data = self._transform_params(arguments) try: if self.method == "POST": response = requests.post( url, json=request_data, headers=_JSON_HEADERS, timeout=60.0, ) else: response = requests.get( url, params=request_data, headers=_STATUS_HEADERS, timeout=60.0, ) if response.status_code == 404: return { "error": "Endpoint not found", "detail": response.text, "query": arguments, } if response.status_code == 400: return { "error": "Bad request", "detail": response.text, "query": arguments, } if response.status_code not in (200, 201): return { "error": f"API returned {response.status_code}", "detail": response.text, "query": arguments, } data = response.json() return { "data": data, "metadata": { "source": "ProteinsPlus", "endpoint": self.endpoint, "query": arguments, "execution_type": "sync", }, } except requests.Timeout: return { "error": "Request timeout", "detail": "Request timed out after 60 seconds", } except Exception as e: return {"error": "Request failed", "detail": str(e)}