Source code for tooluniverse.xml_tool

# import xml.etree.ElementTree as ET
from lxml import etree as ET
from typing import List, Dict, Any, Optional, Set
from .base_tool import BaseTool
from .utils import download_from_hf
from .tool_registry import register_tool


[docs] @register_tool("XMLTool") class XMLDatasetTool(BaseTool): """ Tool to search and filter XML datasets that are organized as a collection of searchable records (e.g., dataset of medical subjects or drug descriptions). Supports user-friendly queries without requiring XPath knowledge. """
[docs] def __init__(self, tool_config: Dict[str, Any]): super().__init__(tool_config) self.xml_root: Optional[ET.Element] = None self.records: List[ET.Element] = [] self.record_xpath: str = tool_config.get("settings").get("record_xpath", ".//*") self.namespaces: Dict[str, str] = tool_config.get("settings").get( "namespaces", {} ) self.field_mappings: Dict[str, str] = tool_config.get("settings").get( "field_mappings", {} ) # Dict of fields we're interested in extracting from each record self.filter_field: Optional[str] = tool_config.get("settings").get( "filter_field" ) # Field to filter on, if specified self.search_fields: List[str] = tool_config.get("settings").get( "search_fields", ["_text"] + list(self.field_mappings.keys()) ) self._record_cache: List[Dict[str, Any]] = [] # Cache extracted data self.temporary_record_fields: Set[str] = set() self._load_dataset()
def _load_dataset(self) -> None: """Load and parse the XML dataset.""" try: xml_path = self._get_dataset_path() if not xml_path: return tree = ET.parse(xml_path) self.xml_root = tree.getroot() self.records = self.xml_root.findall( self.record_xpath, namespaces=self.namespaces ) print( f"Loaded XML dataset: {len(self.records)} records from root '{self.xml_root.tag}'" ) except Exception as e: print(f"Error loading XML dataset: {e}") self.records = [] def _get_dataset_path(self) -> Optional[str]: """Get the path to the XML dataset.""" if "hf_dataset_path" in self.tool_config["settings"]: result = download_from_hf(self.tool_config["settings"]) if result.get("success"): return result["local_path"] print(f"Failed to download dataset: {result.get('error')}") return None if "local_dataset_path" in self.tool_config["settings"]: return self.tool_config["settings"]["local_dataset_path"] print("No dataset path provided in tool configuration") return None def _extract_record_data(self, record_element: ET.Element) -> Dict[str, Any]: """Extract data from a record element with caching.""" data = { "_tag": record_element.tag, "_text": (record_element.text or "").strip(), "_attributes": dict(record_element.attrib), } for field_name, xpath_expr in self.field_mappings.items(): # Extract mapped fields if isinstance(xpath_expr, dict) and "parent_path" in xpath_expr: # Handle nested structure parent_xpath = xpath_expr["parent_path"] subfields = xpath_expr.get("subfields", {}) elements = record_element.findall( parent_xpath, namespaces=self.namespaces ) structured_list = [] for el in elements: entry = {} for sf_name, sf_path in subfields.items(): entry[sf_name] = self._extract_field_value(el, sf_path) if any(entry.values()): # Only add entries with non-empty values structured_list.append(entry) data[field_name] = structured_list # Flatten for search for sf_name, _ in subfields.items(): flat_key = f"{field_name}_{sf_name}" # For efficient search, flatten structured data into a single string data[flat_key] = " | ".join( entry.get(sf_name, "") for entry in structured_list ) self.temporary_record_fields.add(flat_key) else: # Regular flat field extraction data[field_name] = self._extract_field_value(record_element, xpath_expr) return data def _extract_field_value(self, element: ET.Element, xpath_expr: str) -> str: """Extract field value using XPath expression.""" try: # Handle attribute extraction with /@ if "/@" in xpath_expr: elem_path, attr_name = xpath_expr.rsplit("/@", 1) found_elements = element.findall(elem_path, namespaces=self.namespaces) if not found_elements: return "" # Use generator expression for memory efficiency values = ( el.get(attr_name, "").strip() for el in found_elements if el.get(attr_name) ) return " | ".join(values) # Handle direct attribute on current element if xpath_expr.startswith("@"): return element.get(xpath_expr[1:], "").strip() # Handle text content extraction found_elements = element.findall(xpath_expr, namespaces=self.namespaces) if not found_elements: return "" # Use generator expression and filter out empty text values = ((elem.text or "").strip() for elem in found_elements) non_empty_values = (v for v in values if v) return " | ".join(non_empty_values) except Exception: return "" def _get_all_records_data(self) -> List[Dict[str, Any]]: """Get all records data with caching.""" if not self._record_cache: self._record_cache = [ self._extract_record_data(record) for record in self.records ] return self._record_cache
[docs] def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]: """Main entry point for the tool.""" if not self.records: return {"error": "XML dataset not loaded or contains no records"} # Route to appropriate function based on arguments if "query" in arguments: return self._search(arguments) elif "condition" in arguments: return self._filter(arguments) else: return { "error": "Provide either 'query' for search or 'condition' for filtering" }
def _search(self, arguments: Dict[str, Any]) -> Dict[str, Any]: """Search records by text content across multiple fields.""" query = arguments.get("query", "").strip() if not query: return {"error": "Query parameter is required"} # Parse search parameters with sensible defaults case_sensitive = arguments.get("case_sensitive", False) exact_match = arguments.get("exact_match", False) limit = min(arguments.get("limit", 50), 1000) # Cap at 1000 search_query = query if case_sensitive else query.lower() results = [] all_records = self._get_all_records_data() total_matches = 0 for record_data in all_records: matched_fields = self._find_matches( record_data, search_query, self.search_fields, case_sensitive, exact_match, ) if matched_fields: total_matches += 1 if len(results) < limit: result_record = record_data.copy() for temp in self.temporary_record_fields: result_record.pop(temp, None) result_record["matched_fields"] = matched_fields results.append(result_record) return { "query": query, "total_matches": total_matches, "total_returned_results": len(results), "results": results, "search_parameters": { "case_sensitive": case_sensitive, "exact_match": exact_match, "limit": limit, }, } def _find_matches( self, record_data: Dict[str, Any], search_query: str, search_fields: List[str], case_sensitive: bool, exact_match: bool, ) -> List[str]: """Find matching fields in a record.""" matched_fields = [] for field in search_fields: if field not in record_data: continue field_value = self._get_searchable_value(record_data, field, case_sensitive) if self._is_match(field_value, search_query, exact_match): matched_fields.append(field) return matched_fields def _get_searchable_value( self, record_data: Dict[str, Any], field: str, case_sensitive: bool ) -> str: """Get searchable string value for a field.""" if field == "_attributes": value = " ".join(record_data["_attributes"].values()) else: value = str(record_data.get(field, "")) return value if case_sensitive else value.lower() def _is_match(self, field_value: str, search_query: str, exact_match: bool) -> bool: """Check if field value matches search query.""" if exact_match: if "|" in field_value: # Handle multiple values return search_query in [v.strip() for v in field_value.split("|")] return search_query == field_value.strip() return search_query in field_value def _filter(self, arguments: Dict[str, Any]) -> Dict[str, Any]: """Filter records based on field criteria.""" field = self.filter_field condition = arguments.get("condition") value = arguments.get("value", "") limit = min(arguments.get("limit", 100), 1000) # Cap at 1000 if not field or not condition: return {"error": "Both 'field' and 'condition' are required"} # Validate condition requirements if condition not in ["not_empty", "has_attribute"] and not value: return {"error": f"'value' parameter required for condition '{condition}'"} all_records = self._get_all_records_data() # Check if field exists if all_records and field not in all_records[0]: available_fields = sorted(all_records[0].keys()) return { "error": f"Field '{field}' not found. Available: {available_fields}" } filtered_records = [] filter_func = self._get_filter_function(condition, value) if not filter_func: return { "error": f"Unknown condition '{condition}'. Supported: contains, starts_with, ends_with, exact, not_empty, has_attribute" } total_matches = 0 for record_data in all_records: if field in record_data and filter_func(record_data, field): total_matches += 1 if len(filtered_records) < limit: result_record = record_data.copy() for temp in self.temporary_record_fields: result_record.pop(temp, None) filtered_records.append(result_record) return { "total_matches": total_matches, "total_returned_results": len(filtered_records), "results": filtered_records, "applied_filter": self._get_filter_description(field, condition, value), "filter_parameters": { "field": field, "condition": condition, "value": ( value if condition not in ["not_empty", "has_attribute"] else None ), "limit": limit, }, } def _get_filter_function(self, condition: str, value: str): """Get the appropriate filter function for the condition.""" filter_functions = { "contains": lambda data, field: value.lower() in str(data[field]).lower(), "starts_with": lambda data, field: str(data[field]) .lower() .startswith(value.lower()), "ends_with": lambda data, field: str(data[field]) .lower() .endswith(value.lower()), "exact": lambda data, field: str(data[field]).lower() == value.lower(), "not_empty": lambda data, field: str(data[field]).strip() != "", "has_attribute": lambda data, field: field == "_attributes" and value in data["_attributes"], } return filter_functions.get(condition) def _get_filter_description(self, field: str, condition: str, value: str) -> str: """Get human-readable filter description.""" descriptions = { "contains": f"{field} contains '{value}'", "starts_with": f"{field} starts with '{value}'", "ends_with": f"{field} ends with '{value}'", "exact": f"{field} equals '{value}'", "not_empty": f"{field} is not empty", "has_attribute": f"has attribute '{value}'", } return descriptions.get(condition, f"{field} {condition} {value}")
[docs] def get_dataset_info(self) -> Dict[str, Any]: """Get comprehensive information about the loaded XML dataset.""" if not self.records: return {"error": "XML dataset not loaded or contains no records"} # Get field information from sample records sample_data = self._get_all_records_data()[:5] all_fields = set() for record_data in sample_data: all_fields.update(record_data.keys()) info = { "total_records": len(self.records), "root_element": self.xml_root.tag if self.xml_root else None, "record_xpath": self.record_xpath, "field_mappings": self.field_mappings, "available_fields": sorted(all_fields), } if sample_data: info["sample_record"] = sample_data[0] return info