Source code for tooluniverse.rxclass_tool

"""
RxClass Tool

Drug classification tools using the NLM RxClass API (part of RxNav):
  - get_drug_classes:    Look up ATC/EPC/MoA/VA drug classes for a drug name or RXCUI
  - get_class_members:   List drugs that belong to a given class ID
  - find_classes:        Search for drug classes by name keyword
  - get_class_hierarchy: Traverse the ATC ancestor chain / class tree for a class
  - get_disease_relations: MED-RT drug<->disease relations (may_treat, may_prevent,
                         CI_with, induces, has_PE), forward (by drug) and reverse
                         (all drugs for a disease class)

API base: https://rxnav.nlm.nih.gov/REST/rxclass
No authentication required. Free public NLM API.
"""

import requests
from typing import Dict, Any
from .base_tool import BaseTool
from .tool_registry import register_tool

RXCLASS_BASE = "https://rxnav.nlm.nih.gov/REST/rxclass"
RXNORM_BASE = "https://rxnav.nlm.nih.gov/REST"

# Supported relaSource values for byDrugName endpoint
RELA_SOURCES = {
    "ATC": "WHO Anatomical Therapeutic Chemical classification",
    "FDASPL": "FDA Pharmacologic Class (EPC, MoA, PE)",
    "MESH": "MeSH pharmacological actions",
    "VA": "VA Drug Classification",
    "DAILYMED": "DailyMed drug classification",
}


