Source code for tooluniverse.tdc_dataset_tool

"""TDC dataset retrieval tool — load Therapeutics Data Commons benchmark
datasets locally via the PyTDC package.

This is distinct from the existing ``TDC_predict_oracle_score`` tool (which
scores SMILES with pretrained oracles). This tool *loads named TDC benchmark
datasets*: it returns a summary (row count, columns, label distribution,
train/valid/test split sizes) plus a small sample of rows, and can list the
available dataset names for a given TDC problem.

TDC datasets are organized by problem. ``single_pred`` problems
(ADME, Tox, HTS, QM, Yields, Epitope, Develop) each take a ``name``::

    from tdc.single_pred import ADME
    data = ADME(name='Caco2_Wang')
    df = data.get_data()          # a pandas DataFrame
    split = data.get_split()      # {'train','valid','test'} DataFrames

``multi_pred`` problems (DTI, DDI, PPI, GDA, DrugRes, DrugSyn, PeptideMHC,
AntibodyAff, MTI, Catalyst, TCREpitopeBinding, TrialOutcome) follow the same
pattern from ``tdc.multi_pred``.

Notes
-----
- Datasets DOWNLOAD on first use (network). The returned sample is capped
  (default 5 rows, max 20) so responses stay small.
- The problem class is imported lazily per call. Some problem modules pull in
  heavy optional dependencies; if such an import fails in the local
  environment, this tool returns a clean error for that problem instead of
  failing to load. ``single_pred`` problems have minimal dependencies and are
  the most reliable.

PyTDC (1.1.x) ships a legacy ``from rdkit.six import iteritems`` import that
fails on modern RDKit. We install a tiny ``rdkit.six`` shim before importing
tdc so the package imports cleanly on current RDKit.
"""

from .base_tool import BaseTool
from .tool_registry import register_tool


def _install_rdkit_six_shim():
    """Provide a minimal ``rdkit.six`` so PyTDC's legacy import succeeds.

    ``rdkit.six`` was removed from modern RDKit. PyTDC still imports
    ``iteritems`` from it inside a try/except that masks the real error.
    Only the symbols PyTDC references are provided. No-op when ``rdkit.six``
    already exists or when RDKit is absent.
    """
    try:
        import rdkit.six  # noqa: F401  -- already present, nothing to do

        return
    except Exception:
        pass

    try:
        import sys
        import types

        import rdkit  # noqa: F401  -- ensure base package exists first
    except Exception:
        # RDKit itself is missing; let the real import error surface later.
        return

    six_mod = types.ModuleType("rdkit.six")
    six_mod.iteritems = lambda d: iter(d.items())
    six_mod.itervalues = lambda d: iter(d.values())
    six_mod.iterkeys = lambda d: iter(d.keys())
    sys.modules["rdkit.six"] = six_mod
    try:
        rdkit.six = six_mod
    except Exception:
        pass


# Attempt the optional import at module load so missing-dependency handling is
# a clean error rather than an exception. Only the lightweight ``tdc.utils``
# helper is imported here; the per-problem dataset classes are imported lazily
# inside run() so that one heavy/broken problem module does not prevent the
# tool from loading or block the other problems.
TDC_AVAILABLE = False
_IMPORT_ERROR = None
try:
    _install_rdkit_six_shim()
    from tdc.utils import retrieve_dataset_names as _retrieve_dataset_names  # noqa: E402

    TDC_AVAILABLE = True
except Exception as exc:  # ImportError or downstream rdkit-guard ImportError
    _retrieve_dataset_names = None
    _IMPORT_ERROR = str(exc)


# Problem name -> submodule that defines its dataset class. The class name is
# the same as the (canonical) problem key.
_SINGLE_PRED = {
    "ADME": "tdc.single_pred",
    "TOX": "tdc.single_pred",
    "HTS": "tdc.single_pred",
    "QM": "tdc.single_pred",
    "YIELDS": "tdc.single_pred",
    "EPITOPE": "tdc.single_pred",
    "DEVELOP": "tdc.single_pred",
}
_MULTI_PRED = {
    "DTI": "tdc.multi_pred",
    "DDI": "tdc.multi_pred",
    "PPI": "tdc.multi_pred",
    "GDA": "tdc.multi_pred",
    "DRUGRES": "tdc.multi_pred",
    "DRUGSYN": "tdc.multi_pred",
    "PEPTIDEMHC": "tdc.multi_pred",
    "ANTIBODYAFF": "tdc.multi_pred",
    "MTI": "tdc.multi_pred",
    "CATALYST": "tdc.multi_pred",
    "TCREPITOPEBINDING": "tdc.multi_pred",
    "TRIALOUTCOME": "tdc.multi_pred",
}

