tooluniverse.graphql_tool 源代码
from graphql import build_schema
from graphql.language import parse
from graphql.validation import validate
from .base_tool import BaseTool
from .tool_registry import register_tool
import requests
import copy
[文档]
def validate_query(query_str, schema_str):
try:
# Build the GraphQL schema object from the provided schema string
schema = build_schema(schema_str)
# Parse the query string into an AST (Abstract Syntax Tree)
query_ast = parse(query_str)
# Validate the query AST against the schema
validation_errors = validate(schema, query_ast)
if not validation_errors:
return True
else:
# Collect and return the validation errors
error_messages = "\n".join(str(error) for error in validation_errors)
return f"Query validation errors:\n{error_messages}"
except Exception as e:
return f"An error occurred during validation: {str(e)}"
[文档]
def remove_none_and_empty_values(json_obj):
"""Remove all key-value pairs where the value is None or an empty list"""
if isinstance(json_obj, dict):
return {
k: remove_none_and_empty_values(v)
for k, v in json_obj.items()
if v is not None and v != []
}
elif isinstance(json_obj, list):
return [
remove_none_and_empty_values(item)
for item in json_obj
if item is not None and item != []
]
else:
return json_obj
[文档]
def execute_query(endpoint_url, query, variables=None):
response = requests.post(
endpoint_url, json={"query": query, "variables": variables}, timeout=30
)
try:
if not response.ok:
print(f"HTTP {response.status_code} from API: {response.text[:200]}")
return None
result = response.json()
result = remove_none_and_empty_values(result)
# Check if the response contains errors
if "errors" in result:
print("Invalid Query: ", result["errors"])
return None
# Feature-94A-002: always return result when data key is present,
# even if all values are empty/null (e.g. disease not found = {"data": {}}).
# Callers distinguish empty results from errors via status envelope.
elif "data" not in result:
print("No data returned")
return None
else:
return result
except requests.exceptions.JSONDecodeError:
print("JSONDecodeError: Could not decode the response as JSON")
return None
[文档]
class GraphQLTool(BaseTool):
[文档]
def __init__(self, tool_config, endpoint_url):
super().__init__(tool_config)
self.endpoint_url = endpoint_url
self.query_schema = tool_config["query_schema"]
self.parameters = tool_config["parameter"]["properties"]
self.default_size = 5
[文档]
def run(self, arguments):
arguments = copy.deepcopy(arguments)
if "size" in self.parameters and "size" not in arguments:
arguments["size"] = self.default_size
result = execute_query(
endpoint_url=self.endpoint_url, query=self.query_schema, variables=arguments
)
if result is None:
return {"status": "error", "error": "No data returned from API"}
return {"status": "success", "data": result.get("data", result)}
_OT_SEARCH_QUERY = """
query otSearch($q: String!, $entity: [String!]!) {
search(queryString: $q, entityNames: $entity, page: {index: 0, size: 1}) {
hits { id name }
}
}
"""
[文档]
def _ot_resolve_id(endpoint_url: str, query_string: str, entity: str) -> str | None:
"""Resolve a gene symbol or disease name to an OpenTargets ID via search."""
result = execute_query(
endpoint_url,
_OT_SEARCH_QUERY,
{"q": query_string, "entity": [entity]},
)
if result:
hits = result.get("data", {}).get("search", {}).get("hits", [])
if hits:
return hits[0]["id"]
return None
[文档]
@register_tool("OpenTarget")
class OpentargetTool(GraphQLTool):
[文档]
def __init__(self, tool_config):
self.endpoint_url = "https://api.platform.opentargets.org/api/v4/graphql"
super().__init__(tool_config, self.endpoint_url)
[文档]
def run(self, arguments):
arguments = copy.deepcopy(arguments)
# Normalize common aliases before resolution
if "ensemblId" not in arguments and "gene_symbol" not in arguments:
for alias in ("target", "gene", "gene_name"):
if arguments.get(alias):
arguments["gene_symbol"] = arguments.pop(alias)
break
if "efoId" not in arguments and "disease_name" not in arguments:
for alias in ("disease", "disease_id", "trait"):
if arguments.get(alias):
arguments["disease_name"] = arguments.pop(alias)
break
# Resolve gene_symbol → ensemblId if ensemblId not provided
if "ensemblId" not in arguments and "gene_symbol" in arguments:
resolved = _ot_resolve_id(
self.endpoint_url, arguments.pop("gene_symbol"), "target"
)
if resolved:
arguments["ensemblId"] = resolved
else:
return {
"status": "error",
"error": f"Could not resolve gene symbol to Ensembl ID. "
"Try passing ensemblId directly (e.g. ENSG00000141510 for TP53).",
}
# Resolve disease_name → efoId (or diseaseIds) if not provided
needs_disease_ids = "diseaseIds" in self.query_schema
if (
"efoId" not in arguments
and "diseaseIds" not in arguments
and "disease_name" in arguments
):
resolved = _ot_resolve_id(
self.endpoint_url, arguments.pop("disease_name"), "disease"
)
if resolved:
if needs_disease_ids:
arguments["diseaseIds"] = [resolved]
else:
arguments["efoId"] = resolved
else:
return {
"status": "error",
"error": f"Could not resolve disease name to EFO ID. "
"Try passing efoId directly (e.g. EFO_0000384 for Crohn's disease).",
}
result = super().run(arguments)
# Add note when IntOGen evidence count is 0 (Feature-122B-002)
if result.get("status") == "success":
evidences = result.get("data", {}).get("disease", {}).get("evidences", {})
if isinstance(evidences, dict) and evidences.get("count") == 0:
result.setdefault("metadata", {})["note"] = (
"IntOGen returns 0 evidence rows for this query. "
"IntOGen only covers somatic tumor driver mutations — "
"it has no data for non-cancer diseases or non-driver genes. "
"For non-oncology phenotypes, use OpenTargets_get_evidence_by_datasource instead."
)
# If no results, retry with '-' replaced by ' '
if result.get("status") != "success":
if "drugName" in arguments and isinstance(arguments["drugName"], str):
arguments["drugName"] = arguments["drugName"].split("-")[0]
modified_arguments = copy.deepcopy(arguments)
for each_arg, arg_value in modified_arguments.items():
if isinstance(arg_value, str) and "-" in arg_value:
modified_arguments[each_arg] = arg_value.replace("-", " ")
result = super().run(modified_arguments)
return result
[文档]
@register_tool("OpentargetToolDrugNameMatch")
class OpentargetToolDrugNameMatch(GraphQLTool):
[文档]
def __init__(self, tool_config, drug_generic_tool=None):
endpoint_url = "https://api.platform.opentargets.org/api/v4/graphql"
self.drug_generic_tool = drug_generic_tool
self.possible_drug_name_args = ["drugName"]
super().__init__(tool_config, endpoint_url)
[文档]
def run(self, arguments):
arguments = copy.deepcopy(arguments)
results = execute_query(
endpoint_url=self.endpoint_url, query=self.query_schema, variables=arguments
)
if results is None:
print(
"No results found for the drug brand name. Trying with the generic name."
)
# Find which drug name argument was provided
matched_arg = None
for arg_name in self.possible_drug_name_args:
if arg_name in arguments:
matched_arg = arg_name
break
if matched_arg is None:
print("No drug name found in the arguments.")
return {"status": "error", "error": "No drug name found in arguments"}
drug_name_results = self.drug_generic_tool.run(
{"drug_name": arguments[matched_arg]}
)
if (
drug_name_results is not None
and "openfda.generic_name" in drug_name_results
):
arguments[matched_arg] = drug_name_results["openfda.generic_name"]
print(
"Found generic name. Trying with the generic name: ",
arguments[matched_arg],
)
results = execute_query(
endpoint_url=self.endpoint_url,
query=self.query_schema,
variables=arguments,
)
if results is None:
return {"status": "error", "error": "No data returned from API"}
return {"status": "success", "data": results.get("data", results)}
[文档]
@register_tool("OpenTargetGenetics")
class OpentargetGeneticsTool(GraphQLTool):
[文档]
def __init__(self, tool_config):
endpoint_url = "https://api.genetics.opentargets.org/graphql"
super().__init__(tool_config, endpoint_url)
[文档]
def run(self, arguments):
arguments = copy.deepcopy(arguments)
# Resolve disease_name → diseaseIds if not already provided
if "diseaseIds" not in arguments:
disease_name = None
for alias in ("disease_name", "disease", "trait"):
if arguments.get(alias):
disease_name = arguments.pop(alias)
break
if disease_name:
resolved = _ot_resolve_id(
"https://api.platform.opentargets.org/api/v4/graphql",
disease_name,
"disease",
)
if resolved:
arguments["diseaseIds"] = [resolved]
else:
return {
"status": "error",
"error": (
f"Could not resolve '{disease_name}' to a disease ID. "
"Try passing diseaseIds directly (e.g. ['MONDO_0005148'] for type 2 diabetes)."
),
}
return super().run(arguments)
[文档]
@register_tool("DiseaseTargetScoreTool")
class DiseaseTargetScoreTool(GraphQLTool):
"""Tool to extract disease-target association scores from specific data sources"""
[文档]
def __init__(self, tool_config, datasource_id=None):
endpoint_url = "https://api.platform.opentargets.org/api/v4/graphql"
# Get datasource_id from config if not provided as parameter
self.datasource_id = datasource_id or tool_config.get("datasource_id")
super().__init__(tool_config, endpoint_url)
[文档]
def run(self, arguments):
"""
Extract disease-target scores for a specific datasource
Arguments should contain: efoId, datasourceId (optional), pageSize (optional)
"""
arguments = copy.deepcopy(arguments)
efo_id = arguments.get("efoId")
datasource_id = arguments.get("datasourceId", self.datasource_id)
page_size = arguments.get("pageSize", 100)
if not efo_id:
return {"status": "error", "error": "efoId is required"}
if not datasource_id:
return {"status": "error", "error": "datasourceId is required"}
results = []
page_index = 0
total_fetched = 0
total_count = None
disease_info = None
while True:
variables = {"efoId": efo_id, "index": page_index, "size": page_size}
response_data = execute_query(
self.endpoint_url, self.query_schema, variables
)
if not response_data or "data" not in response_data:
break
disease_data = response_data["data"]["disease"]
if not disease_data:
break
if disease_info is None:
disease_info = {
"disease_id": disease_data["id"],
"disease_name": disease_data["name"],
}
rows = disease_data["associatedTargets"]["rows"]
if total_count is None:
total_count = disease_data["associatedTargets"]["count"]
for row in rows:
symbol = row["target"]["approvedSymbol"]
target_id = row["target"]["id"]
score_entry = next(
(ds for ds in row["datasourceScores"] if ds["id"] == datasource_id),
None,
)
if score_entry:
results.append(
{
"target_symbol": symbol,
"target_id": target_id,
"datasource": datasource_id,
"score": score_entry["score"],
}
)
total_fetched += len(rows)
if total_fetched >= total_count or len(rows) == 0:
break
page_index += 1
return {
"status": "success",
"data": {
"disease_info": disease_info,
"datasource": datasource_id,
"total_targets_with_scores": len(results),
"target_scores": results,
},
}