[docs] @register_tool("RxClassTool") class RxClassTool(BaseTool): """ Drug classification via NLM RxClass API. Operations: - get_drug_classes: Get ATC, EPC, MoA, VA drug classes for a drug - get_class_members: List drugs in a specified drug class - find_classes: Search drug classes by keyword - get_class_hierarchy: Traverse the ATC ancestor chain for a class - get_disease_relations: MED-RT drug<->disease relations (forward + reverse) """
[docs] def __init__(self, tool_config: Dict[str, Any]): super().__init__(tool_config) self.timeout = tool_config.get("timeout", 30) self.operation = tool_config.get("fields", {}).get( "operation", "get_drug_classes" )
[docs] def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: op = self.operation if op == "get_drug_classes": return self._get_drug_classes(arguments) if op == "get_class_members": return self._get_class_members(arguments) if op == "find_classes": return self._find_classes(arguments) if op == "get_class_hierarchy": return self._get_class_hierarchy(arguments) if op == "get_disease_relations": return self._get_disease_relations(arguments) return {"status": "error", "error": f"Unknown operation: {op}"}
# ------------------------------------------------------------------ # shared HTTP helper # ------------------------------------------------------------------
[docs] def _api_get(self, url: str, params: Dict[str, Any]) -> Dict[str, Any]: """Shared HTTP GET with consistent RxClass error handling.""" try: resp = requests.get(url, params=params, timeout=self.timeout) resp.raise_for_status() return {"ok": True, "data": resp.json()} except requests.exceptions.Timeout: return {"ok": False, "error": "RxClass API timeout", "retryable": True} except requests.exceptions.HTTPError as e: sc = e.response.status_code return { "ok": False, "error": f"RxClass HTTP {sc}", "retryable": sc in (408, 429, 500, 502, 503, 504), } except ValueError: ct = resp.headers.get("content-type", "") return { "ok": False, "error": "RxClass returned non-JSON response", "content_type": ct, "response_snippet": resp.text[:200], "retryable": "text/html" in ct or resp.text.lstrip().startswith("<"), "suggestion": "RxClass may be under maintenance. Retry in a few minutes.", } except Exception as e: return {"ok": False, "error": str(e), "retryable": False}
# ------------------------------------------------------------------ # operation: get_drug_classes # ------------------------------------------------------------------
[docs] def _get_drug_classes(self, arguments: Dict[str, Any]) -> Dict[str, Any]: drug_name = arguments.get("drug_name") or arguments.get("name") rxcui = arguments.get("rxcui") rela_source = arguments.get("rela_source", "ATC") limit = arguments.get("limit", 20) if not drug_name and not rxcui: return {"status": "error", "error": "Provide 'drug_name' or 'rxcui'."} if rela_source not in RELA_SOURCES and rela_source != "ALL": rela_source = "ATC" try: if rxcui: url = f"{RXCLASS_BASE}/class/byRxcui.json" params: Dict[str, Any] = {"rxcui": str(rxcui).strip()} if rela_source != "ALL": params["relaSource"] = rela_source else: url = f"{RXCLASS_BASE}/class/byDrugName.json" params = {"drugName": drug_name.strip()} if rela_source != "ALL": params["relaSource"] = rela_source resp = requests.get(url, params=params, timeout=self.timeout) resp.raise_for_status() data = resp.json() except requests.exceptions.Timeout: return { "status": "error", "error": "RxClass API timeout", "retryable": True, } except requests.exceptions.HTTPError as e: sc = e.response.status_code return { "status": "error", "error": f"RxClass HTTP {sc}", "retryable": sc in (408, 429, 500, 502, 503, 504), } except ValueError: ct = resp.headers.get("content-type", "") return { "status": "error", "error": "RxClass returned non-JSON response", "content_type": ct, "response_snippet": resp.text[:200], "retryable": "text/html" in ct or resp.text.lstrip().startswith("<"), "suggestion": "RxClass may be under maintenance. Retry in a few minutes.", } except Exception as e: return {"status": "error", "error": str(e), "retryable": False} items = data.get("rxclassDrugInfoList", {}).get("rxclassDrugInfo", []) if not items: query_str = rxcui if rxcui else drug_name return { "status": "success", "data": [], "metadata": { "query": query_str, "rela_source": rela_source, "count": 0, "note": f"No drug classes found for '{query_str}' in source '{rela_source}'. Try rela_source='ALL' or a different source.", }, } classes = [] seen = set() for item in items: mc = item.get("rxclassMinConceptItem", {}) drug_mc = item.get("minConcept", {}) class_id = mc.get("classId", "") class_key = (class_id, drug_mc.get("rxcui", "")) if class_key in seen: continue seen.add(class_key) classes.append( { "classId": class_id, "className": mc.get("className", ""), "classType": mc.get("classType", ""), "rxcui": drug_mc.get("rxcui", ""), "drugName": drug_mc.get("name", ""), "tty": drug_mc.get("tty", ""), "rela": item.get("rela", ""), "relaSource": item.get("relaSource", rela_source), } ) classes = classes[:limit] return { "status": "success", "data": classes, "metadata": { "query": rxcui if rxcui else drug_name, "rela_source": rela_source, "count": len(classes), "available_sources": list(RELA_SOURCES.keys()), }, }
# ------------------------------------------------------------------ # operation: get_class_members # ------------------------------------------------------------------
[docs] def _get_class_members(self, arguments: Dict[str, Any]) -> Dict[str, Any]: class_id = arguments.get("class_id") or arguments.get("classId") rela_source = arguments.get("rela_source", "ATC") ttys = arguments.get("ttys", "IN") limit = arguments.get("limit", 50) if not class_id: return { "status": "error", "error": "Provide 'class_id' (e.g., 'M01AE', 'N02BA').", } try: url = f"{RXCLASS_BASE}/classMembers.json" params: Dict[str, Any] = { "classId": str(class_id).strip(), "relaSource": rela_source, "ttys": ttys, } resp = requests.get(url, params=params, timeout=self.timeout) resp.raise_for_status() data = resp.json() except requests.exceptions.Timeout: return { "status": "error", "error": "RxClass API timeout", "retryable": True, } except requests.exceptions.HTTPError as e: sc = e.response.status_code return { "status": "error", "error": f"RxClass HTTP {sc}", "retryable": sc in (408, 429, 500, 502, 503, 504), } except ValueError: ct = resp.headers.get("content-type", "") return { "status": "error", "error": "RxClass returned non-JSON response", "content_type": ct, "response_snippet": resp.text[:200], "retryable": "text/html" in ct or resp.text.lstrip().startswith("<"), "suggestion": "RxClass may be under maintenance. Retry in a few minutes.", } except Exception as e: return {"status": "error", "error": str(e), "retryable": False} members = data.get("drugMemberGroup", {}).get("drugMember", []) if not members: return { "status": "success", "data": [], "metadata": { "class_id": class_id, "rela_source": rela_source, "count": 0, "note": f"No drug members found for class '{class_id}'. Verify class ID and rela_source.", }, } drugs = [] for m in members[:limit]: mc = m.get("minConcept", {}) drugs.append( { "rxcui": mc.get("rxcui", ""), "name": mc.get("name", ""), "tty": mc.get("tty", ""), } ) return { "status": "success", "data": drugs, "metadata": { "class_id": class_id, "rela_source": rela_source, "ttys": ttys, "count": len(drugs), }, }
# ------------------------------------------------------------------ # operation: find_classes # ------------------------------------------------------------------
[docs] def _find_classes(self, arguments: Dict[str, Any]) -> Dict[str, Any]: query = arguments.get("query") or arguments.get("keyword") class_type = arguments.get("class_type", "") limit = arguments.get("limit", 20) if not query: return { "status": "error", "error": "Provide 'query' keyword to search drug classes.", } # classSearch.json is not available in current RxClass API version. # Use allClasses.json and filter client-side by class name keyword. params: Dict[str, Any] = {} if class_type: params["classTypes"] = class_type try: resp = requests.get( f"{RXCLASS_BASE}/allClasses.json", params=params, timeout=self.timeout ) resp.raise_for_status() data = resp.json() except requests.exceptions.Timeout: return { "status": "error", "error": "RxClass API timeout", "retryable": True, } except requests.exceptions.HTTPError as e: sc = e.response.status_code return { "status": "error", "error": f"RxClass HTTP {sc}", "retryable": sc in (408, 429, 500, 502, 503, 504), } except Exception as e: return {"status": "error", "error": str(e), "retryable": False} all_classes = data.get("rxclassMinConceptList", {}).get("rxclassMinConcept", []) kw = query.strip().lower() matches = [ { "classId": c.get("classId", ""), "className": c.get("className", ""), "classType": c.get("classType", ""), } for c in all_classes if kw in c.get("className", "").lower() ][:limit] return { "status": "success", "data": matches, "metadata": { "query": query, "class_type": class_type or "all", "count": len(matches), }, }
# ------------------------------------------------------------------ # operation: get_class_hierarchy # ------------------------------------------------------------------
[docs] def _get_class_hierarchy(self, arguments: Dict[str, Any]) -> Dict[str, Any]: """Return the ancestor chain / class tree for a class via classGraph.""" class_id = arguments.get("class_id") or arguments.get("classId") source = arguments.get("source", "") if not class_id: return { "status": "error", "error": "Provide 'class_id' (e.g., 'N02BA' for salicylic acid derivatives).", } params: Dict[str, Any] = {"classId": str(class_id).strip()} if source: params["source"] = str(source).strip() result = self._api_get(f"{RXCLASS_BASE}/classGraph.json", params) if not result["ok"]: result.pop("ok", None) return {"status": "error", **result} graph = result["data"].get("rxclassGraph", {}) or {} raw_nodes = graph.get("rxclassMinConceptItem", []) or [] raw_edges = graph.get("rxclassEdge", []) or [] if not raw_nodes: return { "status": "success", "data": {"nodes": [], "edges": [], "ancestor_path": []}, "metadata": { "class_id": class_id, "node_count": 0, "edge_count": 0, "note": f"No class hierarchy found for class '{class_id}'. Verify the class ID.", }, } nodes = [ { "classId": n.get("classId", ""), "className": n.get("className", ""), "classType": n.get("classType", ""), } for n in raw_nodes ] edges = [ { "child": e.get("classId1", ""), "rela": e.get("rela", ""), "parent": e.get("classId2", ""), } for e in raw_edges ] # Build an ordered ancestor path from the requested leaf up to the root # by following parent links from each child. parent_of = {e["child"]: e["parent"] for e in edges} name_of = {n["classId"]: n["className"] for n in nodes} type_of = {n["classId"]: n["classType"] for n in nodes} ancestor_path = [] current = str(class_id).strip() visited = set() while current and current not in visited: visited.add(current) ancestor_path.append( { "classId": current, "className": name_of.get(current, ""), "classType": type_of.get(current, ""), } ) current = parent_of.get(current) return { "status": "success", "data": { "nodes": nodes, "edges": edges, "ancestor_path": ancestor_path, }, "metadata": { "class_id": class_id, "source": source or None, "node_count": len(nodes), "edge_count": len(edges), "depth": len(ancestor_path), }, }
# ------------------------------------------------------------------ # operation: get_disease_relations # ------------------------------------------------------------------
[docs] def _get_disease_relations(self, arguments: Dict[str, Any]) -> Dict[str, Any]: """MED-RT drug<->disease relations (may_treat, may_prevent, CI_with, ...). Two directions: * forward (drug_name or rxcui given): which disease classes a drug relates to under MED-RT (via class/byDrugName or class/byRxcui). * reverse (class_id given): which ingredient drugs relate to a disease class (via classMembers with relaSource=MEDRT). """ drug_name = arguments.get("drug_name") or arguments.get("name") rxcui = arguments.get("rxcui") class_id = arguments.get("class_id") or arguments.get("classId") rela = arguments.get("rela") or arguments.get("relas") ttys = arguments.get("ttys", "IN") if not (drug_name or rxcui or class_id): return { "status": "error", "error": "Provide 'drug_name'/'rxcui' (forward) or 'class_id' (reverse disease-class lookup).", } # Reverse: all drugs for a disease class if class_id and not (drug_name or rxcui): params: Dict[str, Any] = { "classId": str(class_id).strip(), "relaSource": "MEDRT", "ttys": ttys, } if rela: params["rela"] = str(rela).strip() result = self._api_get(f"{RXCLASS_BASE}/classMembers.json", params) if not result["ok"]: result.pop("ok", None) return {"status": "error", **result} members = ( result["data"].get("drugMemberGroup", {}).get("drugMember", []) or [] ) drugs = [] for m in members: mc = m.get("minConcept", {}) or {} drugs.append( { "rxcui": mc.get("rxcui", ""), "name": mc.get("name", ""), "tty": mc.get("tty", ""), "rela": m.get("rela", "") or (rela or ""), } ) return { "status": "success", "data": drugs, "metadata": { "direction": "reverse", "class_id": class_id, "rela_source": "MEDRT", "rela": rela or None, "ttys": ttys, "count": len(drugs), }, } # Forward: which disease classes a drug relates to if rxcui: url = f"{RXCLASS_BASE}/class/byRxcui.json" params = {"rxcui": str(rxcui).strip(), "relaSource": "MEDRT"} query_label = str(rxcui).strip() else: url = f"{RXCLASS_BASE}/class/byDrugName.json" params = {"drugName": drug_name.strip(), "relaSource": "MEDRT"} query_label = drug_name.strip() if rela: params["relas"] = str(rela).strip() result = self._api_get(url, params) if not result["ok"]: result.pop("ok", None) return {"status": "error", **result} items = ( result["data"].get("rxclassDrugInfoList", {}).get("rxclassDrugInfo", []) or [] ) relations = [] seen = set() for it in items: mc = it.get("rxclassMinConceptItem", {}) or {} dmc = it.get("minConcept", {}) or {} key = (mc.get("classId", ""), it.get("rela", ""), dmc.get("rxcui", "")) if key in seen: continue seen.add(key) relations.append( { "classId": mc.get("classId", ""), "className": mc.get("className", ""), "classType": mc.get("classType", ""), "rela": it.get("rela", ""), "rxcui": dmc.get("rxcui", ""), "drugName": dmc.get("name", ""), } ) return { "status": "success", "data": relations, "metadata": { "direction": "forward", "query": query_label, "rela_source": "MEDRT", "rela": rela or None, "count": len(relations), }, }