tooluniverse.structure_annotation_tool 源代码

"""
Structure Annotation Tool

Per-residue structural annotation from a PDB structure:
binding interface, ligand pocket, core/surface, secondary structure.

Methodology adapted from an upstream research workflow — Requires:
  pip install biopython freesasa

For secondary structure, an optional companion PDBe SS lookup is supported via
the `include_secondary_structure` flag — uses PDBe REST and does NOT need DSSP
binary.
"""

import os
import tempfile
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple

import requests

from .base_tool import BaseTool
from .tool_registry import register_tool


_RCSB_PDB_URL = "https://files.rcsb.org/download/{pdb_id}.pdb"
_PDBE_SS_URL = "https://www.ebi.ac.uk/pdbe/api/pdb/entry/secondary_structure/{pdb_id}"
_BACKBONE_ATOMS: Set[str] = {"N", "CA", "C", "O", "OXT"}

# Three-letter to one-letter mapping for the 20 standard AAs
_THREE_TO_ONE: Dict[str, str] = {
    "ALA": "A",
    "ARG": "R",
    "ASN": "N",
    "ASP": "D",
    "CYS": "C",
    "GLN": "Q",
    "GLU": "E",
    "GLY": "G",
    "HIS": "H",
    "ILE": "I",
    "LEU": "L",
    "LYS": "K",
    "MET": "M",
    "PHE": "F",
    "PRO": "P",
    "SER": "S",
    "THR": "T",
    "TRP": "W",
    "TYR": "Y",
    "VAL": "V",
}


def _fetch_pdb(pdb_id: str) -> str:
    """Fetch raw PDB text from RCSB."""
    url = _RCSB_PDB_URL.format(pdb_id=pdb_id.lower())
    resp = requests.get(url, timeout=30)
    resp.raise_for_status()
    return resp.text


def _sidechain_heavy_coords(residue) -> "Any":
    """Side-chain heavy-atom coords; glycine (no side chain) -> Ca.

    Matches the paper's scHA (side-chain heavy-atom) selection.
    """
    import numpy as np

    sc = [
        a for a in residue if a.element != "H" and a.get_name() not in _BACKBONE_ATOMS
    ]
    if not sc:
        sc = [a for a in residue if a.get_name() == "CA"]
    if not sc:
        return np.empty((0, 3))
    return np.array([a.coord for a in sc])


def _all_heavy_coords(residue) -> "Any":
    """All heavy atoms (ligands have no canonical 'side chain')."""
    import numpy as np

    atoms = [a for a in residue if a.element != "H"]
    if not atoms:
        return np.empty((0, 3))
    return np.array([a.coord for a in atoms])


def _min_distance(coords_a, coords_b) -> float:
    """Minimum pairwise distance between two coord arrays."""
    import numpy as np

    if len(coords_a) == 0 or len(coords_b) == 0:
        return float("inf")
    diffs = coords_a[:, None, :] - coords_b[None, :, :]
    return float(np.min(np.linalg.norm(diffs, axis=-1)))


def _compute_rsa(
    structure,
    target_chain: str,
    keep_residue,
) -> Dict[int, float]:
    """Compute relative SASA per target-chain residue on the isolated chain.

    Critical: SASA must come from the ISOLATED target chain — partner chains
    bury interface residues and bias the result.
    Critical: RSA must be self-consistent — let freesasa compute both the SASA
    and the fully-exposed reference (residueAreas().relativeTotal does this
    internally).
    """
    try:
        from Bio.PDB import PDBIO, Select
        import freesasa
    except ImportError as exc:
        raise ImportError(
            "biopython and freesasa are required for RSA computation. "
            "Install with: pip install biopython freesasa"
        ) from exc

    class _ChainOnly(Select):
        def accept_chain(self, c):  # noqa: D401
            return c.id == target_chain

        def accept_residue(self, r):  # noqa: D401
            return keep_residue(r)

    tmp = tempfile.NamedTemporaryFile(suffix=".pdb", delete=False)
    tmp.close()
    try:
        io = PDBIO()
        io.set_structure(structure)
        io.save(tmp.name, _ChainOnly())
        fs_struct = freesasa.Structure(tmp.name)
        fs_result = freesasa.calc(fs_struct)
        areas = fs_result.residueAreas().get(target_chain, {})
        return {int(k): float(v.relativeTotal) for k, v in areas.items()}
    finally:
        try:
            os.unlink(tmp.name)
        except OSError:
            pass


