Source code for tooluniverse.compound_gene_disease_tool

"""Compound tool: gather gene-disease associations from multiple databases in one call.

Queries DisGeNET, OMIM, OpenTargets, GenCC, and ClinVar for a given gene or disease,
cross-references results, and returns a unified comparison table with concordance scores.
"""

from typing import Any, Dict, List

from .base_tool import BaseTool
from .tool_registry import register_tool


[docs] @register_tool("CompoundGeneDiseaseAssociationTool") class CompoundGeneDiseaseAssociationTool(BaseTool): """Query multiple gene-disease databases in a single call.""" SOURCES = ["DisGeNET", "OMIM", "OpenTargets", "GenCC", "ClinVar"]
[docs] def __init__(self, tool_config: Dict[str, Any], **kwargs): super().__init__(tool_config)
[docs] def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: gene = arguments.get("gene") or arguments.get("gene_symbol") disease = arguments.get("disease") if not gene and not disease: return { "status": "error", "error": "At least one of 'gene' or 'disease' is required.", } from .execute_function import ToolUniverse tu = ToolUniverse() tu.load_tools() results_by_source: Dict[str, Any] = {} sources_failed: List[str] = [] def _try_tool(source_name, tool_name, args): """Call a tool and handle errors gracefully.""" try: r = tu.run_one_function({"name": tool_name, "arguments": args}) if isinstance(r, dict) and r.get("status") == "error": sources_failed.append( f"{source_name}: {r.get('error', 'unknown error')[:100]}" ) return r return r except Exception as e: sources_failed.append(f"{source_name}: {str(e)[:100]}") return {"status": "error", "error": str(e)} # 1. DisGeNET if disease: r = _try_tool( "DisGeNET", "DisGeNET_search_disease", {"disease": disease, "limit": 20} ) else: r = _try_tool( "DisGeNET", "DisGeNET_search_gene", {"gene": gene, "limit": 20} ) results_by_source["DisGeNET"] = self._extract_genes_or_diseases(r, "DisGeNET") # 2. OMIM r = _try_tool("OMIM", "OMIM_search", {"query": disease or gene, "limit": 10}) results_by_source["OMIM"] = self._extract_genes_or_diseases(r, "OMIM") # 3. OpenTargets r = _try_tool( "OpenTargets", "OpenTargets_get_disease_ids_by_name", {"name": disease or gene}, ) results_by_source["OpenTargets"] = self._extract_genes_or_diseases( r, "OpenTargets" ) # 4. GenCC if gene: r = _try_tool("GenCC", "GenCC_search_gene", {"gene_symbol": gene}) else: r = _try_tool("GenCC", "GenCC_search_disease", {"disease": disease}) results_by_source["GenCC"] = self._extract_genes_or_diseases(r, "GenCC") # 5. ClinVar r = _try_tool( "ClinVar", "ClinVar_search_variants", {"query": gene or disease, "limit": 10}, ) results_by_source["ClinVar"] = self._extract_genes_or_diseases(r, "ClinVar") # Build cross-reference table associations = self._build_concordance(results_by_source) return { "status": "success", "data": { "query": {"gene": gene, "disease": disease}, "sources_queried": list(results_by_source.keys()), "sources_failed": sources_failed, "num_associations": len(associations), "associations": associations[:50], "per_source_results": {k: v[:10] for k, v in results_by_source.items()}, }, }
[docs] def _extract_genes_or_diseases( self, result: Any, source: str ) -> List[Dict[str, Any]]: """Extract gene/disease names from a tool result, handling various formats.""" items = [] if not isinstance(result, dict): return items if result.get("status") == "error": return items data = result.get("data", result) # GenCC: data has "submissions" list if isinstance(data, dict) and "submissions" in data: for sub in data["submissions"][:30]: if isinstance(sub, dict): name = sub.get("disease_title") or sub.get("gene_symbol") or "" items.append( { "name": str(name), "score": None, "source": source, "evidence": sub.get("classification", ""), } ) return items # OpenTargets: data has "search.hits" list if isinstance(data, dict) and "search" in data: hits = data["search"].get("hits", []) for hit in hits[:30]: if isinstance(hit, dict): name = hit.get("name", "") items.append({"name": str(name), "score": None, "source": source}) return items # ClinVar: data has "variants" list if isinstance(data, dict) and "variants" in data: for v in data["variants"][:30]: if isinstance(v, dict): genes = v.get("genes", []) for g in genes: items.append( { "name": str(g), "score": None, "source": source, "classification": v.get("clinical_significance", ""), } ) return items # Generic: try common patterns if isinstance(data, dict): data = data.get( "results", data.get("data", data.get("genes", data.get("entries", []))) ) if isinstance(data, list): for item in data[:30]: if isinstance(item, dict): name = ( item.get("gene_symbol") or item.get("geneName") or item.get("gene") or item.get("symbol") or item.get("title") or item.get("name") or item.get("disease") or item.get("diseaseName") or "" ) score = item.get( "score", item.get("gda_score", item.get("associationScore")) ) if name: items.append( {"name": str(name), "score": score, "source": source} ) return items
[docs] def _build_concordance( self, results_by_source: Dict[str, List[Dict[str, Any]]] ) -> List[Dict[str, Any]]: """Build a concordance table showing which entities appear in which sources.""" entity_sources: Dict[str, Dict[str, Any]] = {} for source, items in results_by_source.items(): for item in items: name = item["name"].upper() if name not in entity_sources: entity_sources[name] = { "name": item["name"], "sources": [], "scores": {}, } entity_sources[name]["sources"].append(source) if item.get("score") is not None: entity_sources[name]["scores"][source] = item["score"] associations = [] for entity_data in entity_sources.values(): sources = entity_data["sources"] associations.append( { "name": entity_data["name"], "sources": sorted(set(sources)), "concordance": len(set(sources)), "total_sources_queried": len(results_by_source), "scores": entity_data["scores"], } ) associations.sort(key=lambda x: (-x["concordance"], x["name"])) return associations