Source code for tooluniverse.tdc_oracle_tool

"""TDC Oracle tool — local ML property/bioactivity scoring via PyTDC.

Wraps Therapeutics Data Commons (TDC) pretrained ``Oracle`` scorers. Given a
SMILES string (or list of SMILES) and an oracle name, returns molecular
property / drug-likeness / bioactivity scores computed locally.

This is distinct from the existing RDKit/ADMET tools: TDC oracles include
ML-trained bioactivity classifiers (GSK3B, JNK3, DRD2) and the standard
medicinal-chemistry scorers (QED, SA, LogP) used as goal-directed molecular
optimization objectives.

Notes
-----
- ``QED`` and ``LogP`` are pure RDKit and run fully offline/instantly.
- ``SA`` downloads a tiny fragment-score table on first use, then runs locally.
- ML bioactivity oracles (``GSK3B``, ``JNK3``, ``DRD2``) download a pretrained
  model file (a few MB) on first use, then score locally.

PyTDC (1.1.x) ships a legacy ``from rdkit.six import iteritems`` import that
fails on modern RDKit (``rdkit.six`` was removed). That import sits behind a
bare ``except`` that re-raises a misleading "install rdkit" error, even though
the actual QED/SA/LogP scoring functions never touch it. We install a tiny
``rdkit.six`` shim before importing tdc so the real oracles work on current
RDKit. The shim is a no-op for environments where ``rdkit.six`` already exists.
"""

from .base_tool import BaseTool
from .tool_registry import register_tool


def _install_rdkit_six_shim():
    """Provide a minimal ``rdkit.six`` so PyTDC's legacy import succeeds.

    ``rdkit.six`` was removed from modern RDKit. PyTDC still imports
    ``iteritems`` from it inside a try/except that masks the real error.
    Only the symbol PyTDC references (``iteritems``) is provided.
    """
    try:
        import rdkit.six  # noqa: F401  -- already present, nothing to do

        return
    except Exception:
        pass

    try:
        import sys
        import types

        import rdkit  # noqa: F401  -- ensure base package exists first
    except Exception:
        # RDKit itself is missing; let the real import error surface later.
        return

    six_mod = types.ModuleType("rdkit.six")
    six_mod.iteritems = lambda d: iter(d.items())
    six_mod.itervalues = lambda d: iter(d.values())
    six_mod.iterkeys = lambda d: iter(d.keys())
    sys.modules["rdkit.six"] = six_mod
    try:
        rdkit.six = six_mod
    except Exception:
        pass


# Attempt the optional import at module load so missing-dependency handling is
# a clean error rather than an exception. Mirrors the framework's optional-dep
# pattern (try/except ImportError -> AVAILABLE flag).
TDC_AVAILABLE = False
_IMPORT_ERROR = None
try:
    _install_rdkit_six_shim()
    from tdc import Oracle as _TDCOracle  # noqa: E402

    TDC_AVAILABLE = True
except Exception as exc:  # ImportError or downstream rdkit-guard ImportError
    _TDCOracle = None
    _IMPORT_ERROR = str(exc)


# Oracles that are fast and either fully offline or download only a tiny table.
_FAST_ORACLES = {"QED", "SA", "LOGP"}

# Curated list of supported oracle names with one-line descriptions. TDC also
# supports many more; these are the stable, single-SMILES scorers most useful
# for property prediction and goal-directed optimization.
_SUPPORTED_ORACLES = {
    "QED": "Quantitative Estimate of Drug-likeness (0-1, higher = more drug-like). RDKit, offline.",
    "SA": "Synthetic Accessibility score (1=easy to 10=hard to synthesize). Downloads a small table once.",
    "LogP": "Octanol-water partition coefficient (lipophilicity). RDKit, offline.",
    "GSK3B": "ML bioactivity oracle: probability of GSK3-beta inhibition (0-1). Downloads model once.",
    "JNK3": "ML bioactivity oracle: probability of JNK3 inhibition (0-1). Downloads model once.",
    "DRD2": "ML bioactivity oracle: probability of DRD2 activity (0-1). Downloads model once.",
}


[docs] @register_tool("TDCOracleTool") class TDCOracleTool(BaseTool): """Score SMILES with a Therapeutics Data Commons pretrained oracle. Parameters (in ``arguments``) ----------------------------- smiles : str | list[str] A single SMILES string or a list of SMILES strings to score. oracle : str Oracle name. One of: QED, SA, LogP, GSK3B, JNK3, DRD2 (case-insensitive for the canonical names above). """ # Oracle instances are expensive to construct (ML oracles load a model); # cache them per-class so repeated calls in a session reuse the loaded model. _oracle_cache: dict = {}
[docs] @classmethod def _resolve_oracle_name(cls, oracle): """Map a user-provided oracle string to its canonical TDC name.""" if not isinstance(oracle, str): return None key = oracle.strip() # Case-insensitive match against the supported set, preserving TDC's # expected casing (e.g. "logp" -> "LogP"). lowered = key.lower() for canonical in _SUPPORTED_ORACLES: if canonical.lower() == lowered: return canonical return key # pass through; TDC may still recognize it
[docs] @classmethod def _get_oracle(cls, name): """Return a cached or newly constructed TDC Oracle for ``name``.""" if name not in cls._oracle_cache: cls._oracle_cache[name] = _TDCOracle(name=name) return cls._oracle_cache[name]
[docs] def run(self, arguments=None): arguments = arguments or {} if not TDC_AVAILABLE: return { "status": "error", "error": ( "PyTDC is not available. Install it with 'pip install PyTDC' " "(requires rdkit). Underlying import error: " f"{_IMPORT_ERROR}" ), } smiles = arguments.get("smiles") oracle_arg = arguments.get("oracle") if smiles is None or (isinstance(smiles, (list, str)) and len(smiles) == 0): return { "status": "error", "error": "Parameter 'smiles' is required and cannot be empty.", } if not oracle_arg: return { "status": "error", "error": ( "Parameter 'oracle' is required. Supported oracles: " + ", ".join(_SUPPORTED_ORACLES.keys()) ), } oracle_name = self._resolve_oracle_name(oracle_arg) # Normalize input to a list for uniform handling, remembering whether the # caller passed a single string so we can return a scalar in that case. single_input = isinstance(smiles, str) smiles_list = [smiles] if single_input else list(smiles) if not all(isinstance(s, str) and s.strip() for s in smiles_list): return { "status": "error", "error": "All SMILES entries must be non-empty strings.", } try: oracle = self._get_oracle(oracle_name) except Exception as exc: return { "status": "error", "error": ( f"Could not load oracle '{oracle_name}': {exc}. " "Supported oracles: " + ", ".join(_SUPPORTED_ORACLES.keys()) ), } results = [] for smi in smiles_list: try: score = oracle(smi) # TDC may return numpy float types; coerce to native float. score_val = float(score) if score is not None else None results.append({"smiles": smi, "score": score_val, "error": None}) except Exception as exc: results.append({"smiles": smi, "score": None, "error": str(exc)}) data = { "oracle": oracle_name, "oracle_description": _SUPPORTED_ORACLES.get(oracle_name), "results": results, } return {"status": "success", "data": data}