Source code for tooluniverse.compound_gene_disease_tool
"""Compound tool: gather gene-disease associations from multiple databases in one call.
Queries DisGeNET, OMIM, OpenTargets, GenCC, and ClinVar for a given gene or disease,
cross-references results, and returns a unified comparison table with concordance scores.
"""
from typing import Any, Dict, List
from .base_tool import BaseTool
from .tool_registry import register_tool
[docs]
@register_tool("CompoundGeneDiseaseAssociationTool")
class CompoundGeneDiseaseAssociationTool(BaseTool):
"""Query multiple gene-disease databases in a single call."""
SOURCES = ["DisGeNET", "OMIM", "OpenTargets", "GenCC", "ClinVar"]
[docs]
def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
gene = arguments.get("gene") or arguments.get("gene_symbol")
disease = arguments.get("disease")
if not gene and not disease:
return {
"status": "error",
"error": "At least one of 'gene' or 'disease' is required.",
}
from .execute_function import ToolUniverse
tu = ToolUniverse()
tu.load_tools()
results_by_source: Dict[str, Any] = {}
sources_failed: List[str] = []
def _try_tool(source_name, tool_name, args):
"""Call a tool and handle errors gracefully."""
try:
r = tu.run_one_function({"name": tool_name, "arguments": args})
if isinstance(r, dict) and r.get("status") == "error":
sources_failed.append(
f"{source_name}: {r.get('error', 'unknown error')[:100]}"
)
return r
return r
except Exception as e:
sources_failed.append(f"{source_name}: {str(e)[:100]}")
return {"status": "error", "error": str(e)}
# 1. DisGeNET
if disease:
r = _try_tool(
"DisGeNET", "DisGeNET_search_disease", {"disease": disease, "limit": 20}
)
else:
r = _try_tool(
"DisGeNET", "DisGeNET_search_gene", {"gene": gene, "limit": 20}
)
results_by_source["DisGeNET"] = self._extract_genes_or_diseases(r, "DisGeNET")
# 2. OMIM
r = _try_tool("OMIM", "OMIM_search", {"query": disease or gene, "limit": 10})
results_by_source["OMIM"] = self._extract_genes_or_diseases(r, "OMIM")
# 3. OpenTargets
r = _try_tool(
"OpenTargets",
"OpenTargets_get_disease_ids_by_name",
{"name": disease or gene},
)
results_by_source["OpenTargets"] = self._extract_genes_or_diseases(
r, "OpenTargets"
)
# 4. GenCC
if gene:
r = _try_tool("GenCC", "GenCC_search_gene", {"gene_symbol": gene})
else:
r = _try_tool("GenCC", "GenCC_search_disease", {"disease": disease})
results_by_source["GenCC"] = self._extract_genes_or_diseases(r, "GenCC")
# 5. ClinVar
r = _try_tool(
"ClinVar",
"ClinVar_search_variants",
{"query": gene or disease, "limit": 10},
)
results_by_source["ClinVar"] = self._extract_genes_or_diseases(r, "ClinVar")
# Build cross-reference table
associations = self._build_concordance(results_by_source)
return {
"status": "success",
"data": {
"query": {"gene": gene, "disease": disease},
"sources_queried": list(results_by_source.keys()),
"sources_failed": sources_failed,
"num_associations": len(associations),
"associations": associations[:50],
"per_source_results": {k: v[:10] for k, v in results_by_source.items()},
},
}
[docs]
def _extract_genes_or_diseases(
self, result: Any, source: str
) -> List[Dict[str, Any]]:
"""Extract gene/disease names from a tool result, handling various formats."""
items = []
if not isinstance(result, dict):
return items
if result.get("status") == "error":
return items
data = result.get("data", result)
# GenCC: data has "submissions" list
if isinstance(data, dict) and "submissions" in data:
for sub in data["submissions"][:30]:
if isinstance(sub, dict):
name = sub.get("disease_title") or sub.get("gene_symbol") or ""
items.append(
{
"name": str(name),
"score": None,
"source": source,
"evidence": sub.get("classification", ""),
}
)
return items
# OpenTargets: data has "search.hits" list
if isinstance(data, dict) and "search" in data:
hits = data["search"].get("hits", [])
for hit in hits[:30]:
if isinstance(hit, dict):
name = hit.get("name", "")
items.append({"name": str(name), "score": None, "source": source})
return items
# ClinVar: data has "variants" list
if isinstance(data, dict) and "variants" in data:
for v in data["variants"][:30]:
if isinstance(v, dict):
genes = v.get("genes", [])
for g in genes:
items.append(
{
"name": str(g),
"score": None,
"source": source,
"classification": v.get("clinical_significance", ""),
}
)
return items
# Generic: try common patterns
if isinstance(data, dict):
data = data.get(
"results", data.get("data", data.get("genes", data.get("entries", [])))
)
if isinstance(data, list):
for item in data[:30]:
if isinstance(item, dict):
name = (
item.get("gene_symbol")
or item.get("geneName")
or item.get("gene")
or item.get("symbol")
or item.get("title")
or item.get("name")
or item.get("disease")
or item.get("diseaseName")
or ""
)
score = item.get(
"score", item.get("gda_score", item.get("associationScore"))
)
if name:
items.append(
{"name": str(name), "score": score, "source": source}
)
return items
[docs]
def _build_concordance(
self, results_by_source: Dict[str, List[Dict[str, Any]]]
) -> List[Dict[str, Any]]:
"""Build a concordance table showing which entities appear in which sources."""
entity_sources: Dict[str, Dict[str, Any]] = {}
for source, items in results_by_source.items():
for item in items:
name = item["name"].upper()
if name not in entity_sources:
entity_sources[name] = {
"name": item["name"],
"sources": [],
"scores": {},
}
entity_sources[name]["sources"].append(source)
if item.get("score") is not None:
entity_sources[name]["scores"][source] = item["score"]
associations = []
for entity_data in entity_sources.values():
sources = entity_data["sources"]
associations.append(
{
"name": entity_data["name"],
"sources": sorted(set(sources)),
"concordance": len(set(sources)),
"total_sources_queried": len(results_by_source),
"scores": entity_data["scores"],
}
)
associations.sort(key=lambda x: (-x["concordance"], x["name"]))
return associations