Source code for tooluniverse.huggingface_inference_tool
"""HuggingFace serverless Inference API wrapper for ToolUniverse.
A single tool class (``HuggingFaceInferenceTool``) that runs inference on any
model hosted by HuggingFace's serverless ``hf-inference`` provider, exposed as a
few high-value task wrappers:
* ``classify_text`` — text-classification (e.g. sentiment) -> labels + scores
* ``embed_text`` — feature-extraction / embeddings -> a single vector
* ``fill_mask`` — masked-language-model fill-mask -> ranked candidate tokens
(works for protein LMs such as ESM-2 too)
* ``summarize`` — summarization -> condensed summary text
* ``zero_shot_classify`` — zero-shot-classification against caller-supplied labels
* ``ner`` — token-classification / named-entity recognition
* ``question_answering`` — extractive QA over a question + context
* ``translate`` — machine translation -> translated text
* ``classify_image`` — image-classification -> top labels + scores
* ``detect_objects`` — object-detection -> objects with bounding boxes
Endpoint note
-------------
The historical base ``https://api-inference.huggingface.co/models/{id}`` no
longer resolves — HuggingFace migrated serverless inference to the unified
router. This tool targets the current endpoint::
https://router.huggingface.co/hf-inference/models/{model_id}
For embeddings the request is sent to the explicit feature-extraction
sub-pipeline (``.../{model_id}/pipeline/feature-extraction``) so that
sentence-transformers models — which otherwise default to a
sentence-similarity pipeline — return a raw embedding vector.
Authentication
--------------
An optional bearer token is read from the ``HF_TOKEN`` environment variable
(never a tool parameter). Many models work token-less with stricter rate
limits; a token raises those limits and unlocks gated models.
The tool never raises: every path returns a ``{"status": ...}`` dict. A model
that is still warming up returns HTTP 503; that is surfaced as a clear,
retryable status rather than an exception.
"""
import os
from typing import Any, Dict, List, Optional
import requests
from .base_tool import BaseTool
from .tool_registry import register_tool
_BASE_URL = "https://router.huggingface.co/hf-inference/models"
_TIMEOUT = 30
_MAX_EMBED_PREVIEW = 8 # vector entries shown in the preview field
def _err(msg: str, **extra: Any) -> Dict[str, Any]:
out: Dict[str, Any] = {"status": "error", "error": msg}
out.update(extra)
return out
def _ok(data: Any, **metadata: Any) -> Dict[str, Any]:
meta = {"provider": "hf-inference"}
meta.update(metadata)
return {"status": "success", "data": data, "metadata": meta}
[docs]
@register_tool("HuggingFaceInferenceTool")
class HuggingFaceInferenceTool(BaseTool):
"""Run inference on HuggingFace-hosted models (serverless hf-inference)."""
[docs]
def __init__(self, tool_config):
super().__init__(tool_config)
self.parameter = tool_config.get("parameter", {})
# ------------------------------------------------------------------ #
# dispatch
# ------------------------------------------------------------------ #
[docs]
def run(self, arguments: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
arguments = arguments or {}
operation = arguments.get("operation")
handlers = {
"classify_text": self._classify_text,
"embed_text": self._embed_text,
"fill_mask": self._fill_mask,
"summarize": self._summarize,
"zero_shot_classify": self._zero_shot_classify,
"ner": self._ner,
"question_answering": self._question_answering,
"translate": self._translate,
"classify_image": self._classify_image,
"detect_objects": self._detect_objects,
}
handler = handlers.get(operation)
if handler is None:
return _err(
f"Unknown or missing operation: {operation!r}. "
f"Expected one of {sorted(handlers)}."
)
try:
return handler(arguments)
except Exception as exc: # never raise out of run()
return _err(f"Unexpected error: {type(exc).__name__}: {exc}")
# ------------------------------------------------------------------ #
# shared input validation
# ------------------------------------------------------------------ #
[docs]
@staticmethod
def _require_text_and_model(args: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Return an error dict if text/model_id are missing, else None."""
text = args.get("text")
if not text or not str(text).strip():
return _err("Missing required parameter: text")
if not args.get("model_id"):
return _err("Missing required parameter: model_id")
return None
[docs]
def _post_text(
self, args: Dict[str, Any], payload: Dict[str, Any], path_suffix: str = ""
) -> Dict[str, Any]:
"""Validate text+model, POST, and unwrap the router response.
Returns ``{"_json": <body>}`` on success or a ready-made
``{"status": "error"|"loading", ...}`` dict on any failure — the same
shape callers already branch on via the ``"_json"`` key.
"""
invalid = self._require_text_and_model(args)
if invalid is not None:
return invalid
return self._post(
args.get("model_id"),
payload,
wait_for_model=bool(args.get("wait_for_model", False)),
path_suffix=path_suffix,
)
# ------------------------------------------------------------------ #
# shared HTTP helper
# ------------------------------------------------------------------ #
[docs]
def _post(
self,
model_id: str,
payload: Dict[str, Any],
wait_for_model: bool,
path_suffix: str = "",
) -> Dict[str, Any]:
"""POST to the inference router.
Returns either ``{"_json": <decoded body>}`` on success or a ready-made
``{"status": "error"|"loading", ...}`` dict on any failure.
"""
url = f"{_BASE_URL}/{model_id.strip('/')}{path_suffix}"
headers = {"Content-Type": "application/json"}
token = os.environ.get("HF_TOKEN", "")
if token:
headers["Authorization"] = f"Bearer {token}"
if wait_for_model:
# HF holds the request open until the model finishes loading.
headers["x-wait-for-model"] = "true"
try:
resp = requests.post(url, json=payload, headers=headers, timeout=_TIMEOUT)
except requests.exceptions.Timeout:
return _err(
f"Timeout after {_TIMEOUT}s contacting {model_id}. The model may "
"be loading — retry, optionally with wait_for_model=true."
)
except requests.exceptions.RequestException as exc:
return _err(f"Network error contacting {model_id}: {exc}")
# Text POSTs use a gated/private-specific 401 message; all other
# status handling is shared with image POSTs via _interpret_status.
if resp.status_code == 401:
return _err(
f"Unauthorized for {model_id}. The model may be gated/private; "
"set a valid HF_TOKEN with access."
)
return self._interpret_status(resp, model_id)
# ------------------------------------------------------------------ #
# shared image input helpers
# ------------------------------------------------------------------ #
[docs]
@staticmethod
def _content_type_for(source: str, default: str = "image/jpeg") -> str:
"""Guess an image Content-Type from a URL or file path extension."""
ext = os.path.splitext(source.split("?")[0])[1].lower()
return {
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".gif": "image/gif",
".bmp": "image/bmp",
".webp": "image/webp",
".tif": "image/tiff",
".tiff": "image/tiff",
}.get(ext, default)
[docs]
def _load_image_bytes(self, args: Dict[str, Any]) -> Dict[str, Any]:
"""Resolve image bytes from image_url or image_path.
Returns ``{"_bytes": <data>, "_content_type": <ct>}`` on success or a
ready-made ``{"status": "error", ...}`` dict on any failure.
"""
image_url = args.get("image_url")
image_path = args.get("image_path")
if not image_url and not image_path:
return _err(
"Missing image input: provide exactly one of image_url "
"(a public http(s) URL) or image_path (a local file path)."
)
if image_url and image_path:
return _err("Provide only one of image_url or image_path, not both.")
if image_path:
try:
with open(image_path, "rb") as fh:
data = fh.read()
except OSError as exc:
return _err(f"Could not read image_path {image_path!r}: {exc}")
if not data:
return _err(f"Image file is empty: {image_path}")
return {
"_bytes": data,
"_content_type": self._content_type_for(image_path),
}
# image_url
try:
resp = requests.get(image_url, timeout=_TIMEOUT)
except requests.exceptions.Timeout:
return _err(f"Timeout after {_TIMEOUT}s fetching image_url {image_url}.")
except requests.exceptions.RequestException as exc:
return _err(f"Could not fetch image_url {image_url}: {exc}")
if resp.status_code != 200:
return _err(
f"HTTP {resp.status_code} fetching image_url {image_url}. "
"Ensure it is a public, directly-downloadable image."
)
if not resp.content:
return _err(f"Empty image body from image_url {image_url}.")
ct = resp.headers.get("content-type", "").split(";")[0].strip()
if not ct.startswith("image/"):
ct = self._content_type_for(image_url)
return {"_bytes": resp.content, "_content_type": ct}
[docs]
def _post_image(self, model_id: str, args: Dict[str, Any]) -> Dict[str, Any]:
"""Validate model + load image, then POST the raw image bytes.
Image models expect the binary image as the request body with the
image's Content-Type (not JSON). Returns ``{"_json": <body>}`` on
success or a ready-made ``{"status": ...}`` dict on any failure.
"""
if not model_id:
return _err("Missing required parameter: model_id")
loaded = self._load_image_bytes(args)
if "_bytes" not in loaded:
return loaded # error dict
url = f"{_BASE_URL}/{model_id.strip('/')}"
headers = {"Content-Type": loaded["_content_type"]}
token = os.environ.get("HF_TOKEN", "")
if token:
headers["Authorization"] = f"Bearer {token}"
if bool(args.get("wait_for_model", False)):
headers["x-wait-for-model"] = "true"
try:
resp = requests.post(
url, data=loaded["_bytes"], headers=headers, timeout=_TIMEOUT
)
except requests.exceptions.Timeout:
return _err(
f"Timeout after {_TIMEOUT}s contacting {model_id}. The image "
"model may be loading — retry, optionally with "
"wait_for_model=true."
)
except requests.exceptions.RequestException as exc:
return _err(f"Network error contacting {model_id}: {exc}")
return self._interpret_status(resp, model_id)
[docs]
def _interpret_status(self, resp, model_id: str) -> Dict[str, Any]:
"""Map an HTTP response to ``{"_json": ...}`` or a status dict.
Shared status handling for both JSON and image POSTs (503 loading,
401/404/429, other non-200, and non-JSON bodies).
"""
if resp.status_code == 503:
est = None
try:
est = resp.json().get("estimated_time")
except Exception:
pass
return {
"status": "loading",
"error": (
f"Model {model_id} is loading on the HF inference servers"
+ (
f" (~{est:.0f}s estimated)"
if isinstance(est, (int, float))
else ""
)
+ ". Retry shortly, or pass wait_for_model=true to block "
"until it is ready."
),
"estimated_time": est,
}
if resp.status_code == 401:
return _err(
f"Unauthorized for {model_id}. Serverless image inference now "
"requires a token: set a valid HF_TOKEN with access."
)
if resp.status_code == 404:
return _err(
f"Model not found: {model_id}. Check the exact repo id "
"(e.g. 'org/name')."
)
if resp.status_code == 429:
return _err(
f"Rate limited for {model_id}. Set HF_TOKEN to raise free-tier "
"limits, or retry later."
)
if resp.status_code != 200:
detail = ""
try:
body = resp.json()
detail = body.get("error") or str(body)
except Exception:
detail = resp.text[:300]
return _err(f"HTTP {resp.status_code} from {model_id}: {detail}")
try:
return {"_json": resp.json()}
except ValueError:
return _err(f"Non-JSON response from {model_id}: {resp.text[:200]}")
# ------------------------------------------------------------------ #
# shared label-classification parsing
# ------------------------------------------------------------------ #
[docs]
@staticmethod
def _parse_labels(body: Any) -> Optional[List[Dict[str, Any]]]:
"""Parse a (text|image)-classification body into sorted label dicts.
Both tasks return ``[{label, score}, ...]`` or, batched, a nested
``[[{label, score}, ...]]``. Returns the score-descending label list,
or ``None`` if the body is not a list of records.
"""
labels = (
body[0]
if (isinstance(body, list) and body and isinstance(body[0], list))
else body
)
if not isinstance(labels, list):
return None
return sorted(
(
{"label": d.get("label"), "score": d.get("score")}
for d in labels
if isinstance(d, dict)
),
key=lambda d: d["score"] if d["score"] is not None else -1.0,
reverse=True,
)
# ------------------------------------------------------------------ #
# text-classification
# ------------------------------------------------------------------ #
[docs]
def _classify_text(self, args: Dict[str, Any]) -> Dict[str, Any]:
model_id = args.get("model_id")
result = self._post_text(args, {"inputs": args.get("text")})
if "_json" not in result:
return result # error / loading dict
body = result["_json"]
labels = self._parse_labels(body)
if labels is None:
return _err(f"Unexpected classification response: {str(body)[:200]}")
top = labels[0]["label"] if labels else None
return _ok(
{"model_id": model_id, "top_label": top, "labels": labels},
task="text-classification",
)
# ------------------------------------------------------------------ #
# feature-extraction / embeddings
# ------------------------------------------------------------------ #
[docs]
def _embed_text(self, args: Dict[str, Any]) -> Dict[str, Any]:
model_id = args.get("model_id")
# Force the feature-extraction pipeline so sentence-transformers models
# return a raw vector instead of routing to sentence-similarity.
result = self._post_text(
args,
{"inputs": args.get("text")},
path_suffix="/pipeline/feature-extraction",
)
if "_json" not in result:
return result
body = result["_json"]
vector = self._flatten_embedding(body)
if vector is None:
return _err(
f"Could not parse embedding from response for {model_id}: "
f"{str(body)[:200]}"
)
return _ok(
{
"model_id": model_id,
"dimension": len(vector),
"embedding": vector,
"preview": vector[:_MAX_EMBED_PREVIEW],
},
task="feature-extraction",
)
[docs]
@staticmethod
def _flatten_embedding(body: Any) -> Optional[List[float]]:
"""Reduce an HF feature-extraction response to one 1-D vector.
Responses may be ``[float, ...]`` (already pooled), ``[[float, ...]]``
(batch of one), or token-level ``[[[float, ...], ...]]`` which is
mean-pooled over the token axis.
"""
if not isinstance(body, list) or not body:
return None
first = body[0]
# Already a flat vector of numbers.
if isinstance(first, (int, float)):
return [float(x) for x in body]
# Batch of one: [[...]]
if isinstance(first, list) and first and isinstance(first[0], (int, float)):
return [float(x) for x in first]
# Token-level: [[[...], [...], ...]] -> mean-pool tokens.
if isinstance(first, list) and first and isinstance(first[0], list):
tokens = first # tokens of the single input
num = [t for t in tokens if isinstance(t, list) and t]
if not num:
return None
dim = len(num[0])
pooled = [0.0] * dim
for tok in num:
for i in range(dim):
pooled[i] += float(tok[i])
return [v / len(num) for v in pooled]
return None
# ------------------------------------------------------------------ #
# fill-mask
# ------------------------------------------------------------------ #
[docs]
def _fill_mask(self, args: Dict[str, Any]) -> Dict[str, Any]:
model_id = args.get("model_id")
top_k = args.get("top_k")
payload: Dict[str, Any] = {"inputs": args.get("text")}
if isinstance(top_k, int) and top_k > 0:
payload["parameters"] = {"top_k": top_k}
result = self._post_text(args, payload)
if "_json" not in result:
return result
body = result["_json"]
# fill-mask returns [{score, token, token_str, sequence}, ...]; with
# multiple masks it nests one list per mask.
if isinstance(body, list) and body and isinstance(body[0], list):
body = body[0]
if not isinstance(body, list):
return _err(f"Unexpected fill-mask response: {str(body)[:200]}")
predictions = [
{
"token_str": (d.get("token_str") or "").strip(),
"score": d.get("score"),
"sequence": d.get("sequence"),
}
for d in body
if isinstance(d, dict)
]
if not predictions:
return _err(
f"No predictions returned for {model_id}. Ensure the input "
"contains the model's mask token (e.g. [MASK] for BERT, "
"<mask> for RoBERTa/ESM)."
)
return _ok(
{
"model_id": model_id,
"top_token": predictions[0]["token_str"],
"predictions": predictions,
},
task="fill-mask",
)
# ------------------------------------------------------------------ #
# summarization
# ------------------------------------------------------------------ #
[docs]
def _summarize(self, args: Dict[str, Any]) -> Dict[str, Any]:
model_id = args.get("model_id")
parameters: Dict[str, Any] = {}
max_length = args.get("max_length")
min_length = args.get("min_length")
if isinstance(max_length, int) and max_length > 0:
parameters["max_length"] = max_length
if isinstance(min_length, int) and min_length > 0:
parameters["min_length"] = min_length
payload: Dict[str, Any] = {"inputs": args.get("text")}
if parameters:
payload["parameters"] = parameters
result = self._post_text(args, payload)
if "_json" not in result:
return result
body = result["_json"]
# summarization returns [{"summary_text": "..."}].
item = body[0] if isinstance(body, list) and body else body
summary = item.get("summary_text") if isinstance(item, dict) else None
if not summary:
return _err(f"Unexpected summarization response: {str(body)[:200]}")
return _ok(
{"model_id": model_id, "summary_text": summary},
task="summarization",
)
# ------------------------------------------------------------------ #
# zero-shot-classification
# ------------------------------------------------------------------ #
[docs]
def _zero_shot_classify(self, args: Dict[str, Any]) -> Dict[str, Any]:
invalid = self._require_text_and_model(args)
if invalid is not None:
return invalid
model_id = args.get("model_id")
candidate_labels = args.get("candidate_labels")
if not isinstance(candidate_labels, list) or not candidate_labels:
return _err(
"Missing required parameter: candidate_labels (a non-empty list "
"of strings to classify the text against)."
)
labels_in = [str(label) for label in candidate_labels]
parameters: Dict[str, Any] = {"candidate_labels": labels_in}
if args.get("multi_label") is not None:
parameters["multi_label"] = bool(args.get("multi_label"))
result = self._post_text(
args, {"inputs": args.get("text"), "parameters": parameters}
)
if "_json" not in result:
return result
body = result["_json"]
# The router returns either a flat sorted [{"label","score"}, ...] list
# or the classic {"labels":[...], "scores":[...]} dict — normalise both.
labels: List[Dict[str, Any]] = []
if isinstance(body, list):
labels = [
{"label": d.get("label"), "score": d.get("score")}
for d in body
if isinstance(d, dict)
]
elif isinstance(body, dict):
names = body.get("labels")
scores = body.get("scores")
if isinstance(names, list) and isinstance(scores, list):
labels = [{"label": n, "score": s} for n, s in zip(names, scores)]
if not labels:
return _err(
f"Unexpected zero-shot response from {model_id}: {str(body)[:200]}"
)
labels = sorted(
labels,
key=lambda d: d["score"] if d["score"] is not None else -1.0,
reverse=True,
)
return _ok(
{
"model_id": model_id,
"top_label": labels[0]["label"],
"labels": labels,
},
task="zero-shot-classification",
)
# ------------------------------------------------------------------ #
# token-classification / NER
# ------------------------------------------------------------------ #
[docs]
def _ner(self, args: Dict[str, Any]) -> Dict[str, Any]:
model_id = args.get("model_id")
result = self._post_text(args, {"inputs": args.get("text")})
if "_json" not in result:
return result
body = result["_json"]
# token-classification returns [{entity_group|entity, score, word,
# start, end}, ...]; with batching it may nest one list per input.
if isinstance(body, list) and body and isinstance(body[0], list):
body = body[0]
if not isinstance(body, list):
return _err(f"Unexpected NER response: {str(body)[:200]}")
entities = [
{
"entity_group": d.get("entity_group") or d.get("entity"),
"word": d.get("word"),
"score": d.get("score"),
"start": d.get("start"),
"end": d.get("end"),
}
for d in body
if isinstance(d, dict)
]
return _ok(
{
"model_id": model_id,
"entity_count": len(entities),
"entities": entities,
},
task="token-classification",
)
# ------------------------------------------------------------------ #
# extractive question-answering
# ------------------------------------------------------------------ #
[docs]
def _question_answering(self, args: Dict[str, Any]) -> Dict[str, Any]:
model_id = args.get("model_id")
question = args.get("question")
context = args.get("context")
if not model_id:
return _err("Missing required parameter: model_id")
if not question or not str(question).strip():
return _err("Missing required parameter: question")
if not context or not str(context).strip():
return _err("Missing required parameter: context")
result = self._post(
model_id,
{"inputs": {"question": question, "context": context}},
wait_for_model=bool(args.get("wait_for_model", False)),
)
if "_json" not in result:
return result
body = result["_json"]
# QA returns a single {"answer","score","start","end"} object (or a
# one-element list of the same).
item = body[0] if isinstance(body, list) and body else body
if not isinstance(item, dict) or "answer" not in item:
return _err(f"Unexpected QA response from {model_id}: {str(body)[:200]}")
return _ok(
{
"model_id": model_id,
"answer": item.get("answer"),
"score": item.get("score"),
"start": item.get("start"),
"end": item.get("end"),
},
task="question-answering",
)
# ------------------------------------------------------------------ #
# translation
# ------------------------------------------------------------------ #
[docs]
def _translate(self, args: Dict[str, Any]) -> Dict[str, Any]:
model_id = args.get("model_id")
result = self._post_text(args, {"inputs": args.get("text")})
if "_json" not in result:
return result
body = result["_json"]
# translation returns [{"translation_text": "..."}].
item = body[0] if isinstance(body, list) and body else body
translation = item.get("translation_text") if isinstance(item, dict) else None
if not translation:
return _err(f"Unexpected translation response: {str(body)[:200]}")
return _ok(
{"model_id": model_id, "translation_text": translation},
task="translation",
)
# ------------------------------------------------------------------ #
# image-classification
# ------------------------------------------------------------------ #
[docs]
def _classify_image(self, args: Dict[str, Any]) -> Dict[str, Any]:
model_id = args.get("model_id")
result = self._post_image(model_id, args)
if "_json" not in result:
return result # error / loading dict
body = result["_json"]
labels = self._parse_labels(body)
if labels is None:
return _err(f"Unexpected image-classification response: {str(body)[:200]}")
top = labels[0]["label"] if labels else None
return _ok(
{"model_id": model_id, "top_label": top, "labels": labels},
task="image-classification",
)
# ------------------------------------------------------------------ #
# object-detection
# ------------------------------------------------------------------ #
[docs]
def _detect_objects(self, args: Dict[str, Any]) -> Dict[str, Any]:
model_id = args.get("model_id")
result = self._post_image(model_id, args)
if "_json" not in result:
return result # error / loading dict
body = result["_json"]
# object-detection returns [{score, label, box:{xmin,ymin,xmax,ymax}},
# ...]; with batching it may nest one list per input.
if isinstance(body, list) and body and isinstance(body[0], list):
body = body[0]
if not isinstance(body, list):
return _err(f"Unexpected object-detection response: {str(body)[:200]}")
objects = []
for d in body:
if not isinstance(d, dict):
continue
box = d.get("box") if isinstance(d.get("box"), dict) else {}
objects.append(
{
"label": d.get("label"),
"score": d.get("score"),
"box": {
"xmin": box.get("xmin"),
"ymin": box.get("ymin"),
"xmax": box.get("xmax"),
"ymax": box.get("ymax"),
},
}
)
objects.sort(
key=lambda o: o["score"] if o["score"] is not None else -1.0,
reverse=True,
)
return _ok(
{
"model_id": model_id,
"object_count": len(objects),
"objects": objects,
},
task="object-detection",
)