def _fetch_pdbe_secondary_structure(pdb_id: str, target_chain: str) -> Dict[int, str]:
    """Lookup per-residue SS for a chain via PDBe REST.

    Returns {residue_number: ss_element} where ss_element is one of
    {"helix", "strand", "coil"}. Residues not annotated default to "coil"
    when read by the caller.
    """
    url = _PDBE_SS_URL.format(pdb_id=pdb_id.lower())
    try:
        resp = requests.get(url, timeout=30)
        resp.raise_for_status()
        payload = resp.json()
    except (requests.RequestException, ValueError):
        return {}

    entry = payload.get(pdb_id.lower(), {})
    molecules = entry.get("molecules", [])
    ss_by_pos: Dict[int, str] = {}
    for molecule in molecules:
        for chain in molecule.get("chains", []):
            if chain.get("chain_id") != target_chain:
                continue
            ss_data = chain.get("secondary_structure", {})
            for kind, ranges in ss_data.items():
                # kind is "helices" or "strands"
                label = (
                    "helix"
                    if "helic" in kind
                    else "strand"
                    if "strand" in kind
                    else kind
                )
                for rng in ranges:
                    start = rng.get("start", {}).get("residue_number")
                    end = rng.get("end", {}).get("residue_number")
                    if start is None or end is None:
                        continue
                    for pos in range(int(start), int(end) + 1):
                        ss_by_pos[pos] = label
    return ss_by_pos


