Source code for tooluniverse.cellpose_tool
"""Cellpose tool — local deep-learning cell/nucleus segmentation.
Wraps the `cellpose <https://github.com/MouseLand/cellpose>`_ Python package to
segment cells or nuclei from a single microscopy image *locally* (no API, no key).
Given an image file path it returns the number of segmented objects, per-object
areas and centroids, and (optionally) the path to a saved label-mask image.
This is a true local-compute tool and a genuine capability gap in ToolUniverse:
no existing tool performs instance segmentation of microscopy images. (The
"segment"/"segmentation" hits in other tool configs refer to genomic segments or
to remote BioImage-Archive/CryoET dataset metadata, not local image segmentation.)
Cellpose API note
-----------------
Cellpose 3.x exposed ``models.Cellpose(model_type='cyto3'|'nuclei'|'cyto')`` and a
zoo of named models. Cellpose 4.x removed that class and replaced the zoo with a
single unified model (CPSAM); there, ``model_type`` is accepted but **ignored**.
This tool supports both generations: it uses ``models.Cellpose`` when available
(3.x, honoring ``model_type``) and otherwise falls back to ``models.CellposeModel``
(4.x). The model actually used is reported back in the result so callers are never
misled about whether ``model_type`` took effect.
Runtime / dependency notes
--------------------------
- ``pip install cellpose`` (torch is a dependency). Declared as an optional
``required_packages`` entry; ``run()`` returns a clean error if unavailable.
- On first use, cellpose downloads model weights (small in 3.x; ~1 GB CPSAM in
4.x) and caches them. The first call is therefore slow; later calls reuse the
cached weights and a per-process cached model instance. CPU inference on a
256x256 image takes on the order of a minute or two; GPU is much faster.
"""
import os
from .base_tool import BaseTool
from .tool_registry import register_tool
# Optional dependency import at module load so a missing package becomes a clean
# error rather than an exception (framework optional-dependency pattern).
CELLPOSE_AVAILABLE = False
_IMPORT_ERROR = None
try:
import numpy as np # noqa: E402
from cellpose import models as _cp_models # noqa: E402
CELLPOSE_AVAILABLE = True
except Exception as exc: # ImportError, or a downstream import failure
np = None
_cp_models = None
_IMPORT_ERROR = str(exc)
# Legacy cellpose-3.x model types accepted from the caller. In 4.x these are
# ignored by the unified model but still accepted so callers/tests are stable.
_LEGACY_MODEL_TYPES = ("cyto3", "cyto2", "cyto", "nuclei")
# Image extensions we know how to load.
_TIFF_EXTS = (".tif", ".tiff")
_PIL_EXTS = (".png", ".jpg", ".jpeg", ".bmp")
[docs]
@register_tool("CellposeTool")
class CellposeTool(BaseTool):
"""Segment cells/nuclei in a microscopy image with Cellpose (local compute).
Arguments (in ``arguments``)
----------------------------
image_path : str
Path to a local microscopy image (.tif/.tiff/.png/.jpg/.bmp).
model_type : str, optional
One of ``cyto3``, ``cyto``, ``cyto2``, ``nuclei`` (default ``cyto3``).
Honored on cellpose 3.x; ignored by the unified model on 4.x (the
``model_used`` field reports what actually ran).
diameter : float, optional
Expected object diameter in pixels. ``None``/0 lets cellpose estimate it.
channels : list[int], optional
Two-element ``[cytoplasm, nucleus]`` channel spec, e.g. ``[0, 0]`` for a
grayscale image (default). Used by cellpose 3.x; 4.x infers channels.
save_mask : bool, optional
If true, save the integer label mask next to the image (or to
``mask_output_path``) and return its path. Default false.
mask_output_path : str, optional
Where to write the mask image when ``save_mask`` is true. Defaults to
``<image_path>_cp_masks.png``.
"""
# Cache loaded models per (class, key) — model construction (and the 4.x
# weight load) is expensive, so reuse within a process.
_model_cache: dict = {}
[docs]
@classmethod
def _get_model(cls, model_type):
"""Return a cached/new cellpose model and the actual model identifier used."""
# Prefer the 3.x unified ``Cellpose`` class (honors model_type); fall back
# to 4.x ``CellposeModel`` (single CPSAM model, model_type ignored).
if hasattr(_cp_models, "Cellpose"):
cache_key = ("Cellpose", model_type)
if cache_key not in cls._model_cache:
cls._model_cache[cache_key] = (
_cp_models.Cellpose(gpu=False, model_type=model_type),
model_type,
)
return cls._model_cache[cache_key]
cache_key = ("CellposeModel", None)
if cache_key not in cls._model_cache:
model = _cp_models.CellposeModel(gpu=False)
model_used = getattr(model, "pretrained_model", "cellpose-default")
if isinstance(model_used, (list, tuple)):
model_used = model_used[0] if model_used else "cellpose-default"
model_used = os.path.basename(str(model_used))
cls._model_cache[cache_key] = (model, model_used)
return cls._model_cache[cache_key]
[docs]
@staticmethod
def _load_image(image_path):
"""Load an image into a 2D/3D numpy array. Returns (array, error_or_None)."""
ext = os.path.splitext(image_path)[1].lower()
try:
if ext in _TIFF_EXTS:
try:
import tifffile
arr = tifffile.imread(image_path)
except Exception:
from PIL import Image
arr = np.array(Image.open(image_path))
elif ext in _PIL_EXTS:
from PIL import Image
arr = np.array(Image.open(image_path))
else:
# Last-ditch attempt via PIL; report unsupported extension if it fails.
try:
from PIL import Image
arr = np.array(Image.open(image_path))
except Exception:
return None, (
f"Unsupported image extension '{ext}'. Supported: "
+ ", ".join(_TIFF_EXTS + _PIL_EXTS)
)
except Exception as exc:
return None, f"Failed to read image '{image_path}': {exc}"
if arr is None or getattr(arr, "size", 0) == 0:
return None, f"Image '{image_path}' is empty or could not be decoded."
return arr, None
[docs]
@staticmethod
def _object_stats(masks):
"""Compute count and per-object area/centroid from an integer label mask."""
labels = [int(v) for v in np.unique(masks) if int(v) != 0]
objects = []
for label in labels:
ys, xs = np.where(masks == label)
objects.append(
{
"label": label,
"area": int(ys.size),
"centroid_y": float(ys.mean()),
"centroid_x": float(xs.mean()),
}
)
return objects
[docs]
def _save_mask(self, masks, image_path, mask_output_path):
"""Write the label mask to disk. Returns (path_or_None, error_or_None)."""
out_path = mask_output_path or (
os.path.splitext(image_path)[0] + "_cp_masks.png"
)
try:
from PIL import Image
# 16-bit grayscale preserves up to 65535 distinct object labels.
mask16 = masks.astype(np.uint16)
Image.fromarray(mask16).save(out_path)
return out_path, None
except Exception as exc:
return None, f"Failed to save mask to '{out_path}': {exc}"
[docs]
def run(self, arguments=None):
arguments = arguments or {}
if not CELLPOSE_AVAILABLE:
return {
"status": "error",
"error": (
"The 'cellpose' package is not available. Install it with "
"'pip install cellpose' (requires torch). Underlying import "
f"error: {_IMPORT_ERROR}"
),
}
image_path = arguments.get("image_path")
if not image_path or not isinstance(image_path, str):
return {
"status": "error",
"error": "Parameter 'image_path' is required and must be a string.",
}
if not os.path.exists(image_path):
return {
"status": "error",
"error": f"Image file not found: '{image_path}'.",
}
model_type = arguments.get("model_type") or "cyto3"
if model_type not in _LEGACY_MODEL_TYPES:
return {
"status": "error",
"error": (
f"Unsupported model_type '{model_type}'. One of: "
+ ", ".join(_LEGACY_MODEL_TYPES)
),
}
diameter = arguments.get("diameter")
if diameter in (0, 0.0):
diameter = None
channels = arguments.get("channels") or [0, 0]
save_mask = bool(arguments.get("save_mask", False))
mask_output_path = arguments.get("mask_output_path")
img, load_err = self._load_image(image_path)
if load_err is not None:
return {"status": "error", "error": load_err}
try:
model, model_used = self._get_model(model_type)
except Exception as exc:
return {
"status": "error",
"error": f"Failed to load cellpose model '{model_type}': {exc}",
}
try:
# eval returns (masks, flows, styles, diams) in 3.x and
# (masks, flows, styles) in 4.x — only the first element is needed.
masks = model.eval(img, diameter=diameter, channels=channels)[0]
except Exception as exc:
return {
"status": "error",
"error": f"Cellpose segmentation failed: {exc}",
}
objects = self._object_stats(masks)
data = {
"image_path": image_path,
"model_type_requested": model_type,
"model_used": model_used,
"num_objects": len(objects),
"image_shape": [int(d) for d in masks.shape],
"objects": objects,
"mask_path": None,
}
if model_used and not any(
mt in str(model_used).lower() for mt in _LEGACY_MODEL_TYPES
):
data["note"] = (
"This cellpose build uses a single unified model; the requested "
"'model_type' was not applied. Segmentation was produced by "
f"'{model_used}'."
)
if save_mask:
mask_path, save_err = self._save_mask(masks, image_path, mask_output_path)
if save_err is not None:
# Segmentation succeeded; surface the save failure non-fatally.
data["mask_save_error"] = save_err
else:
data["mask_path"] = mask_path
return {"status": "success", "data": data}