Source code for tooluniverse.database_setup.generic_embedding_search_tool
"""
EmbeddingCollectionSearchTool — search any datastore collection by name.
Configuration (tool_config.fields)
----------------------------------
- collection : str (required) e.g., "my_collection"
- db_path : str (optional) e.g., "<user_cache_dir>/embeddings/my_collection.db"
If omitted, defaults to: <user_cache_dir>/embeddings/<collection>.db
"""
from typing import Any, Dict
from tooluniverse.base_tool import BaseTool
from tooluniverse.tool_registry import register_tool
from tooluniverse.database_setup.search import SearchEngine
from tooluniverse.utils import get_user_cache_dir
import os
[docs]
@register_tool("EmbeddingCollectionSearchTool")
class EmbeddingCollectionSearchTool(BaseTool):
"""
Generic search tool for any embedding datastore collection.
Runtime arguments
-----------------
query : str (required)
Search query text.
method : str = "hybrid"
One of: "keyword", "embedding", "hybrid".
top_k : int = 10
Number of results to return.
alpha : float = 0.5
Balance for hybrid search (0=keyword only, 1=embedding only).
Returns
-------
List[dict] with keys:
- doc_id
- doc_key
- text
- metadata
- score
- snippet (first ~280 chars)
"""
[docs]
def run(self, arguments: Dict[str, Any]) -> Any:
fields = self.tool_config.get("fields") or {}
coll = fields.get("collection")
if not coll:
return {"error": "Missing fields.collection in tool config"}
q = arguments.get("query")
if not q:
return {"error": "Missing 'query' argument"}
method = arguments.get("method", "hybrid")
top_k = int(arguments.get("top_k", 10))
alpha = float(arguments.get("alpha", 0.5))
# Allow explicit db path; default to user cache dir ~/Library/Caches/.../embeddings/<collection>.db
if fields.get("db_path"):
db_path = fields["db_path"]
else:
db_path = os.path.join(get_user_cache_dir(), "embeddings", f"{coll}.db")
se = getattr(self, "_se", None) or SearchEngine(db_path=db_path)
try:
res = se.search_collection(coll, q, method=method, top_k=top_k, alpha=alpha)
for r in res:
r["snippet"] = (r.get("text") or "")[:280]
return res
except Exception as e:
return {
"error": f"search failed: {e}",
"collection": coll,
"db_path": db_path,
}