Source code for tooluniverse.pubtator_tool
from __future__ import annotations
import json
import re
from pathlib import Path
from typing import Any, Dict, Optional
import requests
from .base_tool import BaseTool
from .tool_registry import register_tool
# Official REST root (cf. NIH “entity autocomplete” & “search” examples)
BASE_URL = "https://www.ncbi.nlm.nih.gov/research/pubtator3-api"
CONFIG_FILE = Path(__file__).with_name("pubtator_tool_config.json")
[docs]
@register_tool("PubTatorTool")
class PubTatorTool(BaseTool):
"""Generic wrapper around a single PubTator 3 endpoint supporting JSON-defined configs."""
[docs]
def __init__(self, tool_config: Dict[str, Any]):
super().__init__(tool_config)
self._method: str = tool_config.get("method", "GET").upper()
self._path: str = tool_config["endpoint_path"]
self._param_map: Dict[str, str] = tool_config.get("param_map", {})
self._body_param: Optional[str] = tool_config.get("body_param")
self._id_in_path_key: Optional[str] = tool_config.get("id_in_path")
fields = tool_config.get("fields", {})
if "body_param" in fields:
self._body_param = fields["body_param"]
self._tool_subtype: str = fields.get("tool_subtype", "")
# ------------------------------------------------------------------ public API --------------
[docs]
def run(self, arguments: Dict[str, Any]):
args = arguments.copy()
# Special case for PubTatorRelation: combine parameters into a single "text" parameter and use "/search/" endpoint.
if self._tool_subtype == "PubTatorRelation":
subject = args.pop("subject_id", None)
obj = args.pop("object", None)
rel_type = args.pop("relation_type", None)
if not subject or not obj:
raise ValueError(
"Missing required parameters 'subject_id' or 'object' for relation search."
)
text_value = f"relations:{subject},{obj}"
if rel_type:
text_value += f",{rel_type}"
new_args = {"text": text_value}
new_args.update(args)
url = f"{BASE_URL.rstrip('/')}/search/"
data = None
headers: Dict[str, str] = {}
response = requests.request(
self._method,
url,
params=self._query_params(new_args),
data=data,
headers=headers,
timeout=30,
)
response.raise_for_status()
ctype = response.headers.get("Content-Type", "").lower()
if "json" in ctype:
return response.json()
if "text" in ctype or "xml" in ctype:
return response.text
return response.content
# Special handling for PubTatorAnnotate: override endpoint paths
if self._tool_subtype == "PubTatorAnnotate":
if self._method == "POST":
url = f"{BASE_URL.rstrip('/')}/annotations/annotate"
else:
url = f"{BASE_URL.rstrip('/')}/annotations/retrieve"
else:
url = self._compose_url(args)
# ---------- body handling for POST calls ----------
data: Optional[bytes] = None
headers: Dict[str, str] = {}
if self._method == "POST":
if self._body_param:
if self._body_param not in args:
raise ValueError(
f"Missing required body parameter '{self._body_param}'."
)
data = str(args.pop(self._body_param)).encode("utf-8")
headers["Content-Type"] = "text/plain; charset=utf-8"
else:
data = json.dumps(args).encode()
args.clear()
headers["Content-Type"] = "application/json"
# ---------- perform request ----------
response = requests.request(
self._method,
url,
params=self._query_params(args) if self._method != "POST" else {},
data=data,
headers=headers,
timeout=30,
)
if not response.ok:
return {
"error": f"Request failed with status code {response.status_code}: {response.text}"
}
# ---------- auto-detect & return ----------
ctype = response.headers.get("Content-Type", "").lower()
if "json" in ctype:
result = response.json()
# Extra filtering for PubTatorSearch: filter low-score items and facets.
if self._tool_subtype == "PubTatorSearch" and isinstance(result, dict):
result = self._filter_search_results(result)
return result
if "text" in ctype or "xml" in ctype:
return response.text
return response.content
# ------------------------------------------------------------------ helpers -----------------
[docs]
def _compose_url(self, args: Dict[str, Any]) -> str:
"""Substitute template vars & build full URL."""
path = self._path
for placeholder in re.findall(r"{(.*?)}", path):
if placeholder not in args:
raise ValueError(f"Missing URL placeholder argument '{placeholder}'.")
path = path.replace(f"{{{placeholder}}}", str(args.pop(placeholder)))
if self._id_in_path_key and self._id_in_path_key in args:
ids_val = args.pop(self._id_in_path_key)
if isinstance(ids_val, (list, tuple)):
ids_val = ",".join(map(str, ids_val))
path = f"{path}/{ids_val}"
return f"{BASE_URL.rstrip('/')}/{path.lstrip('/')}"
[docs]
def _query_params(self, args: Dict[str, Any]) -> Dict[str, str]:
"""Translate caller arg names → API param names, drop Nones, serialise lists."""
q: Dict[str, str] = {}
for user_key, val in args.items():
if val is None:
continue
api_key = self._param_map.get(user_key, user_key)
if isinstance(val, (list, tuple)):
val = ",".join(map(str, val))
q[api_key] = str(val)
return q
[docs]
def _filter_search_results(self, result: Dict[str, Any]) -> Dict[str, Any]:
"""Filter PubTatorSearch results by score threshold and remove facet items that only have 'name', 'type', and 'value'."""
# Filter result items based on score threshold.
threshold = 230 # Adjust threshold as needed
if "results" in result and isinstance(result["results"], list):
filtered_results = []
for item in result["results"]:
score = item.get("score")
# If there's a numeric score and it's below threshold, skip the item.
if isinstance(score, (int, float)) and score < threshold:
continue
filtered_results.append(item)
result["results"] = filtered_results
# Also filter facets as before.
if "facets" in result and isinstance(result["facets"], dict):
del result["facets"]
return result