[文档] @register_tool("StructureAnnotationTool") class StructureAnnotationTool(BaseTool): """Per-residue structural annotation from a PDB. For each residue of the target chain, classify: - binding interface : min scHA distance to partner chain(s) < cutoff - ligand pocket : min scHA distance to ligand heavy atoms < cutoff - core vs surface : relative SASA < core_rsa_cutoff - region label : {interface, ligand, both, other} - secondary structure (optional, from PDBe) """
[文档] def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: try: operation = arguments.get("operation", "annotate_per_residue") if operation == "annotate_per_residue": return self._annotate_per_residue(arguments) return { "status": "error", "error": f"Unknown operation: {operation!r}. Valid: annotate_per_residue", } except ImportError as exc: return {"status": "error", "error": str(exc)} except requests.RequestException as exc: return {"status": "error", "error": f"PDB fetch failed: {exc}"} except Exception as exc: # noqa: BLE001 return {"status": "error", "error": str(exc)}
# ------------------------------------------------------------------ # # annotate_per_residue # ------------------------------------------------------------------ #
[文档] def _annotate_per_residue(self, arguments: Dict[str, Any]) -> Dict[str, Any]: pdb_id: Optional[str] = arguments.get("pdb_id") pdb_content: Optional[str] = arguments.get("pdb_content") if not pdb_id and not pdb_content: return { "status": "error", "error": "Either pdb_id or pdb_content must be provided", } target_chain: str = arguments.get("target_chain", "A") partner_chains: List[str] = list(arguments.get("partner_chains", []) or []) ligand_resnames: List[str] = [ r.strip().upper() for r in arguments.get("ligand_resnames", []) or [] ] distance_cutoff: float = float(arguments.get("distance_cutoff", 5.0)) core_rsa_cutoff: float = float(arguments.get("core_rsa_cutoff", 0.25)) include_ss: bool = bool(arguments.get("include_secondary_structure", False)) try: from Bio.PDB import PDBParser except ImportError: return { "status": "error", "error": ( "biopython is required. Install with: pip install biopython freesasa" ), } # Load structure (download if needed) if pdb_content is None: pdb_content = _fetch_pdb(pdb_id) tmp = tempfile.NamedTemporaryFile(suffix=".pdb", delete=False, mode="w") try: tmp.write(pdb_content) tmp.close() parser = PDBParser(QUIET=True) structure = parser.get_structure(pdb_id or "input", tmp.name) finally: try: os.unlink(tmp.name) except OSError: pass model = next(iter(structure)) if target_chain not in {c.id for c in model}: return { "status": "error", "error": ( f"target_chain {target_chain!r} not in structure. " f"Available: {sorted(c.id for c in model)}" ), } chain_target = model[target_chain] # Per-residue side-chain heavy coords for the target chain target_residues: Dict[int, Tuple[str, Any]] = {} for residue in chain_target: hetflag = residue.id[0] if hetflag.strip(): # skip HETATM and waters in the target chain continue resnum = residue.id[1] aa = _THREE_TO_ONE.get(residue.resname.strip(), "X") target_residues[resnum] = (aa, _sidechain_heavy_coords(residue)) # Pool partner-chain scHA coords partner_coords = self._gather_partner_coords( model, partner_chains, target_chain ) # Pool ligand all-heavy coords (scan whole model, not just one chain) ligand_coords = self._gather_ligand_coords(model, set(ligand_resnames)) # Compute RSA on the isolated target chain try: rsa_map = _compute_rsa( structure, target_chain, keep_residue=lambda r: not r.id[0].strip(), ) except ImportError as exc: return {"status": "error", "error": str(exc)} # Optional secondary structure from PDBe (only meaningful if pdb_id provided) ss_map: Dict[int, str] = {} if include_ss and pdb_id: ss_map = _fetch_pdbe_secondary_structure(pdb_id, target_chain) # Assemble per-residue rows rows: List[Dict[str, Any]] = [] for resnum in sorted(target_residues): aa, sc_coords = target_residues[resnum] d_partner = _min_distance(sc_coords, partner_coords) d_ligand = _min_distance(sc_coords, ligand_coords) is_interface = d_partner < distance_cutoff is_ligand = d_ligand < distance_cutoff if is_interface and is_ligand: region = "both" elif is_interface: region = "interface" elif is_ligand: region = "ligand" else: region = "other" rsa = rsa_map.get(resnum) is_core = rsa is not None and rsa < core_rsa_cutoff row: Dict[str, Any] = { "position": resnum, "aa": aa, "dist_partner": ( None if d_partner == float("inf") else round(d_partner, 3) ), "dist_ligand": ( None if d_ligand == float("inf") else round(d_ligand, 3) ), "rsa": (None if rsa is None else round(rsa, 3)), "region": region, "is_core": is_core, } if include_ss: row["ss_element"] = ss_map.get(resnum, "coil") rows.append(row) return { "status": "success", "pdb_id": pdb_id, "target_chain": target_chain, "partner_chains": partner_chains, "ligand_resnames": ligand_resnames, "distance_cutoff": distance_cutoff, "core_rsa_cutoff": core_rsa_cutoff, "n_residues": len(rows), "annotations": rows, "method": { "interface_metric": "sidechain_heavy_atom_min_distance", "ligand_metric": "sidechain_to_all_heavy_atom_min_distance", "rsa_source": "freesasa.residueAreas.relativeTotal (isolated target chain)", "ss_source": ("pdbe_rest" if include_ss else None), }, "provenance": ( "Methodology adapted from an upstream research workflow — " "the original script" ), }
[文档] @staticmethod def _gather_partner_coords(model, partner_chains: Iterable[str], target_chain: str): import numpy as np partner_chains = [c for c in partner_chains if c != target_chain] if not partner_chains: return np.empty((0, 3)) arrays = [] chain_ids = {c.id for c in model} for chain_id in partner_chains: if chain_id not in chain_ids: continue for residue in model[chain_id]: if residue.id[0].strip(): continue arrays.append(_sidechain_heavy_coords(residue)) if not arrays: return np.empty((0, 3)) return np.vstack(arrays)
[文档] @staticmethod def _gather_ligand_coords(model, ligand_resnames: Set[str]): import numpy as np if not ligand_resnames: return np.empty((0, 3)) arrays = [] for chain in model: for residue in chain: if residue.resname.strip().upper() not in ligand_resnames: continue arrays.append(_all_heavy_coords(residue)) if not arrays: return np.empty((0, 3)) return np.vstack(arrays)