Source code for tooluniverse.esm_tool

"""
ESM3 / ESMC Tool

Provides access to EvolutionaryScale protein language models:
  - ESMC (300m/600m): fast protein sequence embeddings
  - ESM3 (open/small): sequence generation, structure prediction, sequence scoring

Requires an API token from https://forge.evolutionaryscale.ai
Set via environment variable ESM_API_KEY.

Install: pip install esm
"""

import os
from typing import Dict, Any, List, Optional
from .base_tool import BaseTool
from .tool_registry import register_tool


def _get_client(model: str):
    """Return an ESM3ForgeInferenceClient for the given model, using ESM_API_KEY."""
    try:
        from esm.sdk.forge import ESM3ForgeInferenceClient
    except ImportError:
        raise ImportError(
            "esm package is required. Install with: pip install esm"
        )
    token = os.environ.get("ESM_API_KEY", "")
    if not token:
        raise EnvironmentError(
            "ESM_API_KEY environment variable is not set. "
            "Obtain a token at https://forge.evolutionaryscale.ai"
        )
    return ESM3ForgeInferenceClient(model=model, token=token)


[docs] @register_tool("ESMTool") class ESMTool(BaseTool): """ ESM3 / ESMC tool for protein sequence embeddings, generation, structure prediction, and sequence scoring via the EvolutionaryScale Forge API. """
[docs] def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: """Execute the tool with given arguments.""" try: operation = arguments.get("operation", "") if operation == "get_protein_embedding": return self._get_protein_embedding(arguments) elif operation == "generate_protein_sequence": return self._generate_protein_sequence(arguments) elif operation == "fold_protein": return self._fold_protein(arguments) elif operation == "score_sequence": return self._score_sequence(arguments) else: return { "status": "error", "error": f"Unknown operation: {operation!r}. Valid operations: " "get_protein_embedding, generate_protein_sequence, " "fold_protein, score_sequence", } except ImportError as e: return {"status": "error", "error": str(e)} except EnvironmentError as e: return {"status": "error", "error": str(e)} except Exception as e: return {"status": "error", "error": str(e)}
# ------------------------------------------------------------------ # # get_protein_embedding # ------------------------------------------------------------------ #
[docs] def _get_protein_embedding(self, arguments: Dict[str, Any]) -> Dict[str, Any]: """Get ESMC per-residue and mean-pooled embeddings for a protein sequence.""" try: from esm.sdk.api import ESMProtein, LogitsConfig sequence = arguments.get("sequence") if not sequence: return { "status": "error", "error": "sequence is required for get_protein_embedding", } model = arguments.get("model", "esmc-300m-2024-12") return_per_residue = arguments.get("return_per_residue", False) client = _get_client(model) protein = ESMProtein(sequence=sequence) logits_output = client.logits( protein, LogitsConfig(sequence=True, return_embeddings=True) ) if logits_output.embeddings is None: return { "status": "error", "error": "Model did not return embeddings. Ensure LogitsConfig(return_embeddings=True) is supported.", } embeddings = logits_output.embeddings # shape: (L+2, D) including BOS/EOS # mean pool over residue positions (exclude BOS/EOS tokens) import numpy as np emb_np = embeddings.detach().cpu().numpy() if hasattr(embeddings, "detach") else embeddings mean_emb = emb_np[1:-1].mean(axis=0).tolist() result = { "status": "success", "model": model, "sequence_length": len(sequence), "embedding_dim": len(mean_emb), "mean_embedding": mean_emb, } if return_per_residue: result["per_residue_embeddings"] = emb_np[1:-1].tolist() return result except Exception as e: return {"status": "error", "error": str(e)}
# ------------------------------------------------------------------ # # generate_protein_sequence # ------------------------------------------------------------------ #
[docs] def _generate_protein_sequence(self, arguments: Dict[str, Any]) -> Dict[str, Any]: """Generate or complete a protein sequence using ESM3.""" try: from esm.sdk.api import ESMProtein, GenerationConfig prompt_sequence = arguments.get("prompt_sequence") if not prompt_sequence: return { "status": "error", "error": "prompt_sequence is required. Use '_' characters to denote masked positions to generate.", } model = arguments.get("model", "esm3-open-2024-03") num_steps = int(arguments.get("num_steps", 8)) temperature = float(arguments.get("temperature", 1.0)) client = _get_client(model) protein = ESMProtein(sequence=prompt_sequence) config = GenerationConfig( track="sequence", num_steps=num_steps, temperature=temperature, ) result_protein = client.generate(protein, config) generated_seq = result_protein.sequence if result_protein.sequence else "" return { "status": "success", "model": model, "prompt_sequence": prompt_sequence, "generated_sequence": generated_seq, "sequence_length": len(generated_seq), "num_steps": num_steps, "temperature": temperature, } except Exception as e: return {"status": "error", "error": str(e)}
# ------------------------------------------------------------------ # # fold_protein # ------------------------------------------------------------------ #
[docs] def _fold_protein(self, arguments: Dict[str, Any]) -> Dict[str, Any]: """Predict protein structure with ESM3, returning pTM score and coordinates.""" try: from esm.sdk.api import ESMProtein, GenerationConfig sequence = arguments.get("sequence") if not sequence: return { "status": "error", "error": "sequence is required for fold_protein", } model = arguments.get("model", "esm3-open-2024-03") num_steps = int(arguments.get("num_steps", 8)) client = _get_client(model) protein = ESMProtein(sequence=sequence) config = GenerationConfig(track="structure", num_steps=num_steps) result_protein = client.generate(protein, config) coordinates = None if result_protein.coordinates is not None: coords = result_protein.coordinates if hasattr(coords, "tolist"): coordinates = coords.tolist() else: coordinates = coords ptm = None if hasattr(result_protein, "ptm") and result_protein.ptm is not None: ptm = float(result_protein.ptm) plddt = None if hasattr(result_protein, "plddt") and result_protein.plddt is not None: p = result_protein.plddt if hasattr(p, "tolist"): plddt = p.tolist() else: plddt = p return { "status": "success", "model": model, "sequence": sequence, "sequence_length": len(sequence), "pTM_score": ptm, "plddt_per_residue": plddt, "coordinates_shape": ( [len(coordinates), len(coordinates[0]), len(coordinates[0][0])] if coordinates is not None else None ), "num_steps": num_steps, "note": "Coordinates are (L, 37, 3) backbone atom positions in Angstroms.", } except Exception as e: return {"status": "error", "error": str(e)}
# ------------------------------------------------------------------ # # score_sequence # ------------------------------------------------------------------ #
[docs] def _score_sequence(self, arguments: Dict[str, Any]) -> Dict[str, Any]: """Score a protein sequence using ESMC logits (per-residue log-probabilities).""" try: from esm.sdk.api import ESMProtein, LogitsConfig import math sequence = arguments.get("sequence") if not sequence: return { "status": "error", "error": "sequence is required for score_sequence", } model = arguments.get("model", "esmc-300m-2024-12") client = _get_client(model) protein = ESMProtein(sequence=sequence) logits_output = client.logits( protein, LogitsConfig(sequence=True, return_embeddings=False) ) if logits_output.logits is None or logits_output.logits.sequence is None: return { "status": "error", "error": "Model did not return sequence logits.", } # Compute mean log-prob (pseudo-likelihood) per residue import torch import torch.nn.functional as F seq_logits = logits_output.logits.sequence # (L+2, vocab) log_probs = F.log_softmax(seq_logits, dim=-1) # ESM tokenizer: map each residue to its token id try: from esm.utils.constants.esm3 import SEQUENCE_VOCAB aa_to_idx = {aa: i for i, aa in enumerate(SEQUENCE_VOCAB)} except Exception: aa_to_idx = {} per_residue_logprobs = [] if aa_to_idx: for i, aa in enumerate(sequence): token_id = aa_to_idx.get(aa) if token_id is not None: lp = log_probs[i + 1, token_id].item() per_residue_logprobs.append(lp) mean_logprob = ( sum(per_residue_logprobs) / len(per_residue_logprobs) if per_residue_logprobs else None ) return { "status": "success", "model": model, "sequence": sequence, "sequence_length": len(sequence), "mean_log_probability": mean_logprob, "per_residue_log_probabilities": per_residue_logprobs, } except Exception as e: return {"status": "error", "error": str(e)}