Source code for tooluniverse.alphagenome_tool

"""DeepMind AlphaGenome regulatory-genomics prediction tool.

AlphaGenome (Avsec et al., *Nature* 2026) is the hosted successor to Enformer /
Borzoi: a single DNA-sequence model that predicts multimodal genomic tracks
(RNA-seq, CAGE, ATAC, DNase, histone/TF ChIP, splicing, contact maps) over up to
1 Mb at single-base resolution, and scores regulatory variant effects.

Unlike Enformer/Borzoi (local weights), AlphaGenome is a **hosted API**: requests
go over gRPC through the official ``alphagenome`` Python SDK to DeepMind's
servers, so this is integrated as a normal key-gated tool rather than a remote
MCP server. It is free for non-commercial use; obtain a key at
https://deepmind.google.com/science/alphagenome and set ``ALPHA_GENOME_API_KEY``.

Operations (selected via the ``operation`` field):
  * score_variant    -> recommended ref-vs-alt variant-effect scores per track
  * predict_interval -> a compact summary of predicted tracks for an interval

The SDK (``pip install alphagenome``) is an optional dependency; ``run()`` returns
a clear error dict if it or the API key is missing, and never raises.

Reference
---------
Avsec Z, Latysheva N, Cheng J, et al. "Advancing regulatory variant effect
prediction with AlphaGenome." Nature 649, 1206-1218 (2026).
doi:10.1038/s41586-025-10014-0.
"""

import os
from typing import Any, Dict, List, Optional

from .base_tool import BaseTool
from .tool_registry import register_tool

_ORGANISMS = {"human": "HOMO_SAPIENS", "mouse": "MUS_MUSCULUS"}
_SEQ_LENGTHS = {
    "16KB": "SEQUENCE_LENGTH_16KB",
    "100KB": "SEQUENCE_LENGTH_100KB",
    "500KB": "SEQUENCE_LENGTH_500KB",
    "1MB": "SEQUENCE_LENGTH_1MB",
}


