Source code for tooluniverse.pubmed_tool

import os
import time
import threading
import xml.etree.ElementTree as ET
from typing import Any, Dict
from .base_rest_tool import BaseRESTTool
from .http_utils import request_with_retry
from .tool_registry import register_tool


[docs] @register_tool("PubMedRESTTool") class PubMedRESTTool(BaseRESTTool): """Generic REST tool for PubMed E-utilities (efetch, elink). Implements rate limiting per NCBI guidelines: - Without API key: 3 requests/second - With API key: 10 requests/second API key is read from environment variable NCBI_API_KEY. Get your free key at: https://www.ncbi.nlm.nih.gov/account/ """ # Class-level rate limiting (shared across all instances) _last_request_time = 0.0 _rate_limit_lock = threading.Lock()
[docs] def __init__(self, tool_config): super().__init__(tool_config) # Get API key from environment as fallback self.default_api_key = os.environ.get("NCBI_API_KEY", "")
[docs] def _get_param_mapping(self) -> Dict[str, str]: """Map PubMed E-utilities parameter names.""" return { "limit": "retmax", # limit -> retmax for E-utilities }
[docs] def _enforce_rate_limit(self, has_api_key: bool) -> None: """Enforce NCBI E-utilities rate limits. Args: has_api_key: Whether an API key is provided """ # Rate limits per NCBI guidelines # https://www.ncbi.nlm.nih.gov/books/NBK25497/#chapter2.Usage_Guidelines_and_Requiremen # Using conservative intervals to avoid rate limit errors: # - Without API key: 3 req/sec -> 0.4s interval (more conservative than 0.33s) # - With API key: 10 req/sec -> 0.15s interval (more conservative than 0.1s) min_interval = 0.15 if has_api_key else 0.4 with self._rate_limit_lock: current_time = time.time() time_since_last = current_time - PubMedRESTTool._last_request_time if time_since_last < min_interval: sleep_time = min_interval - time_since_last time.sleep(sleep_time) PubMedRESTTool._last_request_time = time.time()
[docs] def _build_params(self, args: Dict[str, Any]) -> Dict[str, Any]: """Build E-utilities parameters with special handling.""" params = {} # Start with default params from config (db, dbfrom, cmd, linkname, retmode, rettype) for key in ["db", "dbfrom", "cmd", "linkname", "retmode", "rettype"]: if key in self.tool_config["fields"]: params[key] = self.tool_config["fields"][key] # Handle PMID as 'id' parameter (for efetch, elink) # PMIDs can be passed as integer or string, coerce to string for API if "pmid" in args: pmid = args["pmid"] if isinstance(pmid, int): pmid = str(pmid) elif not isinstance(pmid, str): raise ValueError( f"pmid must be string or integer, got {type(pmid).__name__}" ) params["id"] = pmid.strip() # Handle query as 'term' parameter (for esearch) if "query" in args: params["term"] = args["query"] # Add API key from environment variable if self.default_api_key: params["api_key"] = self.default_api_key # Handle limit if "limit" in args and args["limit"]: params["retmax"] = args["limit"] # Set retmode to json for elink and esearch (easier parsing) endpoint = self.tool_config["fields"]["endpoint"] if "retmode" not in params and ("elink" in endpoint or "esearch" in endpoint): params["retmode"] = "json" return params
[docs] def _fetch_summaries(self, pmid_list: list) -> Dict[str, Any]: """Fetch article summaries for a list of PMIDs using esummary. Args: pmid_list: List of PubMed IDs Returns: Dict with article summaries or error """ if not pmid_list: return {"status": "success", "data": []} def parse_article(pmid: str, article_data: Dict[str, Any]) -> Dict[str, Any]: # Extract author list authors = [] if "authors" in article_data: authors = [ author.get("name", "") for author in article_data["authors"] ][:5] # Limit to first 5 authors # Extract article info pub_date = article_data.get("pubdate", "") pub_year = pub_date.split()[0] if pub_date else "" elocationid = article_data.get("elocationid", "") doi = elocationid.replace("doi: ", "") if "doi:" in elocationid else "" doi = doi or None journal = article_data.get( "fulljournalname", article_data.get("source", "") ) journal = journal or None # Check for PMC ID pmcid = "" for aid in article_data.get("articleids", []): if aid.get("idtype") == "pmc": pmcid = f"PMC{aid['value']}" break pmcid = pmcid or None return { "pmid": pmid, "title": article_data.get("title", ""), "authors": authors, "journal": journal, "pub_date": pub_date, "pub_year": pub_year, "doi": doi, "pmcid": pmcid, "article_type": ", ".join(article_data.get("pubtype", [])), "url": f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/", "doi_url": f"https://doi.org/{doi}" if doi else None, "pmc_url": ( f"https://www.ncbi.nlm.nih.gov/pmc/articles/{pmcid}/" if pmcid else None ), } try: has_api_key = bool(self.default_api_key) base = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" summary_url = f"{base}/esummary.fcgi" articles = [] failed_pmids = [] warnings = [] # Batch fetch first, then isolate failures. chunk_size = 100 for chunk_start in range(0, len(pmid_list), chunk_size): chunk = pmid_list[chunk_start : chunk_start + chunk_size] self._enforce_rate_limit(has_api_key) params = { "db": "pubmed", "id": ",".join(chunk), "retmode": "json", } if self.default_api_key: params["api_key"] = self.default_api_key response = request_with_retry( self.session, "GET", summary_url, params=params, timeout=self.timeout, max_attempts=3, ) if response.status_code != 200: warnings.append( f"Batch summary fetch failed (HTTP {response.status_code}) for {len(chunk)} PMIDs" ) failed_pmids.extend(chunk) continue try: data = response.json() except ValueError: warnings.append( f"Batch summary fetch returned invalid JSON for {len(chunk)} PMIDs" ) failed_pmids.extend(chunk) continue if "error" in data: warnings.append( f"NCBI API error on batch summary fetch: {data.get('error')}" ) failed_pmids.extend(chunk) continue result = data.get("result", {}) for pmid in chunk: article_data = result.get(pmid) if article_data: articles.append(parse_article(pmid, article_data)) else: failed_pmids.append(pmid) # Retry failures one-by-one to isolate transient per-ID issues. retry_failed_pmids = [] for pmid in failed_pmids: self._enforce_rate_limit(has_api_key) params = {"db": "pubmed", "id": pmid, "retmode": "json"} if self.default_api_key: params["api_key"] = self.default_api_key try: response = request_with_retry( self.session, "GET", summary_url, params=params, timeout=self.timeout, max_attempts=2, ) except Exception as error: retry_failed_pmids.append((pmid, str(error))) continue if response.status_code != 200: retry_failed_pmids.append((pmid, f"HTTP {response.status_code}")) continue try: data = response.json() article_data = data.get("result", {}).get(pmid) if article_data: articles.append(parse_article(pmid, article_data)) else: retry_failed_pmids.append((pmid, "summary missing for PMID")) except Exception as error: retry_failed_pmids.append((pmid, str(error))) if retry_failed_pmids: warnings.append( f"Failed to fetch summaries for {len(retry_failed_pmids)} PMIDs after per-ID retry" ) if not articles: error_msg = ( warnings[0] if warnings else "Failed to fetch article summaries" ) return {"status": "error", "error": error_msg} result = {"status": "success", "data": articles} if warnings: result["warning"] = "; ".join(warnings) return result except Exception as e: return { "status": "error", "error": f"Failed to fetch article summaries: {str(e)}", }
[docs] def _fetch_abstracts(self, pmid_list: list[str]) -> Dict[str, str]: """Best-effort abstract fetch via efetch XML for a list of PMIDs.""" pmids = [str(p).strip() for p in (pmid_list or []) if str(p).strip()] if not pmids: return {} has_api_key = bool(self.default_api_key) base = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" url = f"{base}/efetch.fcgi" # Batch efetch; keep payload bounded. self._enforce_rate_limit(has_api_key) params: Dict[str, Any] = { "db": "pubmed", "id": ",".join(pmids[:200]), "retmode": "xml", } if self.default_api_key: params["api_key"] = self.default_api_key resp = request_with_retry( self.session, "GET", url, params=params, timeout=self.timeout, max_attempts=3, ) if resp.status_code != 200: return {} try: root = ET.fromstring(resp.text) except ET.ParseError: return {} abstracts: Dict[str, str] = {} for pubmed_article in root.findall(".//PubmedArticle"): pmid_el = pubmed_article.find(".//MedlineCitation/PMID") pmid = (pmid_el.text or "").strip() if pmid_el is not None else "" if not pmid: continue parts: list[str] = [] for at in pubmed_article.findall( ".//MedlineCitation/Article/Abstract/AbstractText" ): # PubMed AbstractText often contains inline tags (e.g. <i/>). # Using itertext() avoids truncating at the first child element. text = " ".join("".join(at.itertext()).split()) if text: parts.append(text) if parts: abstracts[pmid] = "\n".join(parts) return abstracts
[docs] def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: """ PubMed E-utilities need special handling for direct endpoint URLs. Enforces NCBI rate limits to prevent API errors. """ url = None try: # Enforce rate limiting before making request has_api_key = bool(self.default_api_key) self._enforce_rate_limit(has_api_key) endpoint = self.tool_config["fields"]["endpoint"] params = self._build_params(arguments) response = request_with_retry( self.session, "GET", endpoint, params=params, timeout=self.timeout, max_attempts=3, ) if response.status_code != 200: return { "status": "error", "error": "PubMed API error", "url": response.url, "status_code": response.status_code, "detail": (response.text or "")[:500], } # Try JSON first (elink, esearch) try: data = response.json() # Check for API errors in response if "ERROR" in data: error_msg = data.get("ERROR", "Unknown API error") return { "status": "error", "data": f"NCBI API error: {error_msg[:200]}", "url": response.url, } # For esearch responses, extract ID list and fetch summaries if "esearchresult" in data: esearch_result = data.get("esearchresult", {}) id_list = esearch_result.get("idlist", []) # If this is a search request (has 'query' in arguments), # fetch article summaries and return as list if "query" in arguments and id_list: summary_result = self._fetch_summaries(id_list) if summary_result["status"] == "error": # Preserve stable return type: always return a list of # article-shaped objects. If summaries fail entirely, # return minimal stubs with PMID + URL so downstream # agents can still act on the result. import logging _logger = logging.getLogger(__name__) _logger.warning( f"Failed to fetch article summaries: " f"{summary_result.get('error')}" ) limit = arguments.get("limit") try: limit = int(limit) if limit is not None else None except (TypeError, ValueError): limit = None stub_items = [ { "pmid": str(pmid), "title": None, "authors": [], "journal": None, "pub_date": None, "pub_year": None, "doi": None, "pmcid": None, "article_type": None, "url": f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/", "partial": True, "warning": summary_result.get( "error", "Failed to fetch summaries" ), } for pmid in id_list ] return stub_items[:limit] if limit else stub_items # Return article list directly (not wrapped in dict) articles = summary_result["data"] warning = summary_result.get("warning") if warning: for a in articles: if isinstance(a, dict): a["partial"] = True a["warning"] = warning include_abstract = bool( arguments.get("include_abstract", False) ) if include_abstract and articles: pmids = [ str(a.get("pmid")).strip() for a in articles if isinstance(a, dict) and a.get("pmid") ] abstract_map = self._fetch_abstracts(pmids) if abstract_map: for a in articles: if not isinstance(a, dict): continue pmid = str(a.get("pmid") or "").strip() if pmid and abstract_map.get(pmid): a["abstract"] = abstract_map.get(pmid) a["abstract_source"] = "PubMed" else: for a in articles: if isinstance(a, dict) and "abstract" not in a: a["abstract"] = None a["abstract_source"] = None return articles # Return just IDs for non-search requests (as list) return id_list # For elink responses with LinkOut URLs (llinks command) if "linksets" in data: linksets = data.get("linksets", []) # Check for empty linksets with errors if not linksets or (linksets and len(linksets) == 0): return { "status": "success", "data": [], "count": 0, "url": response.url, } if linksets and len(linksets) > 0: linkset = linksets[0] # Extract linked IDs if "linksetdbs" in linkset: linksetdbs = linkset.get("linksetdbs", []) if linksetdbs and len(linksetdbs) > 0: links = linksetdbs[0].get("links", []) try: limit = ( int(arguments.get("limit")) if arguments.get("limit") is not None else None ) except (TypeError, ValueError): limit = None if limit is not None: links = links[:limit] return { "status": "success", "data": links, "count": len(links), "url": response.url, } # Extract LinkOut URLs (idurllist) elif "idurllist" in linkset: return { "status": "success", "data": linkset.get("idurllist", {}), "url": response.url, } # For elink responses with LinkOut URLs (llinks returns direct idurllist) if "idurllist" in data: return { "status": "success", "data": data.get("idurllist", []), "url": response.url, } return { "status": "success", "data": data, "url": response.url, } except Exception: # For XML responses (efetch), return as text return { "status": "success", "data": response.text, "url": response.url, } except Exception as e: return { "status": "error", "error": f"PubMed API error: {str(e)}", "url": url, }