Source code for tooluniverse.rxn_chemistry_tool
"""
IBM RXN for Chemistry API tool for ToolUniverse.
IBM RXN for Chemistry provides ML-based chemical reaction prediction:
* Forward reaction prediction: given reactant SMILES (dot-separated),
predict the most likely product SMILES.
* Retrosynthesis: given a target product SMILES, predict precursor
routes (disconnections back to purchasable building blocks).
API: Direct REST against https://rxn.app.accelerate.science
(formerly https://rxn.res.ibm.com). This tool uses plain `requests`;
it does NOT depend on the optional `rxn4chemistry` Python package.
Auth: Requires a free API key. Read ONLY from the environment variable
RXN4CHEMISTRY_API_KEY. Register at https://rxn.app.accelerate.science
(Profile -> "My profile" -> API key). The key is sent verbatim in the
`Authorization` header (NOT a Bearer token).
Documented request shape (verified against the official rxn4chemistry
wrapper route definitions, api_version "v1"):
Base path: {base}/rxn/api/api/v1
Headers: {"Authorization": <api_key>, "Content-Type": "application/json"}
Project context (a project id is required for predictions):
GET {base}/rxn/api/api/v1/projects
-> {"payload": [{"id": "<project_id>", "name": "...", ...}, ...]}
The tool uses an explicit `project_id` arg if given, otherwise the
first project returned, otherwise it creates one:
POST {base}/rxn/api/api/v1/projects body {"name": "<name>"}
Forward reaction prediction (async submit -> poll):
POST {base}/rxn/api/api/v1/predictions/pr?projectId=<pid>&aiModel=<model>
body {"reactants": "<smiles.smiles>", "aiModel": "<model>"}
-> {"payload": {"id": "<prediction_id>"}}
GET {base}/rxn/api/api/v1/predictions/<prediction_id>
-> poll until payload.status == "SUCCESS";
product at payload.attempts[0].smiles (+ confidence).
Retrosynthesis (async submit -> poll):
POST {base}/rxn/api/api/v1/retrosynthesis/rs?projectId=<pid>&aiModel=<model>
body {"product": "<smiles>", "aiModel": "<model>",
"isInteractive": false, "parameters": {...}}
-> {"payload": {"id": "<prediction_id>"}}
GET {base}/rxn/api/api/v1/retrosynthesis/<prediction_id>
-> poll until payload.status == "SUCCESS";
routes at payload.retrosyntheticPaths (each with sequences of
reactant SMILES and a confidence score).
Polling is bounded: each HTTP request uses a 30s timeout, and the poll loop
is capped by `max_wait_time` (default 60s) at `poll_interval` (default 5s).
On timeout the tool returns a clean error (status="error"), never raises.
"""
import os
import time
from typing import Any, Dict, List, Optional
import requests
from .base_tool import BaseTool
from .tool_registry import register_tool
DEFAULT_BASE_URL = "https://rxn.app.accelerate.science"
API_VERSION = "v1"
DEFAULT_AI_MODEL = "2020-08-10"
DEFAULT_PROJECT_NAME = "tooluniverse"
ENV_KEY = "RXN4CHEMISTRY_API_KEY"
REQUEST_TIMEOUT = 30 # seconds, per HTTP request
DEFAULT_POLL_INTERVAL = 5 # seconds between polls
DEFAULT_MAX_WAIT_TIME = 60 # seconds total polling budget
[docs]
@register_tool("RXNChemistryTool")
class RXNChemistryTool(BaseTool):
"""
Wrap IBM RXN for Chemistry ML reaction-prediction endpoints.
Operations (selected via the fixed `operation` parameter per tool config):
* predict_reaction -- forward prediction: reactants -> product
* predict_retrosynthesis -- retrosynthesis: product -> precursor routes
The API key is read ONLY from os.environ[RXN4CHEMISTRY_API_KEY]; it is
never accepted as a parameter. If the key is missing the tool returns a
structured error rather than raising.
"""
[docs]
def __init__(self, tool_config: Dict[str, Any]):
super().__init__(tool_config)
self.base_url = os.environ.get(
"RXN4CHEMISTRY_BASE_URL", DEFAULT_BASE_URL
).rstrip("/")
self.api_url = f"{self.base_url}/rxn/api/api/{API_VERSION}"
# ------------------------------------------------------------------ #
# Key / header helpers
# ------------------------------------------------------------------ #
[docs]
def _headers(self, api_key: str) -> Dict[str, str]:
return {"Authorization": api_key, "Content-Type": "application/json"}
[docs]
@staticmethod
def _missing_key_error() -> Dict[str, Any]:
return {
"status": "error",
"error": (
f"IBM RXN for Chemistry requires an API key. Set the "
f"{ENV_KEY} environment variable. Register for a free key at "
f"https://rxn.app.accelerate.science (Profile -> My profile -> "
f"API key)."
),
}
# ------------------------------------------------------------------ #
# Dispatch
# ------------------------------------------------------------------ #
[docs]
def run(self, arguments: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
arguments = arguments or {}
operation = arguments.get("operation", "") or self.get_schema_const_operation()
dispatch = {
"predict_reaction": self._predict_reaction,
"predict_retrosynthesis": self._predict_retrosynthesis,
}
handler = dispatch.get(operation)
if handler is None:
return {
"status": "error",
"error": (
f"Unknown operation: {operation!r}. "
f"Supported: {', '.join(dispatch)}"
),
}
api_key = self._api_key()
if not api_key:
return self._missing_key_error()
try:
return handler(arguments, api_key)
except requests.exceptions.Timeout:
return {
"status": "error",
"error": f"IBM RXN request timed out after {REQUEST_TIMEOUT}s.",
}
except requests.exceptions.RequestException as exc:
return {"status": "error", "error": f"IBM RXN request failed: {exc}"}
except Exception as exc: # never raise out of run()
return {"status": "error", "error": f"Unexpected error: {exc}"}
# ------------------------------------------------------------------ #
# Project resolution
# ------------------------------------------------------------------ #
[docs]
def _resolve_project_id(
self, arguments: Dict[str, Any], headers: Dict[str, str]
) -> str:
"""Return a usable project id (explicit arg, first existing, or new)."""
explicit = arguments.get("project_id")
if explicit:
return str(explicit)
resp = requests.get(
f"{self.api_url}/projects", headers=headers, timeout=REQUEST_TIMEOUT
)
resp.raise_for_status()
payload = resp.json().get("payload", [])
if isinstance(payload, list) and payload:
pid = payload[0].get("id") or payload[0].get("_id")
if pid:
return str(pid)
# No project yet -> create one.
create = requests.post(
f"{self.api_url}/projects",
headers=headers,
json={"name": DEFAULT_PROJECT_NAME},
timeout=REQUEST_TIMEOUT,
)
create.raise_for_status()
created = create.json().get("payload", {})
pid = created.get("id") or created.get("_id")
if not pid:
raise RuntimeError("Could not resolve or create an IBM RXN project id.")
return str(pid)
# ------------------------------------------------------------------ #
# Async polling helper
# ------------------------------------------------------------------ #
[docs]
def _poll(
self,
results_url: str,
headers: Dict[str, str],
poll_interval: float,
max_wait_time: float,
) -> Dict[str, Any]:
"""Poll a results URL until status == SUCCESS or budget exhausted.
Returns the `payload` dict on success. Raises RuntimeError on a
terminal failure status; raises TimeoutError when the budget runs out.
"""
deadline = time.time() + max_wait_time
last_status = "NEW"
while time.time() < deadline:
resp = requests.get(results_url, headers=headers, timeout=REQUEST_TIMEOUT)
resp.raise_for_status()
payload = resp.json().get("payload", {}) or {}
last_status = str(payload.get("status", "")).upper()
if last_status == "SUCCESS":
return payload
if last_status in {"ERROR", "FAILED"}:
raise RuntimeError(
f"IBM RXN prediction failed with status {last_status}."
)
time.sleep(poll_interval)
raise TimeoutError(
f"IBM RXN prediction did not finish within {max_wait_time}s "
f"(last status: {last_status})."
)
[docs]
@staticmethod
def _polling_settings(arguments: Dict[str, Any]) -> tuple:
try:
poll_interval = float(
arguments.get("poll_interval") or DEFAULT_POLL_INTERVAL
)
except (TypeError, ValueError):
poll_interval = DEFAULT_POLL_INTERVAL
try:
max_wait = float(arguments.get("max_wait_time") or DEFAULT_MAX_WAIT_TIME)
except (TypeError, ValueError):
max_wait = DEFAULT_MAX_WAIT_TIME
return max(1.0, poll_interval), max(1.0, max_wait)
# ------------------------------------------------------------------ #
# Operation: forward reaction prediction
# ------------------------------------------------------------------ #
[docs]
def _predict_reaction(
self, arguments: Dict[str, Any], api_key: str
) -> Dict[str, Any]:
reactants = (arguments.get("reactants") or "").strip()
if not reactants:
return {
"status": "error",
"error": (
"Missing required parameter: reactants (dot-separated SMILES, "
"e.g. 'BrBr.c1ccc2cc3ccccc3cc2c1')."
),
}
ai_model = arguments.get("ai_model") or DEFAULT_AI_MODEL
headers = self._headers(api_key)
poll_interval, max_wait = self._polling_settings(arguments)
project_id = self._resolve_project_id(arguments, headers)
submit = requests.post(
f"{self.api_url}/predictions/pr",
headers=headers,
params={"projectId": project_id, "aiModel": ai_model},
json={"reactants": reactants, "aiModel": ai_model},
timeout=REQUEST_TIMEOUT,
)
submit.raise_for_status()
prediction_id = (submit.json().get("payload", {}) or {}).get("id")
if not prediction_id:
return {
"status": "error",
"error": "IBM RXN did not return a prediction id for forward prediction.",
}
results_url = f"{self.api_url}/predictions/{prediction_id}"
payload = self._poll(results_url, headers, poll_interval, max_wait)
attempts = payload.get("attempts") or []
products: List[Dict[str, Any]] = [
{
"smiles": att.get("smiles"),
"confidence": att.get("confidence"),
}
for att in attempts
if isinstance(att, dict)
]
top = products[0] if products else {}
return {
"status": "success",
"data": {
"reactants": reactants,
"product_smiles": top.get("smiles"),
"confidence": top.get("confidence"),
"attempts": products,
"prediction_id": prediction_id,
"ai_model": ai_model,
},
}
# ------------------------------------------------------------------ #
# Operation: retrosynthesis
# ------------------------------------------------------------------ #
[docs]
def _predict_retrosynthesis(
self, arguments: Dict[str, Any], api_key: str
) -> Dict[str, Any]:
product = (arguments.get("product") or "").strip()
if not product:
return {
"status": "error",
"error": (
"Missing required parameter: product (target SMILES, "
"e.g. 'CC(=O)Oc1ccccc1C(=O)O' for aspirin)."
),
}
ai_model = arguments.get("ai_model") or DEFAULT_AI_MODEL
max_steps = arguments.get("max_steps")
headers = self._headers(api_key)
poll_interval, max_wait = self._polling_settings(arguments)
parameters: Dict[str, Any] = {}
if max_steps is not None:
parameters["maxSteps"] = max_steps
project_id = self._resolve_project_id(arguments, headers)
body: Dict[str, Any] = {
"product": product,
"aiModel": ai_model,
"isInteractive": False,
}
if parameters:
body["parameters"] = parameters
submit = requests.post(
f"{self.api_url}/retrosynthesis/rs",
headers=headers,
params={"projectId": project_id, "aiModel": ai_model},
json=body,
timeout=REQUEST_TIMEOUT,
)
submit.raise_for_status()
prediction_id = (submit.json().get("payload", {}) or {}).get("id")
if not prediction_id:
return {
"status": "error",
"error": "IBM RXN did not return a prediction id for retrosynthesis.",
}
results_url = f"{self.api_url}/retrosynthesis/{prediction_id}"
payload = self._poll(results_url, headers, poll_interval, max_wait)
raw_paths = payload.get("retrosyntheticPaths") or payload.get("paths") or []
routes: List[Dict[str, Any]] = []
for path in raw_paths:
if not isinstance(path, dict):
continue
routes.append(
{
"smiles": path.get("smiles"),
"confidence": path.get("confidence"),
"count": path.get("count"),
}
)
return {
"status": "success",
"data": {
"product": product,
"routes": routes,
"route_count": len(routes),
"prediction_id": prediction_id,
"ai_model": ai_model,
},
}