[docs] @register_tool("AlphaGenomeTool") class AlphaGenomeTool(BaseTool): """Predict genomic tracks / score variant effects via the AlphaGenome API."""
[docs] def __init__(self, tool_config: Optional[Dict[str, Any]] = None): super().__init__(tool_config) self.tool_config = tool_config or {} self.operation = (self.tool_config.get("fields", {}) or {}).get("operation", "")
# ------------------------------------------------------------------ run
[docs] def run(self, arguments: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: args = arguments or {} operation = self.operation or args.get("operation") client = self._make_client() if isinstance(client, dict): # error dict from setup return client model, mods = client try: if operation == "score_variant": return self._score_variant(model, mods, args) if operation == "predict_interval": return self._predict_interval(model, mods, args) return self._err( f"Unknown operation {operation!r}. Use 'score_variant' or " "'predict_interval'." ) except Exception as exc: # never raise out of run() return self._err(f"AlphaGenome request failed: {type(exc).__name__}: {exc}")
# -------------------------------------------------------------- helpers
[docs] def _make_client(self): """Import the SDK, read the key, and build a client — or return an error dict.""" try: from alphagenome.data import genome from alphagenome.models import dna_client, variant_scorers except ImportError: return self._err( "The 'alphagenome' package is required: pip install alphagenome." ) api_key = os.environ.get("ALPHA_GENOME_API_KEY", "") if not api_key: return self._err( "Set ALPHA_GENOME_API_KEY (free non-commercial key at " "https://deepmind.google.com/science/alphagenome)." ) model = dna_client.create(api_key) return model, (genome, dna_client, variant_scorers)
[docs] @staticmethod def _organism(mods, name: str): _, dna_client, _ = mods return getattr( dna_client.Organism, _ORGANISMS.get((name or "human").lower(), "HOMO_SAPIENS"), )
[docs] @staticmethod def _seq_length(mods, name: str): _, dna_client, _ = mods return getattr( dna_client, _SEQ_LENGTHS.get((name or "1MB").upper(), "SEQUENCE_LENGTH_1MB") )
[docs] @staticmethod def _output_types(mods, names: List[str]): _, dna_client, _ = mods out = [] for n in names or ["RNA_SEQ"]: ot = getattr(dna_client.OutputType, str(n).upper(), None) if ot is not None: out.append(ot) return out or [dna_client.OutputType.RNA_SEQ]
# ------------------------------------------------------------- operations
[docs] def _score_variant(self, model, mods, args: Dict[str, Any]) -> Dict[str, Any]: genome, _, variant_scorers = mods required = ["chromosome", "position", "reference_bases", "alternate_bases"] missing = [k for k in required if not args.get(k)] if missing: return self._err(f"Missing required parameter(s): {', '.join(missing)}") variant = genome.Variant( chromosome=str(args["chromosome"]), position=int(args["position"]), reference_bases=str(args["reference_bases"]), alternate_bases=str(args["alternate_bases"]), ) interval = variant.reference_interval.resize( self._seq_length(mods, args.get("sequence_length")) ) out_type = str(args.get("output_type") or "RNA_SEQ").upper() scorer = variant_scorers.RECOMMENDED_VARIANT_SCORERS[out_type] scores = model.score_variant( interval=interval, variant=variant, variant_scorers=[scorer], organism=self._organism(mods, args.get("organism")), ) top_n = int(args.get("top_n") or 20) variant_label = ( f"{variant.chromosome}:{variant.position}" f"{variant.reference_bases}>{variant.alternate_bases}" ) return self._ok( { "variant": variant_label, "output_type": out_type, "scores": self._summarize_scores(scores, top_n), }, task="score_variant", )
[docs] def _predict_interval(self, model, mods, args: Dict[str, Any]) -> Dict[str, Any]: genome, _, _ = mods required = ["chromosome", "start", "end"] missing = [k for k in required if args.get(k) is None] if missing: return self._err(f"Missing required parameter(s): {', '.join(missing)}") interval = genome.Interval( chromosome=str(args["chromosome"]), start=int(args["start"]), end=int(args["end"]), ).resize(self._seq_length(mods, args.get("sequence_length"))) output = model.predict_interval( interval=interval, requested_outputs=self._output_types(mods, args.get("output_types")), ontology_terms=args.get("ontology_terms") or None, organism=self._organism(mods, args.get("organism")), ) return self._ok( { "interval": f"{interval.chromosome}:{interval.start}-{interval.end}", "tracks": self._summarize_outputs(output), }, task="predict_interval", )
# ------------------------------------------------------------- formatting
[docs] @staticmethod def _summarize_scores(scores, top_n: int) -> List[Dict[str, Any]]: """Flatten the AnnData score objects to the top |score| per-track entries.""" rows: List[Dict[str, Any]] = [] for adata in scores or []: values = adata.X names = list(getattr(adata, "var_names", [])) flat = values.ravel().tolist() if hasattr(values, "ravel") else list(values) for name, val in zip(names, flat): rows.append({"track": str(name), "score": float(val)}) rows.sort(key=lambda r: abs(r["score"]), reverse=True) return rows[:top_n]
[docs] @staticmethod def _summarize_outputs(output) -> List[Dict[str, Any]]: """Per requested modality: track count + shape (the raw tensors are huge).""" summary = [] for attr in ( "rna_seq", "atac", "dnase", "cage", "chip_histone", "chip_tf", "splice_sites", "contact_maps", ): td = getattr(output, attr, None) if td is None: continue values = getattr(td, "values", None) meta = getattr(td, "metadata", None) summary.append( { "modality": attr, "shape": list(getattr(values, "shape", []) or []), "n_tracks": int(len(meta)) if meta is not None else None, } ) return summary
[docs] @staticmethod def _ok(data: Any, **meta: Any) -> Dict[str, Any]: m = {"source": "AlphaGenome", "provider": "Google DeepMind (hosted API)"} m.update(meta) return {"status": "success", "data": data, "metadata": m}
[docs] @staticmethod def _err(message: str) -> Dict[str, Any]: return {"status": "error", "error": message, "source": "AlphaGenome"}