# Canonical class name for each problem (preserves TDC's expected casing).
_PROBLEM_CANONICAL = {
    "ADME": "ADME",
    "TOX": "Tox",
    "HTS": "HTS",
    "QM": "QM",
    "YIELDS": "Yields",
    "EPITOPE": "Epitope",
    "DEVELOP": "Develop",
    "DTI": "DTI",
    "DDI": "DDI",
    "PPI": "PPI",
    "GDA": "GDA",
    "DRUGRES": "DrugRes",
    "DRUGSYN": "DrugSyn",
    "PEPTIDEMHC": "PeptideMHC",
    "ANTIBODYAFF": "AntibodyAff",
    "MTI": "MTI",
    "CATALYST": "Catalyst",
    "TCREPITOPEBINDING": "TCREpitopeBinding",
    "TRIALOUTCOME": "TrialOutcome",
}

# How many distinct numeric label values still counts as "categorical" (so we
# report a value-count distribution instead of summary statistics).
_MAX_CATEGORICAL_VALUES = 20

# Default / maximum number of sample rows returned in the head() preview.
_DEFAULT_SAMPLE = 5
_MAX_SAMPLE = 20


[docs] @register_tool("TDCDatasetTool") class TDCDatasetTool(BaseTool): """Load a Therapeutics Data Commons (TDC) benchmark dataset locally. Operations (selected via the ``operation`` argument) ---------------------------------------------------- ``load_dataset`` (default) Load the dataset named by ``name`` within ``problem`` and return a summary (n_rows, columns, label distribution, split sizes) plus a small sample of rows. ``list_datasets`` Return the available dataset names for ``problem``. Parameters (in ``arguments``) ----------------------------- problem : str TDC problem, e.g. 'ADME', 'Tox', 'HTS', 'QM', 'Yields', 'Epitope', 'Develop' (single_pred) or 'DTI', 'DDI', 'PPI', 'GDA', etc. (multi_pred). Case-insensitive. name : str Dataset name within the problem, e.g. 'Caco2_Wang' (ADME), 'hERG' (Tox). Case-insensitive. Required for ``load_dataset``. sample_rows : int, optional Number of rows to include in the head sample (default 5, max 20). """ # Loaded dataset objects are cached per (problem, name) so repeated calls in # a session reuse the already-downloaded data. _dataset_cache: dict = {}
[docs] @staticmethod def _problem_key(problem): """Normalize a user problem string to its uppercase lookup key.""" if not isinstance(problem, str): return None return problem.strip().upper()
[docs] @classmethod def _resolve_problem(cls, problem): """Return (module_path, class_name) for a problem, or (None, None).""" key = cls._problem_key(problem) if key is None: return None, None if key in _SINGLE_PRED: return _SINGLE_PRED[key], _PROBLEM_CANONICAL[key] if key in _MULTI_PRED: return _MULTI_PRED[key], _PROBLEM_CANONICAL[key] return None, None
[docs] @classmethod def _all_problem_names(cls): return list(_PROBLEM_CANONICAL.values())
[docs] def _unknown_problem_error(self, problem): """Standard error dict for an unrecognized problem name.""" return { "status": "error", "error": ( f"Unknown problem '{problem}'. Supported problems: " + ", ".join(self._all_problem_names()) ), }
[docs] @classmethod def _summarize_labels(cls, series): """Build a label summary dict for the 'Y' column. Numeric labels with few distinct values (classification) are reported as a value-count distribution; otherwise (regression) as summary statistics. Non-numeric labels report a value-count distribution. """ try: n_unique = int(series.nunique(dropna=True)) except Exception: n_unique = None is_numeric = False try: import pandas as pd is_numeric = bool(pd.api.types.is_numeric_dtype(series)) except Exception: is_numeric = False # Classification-style: few distinct values -> distribution. if n_unique is not None and n_unique <= _MAX_CATEGORICAL_VALUES: try: counts = series.value_counts(dropna=True) distribution = {str(idx): int(cnt) for idx, cnt in counts.items()} return { "label_type": "categorical", "n_unique": n_unique, "distribution": distribution, "statistics": None, } except Exception: pass # Regression-style: numeric with many distinct values -> stats. if is_numeric: try: desc = series.describe().to_dict() statistics = { k: (float(v) if v is not None else None) for k, v in desc.items() } return { "label_type": "continuous", "n_unique": n_unique, "distribution": None, "statistics": statistics, } except Exception: pass # Fallback: report distinct count only. return { "label_type": "other", "n_unique": n_unique, "distribution": None, "statistics": None, }
[docs] @classmethod def _get_dataset(cls, module_path, class_name, name): """Import the problem class lazily and load the named dataset. Caches the loaded dataset object per (class_name, name). """ cache_key = (class_name, name) if cache_key in cls._dataset_cache: return cls._dataset_cache[cache_key] module = __import__(module_path, fromlist=[class_name]) problem_cls = getattr(module, class_name) dataset = problem_cls(name=name) cls._dataset_cache[cache_key] = dataset return dataset
[docs] def _handle_list_datasets(self, problem): module_path, class_name = self._resolve_problem(problem) if class_name is None: return self._unknown_problem_error(problem) try: names = _retrieve_dataset_names(class_name) except Exception as exc: return { "status": "error", "error": f"Could not list datasets for problem '{class_name}': {exc}", } return { "status": "success", "data": { "problem": class_name, "n_datasets": len(names), "dataset_names": list(names), }, }
[docs] def _handle_load_dataset(self, problem, name, sample_rows): module_path, class_name = self._resolve_problem(problem) if class_name is None: return self._unknown_problem_error(problem) if not name or not isinstance(name, str): return { "status": "error", "error": "Parameter 'name' is required for load_dataset and must be a dataset name string.", } try: dataset = self._get_dataset(module_path, class_name, name) except Exception as exc: # Try to surface the valid dataset names to help the caller. hint = "" try: valid = _retrieve_dataset_names(class_name) hint = " Available dataset names: " + ", ".join(valid) except Exception: pass return { "status": "error", "error": ( f"Could not load dataset '{name}' for problem '{class_name}': {exc}." + hint ), } try: df = dataset.get_data() except Exception as exc: return { "status": "error", "error": f"Could not retrieve data for '{name}' ({class_name}): {exc}", } n_rows = int(df.shape[0]) columns = [str(c) for c in df.columns] # Label summary on the standard TDC label column 'Y' when present. label_summary = None if "Y" in df.columns: label_summary = self._summarize_labels(df["Y"]) # Train/valid/test split sizes (best-effort; some datasets may differ). split_sizes = None try: split = dataset.get_split() split_sizes = {str(k): int(v.shape[0]) for k, v in split.items()} except Exception: split_sizes = None # Small head() sample, JSON-safe. sample = self._build_sample(df, sample_rows) data = { "problem": class_name, "name": name, "n_rows": n_rows, "columns": columns, "label_summary": label_summary, "split_sizes": split_sizes, "sample_rows": len(sample), "sample": sample, } return {"status": "success", "data": data}
[docs] @staticmethod def _build_sample(df, sample_rows): """Return up to ``sample_rows`` head rows as JSON-safe dicts.""" n = _DEFAULT_SAMPLE if isinstance(sample_rows, int): n = sample_rows n = max(1, min(n, _MAX_SAMPLE)) head = df.head(n) records = [] for _, row in head.iterrows(): record = {} for col in df.columns: value = row[col] record[str(col)] = TDCDatasetTool._jsonable(value) records.append(record) return records
[docs] @staticmethod def _jsonable(value): """Coerce a pandas/numpy scalar to a JSON-serializable value.""" try: import pandas as pd if pd.isna(value): return None except Exception: pass # Numpy / pandas numeric scalars expose .item(). if hasattr(value, "item"): try: return value.item() except Exception: pass if isinstance(value, (int, float, bool, str)) or value is None: return value return str(value)
[docs] def run(self, arguments=None): arguments = arguments or {} if not TDC_AVAILABLE: return { "status": "error", "error": ( "PyTDC is not available. Install it with 'pip install PyTDC' " "(requires rdkit/pandas). Underlying import error: " f"{_IMPORT_ERROR}" ), } operation = arguments.get("operation") or "load_dataset" operation = str(operation).strip().lower() problem = arguments.get("problem") if not problem: return { "status": "error", "error": ( "Parameter 'problem' is required. Supported problems: " + ", ".join(self._all_problem_names()) ), } if operation == "list_datasets": return self._handle_list_datasets(problem) if operation == "load_dataset": return self._handle_load_dataset( problem, arguments.get("name"), arguments.get("sample_rows") ) return { "status": "error", "error": ( f"Unknown operation '{operation}'. Supported operations: " "load_dataset, list_datasets." ), }