Source code for tooluniverse.graphql_tool

from graphql import build_schema
from graphql.language import parse
from graphql.validation import validate
from .base_tool import BaseTool
from .tool_registry import register_tool
import requests
import copy


[docs] def validate_query(query_str, schema_str): try: # Build the GraphQL schema object from the provided schema string schema = build_schema(schema_str) # Parse the query string into an AST (Abstract Syntax Tree) query_ast = parse(query_str) # Validate the query AST against the schema validation_errors = validate(schema, query_ast) if not validation_errors: return True else: # Collect and return the validation errors error_messages = "\n".join(str(error) for error in validation_errors) return f"Query validation errors:\n{error_messages}" except Exception as e: return f"An error occurred during validation: {str(e)}"
[docs] def remove_none_and_empty_values(json_obj): """Remove all key-value pairs where the value is None or an empty list""" if isinstance(json_obj, dict): return { k: remove_none_and_empty_values(v) for k, v in json_obj.items() if v is not None and v != [] } elif isinstance(json_obj, list): return [ remove_none_and_empty_values(item) for item in json_obj if item is not None and item != [] ] else: return json_obj
[docs] def execute_query(endpoint_url, query, variables=None): response = requests.post( endpoint_url, json={"query": query, "variables": variables} ) try: result = response.json() # result = json.dumps(result, ensure_ascii=False) result = remove_none_and_empty_values(result) # Check if the response contains errors if "errors" in result: print("Invalid Query: ", result["errors"]) return None # Check if the data field is empty elif not result.get("data") or all(not v for v in result["data"].values()): print("No data returned") return None else: return result except requests.exceptions.JSONDecodeError: print("JSONDecodeError: Could not decode the response as JSON") return None
[docs] class GraphQLTool(BaseTool):
[docs] def __init__(self, tool_config, endpoint_url): super().__init__(tool_config) self.endpoint_url = endpoint_url self.query_schema = tool_config["query_schema"] self.parameters = tool_config["parameter"]["properties"] self.default_size = 5
[docs] def run(self, arguments): arguments = copy.deepcopy(arguments) if "size" in self.parameters and "size" not in arguments: arguments["size"] = 5 return execute_query( endpoint_url=self.endpoint_url, query=self.query_schema, variables=arguments )
[docs] @register_tool("OpenTarget") class OpentargetTool(GraphQLTool):
[docs] def __init__(self, tool_config): endpoint_url = "https://api.platform.opentargets.org/api/v4/graphql" super().__init__(tool_config, endpoint_url)
[docs] def run(self, arguments): # First try without modifying '-' result = super().run(arguments) # If no results, try with '-' replaced by ' ' if result is None: if "drugName" in arguments and isinstance(arguments["drugName"], str): arguments["drugName"] = arguments["drugName"].split("-")[0] modified_arguments = copy.deepcopy(arguments) for each_arg, arg_value in modified_arguments.items(): if isinstance(arg_value, str) and "-" in arg_value: modified_arguments[each_arg] = arg_value.replace("-", " ") result = super().run(modified_arguments) return result return result
[docs] @register_tool("OpentargetToolDrugNameMatch") class OpentargetToolDrugNameMatch(GraphQLTool):
[docs] def __init__(self, tool_config, drug_generic_tool=None): endpoint_url = "https://api.platform.opentargets.org/api/v4/graphql" self.drug_generic_tool = drug_generic_tool self.possible_drug_name_args = ["drugName"] super().__init__(tool_config, endpoint_url)
[docs] def run(self, arguments): arguments = copy.deepcopy(arguments) results = execute_query( endpoint_url=self.endpoint_url, query=self.query_schema, variables=arguments ) if results is None: print( "No results found for the drug brand name. Trying with the generic name." ) name_arguments = {} for each_args in self.possible_drug_name_args: if each_args in arguments: name_arguments["drug_name"] = arguments[each_args] break if len(name_arguments) == 0: print("No drug name found in the arguments.") return None drug_name_results = self.drug_generic_tool.run(name_arguments) if ( drug_name_results is not None and "openfda.generic_name" in drug_name_results ): arguments[each_args] = drug_name_results["openfda.generic_name"] print( "Found generic name. Trying with the generic name: ", arguments[each_args], ) results = execute_query( endpoint_url=self.endpoint_url, query=self.query_schema, variables=arguments, ) return results
[docs] @register_tool("OpenTargetGenetics") class OpentargetGeneticsTool(GraphQLTool):
[docs] def __init__(self, tool_config): endpoint_url = "https://api.genetics.opentargets.org/graphql" super().__init__(tool_config, endpoint_url)
[docs] @register_tool("DiseaseTargetScoreTool") class DiseaseTargetScoreTool(GraphQLTool): """Tool to extract disease-target association scores from specific data sources"""
[docs] def __init__(self, tool_config, datasource_id=None): endpoint_url = "https://api.platform.opentargets.org/api/v4/graphql" # Get datasource_id from config if not provided as parameter self.datasource_id = datasource_id or tool_config.get("datasource_id") super().__init__(tool_config, endpoint_url)
[docs] def run(self, arguments): """ Extract disease-target scores for a specific datasource Arguments should contain: efoId, datasourceId (optional), pageSize (optional) """ arguments = copy.deepcopy(arguments) efo_id = arguments.get("efoId") datasource_id = arguments.get("datasourceId", self.datasource_id) page_size = arguments.get("pageSize", 100) if not efo_id: return {"error": "efoId is required"} if not datasource_id: return {"error": "datasourceId is required"} results = [] page_index = 0 total_fetched = 0 total_count = None disease_info = None while True: variables = {"efoId": efo_id, "index": page_index, "size": page_size} response_data = execute_query( self.endpoint_url, self.query_schema, variables ) if not response_data or "data" not in response_data: break disease_data = response_data["data"]["disease"] if not disease_data: break if disease_info is None: disease_info = { "disease_id": disease_data["id"], "disease_name": disease_data["name"], } rows = disease_data["associatedTargets"]["rows"] if total_count is None: total_count = disease_data["associatedTargets"]["count"] for row in rows: symbol = row["target"]["approvedSymbol"] target_id = row["target"]["id"] score_entry = next( (ds for ds in row["datasourceScores"] if ds["id"] == datasource_id), None, ) if score_entry: results.append( { "target_symbol": symbol, "target_id": target_id, "datasource": datasource_id, "score": score_entry["score"], } ) total_fetched += len(rows) if total_fetched >= total_count or len(rows) == 0: break page_index += 1 return { "disease_info": disease_info, "datasource": datasource_id, "total_targets_with_scores": len(results), "target_scores": results, }