# embedding_database.py
import os
import hashlib
import numpy as np
from pathlib import Path
from typing import Dict, List, Any, Tuple, Optional
from ..base_tool import BaseTool
from ..tool_registry import register_tool
from ..logging_config import get_logger
from tooluniverse.database_setup.sqlite_store import SQLiteStore
from tooluniverse.database_setup.vector_store import VectorStore
from tooluniverse.database_setup.embedder import Embedder
from tooluniverse.utils import get_user_cache_dir
# ---------------------------
# Resolver helpers (provider/model) — no Azure bias
# ---------------------------
def _resolve_provider(explicit: Optional[str] = None) -> str:
"""
Resolution order:
1) explicit argument
2) EMBED_PROVIDER environment variable
3) heuristics by available credentials: azure > openai > huggingface > local
"""
if explicit:
return explicit
env = os.getenv("EMBED_PROVIDER")
if env:
return env
if os.getenv("AZURE_OPENAI_API_KEY"):
return "azure"
if os.getenv("OPENAI_API_KEY"):
return "openai"
if os.getenv("HF_TOKEN"):
return "huggingface"
return "local"
def _resolve_model(provider: str, explicit: Optional[str] = None) -> str:
"""
Resolution order:
1) explicit argument
2) EMBED_MODEL environment variable
3) provider-specific sensible default
"""
if explicit:
return explicit
if os.getenv("EMBED_MODEL"):
return os.getenv("EMBED_MODEL")
if provider == "azure":
# prefer deployment name for Azure
return os.getenv("AZURE_OPENAI_DEPLOYMENT", "text-embedding-3-small")
if provider == "huggingface":
return os.getenv("HF_EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
if provider == "local":
return os.getenv("LOCAL_EMBED_MODEL", "all-MiniLM-L6-v2")
# openai + any other: modern default
return "text-embedding-3-small"
# ---------------------------
# Misc helpers
# ---------------------------
def _l2_normalize(mat: np.ndarray) -> np.ndarray:
norms = np.linalg.norm(mat, axis=1, keepdims=True)
return mat / (norms + 1e-12)
def _matches_filters(metadata: Dict, filters: Dict) -> bool:
if not filters:
return True
for key, filter_value in filters.items():
if key not in metadata:
return False
meta_value = metadata[key]
if isinstance(filter_value, dict):
if "$gte" in filter_value and meta_value < filter_value["$gte"]:
return False
if "$gt" in filter_value and meta_value <= filter_value["$gt"]:
return False
if "$lte" in filter_value and meta_value > filter_value["$lte"]:
return False
if "$lt" in filter_value and meta_value >= filter_value["$lt"]:
return False
if "$in" in filter_value and meta_value not in filter_value["$in"]:
return False
if "$contains" in filter_value:
needle = filter_value["$contains"]
if isinstance(meta_value, list):
if needle not in meta_value:
return False
else:
if needle not in str(meta_value):
return False
else:
if meta_value != filter_value:
return False
return True
# ---------------------------
# Tool
# ---------------------------
[docs]
@register_tool("EmbeddingDatabase")
class EmbeddingDatabase(BaseTool):
"""
Exposes actions:
- create_from_docs
- add_docs
- search
Backed by SQLiteStore + VectorStore + Embedder.
"""
[docs]
def __init__(self, tool_config):
super().__init__(tool_config)
self.logger = get_logger("EmbeddingDatabase")
storage_config = tool_config.get("configs", {}).get("storage_config", {})
self.data_dir = Path(
storage_config.get(
"data_dir", os.path.join(get_user_cache_dir(), "embeddings")
)
)
self.faiss_index_type = storage_config.get("faiss_index_type", "IndexFlatIP")
self.data_dir.mkdir(parents=True, exist_ok=True)
# ---------- infra helpers (per collection) ----------
[docs]
def _paths(self, name: str) -> Tuple[Path, Path]:
db_path = self.data_dir / f"{name}.db"
index_path = self.data_dir / f"{name}.faiss"
return db_path, index_path
[docs]
def _stores(self, name: str) -> Tuple[SQLiteStore, VectorStore, Path, Path]:
db_path, index_path = self._paths(name)
sqlite_store = SQLiteStore(db_path.as_posix())
vector_store = VectorStore(
db_path.as_posix(), data_dir=self.data_dir.as_posix()
)
return sqlite_store, vector_store, db_path, index_path
[docs]
def _embedder(self, provider: str, model: str) -> Embedder:
return Embedder(
provider=provider,
model=model,
batch_size=100 if provider in ("openai", "azure") else 32,
max_retries=5,
)
[docs]
def _existing_vector_doc_ids(
self, vs: VectorStore, collection: str, doc_ids: List[int]
) -> set:
if not doc_ids:
return set()
placeholders = ",".join("?" for _ in doc_ids)
q = f"SELECT doc_id FROM vectors WHERE collection=? AND doc_id IN ({placeholders})"
args = [collection] + doc_ids
cur = vs.db.execute(q, args)
return {r[0] for r in cur.fetchall()}
# ---------------- entry point ----------------
[docs]
def run(self, arguments):
action = arguments.get("action")
if action == "create_from_docs":
return self._create_from_documents(arguments)
elif action == "add_docs":
return self._add_documents(arguments)
elif action == "search":
return self._search(arguments)
else:
return {"error": f"Unknown action: {action}"}
# ---------------- actions ----------------
[docs]
def _create_from_documents(self, args: Dict[str, Any]):
name = args.get("database_name")
docs: List[str] = args.get("documents", [])
metas: List[Dict[str, Any]] = args.get("metadata", [])
provider = _resolve_provider(args.get("provider"))
model = _resolve_model(provider, args.get("model"))
description = args.get("description", "")
if not name:
return {"error": "database_name is required"}
if not docs:
return {"error": "documents list cannot be empty"}
if metas and len(metas) != len(docs):
return {
"error": "metadata length must match documents length (or omit 'metadata')"
}
sqlite_store, vector_store, db_path, index_path = self._stores(name)
if index_path.exists():
return {
"error": f"Database '{name}' already exists. Use 'add_docs' to add more documents."
}
# Insert docs (dedupe by (collection, doc_key) and (collection, text_hash))
if not metas:
metas = [{} for _ in docs]
rows = []
doc_keys: List[str] = []
for text, meta in zip(docs, metas):
# stable key: sha256 prefix, not md5
text_hash = hashlib.sha256(text.encode("utf-8")).hexdigest()[:16]
doc_key = text_hash
doc_keys.append(doc_key)
rows.append((doc_key, text, meta, text_hash))
sqlite_store.upsert_collection(
name,
description=description,
embedding_model="precomputed", # placeholder until we write vectors
embedding_dimensions=None,
index_type=self.faiss_index_type,
)
sqlite_store.insert_docs(name, rows)
# Map keys -> ids
inserted = sqlite_store.fetch_docs(name, doc_keys=doc_keys, limit=len(rows))
doc_ids = [r["id"] for r in inserted]
# Embed + add to FAISS
vecs = self._embedder(provider, model).embed(docs)
vecs = _l2_normalize(np.asarray(vecs, dtype="float32"))
vector_store.load_index(name, dim=vecs.shape[1])
vector_store.add_embeddings(name, doc_ids, vecs)
# Update collection with the real model + dimension
sqlite_store.upsert_collection(
name,
description=description,
embedding_model=model,
embedding_dimensions=int(vecs.shape[1]),
index_type=self.faiss_index_type,
)
self.logger.info(f"Created collection '{name}' with {len(docs)} docs")
return {
"status": "success",
"database_name": name,
"documents_added": len(docs),
"embedding_model": model,
"dimensions": int(vecs.shape[1]),
"db_path": str(db_path),
"index_path": str(index_path),
}
[docs]
def _add_documents(self, args: Dict[str, Any]):
name = args.get("database_name")
docs: List[str] = args.get("documents", [])
metas: List[Dict[str, Any]] = args.get("metadata", [])
# optional overrides; will be validated against collection meta
provider = _resolve_provider(args.get("provider"))
model_override = args.get("model")
if not name:
return {"error": "database_name is required"}
if not docs:
return {"error": "documents list cannot be empty"}
if metas and len(metas) != len(docs):
return {
"error": "metadata length must match documents length (or omit 'metadata')"
}
sqlite_store, vector_store, db_path, index_path = self._stores(name)
if not index_path.exists() or not db_path.exists():
return {
"error": f"Database '{name}' does not exist. Use 'create_from_docs' first."
}
col_model, col_dim = self._get_collection_meta(sqlite_store, name)
if col_model in (None, "precomputed"):
# if collection didn't store a model, resolve one
col_model = _resolve_model(provider, model_override)
elif model_override and model_override != col_model:
return {
"error": f"Embedding model mismatch: collection uses '{col_model}', request uses '{model_override}'"
}
emb = self._embedder(provider, col_model)
if not metas:
metas = [{} for _ in docs]
rows = []
doc_keys: List[str] = []
for text, meta in zip(docs, metas):
text_hash = hashlib.sha256(text.encode("utf-8")).hexdigest()[:16]
doc_key = text_hash
doc_keys.append(doc_key)
rows.append((doc_key, text, meta, text_hash))
# Insert (duplicates ignored by UNIQUE constraints)
sqlite_store.insert_docs(name, rows)
# Map keys -> ids
inserted = sqlite_store.fetch_docs(name, doc_keys=doc_keys, limit=len(rows))
key_to_id = {r["doc_key"]: r["id"] for r in inserted}
doc_ids_all = [key_to_id[k] for k in doc_keys if k in key_to_id]
# Compute embeddings once
vecs_all = emb.embed(docs)
vecs_all = _l2_normalize(np.asarray(vecs_all, dtype="float32"))
if col_dim and col_dim != vecs_all.shape[1]:
return {
"error": f"Embedding dimension mismatch: {col_dim} vs {vecs_all.shape[1]}"
}
# Filter out doc_ids that already have vectors
existing = self._existing_vector_doc_ids(vector_store, name, doc_ids_all)
doc_ids_to_add: List[int] = []
vecs_to_add: List[np.ndarray] = []
for i, k in enumerate(doc_keys):
did = key_to_id.get(k)
if did is not None and did not in existing:
doc_ids_to_add.append(did)
vecs_to_add.append(vecs_all[i])
# Load index, add only the missing ones
index = vector_store.load_index(name, dim=col_dim or vecs_all.shape[1])
before = index.ntotal
if doc_ids_to_add:
vecs_to_add_arr = np.vstack(vecs_to_add).astype("float32")
vector_store.add_embeddings(name, doc_ids_to_add, vecs_to_add_arr)
after = before + len(doc_ids_to_add)
return {
"status": "success",
"database_name": name,
"documents_added": len(doc_ids_to_add),
"skipped_existing": len(docs) - len(doc_ids_to_add),
"total_vectors": after,
"db_path": str(db_path),
"index_path": str(index_path),
}
[docs]
def _search(self, args: Dict[str, Any]):
name = args.get("database_name")
query = args.get("query")
top_k = int(args.get("top_k", 5))
filters = args.get("filters", args.get("metadata_filter", {}))
provider = _resolve_provider(args.get("provider"))
model_override = args.get("model")
if not name:
return {"error": "database_name is required"}
if not query:
return {"error": "query is required"}
sqlite_store, vector_store, db_path, index_path = self._stores(name)
if not index_path.exists() or not db_path.exists():
return {"error": f"Database '{name}' does not exist"}
col_model, col_dim = self._get_collection_meta(sqlite_store, name)
# pick model for query embedding
model = (
model_override or (None if col_model == "precomputed" else col_model)
) or _resolve_model(provider, None)
emb = self._embedder(provider, model)
# Embed query
q = emb.embed([query])
q = _l2_normalize(np.asarray(q, dtype="float32"))
qdim = int(q.shape[1])
if col_dim and col_dim != qdim:
return {"error": f"Embedding dimension mismatch: {col_dim} vs {qdim}"}
# Search
vector_store.load_index(name, dim=col_dim or qdim)
results = vector_store.search_embeddings(name, q[0], top_k=top_k)
# Hydrate + filter
doc_ids = [doc_id for doc_id, _ in results]
docs = sqlite_store.fetch_docs_by_ids(name, doc_ids)
doc_map = {d["id"]: d for d in docs}
out = []
for doc_id, score in results:
d = doc_map.get(doc_id)
if not d:
continue
md = d.get("metadata") or {}
if _matches_filters(md, filters):
out.append(
{
"text": d["text"],
"metadata": md,
"similarity_score": float(score),
}
)
return {
"status": "success",
"database_name": name,
"query": query,
"results": out[:top_k],
"total_found": len(out),
}