Source code for tooluniverse.rcsb_pdb_tool
from .base_tool import BaseTool
from .tool_registry import register_tool
import requests
[docs]
@register_tool("RCSBTool")
class RCSBTool(BaseTool):
[docs]
def __init__(self, tool_config):
super().__init__(tool_config)
# Prefer optional `rcsbapi` (GraphQL client). Fall back to direct REST calls
# so these tools still work in minimal environments.
try:
from rcsbapi.data import DataQuery
self.DataQuery = DataQuery
except ImportError as e:
self.DataQuery = None
self._rcsbapi_import_error = e
except Exception as e:
raise RuntimeError(
f"Failed to initialize RCSB API client. "
f"This may be due to network issues or API unavailability. "
f"Original error: {str(e)}"
) from e
self.name = tool_config.get("name")
self.description = tool_config.get("description")
self.input_type = tool_config.get("input_type")
fields = tool_config.get("fields", {})
self.search_fields = fields.get("search_fields", {})
self.return_fields = fields.get("return_fields", [])
parameter = tool_config.get("parameter", {})
self.parameter_schema = parameter.get("properties", {})
self.required_params = parameter.get("required", []) or []
self._rest_api_base = "https://data.rcsb.org/rest/v1/core"
self._timeout = 60
[docs]
def validate_params(self, params: dict):
for param_name in self.required_params:
if param_name not in params:
raise ValueError(f"Missing required parameter: {param_name}")
return True
[docs]
def prepare_input_ids(self, params: dict):
for param_name in self.search_fields:
if param_name in params:
val = params[param_name]
return val if isinstance(val, list) else [val]
raise ValueError("No valid search parameter provided")
[docs]
def _split_composite_id(
self, value: str, sep: str, expected_parts: int, label: str
):
if not isinstance(value, str) or not value.strip():
raise ValueError(f"Invalid {label}: must be a non-empty string")
parts = value.strip()
parts = parts.split(sep)
if len(parts) != expected_parts or any(not p for p in parts):
raise ValueError(
f"Invalid {label}: expected format with {expected_parts} parts separated by '{sep}'"
)
return parts
[docs]
def _rest_url_for_input_id(self, input_id: str) -> str:
input_type = (self.input_type or "").strip()
if input_type in ("entry", "entries"):
entry_id = str(input_id).strip().upper()
return f"{self._rest_api_base}/entry/{entry_id}"
if input_type == "polymer_entity":
entry_id, entity_id = self._split_composite_id(
str(input_id), "_", 2, "polymer entity ID (e.g., '1A8M_1')"
)
return (
f"{self._rest_api_base}/polymer_entity/{entry_id.upper()}/{entity_id}"
)
if input_type == "assembly":
entry_id, assembly_id = self._split_composite_id(
str(input_id), "-", 2, "assembly ID (e.g., '1A8M-1')"
)
return f"{self._rest_api_base}/assembly/{entry_id.upper()}/{assembly_id}"
if input_type == "branched_entity":
entry_id, entity_id = self._split_composite_id(
str(input_id), "_", 2, "branched entity ID (e.g., '5FMB_2')"
)
return (
f"{self._rest_api_base}/branched_entity/{entry_id.upper()}/{entity_id}"
)
if input_type == "polymer_entity_instance":
entry_id, asym_id = self._split_composite_id(
str(input_id), ".", 2, "polymer entity instance ID (e.g., '1NDO.A')"
)
return f"{self._rest_api_base}/polymer_entity_instance/{entry_id.upper()}/{asym_id}"
if input_type == "chem_comp":
comp_id = str(input_id).strip().upper()
return f"{self._rest_api_base}/chem_comp/{comp_id}"
raise ValueError(
f"Unsupported RCSB input_type for REST fallback: {input_type!r}"
)
[docs]
def _run_via_rest(self, input_ids: list):
results = []
for input_id in input_ids:
url = self._rest_url_for_input_id(input_id)
resp = requests.get(url, timeout=self._timeout)
resp.raise_for_status()
results.append(resp.json())
# Keep output shape consistent with existing return_schemas.
input_type = (self.input_type or "").strip()
if input_type in ("entry", "entries"):
return {"data": {"entries": results}}
if input_type == "polymer_entity":
return {"data": {"polymer_entities": results}}
if input_type == "assembly":
return {"data": {"assemblies": results}}
if input_type == "branched_entity":
return {"data": {"branched_entities": results}}
if input_type == "polymer_entity_instance":
return {"data": {"polymer_entity_instances": results}}
if input_type == "chem_comp":
return {"data": {"chem_comps": results}}
return {"data": results}
[docs]
def run(self, params: dict):
self.validate_params(params)
input_ids = self.prepare_input_ids(params)
if self.DataQuery is None:
return self._run_via_rest(input_ids)
query = self.DataQuery(
input_type=self.input_type,
input_ids=input_ids,
return_data_list=self.return_fields,
)
return query.exec()