Source code for tooluniverse.pubmed_tool

import os
import time
import threading
import xml.etree.ElementTree as ET
from typing import Any, Dict
from .base_rest_tool import BaseRESTTool
from .http_utils import request_with_retry
from .tool_registry import register_tool


[docs] @register_tool("PubMedRESTTool") class PubMedRESTTool(BaseRESTTool): """Generic REST tool for PubMed E-utilities (efetch, elink). Implements rate limiting per NCBI guidelines: - Without API key: 3 requests/second - With API key: 10 requests/second API key is read from environment variable NCBI_API_KEY. Get your free key at: https://www.ncbi.nlm.nih.gov/account/ """ # Class-level rate limiting (shared across all instances) _last_request_time = 0.0 _rate_limit_lock = threading.Lock()
[docs] def __init__(self, tool_config): super().__init__(tool_config) # Get API key from environment as fallback self.default_api_key = os.environ.get("NCBI_API_KEY", "")
[docs] def _get_param_mapping(self) -> Dict[str, str]: """Map PubMed E-utilities parameter names.""" return { "limit": "retmax", # limit -> retmax for E-utilities }
[docs] def _enforce_rate_limit(self, has_api_key: bool) -> None: """Enforce NCBI E-utilities rate limits. Args: has_api_key: Whether an API key is provided """ # Rate limits per NCBI guidelines # https://www.ncbi.nlm.nih.gov/books/NBK25497/#chapter2.Usage_Guidelines_and_Requiremen # Using conservative intervals to avoid rate limit errors: # - Without API key: 3 req/sec -> 0.4s interval (more conservative than 0.33s) # - With API key: 10 req/sec -> 0.15s interval (more conservative than 0.1s) min_interval = 0.15 if has_api_key else 0.4 with self._rate_limit_lock: current_time = time.time() time_since_last = current_time - PubMedRESTTool._last_request_time if time_since_last < min_interval: sleep_time = min_interval - time_since_last time.sleep(sleep_time) PubMedRESTTool._last_request_time = time.time()
[docs] def _build_params(self, args: Dict[str, Any]) -> Dict[str, Any]: """Build E-utilities parameters with special handling.""" params = {} # Start with default params from config (db, dbfrom, cmd, linkname, retmode, rettype) for key in ["db", "dbfrom", "cmd", "linkname", "retmode", "rettype"]: if key in self.tool_config["fields"]: params[key] = self.tool_config["fields"][key] # Handle PMID as 'id' parameter (for efetch, elink) # PMIDs can be passed as integer or string, coerce to string for API if "pmid" in args: pmid = args["pmid"] if isinstance(pmid, int): pmid = str(pmid) elif not isinstance(pmid, str): raise ValueError( f"pmid must be string or integer, got {type(pmid).__name__}" ) params["id"] = pmid.strip() # Handle query as 'term' parameter (for esearch) if "query" in args: params["term"] = args["query"] # Add API key from environment variable if self.default_api_key: params["api_key"] = self.default_api_key # Handle limit — use `is not None` instead of truthiness so that limit=0 # is honoured (0 is falsy but is a valid, explicit user choice). if "limit" in args and args["limit"] is not None: params["retmax"] = max(0, int(args["limit"])) # Forward date-range parameters for esearch for date_key in ("mindate", "maxdate", "datetype"): if date_key in args and args[date_key]: params[date_key] = args[date_key] # Forward sort parameter for esearch # Valid values: pub_date, Author, JournalName, relevance if "sort" in args and args["sort"]: params["sort"] = args["sort"] # Set retmode to json for elink and esearch (easier parsing) endpoint = self.tool_config["fields"]["endpoint"] if "retmode" not in params and ("elink" in endpoint or "esearch" in endpoint): params["retmode"] = "json" return params
[docs] def _fetch_summaries(self, pmid_list: list) -> Dict[str, Any]: """Fetch article summaries for a list of PMIDs using esummary. Args: pmid_list: List of PubMed IDs Returns: Dict with article summaries or error """ if not pmid_list: return {"status": "success", "data": []} def parse_article(pmid: str, article_data: Dict[str, Any]) -> Dict[str, Any]: # Extract author list authors = [] if "authors" in article_data: authors = [ author.get("name", "") for author in article_data["authors"] ][:5] # Limit to first 5 authors # Extract article info pub_date = article_data.get("pubdate", "") pub_year = pub_date.split()[0] if pub_date else "" elocationid = article_data.get("elocationid", "") # Extract DOI: handle formats like "doi: 10.1234/abc" and mixed # "pii: 2026.02.14.705936. 10.64898/2026.02.14.705936" (preprints). doi = "" if "doi:" in elocationid: # Find the "doi:" token and extract what follows it doi_part = elocationid[elocationid.index("doi:") + 4 :].strip() # Take the first token that contains '/' (valid DOI structure) for token in doi_part.split(): if "/" in token: doi = token.rstrip(".") break doi = doi or None journal = article_data.get( "fulljournalname", article_data.get("source", "") ) journal = journal or None # Check for PMC ID pmcid = "" for aid in article_data.get("articleids", []): if aid.get("idtype") == "pmc": raw_val = aid["value"] # NCBI esummary returns PMC IDs already prefixed (e.g. "PMC12948714"). # Avoid double-prefixing to "PMCPMC12948714". pmcid = raw_val if raw_val.startswith("PMC") else f"PMC{raw_val}" break pmcid = pmcid or None return { "pmid": pmid, "title": article_data.get("title", ""), "authors": authors, "journal": journal, "pub_date": pub_date, "pub_year": pub_year, "doi": doi, "pmcid": pmcid, "article_type": ", ".join(article_data.get("pubtype", [])), "url": f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/", "doi_url": f"https://doi.org/{doi}" if doi else None, "pmc_url": ( f"https://www.ncbi.nlm.nih.gov/pmc/articles/{pmcid}/" if pmcid else None ), } try: has_api_key = bool(self.default_api_key) base = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" summary_url = f"{base}/esummary.fcgi" articles = [] failed_pmids = [] warnings = [] # Batch fetch first, then isolate failures. chunk_size = 100 for chunk_start in range(0, len(pmid_list), chunk_size): chunk = pmid_list[chunk_start : chunk_start + chunk_size] self._enforce_rate_limit(has_api_key) params = { "db": "pubmed", "id": ",".join(chunk), "retmode": "json", } if self.default_api_key: params["api_key"] = self.default_api_key response = request_with_retry( self.session, "GET", summary_url, params=params, timeout=self.timeout, max_attempts=3, ) if response.status_code != 200: warnings.append( f"Batch summary fetch failed (HTTP {response.status_code}) for {len(chunk)} PMIDs" ) failed_pmids.extend(chunk) continue try: data = response.json() except ValueError: warnings.append( f"Batch summary fetch returned invalid JSON for {len(chunk)} PMIDs" ) failed_pmids.extend(chunk) continue if "error" in data: warnings.append( f"NCBI API error on batch summary fetch: {data.get('error')}" ) failed_pmids.extend(chunk) continue result = data.get("result", {}) for pmid in chunk: article_data = result.get(pmid) if article_data: articles.append(parse_article(pmid, article_data)) else: failed_pmids.append(pmid) # Retry failures one-by-one to isolate transient per-ID issues. retry_failed_pmids = [] for pmid in failed_pmids: self._enforce_rate_limit(has_api_key) params = {"db": "pubmed", "id": pmid, "retmode": "json"} if self.default_api_key: params["api_key"] = self.default_api_key try: response = request_with_retry( self.session, "GET", summary_url, params=params, timeout=self.timeout, max_attempts=2, ) except Exception as error: retry_failed_pmids.append((pmid, str(error))) continue if response.status_code != 200: retry_failed_pmids.append((pmid, f"HTTP {response.status_code}")) continue try: data = response.json() article_data = data.get("result", {}).get(pmid) if article_data: articles.append(parse_article(pmid, article_data)) else: retry_failed_pmids.append((pmid, "summary missing for PMID")) except Exception as error: retry_failed_pmids.append((pmid, str(error))) if retry_failed_pmids: warnings.append( f"Failed to fetch summaries for {len(retry_failed_pmids)} PMIDs after per-ID retry" ) if not articles: error_msg = ( warnings[0] if warnings else "Failed to fetch article summaries" ) return {"status": "error", "error": error_msg} result = {"status": "success", "data": articles} if warnings: result["warning"] = "; ".join(warnings) return result except Exception as e: return { "status": "error", "error": f"Failed to fetch article summaries: {str(e)}", }
[docs] def _parse_efetch_xml(self, response) -> Dict[str, Any]: """Parse PubMed efetch XML into structured article data.""" try: root = ET.fromstring(response.text) except ET.ParseError: return {"status": "success", "data": response.text, "url": response.url} def _text(el, path, default=""): found = el.find(path) if el is not None else None return found.text if found is not None and found.text else default def _itertext(el, path): found = el.find(path) if el is not None else None return "".join(found.itertext()) if found is not None else "" def _parse_article(article_el): cit = article_el.find("MedlineCitation") art = cit.find("Article") if cit is not None else None if cit is None or art is None: return None pmid = _text(cit, "PMID") title = _itertext(art, "ArticleTitle") # Abstract: join labeled sections abstract_parts = [] for at in art.findall("Abstract/AbstractText") or []: label, text = at.get("Label", ""), "".join(at.itertext()).strip() if text: abstract_parts.append(f"{label}: {text}" if label else text) abstract = " ".join(abstract_parts) # Authors (first 10) authors = [] for au in (art.findall("AuthorList/Author") or [])[:10]: last, fore = au.findtext("LastName", ""), au.findtext("ForeName", "") name = f"{last} {fore}".strip() if last else fore if not name: continue entry = {"name": name} aff = _text(au, ".//Affiliation") if aff: entry["affiliation"] = aff authors.append(entry) journal_el = art.find("Journal") journal = _text(journal_el, "Title") or _text(journal_el, "ISOAbbreviation") doi = next( ( eid.text for eid in art.findall("ELocationID") if eid.get("EIdType") == "doi" and eid.text ), "", ) pd = art.find(".//PubDate") pub_year = _text(pd, "Year") pub_date = " ".join( filter(None, [pub_year, _text(pd, "Month"), _text(pd, "Day")]) ) mesh = [ d.text for d in cit.findall("MeshHeadingList/MeshHeading/DescriptorName") if d.text ] pub_types = [pt.text for pt in art.findall(".//PublicationType") if pt.text] return { "pmid": pmid, "title": title, "abstract": abstract or None, "authors": authors, "journal": journal or None, "pub_date": pub_date, "pub_year": pub_year, "doi": doi or None, "doi_url": f"https://doi.org/{doi}" if doi else None, "url": f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/", "mesh_terms": mesh or None, "publication_types": pub_types or None, } articles = [ a for a in (_parse_article(el) for el in root.findall(".//PubmedArticle")) if a ] data = articles[0] if len(articles) == 1 else articles result = {"status": "success", "data": data, "url": response.url} if len(articles) != 1: result["count"] = len(articles) return result
[docs] def _fetch_abstracts(self, pmid_list: list[str]) -> Dict[str, str]: """Best-effort abstract fetch via efetch XML for a list of PMIDs.""" pmids = [str(p).strip() for p in (pmid_list or []) if str(p).strip()] if not pmids: return {} has_api_key = bool(self.default_api_key) base = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" url = f"{base}/efetch.fcgi" # Batch efetch; keep payload bounded. self._enforce_rate_limit(has_api_key) params: Dict[str, Any] = { "db": "pubmed", "id": ",".join(pmids[:200]), "retmode": "xml", } if self.default_api_key: params["api_key"] = self.default_api_key resp = request_with_retry( self.session, "GET", url, params=params, timeout=self.timeout, max_attempts=3, ) if resp.status_code != 200: return {} try: root = ET.fromstring(resp.text) except ET.ParseError: return {} abstracts: Dict[str, str] = {} for pubmed_article in root.findall(".//PubmedArticle"): pmid_el = pubmed_article.find(".//MedlineCitation/PMID") pmid = (pmid_el.text or "").strip() if pmid_el is not None else "" if not pmid: continue parts: list[str] = [] for at in pubmed_article.findall( ".//MedlineCitation/Article/Abstract/AbstractText" ): # PubMed AbstractText often contains inline tags (e.g. <i/>). # Using itertext() avoids truncating at the first child element. text = " ".join("".join(at.itertext()).split()) if text: parts.append(text) if parts: abstracts[pmid] = "\n".join(parts) return abstracts
[docs] def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: """ PubMed E-utilities need special handling for direct endpoint URLs. Enforces NCBI rate limits to prevent API errors. """ url = None try: # Enforce rate limiting before making request has_api_key = bool(self.default_api_key) self._enforce_rate_limit(has_api_key) endpoint = self.tool_config["fields"]["endpoint"] params = self._build_params(arguments) response = request_with_retry( self.session, "GET", endpoint, params=params, timeout=self.timeout, max_attempts=3, ) if response.status_code != 200: return { "status": "error", "error": "PubMed API error", "url": response.url, "status_code": response.status_code, "detail": (response.text or "")[:500], } # Try JSON first (elink, esearch) try: data = response.json() # Check for API errors in response if "ERROR" in data: error_msg = data.get("ERROR", "Unknown API error") return { "status": "error", "data": f"NCBI API error: {error_msg[:200]}", "url": response.url, } # For esearch responses, extract ID list and fetch summaries if "esearchresult" in data: esearch_result = data.get("esearchresult", {}) id_list = esearch_result.get("idlist", []) # If this is a search request (has 'query' in arguments), # fetch article summaries and return as list if "query" in arguments and id_list: summary_result = self._fetch_summaries(id_list) if summary_result["status"] == "error": # Preserve stable return type: always return a list of # article-shaped objects. If summaries fail entirely, # return minimal stubs with PMID + URL so downstream # agents can still act on the result. import logging _logger = logging.getLogger(__name__) _logger.warning( f"Failed to fetch article summaries: " f"{summary_result.get('error')}" ) limit = arguments.get("limit") try: limit = int(limit) if limit is not None else None except (TypeError, ValueError): limit = None stub_items = [ { "pmid": str(pmid), "title": None, "authors": [], "journal": None, "pub_date": None, "pub_year": None, "doi": None, "pmcid": None, "article_type": None, "url": f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/", "partial": True, "warning": summary_result.get( "error", "Failed to fetch summaries" ), } for pmid in id_list ] return stub_items[:limit] if limit else stub_items # Return article list directly (not wrapped in dict) articles = summary_result["data"] warning = summary_result.get("warning") if warning: for a in articles: if isinstance(a, dict): a["partial"] = True a["warning"] = warning include_abstract = bool( arguments.get("include_abstract", False) ) if include_abstract and articles: pmids = [ str(a.get("pmid")).strip() for a in articles if isinstance(a, dict) and a.get("pmid") ] abstract_map = self._fetch_abstracts(pmids) if abstract_map: for a in articles: if not isinstance(a, dict): continue pmid = str(a.get("pmid") or "").strip() if pmid and abstract_map.get(pmid): a["abstract"] = abstract_map.get(pmid) a["abstract_source"] = "PubMed" else: for a in articles: if isinstance(a, dict) and "abstract" not in a: a["abstract"] = None a["abstract_source"] = None return articles # Return just IDs for non-search requests (as list) return id_list # For elink responses with LinkOut URLs (llinks command) if "linksets" in data: linksets = data.get("linksets", []) # Check for empty linksets with errors if not linksets or (linksets and len(linksets) == 0): return { "status": "success", "data": [], "count": 0, "url": response.url, } if linksets and len(linksets) > 0: linkset = linksets[0] # Extract linked IDs if "linksetdbs" in linkset: linksetdbs = linkset.get("linksetdbs", []) if linksetdbs and len(linksetdbs) > 0: links = linksetdbs[0].get("links", []) try: limit = ( int(arguments.get("limit")) if arguments.get("limit") is not None else None ) except (TypeError, ValueError): limit = None if limit is not None: links = links[:limit] # Enrich with article metadata pmids = [ str(lk["id"] if isinstance(lk, dict) else lk) for lk in links ] scores = { str(lk["id"]): lk.get("score") for lk in links if isinstance(lk, dict) } summary = self._fetch_summaries(pmids) if summary.get("status") == "success" and summary.get( "data" ): for item in summary["data"]: score = ( scores.get(str(item.get("pmid", ""))) if isinstance(item, dict) else None ) if score is not None: item["relevance_score"] = score links = summary["data"] return { "status": "success", "data": links, "count": len(links), "url": response.url, } # Extract LinkOut URLs (idurllist) elif "idurllist" in linkset: return { "status": "success", "data": linkset.get("idurllist", {}), "url": response.url, } else: # Linkset exists but no linksetdbs or idurllist = no results return { "status": "success", "data": [], "count": 0, "url": response.url, } # For elink responses with LinkOut URLs (llinks returns direct idurllist) if "idurllist" in data: return { "status": "success", "data": data.get("idurllist", []), "url": response.url, } return { "status": "success", "data": data, "url": response.url, } except Exception: # For XML responses (efetch), parse into structured data return self._parse_efetch_xml(response) except Exception as e: return { "status": "error", "error": f"PubMed API error: {str(e)}", "url": url, }