Source code for tooluniverse.esm2_variant_effect_tool

"""Keyless ESM-2 masked-marginal missense variant-effect scoring.

This tool scores a single-amino-acid (missense) protein variant with the
**masked-marginal** log-likelihood ratio of Meier et al. (2021,
"Language models enable zero-shot prediction of the effects of mutations on
protein function", NeurIPS):

    score = log P(mutant | masked context) - log P(wild-type | masked context)

The variant position is replaced with the model's ``<mask>`` token, ESM-2
predicts the residue distribution at that position, and the score is the
log-ratio of the mutant vs. wild-type amino-acid probabilities. A **negative**
score means the model favors the wild-type residue over the mutant — the
variant is disfavored and is a candidate loss-of-function / destabilizing
change; a score near zero or positive means the substitution is tolerated by
the model.

Why this exists alongside the ``ESM_*`` tools
---------------------------------------------
ToolUniverse already exposes richer ESM scoring via EvolutionaryScale's ESMC
API (``ESM_score_sequence``, ``ESM_score_variant_sae_disruption``, …) — prefer
those when you have an ``ESM_API_KEY``. This tool fills a different niche: it
runs **without any API key** over HuggingFace's free ``hf-inference`` provider,
so it works as a zero-setup fallback for a quick single-variant screen.

It composes the generic :class:`HuggingFaceInferenceTool` for the HTTP/fill-mask
plumbing and adds only the masked-marginal method on top, so there is no
duplicated network code. ``run()`` never raises — every path returns a dict
with a ``status`` key.
"""

import math
from typing import Any, Dict, Optional

from .base_tool import BaseTool
from .huggingface_inference_tool import HuggingFaceInferenceTool
from .tool_registry import register_tool

# ESM-2 family default. The small vocab (~33 tokens) is shared across sizes, so
# requesting this many fill-mask candidates returns every amino-acid token.
_DEFAULT_MODEL = "facebook/esm2_t33_650M_UR50D"
_VOCAB_TOP_K = 33
# hf-inference truncates very long inputs; keep a budget that leaves room for
# the <cls>/<eos> tokens. Longer sequences are windowed around the variant.
_MAX_CONTEXT = 1022
# The 20 standard amino acids — the only residues a missense call substitutes.
_STANDARD_AA = set("ACDEFGHIKLMNPQRSTVWY")


[docs] @register_tool("ESM2VariantEffectTool") class ESM2VariantEffectTool(BaseTool): """Score a missense protein variant with ESM-2 masked-marginal LLR (no key)."""
[docs] def __init__(self, tool_config: Optional[Dict[str, Any]] = None): super().__init__(tool_config) self.tool_config = tool_config or {} # Compose the generic HF inference tool for the actual fill-mask call. self._hf = HuggingFaceInferenceTool({"name": "esm2-variant-hf"})
# ------------------------------------------------------------------ run
[docs] def run(self, arguments: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: args = arguments or {} sequence = "".join(str(args.get("sequence", "")).split()).upper() if not sequence or not sequence.isalpha(): return _err( "sequence is required (a wild-type protein sequence in " "1-letter amino-acid code)." ) position = args.get("position") try: position = int(position) except (TypeError, ValueError): return _err("position is required and must be a 1-based integer.") if not 1 <= position <= len(sequence): return _err( f"position {position} is out of range for a sequence of " f"length {len(sequence)} (use 1-based coordinates)." ) mutant = str(args.get("mutant", "")).strip().upper() if mutant not in _STANDARD_AA: return _err( f"mutant must be one standard amino acid (one of {sorted(_STANDARD_AA)}), " f"got {args.get('mutant')!r}." ) wild_type = sequence[position - 1] declared_wt = args.get("wild_type") if declared_wt: declared_wt = str(declared_wt).strip().upper() if declared_wt != wild_type: return _err( f"wild_type {declared_wt!r} does not match residue " f"{wild_type!r} at position {position} of the supplied " "sequence — check the coordinates / sequence." ) if mutant == wild_type: return _err( f"mutant equals the wild-type residue ({wild_type} at position " f"{position}); a missense variant must change the amino acid." ) # Window long sequences around the variant so the input fits the model. window, local_idx, win_start, windowed = self._window(sequence, position) masked = " ".join( "<mask>" if i == local_idx else aa for i, aa in enumerate(window) ) model_id = args.get("model_id") or _DEFAULT_MODEL fill = self._hf.run( { "operation": "fill_mask", "model_id": model_id, "text": masked, "top_k": _VOCAB_TOP_K, "wait_for_model": bool(args.get("wait_for_model", False)), } ) if fill.get("status") != "success": return fill # propagate the loading/error dict unchanged probs = { p["token_str"]: p["score"] for p in fill["data"].get("predictions", []) if len(p.get("token_str", "")) == 1 and p["score"] is not None } p_wt, p_mut = probs.get(wild_type), probs.get(mutant) if not p_wt or not p_mut: missing = wild_type if not p_wt else mutant return _err( f"Model {model_id} did not return a probability for residue " f"{missing!r}; cannot compute the log-likelihood ratio." ) llr = math.log(p_mut) - math.log(p_wt) if llr < 0: direction = "mutant disfavored vs wild-type (candidate deleterious)" else: direction = "mutant tolerated or favored (likely neutral)" window_span = [win_start + 1, win_start + len(window)] if windowed else None return { "status": "success", "data": { "model_id": model_id, "variant": f"{wild_type}{position}{mutant}", "position": position, "wild_type": wild_type, "mutant": mutant, "p_wild_type": p_wt, "p_mutant": p_mut, "log_likelihood_ratio": llr, "direction": direction, }, "metadata": { "method": "ESM-2 masked-marginal LLR (Meier et al. 2021)", "source": "HuggingFace hf-inference (no API key required)", "windowed": windowed, "window": window_span, "note": ( "Negative = mutant less likely than wild-type. Magnitude is " "not a calibrated pathogenicity probability; rank variants or " "calibrate against a reference set. For key-based ESMC scoring " "use ESM_score_sequence / ESM_score_variant_sae_disruption." ), }, }
# -------------------------------------------------------------- helpers
[docs] @staticmethod def _window(sequence: str, position: int): """Return (window, local_index, window_start, windowed?) for the model. Sequences within the context budget are returned whole. Longer ones are clipped to a ``_MAX_CONTEXT`` window centered on the variant so the masked position keeps its real sequence neighborhood. """ if len(sequence) <= _MAX_CONTEXT: return sequence, position - 1, 0, False half = _MAX_CONTEXT // 2 start = max(0, (position - 1) - half) start = min(start, len(sequence) - _MAX_CONTEXT) window = sequence[start : start + _MAX_CONTEXT] return window, (position - 1) - start, start, True
def _err(message: str) -> Dict[str, Any]: return {"status": "error", "error": message, "source": "ESM2VariantEffectTool"}