tooluniverse.esm_tool 源代码
"""
ESM3 / ESMC Tool
Provides access to EvolutionaryScale protein language models:
- ESMC (300m/600m): fast protein sequence embeddings
- ESM3 (open/small): sequence generation, structure prediction, sequence scoring
Requires an API token from https://forge.evolutionaryscale.ai
Set via environment variable ESM_API_KEY.
Install: pip install esm
"""
import os
from typing import Dict, Any, List, Optional
from .base_tool import BaseTool
from .tool_registry import register_tool
def _get_client(model: str):
"""Return an ESM3ForgeInferenceClient for the given model, using ESM_API_KEY."""
try:
from esm.sdk.forge import ESM3ForgeInferenceClient
except ImportError:
raise ImportError("esm package is required. Install with: pip install esm")
token = os.environ.get("ESM_API_KEY", "")
if not token:
raise EnvironmentError(
"ESM_API_KEY environment variable is not set. "
"Obtain a token at https://forge.evolutionaryscale.ai"
)
return ESM3ForgeInferenceClient(model=model, token=token)
def _get_esmc_client(model: str):
"""Return an ESMCForgeInferenceClient (needed for SAE inference).
ESMC SAE features require the dedicated ESMCForgeInferenceClient (not
ESM3ForgeInferenceClient) because SAE outputs are exposed through ESMC's
logits endpoint via LogitsConfig(sae_config=...).
"""
try:
from esm.sdk.forge import ESMCForgeInferenceClient
except ImportError:
raise ImportError(
"esm package with SAE support is required. The current PyPI "
"release (esm 3.2.x) does NOT include SAEConfig — SAE features "
"live on an unmerged feature branch. Install from there:\n"
" pip install 'esm @ git+https://github.com/evolutionaryscale/esm@ee891c52'"
)
token = os.environ.get("ESM_API_KEY", "")
if not token:
raise EnvironmentError(
"ESM_API_KEY environment variable is not set. "
"Obtain a token at https://forge.evolutionaryscale.ai"
)
return ESMCForgeInferenceClient(model=model, token=token)
[文档]
@register_tool("ESMTool")
class ESMTool(BaseTool):
"""
ESM3 / ESMC tool for protein sequence embeddings, generation,
structure prediction, and sequence scoring via the EvolutionaryScale Forge API.
"""
[文档]
def __init__(self, tool_config):
super().__init__(tool_config)
# Bind this instance to one operation via fields.operation, matching
# the AlphaMissense / MaveDB multi-op pattern. Callers no longer need
# to pass operation= explicitly via tu.tools.X() / tu.run_one_function().
# Falls back to arguments["operation"] when fields.operation isn't set,
# for backward compatibility with tests / direct instantiation.
self.operation = tool_config.get("fields", {}).get("operation", "")
[文档]
def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Execute the tool with given arguments."""
try:
operation = self.operation or arguments.get("operation", "")
if operation == "get_protein_embedding":
return self._get_protein_embedding(arguments)
elif operation == "generate_protein_sequence":
return self._generate_protein_sequence(arguments)
elif operation == "fold_protein":
return self._fold_protein(arguments)
elif operation == "score_sequence":
return self._score_sequence(arguments)
elif operation == "get_sae_features":
return self._get_sae_features(arguments)
elif operation == "score_variant_sae_disruption":
return self._score_variant_sae_disruption(arguments)
elif operation == "describe_sae_feature":
return self._describe_sae_feature(arguments)
elif operation == "score_variant_sae_batch":
return self._score_variant_sae_batch(arguments)
elif operation == "get_region_sae_features":
return self._get_region_sae_features(arguments)
elif operation == "explain_variant_mechanism":
return self._explain_variant_mechanism(arguments)
else:
return {
"status": "error",
"error": f"Unknown operation: {operation!r}. Valid operations: "
"get_protein_embedding, generate_protein_sequence, "
"fold_protein, score_sequence, get_sae_features, "
"score_variant_sae_disruption, describe_sae_feature, "
"score_variant_sae_batch, get_region_sae_features, "
"explain_variant_mechanism",
}
except ImportError as e:
return {"status": "error", "error": str(e)}
except EnvironmentError as e:
return {"status": "error", "error": str(e)}
except Exception as e:
return {"status": "error", "error": str(e)}
# ------------------------------------------------------------------ #
# get_protein_embedding
# ------------------------------------------------------------------ #
[文档]
def _get_protein_embedding(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Get ESMC per-residue and mean-pooled embeddings for a protein sequence."""
try:
from esm.sdk.api import ESMProtein, LogitsConfig
sequence = arguments.get("sequence")
if not sequence:
return {
"status": "error",
"error": "sequence is required for get_protein_embedding",
}
model = arguments.get("model", "esmc-300m-2024-12")
return_per_residue = arguments.get("return_per_residue", False)
client = _get_client(model)
protein = ESMProtein(sequence=sequence)
logits_output = client.logits(
protein, LogitsConfig(sequence=True, return_embeddings=True)
)
if logits_output.embeddings is None:
return {
"status": "error",
"error": "Model did not return embeddings. Ensure LogitsConfig(return_embeddings=True) is supported.",
}
embeddings = logits_output.embeddings # shape: (L+2, D) including BOS/EOS
# mean pool over residue positions (exclude BOS/EOS tokens)
import numpy as np
emb_np = (
embeddings.detach().cpu().numpy()
if hasattr(embeddings, "detach")
else embeddings
)
mean_emb = emb_np[1:-1].mean(axis=0).tolist()
result = {
"status": "success",
"model": model,
"sequence_length": len(sequence),
"embedding_dim": len(mean_emb),
"mean_embedding": mean_emb,
}
if return_per_residue:
result["per_residue_embeddings"] = emb_np[1:-1].tolist()
return result
except Exception as e:
return {"status": "error", "error": str(e)}
# ------------------------------------------------------------------ #
# generate_protein_sequence
# ------------------------------------------------------------------ #
[文档]
def _generate_protein_sequence(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Generate or complete a protein sequence using ESM3."""
try:
from esm.sdk.api import ESMProtein, GenerationConfig
prompt_sequence = arguments.get("prompt_sequence")
if not prompt_sequence:
return {
"status": "error",
"error": "prompt_sequence is required. Use '_' characters to denote masked positions to generate.",
}
model = arguments.get("model", "esm3-open-2024-03")
num_steps = int(arguments.get("num_steps", 8))
temperature = float(arguments.get("temperature", 1.0))
client = _get_client(model)
protein = ESMProtein(sequence=prompt_sequence)
config = GenerationConfig(
track="sequence",
num_steps=num_steps,
temperature=temperature,
)
result_protein = client.generate(protein, config)
generated_seq = result_protein.sequence if result_protein.sequence else ""
return {
"status": "success",
"model": model,
"prompt_sequence": prompt_sequence,
"generated_sequence": generated_seq,
"sequence_length": len(generated_seq),
"num_steps": num_steps,
"temperature": temperature,
}
except Exception as e:
return {"status": "error", "error": str(e)}
# ------------------------------------------------------------------ #
# fold_protein
# ------------------------------------------------------------------ #
[文档]
def _fold_protein(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Predict protein structure with ESM3, returning pTM score and coordinates."""
try:
from esm.sdk.api import ESMProtein, GenerationConfig
sequence = arguments.get("sequence")
if not sequence:
return {
"status": "error",
"error": "sequence is required for fold_protein",
}
model = arguments.get("model", "esm3-open-2024-03")
num_steps = int(arguments.get("num_steps", 8))
client = _get_client(model)
protein = ESMProtein(sequence=sequence)
config = GenerationConfig(track="structure", num_steps=num_steps)
result_protein = client.generate(protein, config)
coordinates = None
if result_protein.coordinates is not None:
coords = result_protein.coordinates
if hasattr(coords, "tolist"):
coordinates = coords.tolist()
else:
coordinates = coords
ptm = None
if hasattr(result_protein, "ptm") and result_protein.ptm is not None:
ptm = float(result_protein.ptm)
plddt = None
if hasattr(result_protein, "plddt") and result_protein.plddt is not None:
p = result_protein.plddt
if hasattr(p, "tolist"):
plddt = p.tolist()
else:
plddt = p
return {
"status": "success",
"model": model,
"sequence": sequence,
"sequence_length": len(sequence),
"pTM_score": ptm,
"plddt_per_residue": plddt,
"coordinates_shape": (
[len(coordinates), len(coordinates[0]), len(coordinates[0][0])]
if coordinates is not None
else None
),
"num_steps": num_steps,
"note": "Coordinates are (L, 37, 3) backbone atom positions in Angstroms.",
}
except Exception as e:
return {"status": "error", "error": str(e)}
# ------------------------------------------------------------------ #
# score_sequence
# ------------------------------------------------------------------ #
[文档]
def _score_sequence(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Score a protein sequence using ESMC logits (per-residue log-probabilities)."""
try:
from esm.sdk.api import ESMProtein, LogitsConfig
import math
sequence = arguments.get("sequence")
if not sequence:
return {
"status": "error",
"error": "sequence is required for score_sequence",
}
model = arguments.get("model", "esmc-300m-2024-12")
client = _get_client(model)
protein = ESMProtein(sequence=sequence)
logits_output = client.logits(
protein, LogitsConfig(sequence=True, return_embeddings=False)
)
if logits_output.logits is None or logits_output.logits.sequence is None:
return {
"status": "error",
"error": "Model did not return sequence logits.",
}
# Compute mean log-prob (pseudo-likelihood) per residue
import torch
import torch.nn.functional as F
seq_logits = logits_output.logits.sequence # (L+2, vocab)
log_probs = F.log_softmax(seq_logits, dim=-1)
# ESM tokenizer: map each residue to its token id
try:
from esm.utils.constants.esm3 import SEQUENCE_VOCAB
aa_to_idx = {aa: i for i, aa in enumerate(SEQUENCE_VOCAB)}
except Exception:
aa_to_idx = {}
per_residue_logprobs = []
if aa_to_idx:
for i, aa in enumerate(sequence):
token_id = aa_to_idx.get(aa)
if token_id is not None:
lp = log_probs[i + 1, token_id].item()
per_residue_logprobs.append(lp)
mean_logprob = (
sum(per_residue_logprobs) / len(per_residue_logprobs)
if per_residue_logprobs
else None
)
return {
"status": "success",
"model": model,
"sequence": sequence,
"sequence_length": len(sequence),
"mean_log_probability": mean_logprob,
"per_residue_log_probabilities": per_residue_logprobs,
}
except Exception as e:
return {"status": "error", "error": str(e)}
# ------------------------------------------------------------------ #
# get_sae_features
# ------------------------------------------------------------------ #
[文档]
def _get_sae_features(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Run a protein sequence through an ESMC Sparse Autoencoder (SAE) and
return sparse feature activations per residue.
ESMC SAEs decompose the model's hidden states into a 16,384-feature
sparse codebook with top-k=64 sparsity per residue. Each feature is an
interpretable latent dimension (catalytic site, binding region, PTM
sequon, etc., once labelled separately via ESM_describe_sae_feature).
License note: SAE outputs are governed by the EvolutionaryScale
Cambrian Inference Clickthrough License Agreement — non-commercial use
only unless covered by a separate commercial agreement.
"""
# Validate input arguments first — fail fast with clear errors before
# checking environment / imports, so a user calling with bad args sees
# the input problem (not "install esm").
sequence = arguments.get("sequence")
if not sequence:
return {
"status": "error",
"error": "sequence is required for get_sae_features",
}
model = arguments.get("model", "esmc-6b-2024-12")
sae_model = arguments.get(
"sae_model", "esmc-6b-2024-12_k64_codebook16384_layer60"
)
position = arguments.get("position") # 1-indexed, optional
window = arguments.get("window", 8)
top_k_per_residue = arguments.get("top_k_per_residue", 64)
seq_len = len(sequence)
# Sequence length cap — ESMC-6B Forge handles up to ~2,700 AA in
# practice (per EvolutionaryScale docs). Longer sequences
# fail with opaque server errors; catch up front with a clear message.
MAX_SEQ_LEN = 2700
if seq_len > MAX_SEQ_LEN:
return {
"status": "error",
"error": (
f"sequence length {seq_len} exceeds practical Forge SAE "
f"limit (~{MAX_SEQ_LEN} AA). Either truncate to a region "
f"of interest or split the protein and run separately."
),
}
if position is not None and (position < 1 or position > seq_len):
return {
"status": "error",
"error": (
f"position {position} out of range [1, {seq_len}] "
f"for sequence of length {seq_len}"
),
}
# Now check for SDK + API key (env-side problems)
try:
from esm.sdk.api import ESMProtein, SAEConfig, LogitsConfig
except ImportError:
return {
"status": "error",
"error": (
"esm package with SAE support is required. The PyPI "
"release does NOT include SAEConfig. Install from the "
"feature branch: pip install 'esm @ "
"git+https://github.com/evolutionaryscale/esm@ee891c52'"
),
}
try:
client = _get_esmc_client(model)
except (ImportError, EnvironmentError) as e:
return {"status": "error", "error": str(e)}
try:
protein = ESMProtein(sequence=sequence)
protein_tensor = client.encode(protein)
output = client.logits(
protein_tensor,
config=LogitsConfig(
sae_config=SAEConfig(model=sae_model, normalize_features=True)
),
)
except Exception as e:
return {
"status": "error",
"error": f"Forge SAE inference failed: {type(e).__name__}: {e}",
}
sae_outputs = output.sae_outputs
if not sae_outputs or sae_model not in sae_outputs:
return {
"status": "error",
"error": (
f"Forge response did not include sae_outputs for "
f"{sae_model}. Available: {list(sae_outputs.keys()) if sae_outputs else None}"
),
}
sae_tensor = sae_outputs[sae_model]
# sae_tensor is torch.sparse_coo_tensor with shape (L+2, 16384):
# row 0 = BOS, row L+1 = EOS, rows 1..L correspond to residues 1..L
try:
indices = sae_tensor.coalesce().indices() # shape (2, nnz)
values = sae_tensor.coalesce().values() # shape (nnz,)
except Exception:
# Fallback if tensor is dense
indices = sae_tensor.nonzero(as_tuple=False).t()
values = sae_tensor[indices[0], indices[1]]
rows = indices[0].tolist()
cols = indices[1].tolist()
vals = values.tolist()
# Determine residue rows to return (1-indexed positions → tensor row idx)
if position is not None:
lo_1idx = max(1, position - window)
hi_1idx = min(seq_len, position + window)
wanted_rows = set(range(lo_1idx, hi_1idx + 1))
else:
wanted_rows = set(range(1, seq_len + 1))
# Bucket non-zero entries by residue row (skip BOS row 0 and EOS row L+1)
per_residue: Dict[int, List[tuple]] = {}
for r, c, v in zip(rows, cols, vals):
if r in wanted_rows:
per_residue.setdefault(r, []).append((c, v))
# Sort each residue's features by |activation| descending, take top-K
residues_out: List[Dict[str, Any]] = []
for r in sorted(per_residue.keys()):
feats = sorted(per_residue[r], key=lambda x: -abs(x[1]))[:top_k_per_residue]
residues_out.append(
{
"residue_idx_1based": r,
"active_features": [
{"feature_id": int(c), "activation": float(v)} for c, v in feats
],
}
)
return {
"status": "success",
"data": {
"sequence_length": seq_len,
"model": model,
"sae_model": sae_model,
"residues_returned": len(residues_out),
"position": position,
"window": window if position is not None else None,
"top_k_per_residue": top_k_per_residue,
"activations": residues_out,
},
"metadata": {
"total_features_in_codebook": 16384,
"sparsity_k": 64,
"license": (
"Outputs are governed by EvolutionaryScale Cambrian "
"Inference License — non-commercial use only"
),
},
}
# ------------------------------------------------------------------ #
# score_variant_sae_disruption
# ------------------------------------------------------------------ #
[文档]
def _score_variant_sae_disruption(
self, arguments: Dict[str, Any]
) -> Dict[str, Any]:
"""Composite tool: compare SAE features for reference vs mutant sequence
and return per-feature delta scores ranked by absolute magnitude.
This is the convenience layer that variant-interpretation skills use to
avoid two manual ESM_get_sae_features calls plus manual delta
computation. It builds the mutant sequence, runs both ref + mut through
the SAE, sums activations over a residue window, and ranks features by
gain or loss.
For each ranked feature, the response includes the ref/mut activation
sums so the agent can sanity-check whether a delta reflects a strong
feature flipping off versus a weak feature appearing.
"""
sequence = arguments.get("sequence")
position = arguments.get("position")
ref_aa = arguments.get("ref_aa")
alt_aa = arguments.get("alt_aa")
window = arguments.get("window", 8)
top_k_features = arguments.get("top_k_features", 10)
# Validate inputs
if not sequence:
return {
"status": "error",
"error": "sequence is required for score_variant_sae_disruption",
}
if position is None or not isinstance(position, int):
return {
"status": "error",
"error": "position (1-indexed int) is required",
}
if not ref_aa or len(ref_aa) != 1:
return {
"status": "error",
"error": "ref_aa must be a single amino-acid letter",
}
if not alt_aa or len(alt_aa) != 1:
return {
"status": "error",
"error": "alt_aa must be a single amino-acid letter",
}
seq_len = len(sequence)
if position < 1 or position > seq_len:
return {
"status": "error",
"error": (
f"position {position} out of range [1, {seq_len}] "
f"for sequence of length {seq_len}"
),
}
actual_aa = sequence[position - 1]
if actual_aa != ref_aa:
return {
"status": "error",
"error": (
f"ref_aa mismatch: position {position} in sequence is "
f"{actual_aa!r}, but ref_aa was given as {ref_aa!r}. "
f"Double-check the sequence is the canonical reference for "
f"this variant."
),
}
# Build mutant
mutant = sequence[: position - 1] + alt_aa + sequence[position:]
# Get SAE features for ref + mut (reuse the existing operation)
common_args = {
"operation": "get_sae_features",
"position": position,
"window": window,
"top_k_per_residue": 64,
"model": arguments.get("model", "esmc-6b-2024-12"),
"sae_model": arguments.get(
"sae_model", "esmc-6b-2024-12_k64_codebook16384_layer60"
),
}
ref_result = self._get_sae_features({**common_args, "sequence": sequence})
if ref_result["status"] != "success":
return {**ref_result, "stage": "ref_sae"}
mut_result = self._get_sae_features({**common_args, "sequence": mutant})
if mut_result["status"] != "success":
return {**mut_result, "stage": "mut_sae"}
# Aggregate feature activations by feature_id (summed over window)
def sum_features(activations_list):
sums: Dict[int, float] = {}
for r in activations_list:
for f in r["active_features"]:
sums[f["feature_id"]] = (
sums.get(f["feature_id"], 0.0) + f["activation"]
)
return sums
ref_sums = sum_features(ref_result["data"]["activations"])
mut_sums = sum_features(mut_result["data"]["activations"])
all_feats = set(ref_sums) | set(mut_sums)
deltas = [(f, mut_sums.get(f, 0.0) - ref_sums.get(f, 0.0)) for f in all_feats]
top_lost = sorted(deltas, key=lambda x: x[1])[:top_k_features]
top_gained = sorted(deltas, key=lambda x: -x[1])[:top_k_features]
def feat_row(fid: int, delta: float) -> Dict[str, Any]:
return {
"feature_id": int(fid),
"delta": float(delta),
"ref_activation_sum": float(ref_sums.get(fid, 0.0)),
"mut_activation_sum": float(mut_sums.get(fid, 0.0)),
}
return {
"status": "success",
"data": {
"variant": f"{ref_aa}{position}{alt_aa}",
"position": position,
"window": window,
"n_unique_features_touched": len(all_feats),
"top_features_lost": [feat_row(f, d) for f, d in top_lost],
"top_features_gained": [feat_row(f, d) for f, d in top_gained],
},
"metadata": {
"method": (
"ESMC-6B SAE per-feature activation delta, summed over "
f"+/-{window} residue window centered on the mutation site"
),
"ref_residues_analyzed": ref_result["data"]["residues_returned"],
"mut_residues_analyzed": mut_result["data"]["residues_returned"],
"forge_calls_made": 2,
"license": (
"Outputs are governed by EvolutionaryScale Cambrian "
"Inference License — non-commercial use only"
),
},
}
# ------------------------------------------------------------------ #
# describe_sae_feature — on-demand SAE feature labeling
# ------------------------------------------------------------------ #
# Curated panel of well-annotated diverse human proteins. Used to label
# SAE features by aggregating which UniProt feature types the SAE feature
# activates on across the panel. Selected for category diversity:
# transcription factor, kinase, GTPase, serine protease, P450 enzyme,
# processed hormone, oxygen carrier, fibrinolytic protease, transport
# protein, kinase.
_SAE_LABELING_PANEL = [
"P04637", # TP53 — tumor suppressor / DNA binding
"P00533", # EGFR — receptor tyrosine kinase
"P01116", # KRAS — small GTPase
"P00734", # F2 thrombin — serine protease, catalytic triad
"P08684", # CYP3A4 — cytochrome P450, heme binding
"P01308", # INS — insulin (signal + disulfide + processing)
"P68871", # HBB — hemoglobin beta, oxygen / heme binding
"P00750", # PLAT/tPA — fibrinolytic protease
"P02768", # ALB — serum albumin, ligand binding
"P31749", # AKT1 — serine/threonine kinase
]
# Map raw UniProt feature.type strings to high-level interpretable
# categories. Types with value None are dropped from labeling counts
# (too generic or uninformative for variant interpretation).
_UNIPROT_TYPE_TO_CATEGORY = {
"Active site": "catalytic",
"Site": "catalytic",
"Binding site": "ligand-binding",
"Metal binding": "ligand-binding",
"DNA binding": "ligand-binding",
"Modified residue": "ptm",
"Cross-link": "ptm",
"Glycosylation": "ptm",
"Lipidation": "ptm",
"Domain": "domain",
"Motif": "motif",
"Repeat": "domain",
"Zinc finger": "domain",
"Disulfide bond": "structural-stability",
"Helix": "secondary-structure",
"Beta strand": "secondary-structure",
"Turn": "secondary-structure",
"Transmembrane": "transmembrane",
"Intramembrane": "transmembrane",
"Signal": "signal-peptide",
"Propeptide": "propeptide",
"Coiled coil": "structural-stability",
"Region": None,
"Compositional bias": None,
"Natural variant": None,
"Mutagenesis": None,
"Alternative sequence": None,
"Chain": None,
"Peptide": None,
"Initiator methionine": None,
}
[文档]
def _fetch_uniprot_entry(self, accession: str) -> Optional[Dict[str, Any]]:
"""Fetch the full UniProt entry (sequence + features) for an accession.
Returns None on network failure — caller will skip this protein.
"""
import urllib.request
import urllib.error
url = f"https://rest.uniprot.org/uniprotkb/{accession}.json"
try:
req = urllib.request.Request(
url, headers={"User-Agent": "tooluniverse/esm_tool"}
)
with urllib.request.urlopen(req, timeout=30) as resp:
import json as _json
return _json.loads(resp.read().decode("utf-8"))
except (urllib.error.URLError, urllib.error.HTTPError, ValueError):
return None
[文档]
def _uniprot_features_at_position(
self, features: List[Dict[str, Any]], position_1idx: int
) -> List[Dict[str, Any]]:
"""Return UniProt features whose annotated position/range covers
the 1-indexed residue position."""
hits = []
for f in features:
loc = f.get("location", {})
start = loc.get("start", {}).get("value")
end = loc.get("end", {}).get("value", start)
if start is None or end is None:
continue
if start <= position_1idx <= end:
hits.append(f)
return hits
[文档]
def _cache_path_for_feature(self, sae_model: str, feature_id: int):
"""Per-feature label cache path under ~/.cache/tooluniverse/."""
from pathlib import Path
import re
safe_model = re.sub(r"[^A-Za-z0-9._-]", "_", sae_model)
cache_dir = Path.home() / ".cache" / "tooluniverse" / "sae_labels" / safe_model
return cache_dir / f"feature_{feature_id}.json"
[文档]
def _describe_sae_feature(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""On-demand labeling for a single SAE feature_id.
Runs SAE inference across a curated panel of well-annotated human
proteins, finds where the target feature activates most strongly,
and aggregates the UniProt feature types at those positions to
infer a biological category. Caches results under
~/.cache/tooluniverse/sae_labels/{sae_model}/feature_{id}.json.
Cost: 1 Forge credit per protein in the panel (default 10, so 10
credits per first-time label). UniProt fetches are free.
Subsequent calls for the same feature_id hit the cache (free).
Categories returned: catalytic, ligand-binding, ptm, domain,
motif, structural-stability, secondary-structure, transmembrane,
signal-peptide, propeptide, or 'uncategorized'.
"""
feature_id = arguments.get("feature_id")
sae_model = arguments.get(
"sae_model", "esmc-6b-2024-12_k64_codebook16384_layer60"
)
model = arguments.get("model", "esmc-6b-2024-12")
n_proteins = arguments.get("n_proteins", 10)
top_residues_per_protein = arguments.get("top_residues_per_protein", 3)
use_cache = arguments.get("use_cache", True)
# Validate input
if feature_id is None or not isinstance(feature_id, int):
return {
"status": "error",
"error": "feature_id (int 0-16383) is required",
}
if feature_id < 0 or feature_id >= 16384:
return {
"status": "error",
"error": f"feature_id {feature_id} out of range [0, 16383]",
}
if n_proteins < 1 or n_proteins > len(self._SAE_LABELING_PANEL):
return {
"status": "error",
"error": (
f"n_proteins must be in [1, {len(self._SAE_LABELING_PANEL)}], "
f"got {n_proteins}"
),
}
# Cache check
cache_path = self._cache_path_for_feature(sae_model, feature_id)
if use_cache and cache_path.exists():
try:
import json as _json
cached = _json.loads(cache_path.read_text())
cached.setdefault("metadata", {})["from_cache"] = True
return cached
except Exception:
pass # fall through and recompute
# Run SAE labeling pipeline across the panel
panel = self._SAE_LABELING_PANEL[:n_proteins]
evidence: List[Dict[str, Any]] = []
for accession in panel:
entry = self._fetch_uniprot_entry(accession)
if entry is None:
continue
seq = entry.get("sequence", {}).get("value")
uniprot_features = entry.get("features", []) or []
if not seq:
continue
sae_response = self._get_sae_features(
{
"operation": "get_sae_features",
"sequence": seq,
"model": model,
"sae_model": sae_model,
"top_k_per_residue": 64,
}
)
if sae_response.get("status") != "success":
# Likely a panel protein too long, Forge errored, etc.; skip
continue
# Find residues where target feature_id activates
activating: List[Dict[str, Any]] = []
for residue in sae_response["data"]["activations"]:
for feat in residue["active_features"]:
if feat["feature_id"] == feature_id:
activating.append(
{
"position": residue["residue_idx_1based"],
"activation": feat["activation"],
}
)
break
if not activating:
continue
activating.sort(key=lambda x: -abs(x["activation"]))
top = activating[:top_residues_per_protein]
for hit in top:
pos = hit["position"]
overlapping = self._uniprot_features_at_position(uniprot_features, pos)
informative_types = [
f["type"]
for f in overlapping
if self._UNIPROT_TYPE_TO_CATEGORY.get(f["type"]) is not None
]
evidence.append(
{
"protein": accession,
"position_1based": pos,
"activation": float(hit["activation"]),
"uniprot_types": informative_types,
"uniprot_categories": [
self._UNIPROT_TYPE_TO_CATEGORY[t] for t in informative_types
],
}
)
# Aggregate categories across all evidence rows
category_counts: Dict[str, int] = {}
for e in evidence:
for cat in e["uniprot_categories"]:
category_counts[cat] = category_counts.get(cat, 0) + 1
if category_counts:
dominant = max(category_counts.items(), key=lambda x: x[1])
category = dominant[0]
total_votes = sum(category_counts.values())
confidence = dominant[1] / total_votes
else:
category = "uncategorized"
confidence = 0.0
result = {
"status": "success",
"data": {
"feature_id": feature_id,
"sae_model": sae_model,
"category": category,
"confidence": round(confidence, 3),
"n_proteins_with_activation": len(set(e["protein"] for e in evidence)),
"n_proteins_analyzed": len(panel),
"category_vote_counts": category_counts,
"supporting_evidence": evidence,
},
"metadata": {
"from_cache": False,
"method": (
"Aggregated UniProt feature-type overlap at SAE-activating "
"residues across a curated 10-protein panel"
),
"forge_credits_used_first_call": n_proteins,
"license": (
"Outputs are governed by EvolutionaryScale Cambrian "
"Inference License — non-commercial use only"
),
},
}
# Write cache (best-effort)
try:
import json as _json
cache_path.parent.mkdir(parents=True, exist_ok=True)
cache_path.write_text(_json.dumps(result, indent=2))
except Exception:
pass
return result
[文档]
@staticmethod
def _build_per_pos_map(
activations: List[Dict[str, Any]],
) -> Dict[int, Dict[int, float]]:
"""Index SAE activations as {residue_idx_1based: {feature_id: activation}}."""
return {
r["residue_idx_1based"]: {
f["feature_id"]: f["activation"] for f in r["active_features"]
}
for r in activations
}
[文档]
@staticmethod
def _validate_batch_variant(
v: Dict[str, Any], i: int, sequence: str, seq_len: int
) -> Optional[str]:
"""Return None if the variant dict is valid, else an error string."""
for key in ("position", "ref_aa", "alt_aa"):
if key not in v:
return f"variants[{i}] missing key {key!r}"
pos = v["position"]
if not isinstance(pos, int) or pos < 1 or pos > seq_len:
return f"variants[{i}] position {pos} out of range [1, {seq_len}]"
if not isinstance(v["ref_aa"], str) or len(v["ref_aa"]) != 1:
return f"variants[{i}] ref_aa must be a single letter"
if not isinstance(v["alt_aa"], str) or len(v["alt_aa"]) != 1:
return f"variants[{i}] alt_aa must be a single letter"
actual = sequence[pos - 1]
if actual != v["ref_aa"]:
return (
f"variants[{i}]: position {pos} in sequence is "
f"{actual!r}, but ref_aa was {v['ref_aa']!r}"
)
return None
# ------------------------------------------------------------------ #
# score_variant_sae_batch — N variants, N+1 Forge calls (not 2N)
# ------------------------------------------------------------------ #
[文档]
def _score_variant_sae_batch(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Score many missense variants against one reference sequence.
Standard score_variant_sae_disruption makes 2 Forge calls per variant
(ref + mut). For N variants on the same reference this tool runs the
reference SAE once and one mutant SAE per variant — N+1 calls total.
Use for saturation mutagenesis (all 19 alts at one position),
DMS-style sweeps, or scoring a clinical-variant panel.
"""
sequence = arguments.get("sequence")
variants = arguments.get("variants")
window = arguments.get("window", 8)
top_k_features = arguments.get("top_k_features", 10)
model = arguments.get("model", "esmc-6b-2024-12")
sae_model = arguments.get(
"sae_model", "esmc-6b-2024-12_k64_codebook16384_layer60"
)
if not sequence:
return {"status": "error", "error": "sequence is required"}
if not variants or not isinstance(variants, list):
return {
"status": "error",
"error": "variants must be a non-empty list of {position, ref_aa, alt_aa}",
}
# Cap batch size — Forge cost scales linearly and a runaway list
# of 1000 variants would silently burn the user's credit budget.
MAX_VARIANTS = 100
if len(variants) > MAX_VARIANTS:
return {
"status": "error",
"error": (
f"too many variants ({len(variants)}); cap is {MAX_VARIANTS} "
f"per call. Split into multiple calls."
),
}
seq_len = len(sequence)
for i, v in enumerate(variants):
err = self._validate_batch_variant(v, i, sequence, seq_len)
if err is not None:
return {"status": "error", "error": err}
# One reference SAE call over the full sequence — reused for every
# variant's per-window delta computation.
ref_full = self._get_sae_features(
{
"operation": "get_sae_features",
"sequence": sequence,
"model": model,
"sae_model": sae_model,
"top_k_per_residue": 64,
}
)
if ref_full["status"] != "success":
return {**ref_full, "stage": "ref_sae"}
ref_by_pos = self._build_per_pos_map(ref_full["data"]["activations"])
def sum_in_window(per_pos_map, pos, window):
lo = max(1, pos - window)
hi = min(seq_len, pos + window)
sums: Dict[int, float] = {}
for p in range(lo, hi + 1):
for fid, act in per_pos_map.get(p, {}).items():
sums[fid] = sums.get(fid, 0.0) + act
return sums
per_variant: List[Dict[str, Any]] = []
forge_calls = 1 # the ref_full call above
for v in variants:
pos = v["position"]
mutant = sequence[: pos - 1] + v["alt_aa"] + sequence[pos:]
variant_label = f"{v['ref_aa']}{pos}{v['alt_aa']}"
mut_result = self._get_sae_features(
{
"operation": "get_sae_features",
"sequence": mutant,
"model": model,
"sae_model": sae_model,
"position": pos,
"window": window,
"top_k_per_residue": 64,
}
)
forge_calls += 1
if mut_result["status"] != "success":
per_variant.append(
{
"variant": variant_label,
"status": "error",
"error": mut_result.get("error", "mut SAE failed"),
}
)
continue
mut_by_pos = self._build_per_pos_map(mut_result["data"]["activations"])
ref_sums = sum_in_window(ref_by_pos, pos, window)
mut_sums = sum_in_window(mut_by_pos, pos, window)
all_feats = set(ref_sums) | set(mut_sums)
deltas = [
(f, mut_sums.get(f, 0.0) - ref_sums.get(f, 0.0)) for f in all_feats
]
top_lost = sorted(deltas, key=lambda x: x[1])[:top_k_features]
top_gained = sorted(deltas, key=lambda x: -x[1])[:top_k_features]
def feat_row(fid, delta):
return {
"feature_id": int(fid),
"delta": float(delta),
"ref_activation_sum": float(ref_sums.get(fid, 0.0)),
"mut_activation_sum": float(mut_sums.get(fid, 0.0)),
}
per_variant.append(
{
"variant": variant_label,
"status": "success",
"position": pos,
"n_unique_features_touched": len(all_feats),
"top_features_lost": [feat_row(f, d) for f, d in top_lost],
"top_features_gained": [feat_row(f, d) for f, d in top_gained],
}
)
return {
"status": "success",
"data": {
"sequence_length": seq_len,
"n_variants": len(variants),
"n_succeeded": sum(1 for v in per_variant if v["status"] == "success"),
"window": window,
"results": per_variant,
},
"metadata": {
"method": (
f"ESMC-6B SAE per-feature activation delta, summed over "
f"+/-{window} residue window. Reference SAE computed once "
f"and reused across all variants."
),
"forge_calls_made": forge_calls,
"forge_calls_saved_vs_per_variant": len(variants) - 1,
"license": (
"Outputs are governed by EvolutionaryScale Cambrian "
"Inference License — non-commercial use only"
),
},
}
# ------------------------------------------------------------------ #
# get_region_sae_features — domain/epitope-level feature signature
# ------------------------------------------------------------------ #
[文档]
def _get_region_sae_features(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Aggregate SAE features over a residue range.
Characterizes a contiguous region (a domain, epitope, binding pocket,
signal peptide, etc.) by its dominant SAE features. Each returned
feature has its total |activation| over the region, mean activation,
and which residues in the region activate it. Feed feature_ids into
ESM_describe_sae_feature to get biological category labels.
"""
sequence = arguments.get("sequence")
start = arguments.get("start_position")
end = arguments.get("end_position")
top_k_features = arguments.get("top_k_features", 20)
model = arguments.get("model", "esmc-6b-2024-12")
sae_model = arguments.get(
"sae_model", "esmc-6b-2024-12_k64_codebook16384_layer60"
)
if not sequence:
return {"status": "error", "error": "sequence is required"}
if not isinstance(start, int) or not isinstance(end, int):
return {
"status": "error",
"error": "start_position and end_position must be integers (1-indexed)",
}
seq_len = len(sequence)
if start < 1 or end > seq_len or start > end:
return {
"status": "error",
"error": (
f"region [{start}, {end}] out of range [1, {seq_len}] "
f"or start > end"
),
}
region_len = end - start + 1
# Region cap — beyond this, top-K aggregation becomes a coarse summary
# and the user is better off making multiple smaller calls.
MAX_REGION_LEN = 500
if region_len > MAX_REGION_LEN:
return {
"status": "error",
"error": (
f"region length {region_len} exceeds {MAX_REGION_LEN}; "
f"split into smaller windows for meaningful aggregation"
),
}
sae_result = self._get_sae_features(
{
"operation": "get_sae_features",
"sequence": sequence,
"model": model,
"sae_model": sae_model,
"top_k_per_residue": 64,
}
)
if sae_result["status"] != "success":
return sae_result
abs_sums: Dict[int, float] = {}
signed_sums: Dict[int, float] = {}
hit_residues: Dict[int, List[int]] = {}
for residue in sae_result["data"]["activations"]:
pos = residue["residue_idx_1based"]
if pos < start or pos > end:
continue
for feat in residue["active_features"]:
fid = feat["feature_id"]
act = feat["activation"]
abs_sums[fid] = abs_sums.get(fid, 0.0) + abs(act)
signed_sums[fid] = signed_sums.get(fid, 0.0) + act
hit_residues.setdefault(fid, []).append(pos)
ranked = sorted(abs_sums.items(), key=lambda x: -x[1])[:top_k_features]
features_out: List[Dict[str, Any]] = []
for fid, total_abs in ranked:
positions = hit_residues[fid]
features_out.append(
{
"feature_id": int(fid),
"total_abs_activation": float(total_abs),
"mean_activation": float(signed_sums[fid] / len(positions)),
"n_residues_active": len(positions),
"fraction_residues_active": round(len(positions) / region_len, 3),
"active_positions": sorted(positions),
}
)
return {
"status": "success",
"data": {
"sequence_length": seq_len,
"region": [start, end],
"region_length": region_len,
"n_features_active_in_region": len(abs_sums),
"top_features": features_out,
},
"metadata": {
"method": (
"SAE features summed over a residue range, ranked by "
"total |activation|. Feed top feature_ids into "
"ESM_describe_sae_feature for biological category labels."
),
"forge_calls_made": 1,
"license": (
"Outputs are governed by EvolutionaryScale Cambrian "
"Inference License — non-commercial use only"
),
},
}
# ------------------------------------------------------------------ #
# explain_variant_mechanism — disruption + describe + summary in one call
# ------------------------------------------------------------------ #
[文档]
def _explain_variant_mechanism(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Composite tool: variant disruption + describe top affected features.
Runs score_variant_sae_disruption to find the most-changed features,
then describe_sae_feature on each (cached after first call) to get
biological category labels, then composes a 1-line mechanism summary
(e.g. "Disrupted feature categories (lost): catalytic=2, ligand-binding=1").
Use when the calling skill wants mechanism in one call rather than
orchestrating two tool calls and parsing both results.
"""
sequence = arguments.get("sequence")
position = arguments.get("position")
ref_aa = arguments.get("ref_aa")
alt_aa = arguments.get("alt_aa")
window = arguments.get("window", 8)
top_k_features = arguments.get("top_k_features", 5)
include_descriptions = arguments.get("include_descriptions", True)
model = arguments.get("model", "esmc-6b-2024-12")
sae_model = arguments.get(
"sae_model", "esmc-6b-2024-12_k64_codebook16384_layer60"
)
# Delegate input validation to disruption sub-call
disruption = self._score_variant_sae_disruption(
{
"operation": "score_variant_sae_disruption",
"sequence": sequence,
"position": position,
"ref_aa": ref_aa,
"alt_aa": alt_aa,
"window": window,
"top_k_features": top_k_features,
"model": model,
"sae_model": sae_model,
}
)
if disruption["status"] != "success":
return disruption
top_lost = disruption["data"]["top_features_lost"]
top_gained = disruption["data"]["top_features_gained"]
described_lost: List[Dict[str, Any]] = []
described_gained: List[Dict[str, Any]] = []
describe_calls = 0
describe_credits = 0
if include_descriptions:
label_cache: Dict[int, Dict[str, Any]] = {}
def label_for(fid: int) -> Dict[str, Any]:
nonlocal describe_calls, describe_credits
if fid in label_cache:
return label_cache[fid]
result = self._describe_sae_feature(
{
"operation": "describe_sae_feature",
"feature_id": fid,
"sae_model": sae_model,
"model": model,
}
)
describe_calls += 1
meta = result.get("metadata", {}) if isinstance(result, dict) else {}
if meta.get("from_cache") is False:
describe_credits += meta.get("forge_credits_used_first_call", 0)
if result.get("status") == "success":
cat_label = {
"category": result["data"]["category"],
"confidence": result["data"]["confidence"],
}
else:
cat_label = {"category": "unknown", "confidence": 0.0}
label_cache[fid] = cat_label
return cat_label
for feat in top_lost:
described_lost.append({**feat, **label_for(feat["feature_id"])})
for feat in top_gained:
described_gained.append({**feat, **label_for(feat["feature_id"])})
else:
described_lost = list(top_lost)
described_gained = list(top_gained)
def categorize(items: List[Dict[str, Any]]) -> List[tuple]:
cats: Dict[str, int] = {}
for it in items:
cat = it.get("category", "unknown")
cats[cat] = cats.get(cat, 0) + 1
return sorted(cats.items(), key=lambda x: -x[1])
lost_cats = categorize(described_lost)
gained_cats = categorize(described_gained)
if not include_descriptions:
summary = (
"Descriptions skipped (include_descriptions=False); "
"see top_features_lost/gained for raw feature_ids."
)
else:
summary_parts: List[str] = []
if lost_cats:
summary_parts.append(
"Disrupted feature categories (lost): "
+ ", ".join(f"{c}={n}" for c, n in lost_cats[:3])
)
if gained_cats:
summary_parts.append(
"Induced feature categories (gained): "
+ ", ".join(f"{c}={n}" for c, n in gained_cats[:3])
)
summary = (
"; ".join(summary_parts)
if summary_parts
else "No interpretable feature changes detected."
)
return {
"status": "success",
"data": {
"variant": disruption["data"]["variant"],
"position": position,
"window": window,
"mechanism_summary": summary,
"lost_feature_categories": dict(lost_cats),
"gained_feature_categories": dict(gained_cats),
"top_features_lost": described_lost,
"top_features_gained": described_gained,
},
"metadata": {
"method": (
"ESMC-6B SAE variant disruption + per-feature biological "
"labeling, composed into a 1-line mechanism category summary."
),
"disruption_forge_calls": 2,
"describe_feature_calls": describe_calls,
"describe_forge_credits_used": describe_credits,
"license": (
"Outputs are governed by EvolutionaryScale Cambrian "
"Inference License — non-commercial use only"
),
},
}